Structured Pruning for Diverse Best-of-N Reasoning Optimization
This work addresses the challenge of optimizing reasoning capabilities in language models for tasks like mathematical problem-solving, offering a novel approach that could benefit AI applications requiring enhanced reasoning, though it appears incremental as it builds on existing pruning techniques.
The paper tackles the problem of improving reasoning performance in transformer-based language models by discovering that selective pruning of attention heads enhances reasoning, and proposes SPRINT, a contrastive learning framework that dynamically selects optimal heads and layers to prune during inference, achieving significant performance gains over traditional strategies on MATH500 and GSM8K datasets.
Model pruning in transformer-based language models, traditionally viewed as a means of achieving computational savings, can enhance the model's reasoning capabilities. In this work, we uncover a surprising phenomenon: the selective pruning of certain attention heads leads to improvements in reasoning performance, particularly on challenging tasks. Motivated by this observation, we propose SPRINT, a novel contrastive learning framework that dynamically selects the optimal head and layer to prune during inference. By aligning question embeddings with head embeddings, SPRINT identifies those pruned-head configurations that result in more accurate reasoning. Extensive experiments demonstrate that our method significantly outperforms traditional best-of-$N$ and random head selection strategies on the MATH500 and GSM8K datasets.