StateX: Enhancing RNN Recall via Post-training State Expansion
This addresses a bottleneck for users of RNNs in tasks requiring long-context recall, offering an incremental improvement over existing methods.
The paper tackles the problem of recurrent neural networks (RNNs) struggling with accurate recall from long contexts due to limited state size, and introduces StateX, a post-training pipeline that efficiently expands state sizes, enhancing recall and in-context learning ability in models up to 1.3B parameters without high costs or compromising other capabilities.
While Transformer-based models have demonstrated remarkable language modeling performance, their high complexities result in high costs when processing long contexts. In contrast, recurrent neural networks (RNNs) such as linear attention and state space models have gained popularity due to their constant per-token complexities. However, these recurrent models struggle with tasks that require accurate recall of contextual information from long contexts, because all contextual information is compressed into a constant-size recurrent state. Previous works have shown that recall ability is positively correlated with the recurrent state size, yet directly training RNNs with larger recurrent states results in high training costs. In this paper, we introduce StateX, a training pipeline for efficiently expanding the states of pre-trained RNNs through post-training. For two popular classes of RNNs, linear attention and state space models, we design post-training architectural modifications to scale up the state size with no or negligible increase in model parameters. Experiments on models up to 1.3B parameters demonstrate that StateX efficiently enhances the recall and in-context learning ability of RNNs without incurring high post-training costs or compromising other capabilities.