Metastable Dynamics of Chain-of-Thought Reasoning: Provable Benefits of Search, RL and Distillation
This work addresses the challenge of enhancing reasoning capabilities in AI systems, offering theoretical insights and practical methods for optimizing inference-time compute, though it is incremental in building on existing chain-of-thought and search paradigms.
The paper tackles the problem of improving reasoning in large language models by modeling chain-of-thought generation as a metastable Markov process, proving that search protocols reduce steps to reach reasoning clusters and enabling finetuning and distillation for better efficiency.
A key paradigm to improve the reasoning capabilities of large language models (LLMs) is to allocate more inference-time compute to search against a verifier or reward model. This process can then be utilized to refine the pretrained model or distill its reasoning patterns into more efficient models. In this paper, we study inference-time compute by viewing chain-of-thought (CoT) generation as a metastable Markov process: easy reasoning steps (e.g., algebraic manipulations) form densely connected clusters, while hard reasoning steps (e.g., applying a relevant theorem) create sparse, low-probability edges between clusters, leading to phase transitions at longer timescales. Under this framework, we prove that implementing a search protocol that rewards sparse edges improves CoT by decreasing the expected number of steps to reach different clusters. In contrast, we establish a limit on reasoning capability when the model is restricted to local information of the pretrained graph. We also show that the information gained by search can be utilized to obtain a better reasoning model: (1) the pretrained model can be directly finetuned to favor sparse edges via policy gradient methods, and moreover (2) a compressed metastable representation of the reasoning dynamics can be distilled into a smaller, more efficient model.