LGNov 4, 2024

SALSA: Soup-based Alignment Learning for Stronger Adaptation in RLHF

UW
arXiv:2411.01798v12 citationsh-index: 37
Originality Incremental advance
AI Analysis

This addresses the problem of suboptimal alignment in LLMs for developers and researchers, offering a novel method to enhance exploration and performance, though it is incremental as it builds on existing RLHF techniques.

The paper tackles the limitation of RLHF in LLMs, where traditional KL divergence constraints restrict exploration, by introducing SALSA, a method that uses weight-space averaging of SFT models to create a flexible reference model, resulting in higher rewards and improved performance across benchmarks like MT-Bench and Arena-Hard.

In Large Language Model (LLM) development, Reinforcement Learning from Human Feedback (RLHF) is crucial for aligning models with human values and preferences. RLHF traditionally relies on the Kullback-Leibler (KL) divergence between the current policy and a frozen initial policy as a reference, which is added as a penalty in policy optimization algorithms like Proximal Policy Optimization (PPO). While this constraint prevents models from deviating too far from the initial checkpoint, it limits exploration of the reward landscape, reducing the model's ability to discover higher-quality solutions. As a result, policy optimization is often trapped in a narrow region of the parameter space, leading to suboptimal alignment and performance. This paper presents SALSA (Soup-based Alignment Learning for Stronger Adaptation), a novel approach designed to overcome these limitations by creating a more flexible and better located reference model through weight-space averaging of two independent supervised fine-tuned (SFT) models. This model soup allows for larger deviation in KL divergence and exploring a promising region of the solution space without sacrificing stability. By leveraging this more robust reference model, SALSA fosters better exploration, achieving higher rewards and improving model robustness, out-of-distribution generalization, and performance. We validate the effectiveness of SALSA through extensive experiments on popular open models (Llama2-7B, Mistral-7B, and Gemma-2B) across various benchmarks (MT-Bench, Arena-Hard, UltraFeedback), where it consistently surpasses PPO by fostering deeper exploration and achieving superior alignment in LLMs.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes