TAPIOCA: Why Task- Aware Pruning Improves OOD model Capability
Provides a geometric explanation for why task-aware pruning benefits OOD generalization, addressing a key challenge in deploying models on shifted distributions.
Task-aware layer pruning consistently improves out-of-distribution (OOD) accuracy without affecting in-distribution (ID) performance, across polynomial regression and large language models. The improvement is attributed to removing layers that distort the task-adapted geometry for OOD inputs.
Recent work has promoted task-aware layer pruning as a way to improve model performance on particular tasks, as shown by TALE. In this paper, we investigate when such improvements occur and why. We show first that, across controlled polynomial regression tasks and large language models, such pruning yields no benefit on in-distribution (ID) data but consistently improves out-of-distribution (OOD) accuracy. We further show empirically that OOD inputs induce layerwise norm and pairwise-distance profiles that deviate from the corresponding ID profiles. This leads to a geometric explanation of task-aware pruning: each task induces a task-adapted geometry, characterized empirically by the representation profiles observed on ID inputs. OOD inputs can introduce a distorted version of the task-adapted geometry. Task-aware pruning identifies layers that create or amplify this distortion; by removing them, it shifts OOD representational norms and pairwise distances toward those observed on the adapted distribution. This realigns OOD inputs with the model's task-adapted geometry and improves performance. We provide causal evidence through controlled distribution shifts and residual-scaling interventions, and demonstrate consistent behavior across model scales.