Identifying Sparsely Active Circuits Through Local Loss Landscape Decomposition
This work addresses the challenge of mechanistic interpretability for researchers by providing a tool to uncover underlying circuits, though it appears incremental as it builds on existing interpretability methods.
The paper tackled the problem of understanding neural network circuits by introducing Local Loss Landscape Decomposition (L3D), a method that identifies low-rank subnetworks in parameter space, and demonstrated its effectiveness in recovering known subnetworks in toy models and applying it to real-world models like transformers and CNNs.
Much of mechanistic interpretability has focused on understanding the activation spaces of large neural networks. However, activation space-based approaches reveal little about the underlying circuitry used to compute features. To better understand the circuits employed by models, we introduce a new decomposition method called Local Loss Landscape Decomposition (L3D). L3D identifies a set of low-rank subnetworks: directions in parameter space of which a subset can reconstruct the gradient of the loss between any sample's output and a reference output vector. We design a series of progressively more challenging toy models with well-defined subnetworks and show that L3D can nearly perfectly recover the associated subnetworks. Additionally, we investigate the extent to which perturbing the model in the direction of a given subnetwork affects only the relevant subset of samples. Finally, we apply L3D to a real-world transformer model and a convolutional neural network, demonstrating its potential to identify interpretable and relevant circuits in parameter space.