Beyond Uniform Credit: Causal Credit Assignment for Policy Optimization
This addresses a specific bottleneck in policy optimization for language models, offering an incremental improvement over existing methods.
The paper tackles the problem of uniform credit assignment in policy gradient methods for language model reasoning, where all tokens receive equal gradient updates regardless of importance. The result is that their proposed counterfactual importance weighting method improves performance on GSM8K across multiple models, leading to faster convergence to equivalent accuracy.
Policy gradient methods for language model reasoning, such as GRPO and DAPO, assign uniform credit to all generated tokens - the filler phrase "Let me think" receives the same gradient update as the critical calculation "23 + 45 = 68." We propose counterfactual importance weighting: mask reasoning spans, measure the drop in answer probability, and upweight tokens accordingly during policy gradient updates. Our method requires no auxiliary models or external annotation, instead importance is estimated directly from the policy model's own probability shifts. Experiments on GSM8K across three models spanning the Qwen and Llama families demonstrate consistent improvements over uniform baselines and faster convergence to equivalent accuracy. Inverting the importance signal hurts performance, confirming we capture genuine causal structure rather than noise. Analysis shows the method correctly prioritizes calculation steps over scaffolding text. We view these findings as establishing counterfactual importance weighting as a foundation for further research rather than a complete solution.