Implicit Chain of Thought Reasoning via Knowledge Distillation
This addresses the challenge of inefficient reasoning in language models for tasks like math problems, though it is incremental by building on existing chain-of-thought methods.
The paper tackled the problem of enabling language models to reason without explicit chain-of-thought steps by using internal hidden states for implicit reasoning, distilled from a teacher model, and found that this approach solved tasks previously unsolvable without explicit chain-of-thought at comparable speed.
To augment language models with the ability to reason, researchers usually prompt or finetune them to produce chain of thought reasoning steps before producing the final answer. However, although people use natural language to reason effectively, it may be that LMs could reason more effectively with some intermediate computation that is not in natural language. In this work, we explore an alternative reasoning approach: instead of explicitly producing the chain of thought reasoning steps, we use the language model's internal hidden states to perform implicit reasoning. The implicit reasoning steps are distilled from a teacher model trained on explicit chain-of-thought reasoning, and instead of doing reasoning "horizontally" by producing intermediate words one-by-one, we distill it such that the reasoning happens "vertically" among the hidden states in different layers. We conduct experiments on a multi-digit multiplication task and a grade school math problem dataset and find that this approach enables solving tasks previously not solvable without explicit chain-of-thought, at a speed comparable to no chain-of-thought.