Explaining and Improving Model Behavior with k Nearest Neighbor Representations
This work addresses interpretability and robustness issues for NLP practitioners, but it is incremental as it applies an existing kNN method to new tasks like NLI.
The authors tackled the problem of understanding and improving model behavior in NLP by using k-nearest neighbor (kNN) representations to identify training examples influencing predictions and uncover spurious associations, resulting in improved robustness to adversarial inputs without updating model parameters.
Interpretability techniques in NLP have mainly focused on understanding individual predictions using attention visualization or gradient-based saliency maps over tokens. We propose using k nearest neighbor (kNN) representations to identify training examples responsible for a model's predictions and obtain a corpus-level understanding of the model's behavior. Apart from interpretability, we show that kNN representations are effective at uncovering learned spurious associations, identifying mislabeled examples, and improving the fine-tuned model's performance. We focus on Natural Language Inference (NLI) as a case study and experiment with multiple datasets. Our method deploys backoff to kNN for BERT and RoBERTa on examples with low model confidence without any update to the model parameters. Our results indicate that the kNN approach makes the finetuned model more robust to adversarial inputs.