MechRL: Reinforcement Learning Agents Perform Circuit Discovery for Mechanistic Interpretability
This work provides a general, automated method for circuit discovery in transformer language models, reducing the need for bespoke analytical pipelines for each new task.
The authors reformulate circuit discovery in transformer language models as a reinforcement learning problem, where a PPO agent learns to identify causally important attention heads by performing zero-ablation actions. The agent achieves oracle-level performance on training tasks (induction and IOI) and recovers 96% of the oracle ceiling on a held-out task (docstring completion) without task-specific signal, while its discovered heads align with known mechanistic circuits.
Mechanistic interpretability has identified small sets of attention heads that implement specific behaviours in transformer language models, but recovering these circuits typically requires a bespoke analytical pipeline for each new task. We recast circuit discovery as a reinforcement-learning problem. An agent operates over the 144 attention heads of GPT-2 small as a discrete action space; each action triggers a zero-ablation and a contrastive reward that subtracts the ablation's damage to general next-token prediction from its damage to the target task. A single PPO policy, trained on two tasks (induction and IOI) in a vectorised multi-task environment, attains the per-episode oracle on both training tasks and on a held-out third task (docstring completion). Its preferred heads coincide with the canonical heads of established literature on precisely the axes those papers identify as causally non-redundant under single-head ablation; the categories they identify as redundant are correctly de-prioritised by the agent. On the held-out task, best-of-five planning recovers 96\% of the oracle ceiling with no task signal supplied at evaluation. These results indicate that reinforcement learning over causal interventions is a viable, transferable substrate for identifying the single-head bottlenecks of mechanistic circuits, complementary to existing path-patching approaches.