Implicit Variational Inference for High-Dimensional Posteriors
This addresses a crucial bottleneck in Bayesian deep learning for practitioners needing reliable uncertainty estimates in large-scale models, representing a significant advance rather than an incremental improvement.
The paper tackles the problem of accurately approximating complex, high-dimensional posterior distributions in Bayesian models by proposing a neural sampler with implicit distributions, achieving the first demonstration of recovering correlations across layers in large Bayesian neural networks with tens of millions of latent variables and outperforming state-of-the-art uncertainty quantification methods.
In variational inference, the benefits of Bayesian models rely on accurately capturing the true posterior distribution. We propose using neural samplers that specify implicit distributions, which are well-suited for approximating complex multimodal and correlated posteriors in high-dimensional spaces. Our approach introduces novel bounds for approximate inference using implicit distributions by locally linearising the neural sampler. This is distinct from existing methods that rely on additional discriminator networks and unstable adversarial objectives. Furthermore, we present a new sampler architecture that, for the first time, enables implicit distributions over tens of millions of latent variables, addressing computational concerns by using differentiable numerical approximations. We empirically show that our method is capable of recovering correlations across layers in large Bayesian neural networks, a property that is crucial for a network's performance but notoriously challenging to achieve. To the best of our knowledge, no other method has been shown to accomplish this task for such large models. Through experiments in downstream tasks, we demonstrate that our expressive posteriors outperform state-of-the-art uncertainty quantification methods, validating the effectiveness of our training algorithm and the quality of the learned implicit approximation.