Latent Thought Models with Variational Bayes Inference-Time Computation
This work addresses the challenge of scaling and efficiency in language models for AI applications, offering a novel approach that could enhance text generation and reasoning capabilities.
The authors tackled the problem of improving language model efficiency and performance by introducing Latent Thought Models (LTMs), which incorporate latent thought vectors with variational Bayes inference-time computation, resulting in superior sample and parameter efficiency, outperforming autoregressive and discrete diffusion models in validation perplexity and zero-shot tasks.
We propose a novel class of language models, Latent Thought Models (LTMs), which incorporate explicit latent thought vectors that follow an explicit prior model in latent space. These latent thought vectors guide the autoregressive generation of ground tokens through a Transformer decoder. Training employs a dual-rate optimization process within the classical variational Bayes framework: fast learning of local variational parameters for the posterior distribution of latent vectors (inference-time computation), and slow learning of global decoder parameters. Empirical studies reveal that LTMs possess additional scaling dimensions beyond traditional Large Language Models (LLMs), such as the number of iterations in inference-time computation and number of latent thought vectors. Higher sample efficiency can be achieved by increasing training compute per token, with further gains possible by trading model size for more inference steps. Designed based on these scaling properties, LTMs demonstrate superior sample and parameter efficiency compared to autoregressive models and discrete diffusion models. They significantly outperform these counterparts in validation perplexity and zero-shot language modeling tasks. Additionally, LTMs exhibit emergent few-shot in-context reasoning capabilities that scale with model size, and achieve competitive performance in conditional and unconditional text generation.