Mitigating Spurious Correlations in LLMs via Causality-Aware Post-Training
This addresses the issue of poor generalization in LLMs for users relying on robust AI systems, representing a novel method for a known bottleneck rather than an incremental improvement.
The paper tackles the problem of spurious correlations in large language models (LLMs) that cause failures on out-of-distribution samples by introducing causality-aware post-training (CAPT), which reduces biases and improves generalization. Experiments show that 3B-scale models with CAPT outperform traditional fine-tuning and larger models on in-distribution and out-of-distribution tasks using only 100 fine-tuning samples.
While large language models (LLMs) have demonstrated remarkable capabilities in language modeling, recent studies reveal that they often fail on out-of-distribution (OOD) samples due to spurious correlations acquired during pre-training. Here, we aim to mitigate such spurious correlations through causality-aware post-training (CAPT). By decomposing a biased prediction into two unbiased steps, known as \textit{event estimation} and \textit{event intervention}, we reduce LLMs' pre-training biases without incurring additional fine-tuning biases, thus enhancing the model's generalization ability. Experiments on the formal causal inference benchmark CLadder and the logical reasoning dataset PrOntoQA show that 3B-scale language models fine-tuned with CAPT can outperform both traditional SFT and larger LLMs on in-distribution (ID) and OOD tasks using only 100 ID fine-tuning samples, demonstrating the effectiveness and sample efficiency of CAPT.