fmxcoders: Factorized Masked Crosscoders for Cross-Layer Feature Discovery
For mechanistic interpretability researchers, this work solves a key failure of cross-layer sparse autoencoders, enabling faithful recovery of multi-layer features in pretrained transformers.
Standard crosscoders fail to learn cross-layer features because their latents are driven by only one or two layers. The proposed fmxcoders use low-rank tensor factorizations and stochastic layer masking to enforce cross-layer sharing, achieving 10-30 point F1 gains, 25-50% MSE reduction, and 3-13x more semantically coherent latents across four LLMs.
Many features in pretrained Transformers span multiple layers: they emerge through stages of inference, persist in the residual stream, or are built jointly by parallel MLPs. Crosscoders (namely, sparse dictionaries trained jointly across layers) aim to recover these cross-layer features in a single shared latent space. We show that standard crosscoders largely fail at this purpose. Although their decoder weight norms spread evenly across layers, a functional coherence metric we introduce reveals that each latent's activation is effectively driven by only one or two layers on average. While functionally coherent latents act as human-interpretable concept detectors (e.g., US states and cities), the layer-localized latents that crosscoders predominantly learn collapse onto surface-level patterns such as digit detectors. We trace this failure to two structural limitations: unconstrained cross-layer parameterization and unregularized cross-layer dependence. We address both by introducing fmxcoders, which (i) replace the encoder and decoder with low-rank tensor factorizations that draw every latent's per-layer weights from a shared cross-layer basis, and (ii) apply stochastic layer masking, a denoising regularizer along the layer axis that penalizes latents whose contribution collapses when a single layer is masked. Across GPT2-Small, Pythia-410M, Pythia-1.4B, and Gemma2-2B, fmxcoders lift mean probing F1 by 10-30 points, surpassing per-layer SAE baselines that standard crosscoders fail to reach, reduce reconstruction MSE by 25-50%, and roughly double mean functional coherence. An LLM-as-a-judge evaluation further shows that fmxcoders recover 3-13$\times$ more semantically coherent latents than standard crosscoders across all four base LLMs.