Backward Lens: Projecting Language Model Gradients into the Vocabulary Space
This work provides incremental insights into the mechanics of information storage in language models, primarily benefiting researchers in deep learning interpretability.
The authors tackled the problem of understanding how Transformer-based Language Models learn and recall information by extending interpretability methods to project gradients from the backward pass into vocabulary space, proving that gradients can be expressed as a low-rank linear combination of forward and backward inputs.
Understanding how Transformer-based Language Models (LMs) learn and recall information is a key goal of the deep learning community. Recent interpretability methods project weights and hidden states obtained from the forward pass to the models' vocabularies, helping to uncover how information flows within LMs. In this work, we extend this methodology to LMs' backward pass and gradients. We first prove that a gradient matrix can be cast as a low-rank linear combination of its forward and backward passes' inputs. We then develop methods to project these gradients into vocabulary items and explore the mechanics of how new information is stored in the LMs' neurons.