The Belief State Transformer
This addresses the problem of inefficient goal-conditioned decoding and poor test-time inference in transformers for researchers and practitioners in NLP, though it appears incremental as it builds on existing transformer architectures with a novel objective.
The paper tackles the problem of conventional forward-only transformers struggling with challenging tasks by introducing the Belief State Transformer, which predicts both next and previous tokens using a compact belief state, outperforming the Fill-in-the-Middle method in story writing tasks with known and unknown goals.
We introduce the "Belief State Transformer", a next-token predictor that takes both a prefix and suffix as inputs, with a novel objective of predicting both the next token for the prefix and the previous token for the suffix. The Belief State Transformer effectively learns to solve challenging problems that conventional forward-only transformers struggle with, in a domain-independent fashion. Key to this success is learning a compact belief state that captures all relevant information necessary for accurate predictions. Empirical ablations show that each component of the model is essential in difficult scenarios where standard Transformers fall short. For the task of story writing with known prefixes and suffixes, our approach outperforms the Fill-in-the-Middle method for reaching known goals and demonstrates improved performance even when the goals are unknown. Altogether, the Belief State Transformer enables more efficient goal-conditioned decoding, better test-time inference, and high-quality text representations on small scale problems. Website: https://edwhu.github.io/bst-website