BALI: Learning Neural Networks via Bayesian Layerwise Inference
This provides a more efficient approach for Bayesian neural networks, which is incremental as it builds on existing Bayesian methods with layerwise optimization.
The paper tackles the problem of learning Bayesian neural networks by introducing a layerwise inference method that treats each layer as a Bayesian linear regression model, achieving performance comparable to or better than leading methods on regression, classification, and out-of-distribution detection benchmarks.
We introduce a new method for learning Bayesian neural networks, treating them as a stack of multivariate Bayesian linear regression models. The main idea is to infer the layerwise posterior exactly if we know the target outputs of each layer. We define these pseudo-targets as the layer outputs from the forward pass, updated by the backpropagated gradients of the objective function. The resulting layerwise posterior is a matrix-normal distribution with a Kronecker-factorized covariance matrix, which can be efficiently inverted. Our method extends to the stochastic mini-batch setting using an exponential moving average over natural-parameter terms, thus gradually forgetting older data. The method converges in few iterations and performs as well as or better than leading Bayesian neural network methods on various regression, classification, and out-of-distribution detection benchmarks.