On the Out-of-Distribution Generalization of Self-Supervised Learning
This addresses the issue of poor generalization in self-supervised learning for machine learning practitioners, offering an incremental improvement through a novel sampling method.
The paper tackles the problem of out-of-distribution (OOD) generalization in self-supervised learning by identifying that SSL learns spurious correlations, which reduces OOD performance, and proposes a batch sampling strategy based on a post-intervention distribution to mitigate this, achieving improved results on various downstream OOD tasks.
In this paper, we focus on the out-of-distribution (OOD) generalization of self-supervised learning (SSL). By analyzing the mini-batch construction during the SSL training phase, we first give one plausible explanation for SSL having OOD generalization. Then, from the perspective of data generation and causal inference, we analyze and conclude that SSL learns spurious correlations during the training process, which leads to a reduction in OOD generalization. To address this issue, we propose a post-intervention distribution (PID) grounded in the Structural Causal Model. PID offers a scenario where the spurious variable and label variable is mutually independent. Besides, we demonstrate that if each mini-batch during SSL training satisfies PID, the resulting SSL model can achieve optimal worst-case OOD performance. This motivates us to develop a batch sampling strategy that enforces PID constraints through the learning of a latent variable model. Through theoretical analysis, we demonstrate the identifiability of the latent variable model and validate the effectiveness of the proposed sampling strategy. Experiments conducted on various downstream OOD tasks demonstrate the effectiveness of the proposed sampling strategy.