Instance based Generalization in Reinforcement Learning
This addresses the generalization challenge in reinforcement learning for AI agents, offering an incremental improvement with theoretical insights and a practical method.
The paper tackles the problem of deep reinforcement learning agents failing to generalize to unseen environments by analyzing policy learning in POMDPs and formalizing training levels as instances, showing that reusing instances leads to instance-specific policies and providing generalization bounds. It proposes training a shared belief representation over specialized policies to improve generalization, achieving a 15% performance gain on unseen levels in the CoinRun benchmark.
Agents trained via deep reinforcement learning (RL) routinely fail to generalize to unseen environments, even when these share the same underlying dynamics as the training levels. Understanding the generalization properties of RL is one of the challenges of modern machine learning. Towards this goal, we analyze policy learning in the context of Partially Observable Markov Decision Processes (POMDPs) and formalize the dynamics of training levels as instances. We prove that, independently of the exploration strategy, reusing instances introduces significant changes on the effective Markov dynamics the agent observes during training. Maximizing expected rewards impacts the learned belief state of the agent by inducing undesired instance specific speedrunning policies instead of generalizeable ones, which are suboptimal on the training set. We provide generalization bounds to the value gap in train and test environments based on the number of training instances, and use insights based on these to improve performance on unseen levels. We propose training a shared belief representation over an ensemble of specialized policies, from which we compute a consensus policy that is used for data collection, disallowing instance specific exploitation. We experimentally validate our theory, observations, and the proposed computational solution over the CoinRun benchmark.