Cyclical Weight Consolidation: Towards Solving Catastrophic Forgetting in Serial Federated Learning
It addresses efficiency and stability issues in serial federated learning for applications like medical collaborations, though it is incremental as it builds on existing serial FL methods.
The paper tackles performance fluctuations and lower convergence in serial federated learning due to catastrophic forgetting, introducing cyclical weight consolidation (CWC) which mitigates fluctuations and enhances performance to be comparable or better than parallel methods in non-IID settings.
Federated Learning (FL) has gained attention for addressing data scarcity and privacy concerns. While parallel FL algorithms like FedAvg exhibit remarkable performance, they face challenges in scenarios with diverse network speeds and concerns about centralized control, especially in multi-institutional collaborations like the medical domain. Serial FL presents an alternative solution, circumventing these challenges by transferring model updates serially between devices in a cyclical manner. Nevertheless, it is deemed inferior to parallel FL in that (1) its performance shows undesirable fluctuations, and (2) it converges to a lower plateau, particularly when dealing with non-IID data. The observed phenomenon is attributed to catastrophic forgetting due to knowledge loss from previous sites. In this paper, to overcome fluctuation and low efficiency in the iterative learning and forgetting process, we introduce cyclical weight consolidation (CWC), a straightforward yet potent approach specifically tailored for serial FL. CWC employs a consolidation matrix to regulate local optimization. This matrix tracks the significance of each parameter on the overall federation throughout the entire training trajectory, preventing abrupt changes in significant weights. During revisitation, to maintain adaptability, old memory undergoes decay to incorporate new information. Our comprehensive evaluations demonstrate that in various non-IID settings, CWC mitigates the fluctuation behavior of the original serial FL approach and enhances the converged performance consistently and significantly. The improved performance is either comparable to or better than the parallel vanilla.