Task-Aware Virtual Training: Enhancing Generalization in Meta-Reinforcement Learning for Out-of-Distribution Tasks
This addresses a key limitation in meta-RL for robotics and simulation applications, though it appears incremental as it builds on existing context-based methods.
The paper tackles the problem of meta-reinforcement learning struggling with out-of-distribution tasks by proposing Task-Aware Virtual Training (TAVT), which uses metric-based representation learning and state regularization to enhance generalization, achieving significant improvements across MuJoCo and MetaWorld environments.
Meta reinforcement learning aims to develop policies that generalize to unseen tasks sampled from a task distribution. While context-based meta-RL methods improve task representation using task latents, they often struggle with out-of-distribution (OOD) tasks. To address this, we propose Task-Aware Virtual Training (TAVT), a novel algorithm that accurately captures task characteristics for both training and OOD scenarios using metric-based representation learning. Our method successfully preserves task characteristics in virtual tasks and employs a state regularization technique to mitigate overestimation errors in state-varying environments. Numerical results demonstrate that TAVT significantly enhances generalization to OOD tasks across various MuJoCo and MetaWorld environments. Our code is available at https://github.com/JM-Kim-94/tavt.git.