LGAug 19, 2024Code
Transformers to SSMs: Distilling Quadratic Knowledge to Subquadratic ModelsAviv Bick, Kevin Y. Li, Eric P. Xing et al.
Transformer architectures have become a dominant paradigm for domains like language modeling but suffer in many inference settings due to their quadratic-time self-attention. Recently proposed subquadratic architectures, such as Mamba, have shown promise, but have been pretrained with substantially less computational resources than the strongest Transformer models. In this work, we present a method that is able to distill a pretrained Transformer architecture into alternative architectures such as state space models (SSMs). The key idea to our approach is that we can view both Transformers and SSMs as applying different forms of mixing matrices over the token sequences. We can thus progressively distill the Transformer architecture by matching different degrees of granularity in the SSM: first matching the mixing matrices themselves, then the hidden units at each block, and finally the end-to-end predictions. Our method, called MOHAWK, is able to distill a Mamba-2 variant based on the Phi-1.5 architecture (Phi-Mamba) using only 3B tokens and a hybrid version (Hybrid Phi-Mamba) using 5B tokens. Despite using less than 1% of the training data typically used to train models from scratch, Phi-Mamba boasts substantially stronger performance compared to all past open-source non-Transformer models. MOHAWK allows models like SSMs to leverage computational resources invested in training Transformer-based architectures, highlighting a new avenue for building such models.
LGMar 16
Mamba-3: Improved Sequence Modeling using State Space PrinciplesAakash Lahoti, Kevin Y. Li, Berlin Chen et al.
Scaling inference-time compute has emerged as an important driver of LLM performance, making inference efficiency a central focus of model design alongside model quality. While the current Transformer-based models deliver strong model quality, their quadratic compute and linear memory make inference expensive. This has spurred the development of sub-quadratic models with reduced linear compute and constant memory requirements. However, many recent linear models trade off model quality and capability for algorithmic efficiency, failing on tasks such as state tracking. Moreover, their theoretically linear inference remains hardware-inefficient in practice. Guided by an inference-first perspective, we introduce three core methodological improvements inspired by the state space model (SSM) viewpoint of linear models. We combine: (1) a more expressive recurrence derived from SSM discretization, (2) a complex-valued state update rule that enables richer state tracking, and (3) a multi-input, multi-output (MIMO) formulation for better model performance without increasing decode latency. Together with architectural refinements, our Mamba-3 model achieves significant gains across retrieval, state-tracking, and downstream language modeling tasks. At the 1.5B scale, Mamba-3 improves average downstream accuracy by 0.6 percentage points compared to the next best model (Gated DeltaNet), with Mamba-3's MIMO variant further improving accuracy by another 1.2 points for a total 1.8 point gain. Across state-size experiments, Mamba-3 achieves comparable perplexity to Mamba-2 despite using half of its predecessor's state size. Our evaluations demonstrate Mamba-3's ability to advance the performance-efficiency Pareto frontier.
CLFeb 27, 2025
Thinking Slow, Fast: Scaling Inference Compute with Distilled ReasonersDaniele Paliotta, Junxiong Wang, Matteo Pagliardini et al.
Recent advancements have demonstrated that the performance of large language models (LLMs) can be significantly enhanced by scaling computational resources at test time. A common strategy involves generating multiple Chain-of-Thought (CoT) trajectories and aggregating their outputs through various selection mechanisms. This raises a fundamental question: can models with lower complexity leverage their superior generation throughput to outperform similarly sized Transformers for a fixed computational budget? To address this question and overcome the lack of strong subquadratic reasoners, we distill pure and hybrid Mamba models from pretrained Transformers. Trained on only 8 billion tokens, our distilled models show strong performance and scaling on mathematical reasoning datasets while being much faster at inference for large batches and long sequences. Despite the zero-shot performance hit due to distillation, both pure and hybrid Mamba models can scale their coverage and accuracy performance past their Transformer teacher models under fixed time budgets, opening a new direction for scaling inference compute.
LGFeb 20, 2025
Llamba: Scaling Distilled Recurrent Models for Efficient Language ProcessingAviv Bick, Tobias Katsch, Nimit Sohoni et al.
We introduce Llamba, a family of efficient recurrent language models distilled from Llama-3.x into the Mamba architecture. The series includes Llamba-1B, Llamba-3B, and Llamba-8B, which achieve higher inference throughput and handle significantly larger batch sizes than Transformer-based models while maintaining comparable benchmark performance. Furthermore, Llamba demonstrates the effectiveness of cross-architecture distillation using MOHAWK (Bick et al., 2024), achieving these results with less than 0.1% of the training data typically used for models of similar size. To take full advantage of their efficiency, we provide an optimized implementation of Llamba for resource-constrained devices such as smartphones and edge platforms, offering a practical and memory-efficient alternative to Transformers. Overall, Llamba improves the tradeoff between speed, memory efficiency, and performance, making high-quality language models more accessible.
LGApr 22, 2025
Understanding the Skill Gap in Recurrent Language Models: The Role of the Gather-and-Aggregate MechanismAviv Bick, Eric Xing, Albert Gu
State-space models (SSMs) offer efficient alternatives to Transformers for long sequences, but their fixed-size recurrent state limits capability on algorithmic tasks, such as retrieving past context. In this work, we examine how in-context retrieval operates in Transformer- and SSM-based language models and find that both rely on a similar Gather-and-Aggregate (G&A) mechanism: a Gather Head extracts relevant information pieces from context, which an Aggregate Head integrates into a single representation. In both architectures, G&A concentrates in a few heads, forming critical bottlenecks even for simple retrieval. For example, we show that disabling a single Gather or Aggregate Head in a pruned Llama-3.1-8B impairs retrieving the correct answer letter in MMLU, reducing its accuracy from 66% to 25% (random guessing). Moreover, this retrieval bottleneck can obscure limited knowledge demands of tasks as the pruned model succeeds on MMLU with functioning G&A heads yet fails on other knowledge benchmarks. The bottleneck similarly extends to tasks where SSMs typically underperform, such as GSM8K, BBH, and dialogue comprehension. We show that SSMs' retrieval challenges manifest in these heads, creating smoother attention patterns instead of the sharp token transitions effective G&A requires. Thus, the Transformer-SSM retrieval gap exists in just a few heads, rather than the entire language model. This suggests a unified explanation for Transformer vs. SSM performance gap while showing how to merge their strengths. We find that pretrained hybrid models, where SSMs are combined with a few attention layers, delegate the role of Aggregate Heads to attention. Similarly, replacing a single G&A head in a pretrained SSM with an attention variant boosts retrieval and benchmark scores.
LGFeb 11
Retrieval-Aware Distillation for Transformer-SSM HybridsAviv Bick, Eric P. Xing, Albert Gu
State-space models (SSMs) offer efficient sequence modeling but lag behind Transformers on benchmarks that require in-context retrieval. Prior work links this gap to a small set of attention heads, termed Gather-and-Aggregate (G&A), which SSMs struggle to reproduce. We propose *retrieval-aware distillation*, which converts a pretrained Transformer into a hybrid student by preserving only these retrieval-critical heads and distilling the rest into recurrent heads. We identify the essential heads via ablation on a synthetic retrieval task, producing a hybrid with sparse, non-uniform attention placement. We show that preserving **just 2% of attention heads recovers over 95% of teacher performance on retrieval-heavy tasks** (10 heads in a 1B model), requiring far fewer heads than hybrids that retain at least 25%. We further find that large recurrent states often compensate for missing retrieval: once retrieval is handled by these heads, the SSM backbone can be simplified with limited loss, even with an $8\times$ reduction in state dimension. By reducing both the attention cache and the SSM state, the resulting hybrid is $5$--$6\times$ more memory-efficient than comparable hybrids, closing the Transformer--SSM gap at a fraction of the memory cost.