Robust Representation Learning through Explicit Environment Modeling
For machine learning practitioners dealing with distribution shifts across environments, this work provides a practical alternative to invariant learning when causal assumptions are violated.
The paper addresses robust prediction across unseen environments when the environment directly affects the target, a scenario where causal invariant-representation methods fail. The proposed method, based on generalized random-intercept models, outperforms invariant-learning methods across challenging settings.
We consider learning from labeled data collected across multiple environments, where the data distribution may vary across these environments. This problem is commonly approached from a causal perspective, seeking invariant representations that retain causal factors while discarding spurious ones. However, this framework assumes that the environment has no direct effect on the target. In contrast, we consider settings in which this assumption fails, but still aim to learn representations that support robust prediction on average across previously unseen environments. To this end, we study representations learned by explicitly modeling variation across environments and then marginalizing that variation out. We analyze the resulting representations and characterize when they are preferable to those learned by causal invariant-representation methods. We propose a concrete method based on generalized random-intercept models, a class of predictors in which such marginalization is possible, and study their generalization properties. Empirically, we show that these models outperform invariant-learning methods across a range of challenging settings.