EMC$^2$: Efficient MCMC Negative Sampling for Contrastive Learning with Global Convergence
This work addresses a key computational bottleneck in contrastive learning for machine learning practitioners, offering a globally convergent method that is effective with small batch sizes, though it is incremental as it builds on prior frameworks like SogCLR.
The paper tackles the challenge of efficiently generating negative samples in contrastive learning by proposing EMC^2, an efficient Markov Chain Monte Carlo method that uses an adaptive Metropolis-Hastings subroutine for online, hardness-aware sampling, and proves it achieves an O(1/√T)-stationary point in T iterations with low computational cost.
A key challenge in contrastive learning is to generate negative samples from a large sample set to contrast with positive samples, for learning better encoding of the data. These negative samples often follow a softmax distribution which are dynamically updated during the training process. However, sampling from this distribution is non-trivial due to the high computational costs in computing the partition function. In this paper, we propose an Efficient Markov Chain Monte Carlo negative sampling method for Contrastive learning (EMC$^2$). We follow the global contrastive learning loss as introduced in SogCLR, and propose EMC$^2$ which utilizes an adaptive Metropolis-Hastings subroutine to generate hardness-aware negative samples in an online fashion during the optimization. We prove that EMC$^2$ finds an $\mathcal{O}(1/\sqrt{T})$-stationary point of the global contrastive loss in $T$ iterations. Compared to prior works, EMC$^2$ is the first algorithm that exhibits global convergence (to stationarity) regardless of the choice of batch size while exhibiting low computation and memory cost. Numerical experiments validate that EMC$^2$ is effective with small batch training and achieves comparable or better performance than baseline algorithms. We report the results for pre-training image encoders on STL-10 and Imagenet-100.