LGMLMar 22, 2024

Federated Bayesian Deep Learning: The Application of Statistical Aggregation Methods to Bayesian Models

arXiv:2403.15263v26 citationsh-index: 3IEEE Access
Originality Incremental advance
AI Analysis

This work addresses the problem of enabling uncertainty-aware federated learning for safety-critical applications like remote sensing, though it is incremental as it adapts existing methods to Bayesian models.

The paper tackles the challenge of applying federated learning to Bayesian deep learning models, which are well-calibrated and quantify uncertainty but cannot use standard aggregation methods due to their probabilistic nature. It analyzes six aggregation strategies on CIFAR-10 datasets, showing that the choice of strategy significantly impacts accuracy, calibration, uncertainty quantification, training stability, and compute requirements.

Federated learning (FL) is an approach to training machine learning models that takes advantage of multiple distributed datasets while maintaining data privacy and reducing communication costs associated with sharing local datasets. Aggregation strategies have been developed to pool or fuse the weights and biases of distributed deterministic models; however, modern deterministic deep learning (DL) models are often poorly calibrated and lack the ability to communicate a measure of epistemic uncertainty in prediction, which is desirable for remote sensing platforms and safety-critical applications. Conversely, Bayesian DL models are often well calibrated and capable of quantifying and communicating a measure of epistemic uncertainty along with a competitive prediction accuracy. Unfortunately, because the weights and biases in Bayesian DL models are defined by a probability distribution, simple application of the aggregation methods associated with FL schemes for deterministic models is either impossible or results in sub-optimal performance. In this work, we use independent and identically distributed (IID) and non-IID partitions of the CIFAR-10 dataset and a fully variational ResNet-20 architecture to analyze six different aggregation strategies for Bayesian DL models. Additionally, we analyze the traditional federated averaging approach applied to an approximate Bayesian Monte Carlo dropout model as a lightweight alternative to more complex variational inference methods in FL. We show that aggregation strategy is a key hyperparameter in the design of a Bayesian FL system with downstream effects on accuracy, calibration, uncertainty quantification, training stability, and client compute requirements.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes