Enhancing Distributional Stability among Sub-populations
This addresses the challenge of distributional shifts in machine learning for improving model robustness, though it appears incremental by building on causal and invariant learning approaches.
The paper tackles the problem of Out-of-Distribution (OOD) generalization by introducing a 'distributional stability' notion to quantify prediction stability among sub-populations, proposing a stable risk minimization (SRM) algorithm that shows effectiveness in experiments.
Enhancing the stability of machine learning algorithms under distributional shifts is at the heart of the Out-of-Distribution (OOD) Generalization problem. Derived from causal learning, recent works of invariant learning pursue strict invariance with multiple training environments. Although intuitively reasonable, strong assumptions on the availability and quality of environments are made to learn the strict invariance property. In this work, we come up with the ``distributional stability" notion to mitigate such limitations. It quantifies the stability of prediction mechanisms among sub-populations down to a prescribed scale. Based on this, we propose the learnability assumption and derive the generalization error bound under distribution shifts. Inspired by theoretical analyses, we propose our novel stable risk minimization (SRM) algorithm to enhance the model's stability w.r.t. shifts in prediction mechanisms ($Y|X$-shifts). Experimental results are consistent with our intuition and validate the effectiveness of our algorithm. The code can be found at https://github.com/LJSthu/SRM.