Model-based Trajectory Stitching for Improved Offline Reinforcement Learning
This addresses the challenge of costly data collection in real-world applications for offline RL practitioners, though it appears incremental as it builds on existing methods like behavioral cloning.
The paper tackles the problem of limited data in offline reinforcement learning by proposing Trajectory Stitching, a model-based data augmentation method that generates higher-quality trajectories from sub-optimal historical data, leading to improvements over behavior-cloned policies.
In many real-world applications, collecting large and high-quality datasets may be too costly or impractical. Offline reinforcement learning (RL) aims to infer an optimal decision-making policy from a fixed set of data. Getting the most information from historical data is then vital for good performance once the policy is deployed. We propose a model-based data augmentation strategy, Trajectory Stitching (TS), to improve the quality of sub-optimal historical trajectories. TS introduces unseen actions joining previously disconnected states: using a probabilistic notion of state reachability, it effectively `stitches' together parts of the historical demonstrations to generate new, higher quality ones. A stitching event consists of a transition between a pair of observed states through a synthetic and highly probable action. New actions are introduced only when they are expected to be beneficial, according to an estimated state-value function. We show that using this data augmentation strategy jointly with behavioural cloning (BC) leads to improvements over the behaviour-cloned policy from the original dataset. Improving over the BC policy could then be used as a launchpad for online RL through planning and demonstration-guided RL.