CausalVAE as a Plug-in for World Models: Towards Reliable Counterfactual Dynamics
This work addresses the need for more reliable and interpretable counterfactual predictions in latent world models, particularly for applications like physics simulations, but it is incremental as it builds on existing encoder-transition backbones.
The paper tackles the problem of improving counterfactual dynamics in world models by introducing CausalVAE as a plug-in module, resulting in significant gains such as a +102.5% average improvement in CF-H@1 on the Physics benchmark and an increase from 11.0 to 41.0 in a specific setting.
In this work, CausalVAE is introduced as a plug-in structural module for latent world models and is attached to diverse encoder-transition backbones. Across the reported benchmarks, competitive factual prediction is preserved and intervention-aware counterfactual retrieval is improved after the plug-in is added, suggesting stronger robustness under distribution shift and interventions. The largest gains are observed on the Physics benchmark: when averaged over 8 paired baselines, CF-H@1 is improved by +102.5%. In a representative GNN-NLL setting on Physics, CF-H@1 is increased from 11.0 to 41.0 (+272.7%). Through causal analysis, learned structural dependencies are shown to recover meaningful first-order physical interaction trends, supporting the interpretability of the learned latent causal structure.