LGSep 23, 2022

Jensen-Shannon Divergence Based Novel Loss Functions for Bayesian Neural Networks

arXiv:2209.11366v426 citationsh-index: 4
AI Analysis

This work addresses optimization and uncertainty quantification issues in BNNs for machine learning practitioners, offering incremental improvements over existing KL divergence-based methods.

The paper tackles the instability and approximation challenges in Bayesian neural networks (BNNs) by proposing novel loss functions based on modified Jensen-Shannon (JS) divergence, resulting in improvements such as 5-8% accuracy gains on noisy datasets and a 13% reduction in false negatives on a biased dataset.

Bayesian neural networks (BNNs) are state-of-the-art machine learning methods that can naturally regularize and systematically quantify uncertainties using their stochastic parameters. Kullback-Leibler (KL) divergence-based variational inference used in BNNs suffers from unstable optimization and challenges in approximating light-tailed posteriors due to the unbounded nature of the KL divergence. To resolve these issues, we formulate a novel loss function for BNNs based on a new modification to the generalized Jensen-Shannon (JS) divergence, which is bounded. In addition, we propose a Geometric JS divergence-based loss, which is computationally efficient since it can be evaluated analytically. We found that the JS divergence-based variational inference is intractable, and hence employed a constrained optimization framework to formulate these losses. Our theoretical analysis and empirical experiments on multiple regression and classification data sets suggest that the proposed losses perform better than the KL divergence-based loss, especially when the data sets are noisy or biased. Specifically, there are approximately 5% and 8% improvements in accuracy for a noise-added CIFAR-10 dataset and a regression dataset, respectively. There is about a 13% reduction in false negative predictions of a biased histopathology dataset. In addition, we quantify and compare the uncertainty metrics for the regression and classification tasks.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes