CLAILGNov 1, 2024

Birdie: Advancing State Space Models with Reward-Driven Objectives and Curricula

arXiv:2411.01030v52 citationsh-index: 29Has Code
Originality Highly original
AI Analysis

This addresses a key limitation in efficient SSMs for retrieval-intensive applications, offering a new training-based direction rather than architectural changes.

The paper tackles the problem of state space models (SSMs) struggling with long-range in-context retrieval tasks like text copying and question answering, and shows that Birdie, a novel training procedure, significantly improves performance on such tasks while maintaining computational efficiency, narrowing the gap with Transformers.

Efficient state space models (SSMs), such as linear recurrent neural networks and linear attention variants, offer computational advantages over Transformers but struggle with tasks requiring long-range in-context retrieval-like text copying, associative recall, and question answering over long contexts. Previous efforts to address these challenges have focused on architectural modifications, often reintroducing computational inefficiencies. In this paper, we propose a novel training procedure, Birdie, that significantly enhances the in-context retrieval capabilities of SSMs without altering their architecture. Our approach combines bidirectional input processing with dynamic mixtures of specialized pre-training objectives, optimized via reinforcement learning. We introduce a new bidirectional SSM architecture that seamlessly transitions from bidirectional context processing to causal generation. Experimental evaluations demonstrate that Birdie markedly improves performance on retrieval-intensive tasks such as multi-number phone book lookup, long paragraph question-answering, and infilling. This narrows the performance gap with Transformers, while retaining computational efficiency. Our findings highlight the importance of training procedures in leveraging the fixed-state capacity of SSMs, offering a new direction to advance their capabilities. All code and pre-trained models are available at https://www.github.com/samblouir/birdie, with support for JAX and PyTorch.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes