Parallel Token Prediction for Language Models
This addresses the inference speed bottleneck for users of large language models, offering a practical improvement.
The paper tackles the slow autoregressive decoding in language models by proposing Parallel Token Prediction (PTP), a framework that predicts multiple tokens in a single forward pass, achieving a 2.4x speedup on a speculative decoding benchmark.
Autoregressive decoding in language models is inherently slow, generating only one token per forward pass. We propose Parallel Token Prediction (PTP), a general-purpose framework for predicting multiple tokens in a single model call. PTP moves the source of randomness from post-hoc sampling to random input variables, making future tokens deterministic functions of those inputs and thus jointly predictable in a single forward pass. We prove that a single PTP call can represent arbitrary dependencies between tokens. PTP is trained by distilling an existing model or through inverse autoregressive training without a teacher. Experimentally, PTP achieves a 2.4x speedup on a diverse-task speculative decoding benchmark. We provide code and checkpoints at https://github.com/mandt-lab/ptp.