Offline Reinforcement Learning for LLM Multi-Step Reasoning
This work addresses the challenge of adapting LLMs to complex multi-step reasoning tasks, which is incremental as it builds on maximum entropy reinforcement learning to overcome limitations of methods like DPO.
The paper tackles the problem of improving multi-step reasoning in large language models (LLMs) by proposing OREO, an offline reinforcement learning method that reduces the need for paired preference data and enables better credit assignment, achieving superior performance on benchmarks like GSM8K, MATH, and ALFWorld compared to existing offline learning methods.
Improving the multi-step reasoning ability of large language models (LLMs) with offline reinforcement learning (RL) is essential for quickly adapting them to complex tasks. While Direct Preference Optimization (DPO) has shown promise in aligning LLMs with human preferences, it is less suitable for multi-step reasoning tasks because (1) DPO relies on paired preference data, which is not readily available for multi-step reasoning tasks, and (2) it treats all tokens uniformly, making it ineffective for credit assignment in multi-step reasoning tasks, which often come with sparse reward. In this work, we propose OREO (Offline Reasoning Optimization), an offline RL method for enhancing LLM multi-step reasoning. Building on insights from previous works of maximum entropy reinforcement learning, it jointly learns a policy model and value function by optimizing the soft Bellman Equation. We show in principle that it reduces the need to collect pairwise data and enables better credit assignment. Empirically, OREO surpasses existing offline learning methods on multi-step reasoning benchmarks, including mathematical reasoning tasks (GSM8K, MATH) and embodied agent control (ALFWorld). The approach can be extended to a multi-iteration framework when additional resources are available. Furthermore, the learned value function can be leveraged to guide the tree search for free, which can further boost performance during test time.