Demystifying Language Model Forgetting with Low-rank Example Associations
This work addresses forgetting in LLMs for practitioners, offering a targeted mitigation approach, though it is incremental as it builds on existing fine-tuning and replay techniques.
The paper tackles the problem of forgetting upstream knowledge in large language models (LLMs) during fine-tuning by analyzing dependencies between new tasks and forgotten examples, showing that these associations are low-rank and enabling efficient prediction of forgetting with matrix completion, which outperforms prior methods and reduces forgetting when used for replay.
Large language models (LLMs) suffer from forgetting of upstream knowledge when fine-tuned. Despite efforts on mitigating forgetting, few have investigated how forgotten upstream examples are dependent on newly learned tasks. Insights on such dependencies enable efficient and targeted mitigation of forgetting. In this paper, we empirically analyze forgetting that occurs in $N$ upstream examples of language modeling or instruction-tuning after fine-tuning LLMs on one of $M$ new tasks, visualized in $M\times N$ matrices. We show that the matrices are often well-approximated with low-rank matrices, indicating the dominance of simple associations between the learned tasks and forgotten upstream examples. Leveraging the analysis, we predict forgetting of upstream examples when fine-tuning LLMs on unseen tasks with matrix completion over the empirical associations. This enables fast identification of most forgotten examples without expensive inference on the entire upstream data. Despite simplicity, the approach outperforms prior approaches that learn semantic relationships of learned tasks and upstream examples with LMs. We demonstrate the practical utility of our analysis by showing statistically significantly reduced forgetting as we upweight predicted examples for replay during fine-tuning. Code, data, and statistics collected: https://github.com/AuCson/low-rank-forgetting