Gradient Correction in Federated Learning with Adaptive Optimization
This work addresses a specific bottleneck in federated learning for distributed systems by enabling effective drift compensation in adaptive optimization, though it is incremental as it builds on existing correction methods.
The paper tackles the problem of data heterogeneity in federated learning with adaptive optimizers like Adam, where naive correction methods degrade performance, and proposes FAdamGC, which integrates drift compensation by aligning with moment structures, resulting in better convergence rates and reduced communication and computation costs across varying heterogeneity levels.
In federated learning (FL), model training performance is strongly impacted by data heterogeneity across clients. Client-drift compensation methods have recently emerged as a solution to this issue, introducing correction terms into local model updates. To date, these methods have only been considered under stochastic gradient descent (SGD)-based model training, while modern FL frameworks also employ adaptive optimizers (e.g., Adam) for improved convergence. However, due to the complex interplay between first and second moments found in most adaptive optimization methods, naively injecting correction terms can lead to performance degradation in heterogeneous settings. In this work, we propose {\tt FAdamGC}, the first algorithm to integrate drift compensation into adaptive federated optimization. The key idea of {\tt FAdamGC} is injecting a pre-estimation correction term that aligns with the moment structure of adaptive methods. We provide a rigorous convergence analysis of our algorithm under non-convex settings, showing that {\tt FAdamGC} results in better rate and milder assumptions than naively porting SGD-based correction algorithms into adaptive optimizers. Our experimental results demonstrate that {\tt FAdamGC} consistently outperform existing methods in total communication and computation cost across varying levels of data heterogeneity, showing the efficacy of correcting gradient information in federated adaptive optimization.