Variational Inference Failures Under Model Symmetries: Permutation Invariant Posteriors for Bayesian Neural Networks
This addresses a specific challenge in Bayesian deep learning for researchers and practitioners, offering an incremental improvement by mitigating biases in VI for symmetric models.
The paper tackles the problem of variational inference (VI) failures in Bayesian neural networks due to weight space permutation symmetries, which cause biases in approximate posteriors and degrade predictive performance. The result is a symmetrization mechanism that constructs permutation invariant variational posteriors, leading to improved predictions and a higher ELBO, with demonstrated better fit to the true posterior.
Weight space symmetries in neural network architectures, such as permutation symmetries in MLPs, give rise to Bayesian neural network (BNN) posteriors with many equivalent modes. This multimodality poses a challenge for variational inference (VI) techniques, which typically rely on approximating the posterior with a unimodal distribution. In this work, we investigate the impact of weight space permutation symmetries on VI. We demonstrate, both theoretically and empirically, that these symmetries lead to biases in the approximate posterior, which degrade predictive performance and posterior fit if not explicitly accounted for. To mitigate this behavior, we leverage the symmetric structure of the posterior and devise a symmetrization mechanism for constructing permutation invariant variational posteriors. We show that the symmetrized distribution has a strictly better fit to the true posterior, and that it can be trained using the original ELBO objective with a modified KL regularization term. We demonstrate experimentally that our approach mitigates the aforementioned biases and results in improved predictions and a higher ELBO.