Model Predictive Control with Differentiable World Models for Offline Reinforcement Learning
This addresses the challenge of improving policy performance in offline RL without environment interactions, though it is incremental as it builds on existing world model and MPC methods.
The paper tackles the problem of offline reinforcement learning by introducing an inference-time adaptation framework that uses a differentiable world model with model predictive control to optimize policy parameters on the fly, showing consistent gains over strong baselines on D4RL benchmarks.
Offline Reinforcement Learning (RL) aims to learn optimal policies from fixed offline datasets, without further interactions with the environment. Such methods train an offline policy (or value function), and apply it at inference time without further refinement. We introduce an inference time adaptation framework inspired by model predictive control (MPC) that utilizes a pretrained policy along with a learned world model of state transitions and rewards. While existing world model and diffusion-planning methods use learned dynamics to generate imagined trajectories during training, or to sample candidate plans at inference time, they do not use inference-time information to optimize the policy parameters on the fly. In contrast, our design is a Differentiable World Model (DWM) pipeline that enables endto-end gradient computation through imagined rollouts for policy optimization at inference time based on MPC. We evaluate our algorithm on D4RL continuous-control benchmarks (MuJoCo locomotion tasks and AntMaze), and show that exploiting inference-time information to optimize the policy parameters yields consistent gains over strong offline RL baselines.