Subspace-Aware Sparse Autoencoders for Effective Mechanistic Interpretability
For researchers in mechanistic interpretability of large language models, this work addresses a fundamental limitation of SAEs by proposing a method that better captures the multi-dimensional nature of features, improving interpretability and efficiency.
Sparse Autoencoders (SAEs) for mechanistic interpretability assume features are one-dimensional, but model features are multi-dimensional, causing feature splitting. The authors propose Subspace-Aware SAEs (SASA) with learned decoder subspaces and block sparsity, which consolidate multi-dimensional features into single groups, reducing splitting and improving interpretability. On GPT-2 and Mistral-7B, SASA matches or exceeds standard SAEs while training on roughly half the token budget.
Sparse Autoencoders (SAEs) are widely used for mechanistic interpretability in large language models, yet their formulation assigns each latent feature a single decoder direction, implicitly assuming features to be one-dimensional. We show that this assumption mismatches with the multi-dimensional structure of model features, provably inducing feature splitting through two distinct mechanisms. Geometrically, reconstructing a feature of intrinsic dimension $d_i \ge 2$ to error $\varepsilon$ with single-direction decoders forces a number of atoms that is exponential in $d_i$. From an end-to-end optimization perspective, this splitting is not merely possible but actively preferred. We prove that there exists a continuous path from the true $d_i$-dimensional basis to a strictly lower risk of the $\ell_1$-regularized SAE objective, whose descent directions drive any trained dictionary into that exponential regime. A single coherent feature is therefore fragmented across many near-collinear latents, producing spurious multiplicity and obscuring the intrinsic geometry. Motivated by this, we introduce Subspace-Aware Sparse Autoencoders (SASA), which replace single-vector decoders with learned decoder subspaces, enforce block sparsity via Top-$s$ group gating, and adapt each group's effective rank with a nuclear-norm regularizer. We then show that once the block size satisfies $r \ge d_i$, a single group not only can represent the entire feature slice but is the global minimizer of the SASA objective. This consolidation yields a sample complexity polynomial in $d_i$ rather than exponential -- a decisive advantage given that every training activation costs an LLM forward pass. Empirically, on GPT-2 and Mistral-7B, SASA reduces feature splitting and absorption, improves monosemanticity and interpretability, and matches or exceeds standard SAEs while training on roughly half the token budget.