Discovering Knowledge-Critical Subnetworks in Pretrained Language Models
This work addresses the challenge of interpretability and control in large language models for researchers and practitioners, though it is incremental in building on existing subnetwork discovery methods.
The authors tackled the problem of localizing and disentangling implicit knowledge representations in pretrained language models by discovering sparse knowledge-critical subnetworks, achieving over 98% sparsity to precisely suppress specific relational knowledge while minimizing adverse effects on model behavior.
Pretrained language models (LMs) encode implicit representations of knowledge in their parameters. However, localizing these representations and disentangling them from each other remains an open problem. In this work, we investigate whether pretrained language models contain various knowledge-critical subnetworks: particular sparse computational subgraphs that can, if removed, precisely suppress specific knowledge the model has memorized. We propose a multi-objective differentiable masking scheme that can be applied to both weights and neurons to discover such subnetworks and show that we can use them to precisely remove specific knowledge from models while minimizing adverse effects on the behavior of the original model. We demonstrate our method on multiple GPT2 variants, uncovering highly sparse subnetworks (98%+ sparsity) that are critical for expressing specific collections of relational knowledge. When these subnetworks are removed, the remaining network maintains most of its initial abilities but struggles to represent the suppressed knowledge.