Gaussian Embeddings: How JEPAs Secretly Learn Your Data Density
This provides a novel, efficient way to extract density estimates from existing self-supervised learning models, benefiting researchers and practitioners in data analysis and machine learning.
The paper reveals that the anti-collapse term in Joint Embedding Predictive Architectures (JEPAs) implicitly estimates data density, enabling trained JEPAs to compute sample probabilities for tasks like outlier detection and density estimation, with validation across diverse datasets and models.
Joint Embedding Predictive Architectures (JEPAs) learn representations able to solve numerous downstream tasks out-of-the-box. JEPAs combine two objectives: (i) a latent-space prediction term, i.e., the representation of a slightly perturbed sample must be predictable from the original sample's representation, and (ii) an anti-collapse term, i.e., not all samples should have the same representation. While (ii) is often considered as an obvious remedy to representation collapse, we uncover that JEPAs' anti-collapse term does much more--it provably estimates the data density. In short, any successfully trained JEPA can be used to get sample probabilities, e.g., for data curation, outlier detection, or simply for density estimation. Our theoretical finding is agnostic of the dataset and architecture used--in any case one can compute the learned probabilities of sample $x$ efficiently and in closed-form using the model's Jacobian matrix at $x$. Our findings are empirically validated across datasets (synthetic, controlled, and Imagenet) and across different Self Supervised Learning methods falling under the JEPA family (I-JEPA and DINOv2) and on multimodal models, such as MetaCLIP. We denote the method extracting the JEPA learned density as {\bf JEPA-SCORE}.