Scaling Up Influence Functions
This work addresses a scalability bottleneck for researchers and practitioners in machine learning who need to track predictions back to training data in large models, representing an incremental improvement in computational efficiency.
The paper tackles the problem of efficiently calculating influence functions for large Transformer models by proposing a new approach based on Arnoldi iteration to speed up inverse Hessian calculations, achieving the first successful implementation that scales to models with hundreds of millions of parameters and datasets with tens to hundreds of millions of examples.
We address efficient calculation of influence functions for tracking predictions back to the training data. We propose and analyze a new approach to speeding up the inverse Hessian calculation based on Arnoldi iteration. With this improvement, we achieve, to the best of our knowledge, the first successful implementation of influence functions that scales to full-size (language and vision) Transformer models with several hundreds of millions of parameters. We evaluate our approach on image classification and sequence-to-sequence tasks with tens to a hundred of millions of training examples. Our code will be available at https://github.com/google-research/jax-influence.