Post-Training with Policy Gradients: Optimality and the Base Model Barrier
This research is significant for developers of large language models, as it identifies and offers a potential solution to a fundamental limitation in post-training models to generalize beyond their initial training data, which is an incremental improvement.
This paper investigates post-training linear autoregressive models using policy gradients with both outcome and process rewards. It demonstrates that policy gradient can achieve a likelihood of 1 - ε with a near-optimal number of reward queries when the base model has a non-trivial likelihood, but faces an exponential barrier in N for samples outside the base model's support. The study shows that using process rewards can overcome this N-dimensional barrier by leveraging a token-level likelihood quantile.
We study post-training linear autoregressive models with outcome and process rewards. Given a context $\boldsymbol{x}$, the model must predict the response $\boldsymbol{y} \in Y^N$, a sequence of length $N$ that satisfies a $γ$ margin condition, an extension of the standard separability to sequences. We prove that on test samples where the base model achieves a non-trivial likelihood $α$, a variant of policy gradient (PG) can achieve likelihood $1 - \varepsilon$ with an essentially minimax optimal number of reward queries $\tilde{O}((α^{-1} + \varepsilon^{-1})/γ^2)$. However, a barrier arises for going beyond the support of the base model. We prove that the overall expected error after post-training with outcome rewards is governed by a property of the base model called the Likelihood Quantile (LQ), and that variants of PG, while minimax optimal, may require a number of reward queries exponential in $N$ to go beyond this support, regardless of the pre-training algorithm. To overcome this barrier, we study post-training with a process reward model, and demonstrate how PG variants in this setting avoid the curse of dimensionality in $N$ via dependence on a token-level LQ. Along the way, we prove that under the margin condition, SGD with adaptive learning rate (LR) achieves a near optimal test error for statistical learning, and PG with adaptive LR achieves a near optimal number of mistakes for online learning while being computationally efficient whenever possible, both of which may be of independent interest.