Modeling Multi-Task Model Merging as Adaptive Projective Gradient Descent
This addresses the challenge of task conflicts in model merging for multi-task learning, offering a data-free solution that improves performance, though it is incremental as it builds on existing merging methods.
The paper tackles the problem of merging multiple expert models for multi-task learning without original data by framing it as a constrained optimization problem, and the result is a plug-and-play approach that achieves state-of-the-art performance across diverse architectures and tasks in vision and NLP domains.
Merging multiple expert models offers a promising approach for performing multi-task learning without accessing their original data. Existing methods attempt to alleviate task conflicts by sparsifying task vectors or promoting orthogonality among them. However, they overlook the fundamental target of model merging: the merged model performs as closely as possible to task-specific models on respective tasks. We find these methods inevitably discard task-specific information that, while causing conflicts, is crucial for performance. Based on our findings, we frame model merging as a constrained optimization problem ($\textit{i.e.}$, minimizing the gap between the merged model and individual models, subject to the constraint of retaining shared knowledge) and solve it via adaptive projective gradient descent. Specifically, we align the merged model with individual models by decomposing and reconstituting the loss function, alleviating conflicts through $\textit{data-free}$ optimization of task vectors. To retain shared knowledge, we optimize this objective by projecting gradients within a $\textit{shared subspace}$ spanning all tasks. Moreover, we view merging coefficients as adaptive learning rates and propose a task-aware, training-free strategy. Experiments show that our plug-and-play approach consistently outperforms previous methods, achieving state-of-the-art results across diverse architectures and tasks in both vision and NLP domains.