LLM-JEPA: Large Language Models Meet Joint Embedding Predictive Architectures
This work addresses the challenge of enhancing LLM training efficiency and performance for natural language processing tasks, representing an incremental step by applying vision-inspired methods to language.
The paper tackles the problem of improving large language model (LLM) training by adapting joint embedding predictive architectures (JEPAs) from vision to language, resulting in LLM-JEPA, which significantly outperforms standard LLM training objectives across multiple datasets and models while being robust to overfitting.
Large Language Model (LLM) pretraining, finetuning, and evaluation rely on input-space reconstruction and generative capabilities. Yet, it has been observed in vision that embedding-space training objectives, e.g., with Joint Embedding Predictive Architectures (JEPAs), are far superior to their input-space counterpart. That mismatch in how training is achieved between language and vision opens up a natural question: {\em can language training methods learn a few tricks from the vision ones?} The lack of JEPA-style LLM is a testimony of the challenge in designing such objectives for language. In this work, we propose a first step in that direction where we develop LLM-JEPA, a JEPA based solution for LLMs applicable both to finetuning and pretraining. Thus far, LLM-JEPA is able to outperform the standard LLM training objectives by a significant margin across models, all while being robust to overfiting. Those findings are observed across numerous datasets (NL-RX, GSM8K, Spider, RottenTomatoes) and various models from the Llama3, OpenELM, Gemma2 and Olmo families. Code: https://github.com/rbalestr-lab/llm-jepa.