Discriminative Policy Optimization for Token-Level Reward Models
This addresses the problem of training instability and inefficiency in reinforcement learning for LLMs, particularly in complex reasoning tasks, though it is an incremental improvement over existing PRM methods.
The paper tackles the challenge of inaccurate credit assignment in token-level process reward models (PRMs) for LLMs by proposing Q-RM, a discriminative policy optimization method that decouples reward modeling from language generation. The result shows Q-RM consistently outperforms baselines, improving Pass@1 scores by up to 5.85 points on mathematical reasoning tasks and achieving convergence up to 12 times faster.
Process reward models (PRMs) provide more nuanced supervision compared to outcome reward models (ORMs) for optimizing policy models, positioning them as a promising approach to enhancing the capabilities of LLMs in complex reasoning tasks. Recent efforts have advanced PRMs from step-level to token-level granularity by integrating reward modeling into the training of generative models, with reward scores derived from token generation probabilities. However, the conflict between generative language modeling and reward modeling may introduce instability and lead to inaccurate credit assignments. To address this challenge, we revisit token-level reward assignment by decoupling reward modeling from language generation and derive a token-level reward model through the optimization of a discriminative policy, termed the Q-function Reward Model (Q-RM). We theoretically demonstrate that Q-RM explicitly learns token-level Q-functions from preference data without relying on fine-grained annotations. In our experiments, Q-RM consistently outperforms all baseline methods across various benchmarks. For example, when integrated into PPO/REINFORCE algorithms, Q-RM enhances the average Pass@1 score by 5.85/4.70 points on mathematical reasoning tasks compared to the ORM baseline, and by 4.56/5.73 points compared to the token-level PRM counterpart. Moreover, reinforcement learning with Q-RM significantly enhances training efficiency, achieving convergence 12 times faster than ORM on GSM8K and 11 times faster than step-level PRM on MATH. Code and data are available at https://github.com/homzer/Q-RM.