One-shot World Models Using a Transformer Trained on a Synthetic Prior
This work addresses the need for more flexible world models in reinforcement learning by enabling learning from synthetic data, but it is incremental as it shows limited success in complex environments.
The authors tackled the problem of world models requiring real-world data by proposing OSWM, a transformer trained on synthetic data from a prior distribution, which adapts to simple environments like CartPole with 1k transition steps and trains agent policies, though transfer to complex environments remains challenging.
A World Model is a compressed spatial and temporal representation of a real world environment that allows one to train an agent or execute planning methods. However, world models are typically trained on observations from the real world environment, and they usually do not enable learning policies for other real environments. We propose One-Shot World Model (OSWM), a transformer world model that is learned in an in-context learning fashion from purely synthetic data sampled from a prior distribution. Our prior is composed of multiple randomly initialized neural networks, where each network models the dynamics of each state and reward dimension of a desired target environment. We adopt the supervised learning procedure of Prior-Fitted Networks by masking next-state and reward at random context positions and query OSWM to make probabilistic predictions based on the remaining transition context. During inference time, OSWM is able to quickly adapt to the dynamics of a simple grid world, as well as the CartPole gym and a custom control environment by providing 1k transition steps as context and is then able to successfully train environment-solving agent policies. However, transferring to more complex environments remains a challenge, currently. Despite these limitations, we see this work as an important stepping-stone in the pursuit of learning world models purely from synthetic data.