A Theory of Learning with Autoregressive Chain of Thought
This provides a theoretical foundation for chain-of-thought reasoning in AI, potentially benefiting researchers and practitioners in machine learning by formalizing learning complexities and introducing attention mechanisms.
The paper tackles the problem of learning prompt-to-answer mappings using autoregressive chain-of-thought processes, analyzing sample and computational complexity for both observed and latent chains, and presents a base class enabling universal representability and tractable learning with sample complexity independent of chain length.
For a given base class of sequence-to-next-token generators, we consider learning prompt-to-answer mappings obtained by iterating a fixed, time-invariant generator for multiple steps, thus generating a chain-of-thought, and then taking the final token as the answer. We formalize the learning problems both when the chain-of-thought is observed and when training only on prompt-answer pairs, with the chain-of-thought latent. We analyze the sample and computational complexity both in terms of general properties of the base class (e.g. its VC dimension) and for specific base classes such as linear thresholds. We present a simple base class that allows for universal representability and computationally tractable chain-of-thought learning. Central to our development is that time invariance allows for sample complexity that is independent of the length of the chain-of-thought. Attention arises naturally in our construction.