Fed-Focal Loss for imbalanced data classification in Federated Learning
This addresses class imbalance in Federated Learning, which affects training on distributed devices, though it appears incremental as it adapts focal loss to this setting.
The paper tackles class imbalance in Federated Learning by introducing Fed-Focal Loss, which reshapes cross-entropy loss to down-weight well-classified examples, combined with a tunable sampling framework. It demonstrates consistently superior performance across benchmarks, achieving over 9% absolute improvement on unbalanced MNIST.
The Federated Learning setting has a central server coordinating the training of a model on a network of devices. One of the challenges is variable training performance when the dataset has a class imbalance. In this paper, we address this by introducing a new loss function called Fed-Focal Loss. We propose to address the class imbalance by reshaping cross-entropy loss such that it down-weights the loss assigned to well-classified examples along the lines of focal loss. Additionally, by leveraging a tunable sampling framework, we take into account selective client model contributions on the central server to further focus the detector during training and hence improve its robustness. Using a detailed experimental analysis with the VIRTUAL (Variational Federated Multi-Task Learning) approach, we demonstrate consistently superior performance in both the balanced and unbalanced scenarios for MNIST, FEMNIST, VSN and HAR benchmarks. We obtain a more than 9% (absolute percentage) improvement in the unbalanced MNIST benchmark. We further show that our technique can be adopted across multiple Federated Learning algorithms to get improvements.