Robustness to corruption in pre-trained Bayesian neural networks
This addresses robustness issues in BNNs for machine learning practitioners, but it is incremental as it builds on existing training-data-dependent priors and uses extra test information.
The paper tackles the problem of robustness to corruption in Bayesian neural networks (BNNs) by developing ShiftMatch, a training-data-dependent likelihood that matches test-time spatial correlations to training-time ones without altering the network's training likelihood, enabling use of pre-trained BNN samples. The result is strong performance improvements on CIFAR-10-C, outperforming prior methods like EmpCov priors and potentially being the first Bayesian method to convincingly beat plain deep ensembles.
We develop ShiftMatch, a new training-data-dependent likelihood for robustness to corruption in Bayesian neural networks (BNNs). ShiftMatch is inspired by the training-data-dependent "EmpCov" priors from Izmailov et al. (2021a), and efficiently matches test-time spatial correlations to those at training time. Critically, ShiftMatch is designed to leave the neural network's training time likelihood unchanged, allowing it to use publicly available samples from pre-trained BNNs. Using pre-trained HMC samples, ShiftMatch gives strong performance improvements on CIFAR-10-C, outperforms EmpCov priors (though ShiftMatch uses extra information from a minibatch of corrupted test points), and is perhaps the first Bayesian method capable of convincingly outperforming plain deep ensembles.