Understanding Forgetting in LLM Supervised Fine-Tuning and Preference Learning -- A Convex Optimization Perspective
This addresses a critical inefficiency in LLM development for practitioners, offering a practical solution to improve model performance without significant extra cost, though it is incremental as it builds on existing post-training methods.
The paper tackles the suboptimal trade-off between supervised fine-tuning (SFT) and preference learning (RLHF/DPO) in LLM post-training, where sequential training causes forgetting, and proposes a joint framework that achieves up to 23% overall performance improvement across benchmarks with minimal computational overhead.
The post-training of LLMs, which typically consists of the supervised fine-tuning (SFT) stage and the preference learning stage (RLHF or DPO), is crucial to effective and safe LLM applications. The widely adopted approach in post-training popular open-source LLMs is to sequentially perform SFT and RLHF/DPO. However, this is suboptimal in terms of SFT and RLHF/DPO trade-off: the LLM gradually forgets about the first stage's training when undergoing the second stage's training. This sequential paradigm persists largely due to its simplicity and modularity, which make it easier to implement and manage at scale despite its limitations. We theoretically prove the sub-optimality of sequential post-training and propose a practical joint post-training framework which has theoretical convergence guarantees and empirically outperforms sequential post-training framework, with up to 23% overall performance improvement across multiple LLM evaluation benchmarks, while having minimal computational overhead. Our code is available at https://github.com/heshandevaka/XRIGHT.