Generalized Federated Learning via Sharpness Aware Minimization
This work addresses data heterogeneity issues in Federated Learning, offering a novel approach to improve model generalization, though it is incremental as it builds on existing Sharpness Aware Minimization techniques.
The paper tackles the problem of non-IID data distribution in Federated Learning, which causes optimization difficulties, by proposing FedSAM and MoFedSAM algorithms that use Sharpness Aware Minimization as a local optimizer, resulting in substantial performance improvements and reduced learning deviation compared to existing methods.
Federated Learning (FL) is a promising framework for performing privacy-preserving, distributed learning with a set of clients. However, the data distribution among clients often exhibits non-IID, i.e., distribution shift, which makes efficient optimization difficult. To tackle this problem, many FL algorithms focus on mitigating the effects of data heterogeneity across clients by increasing the performance of the global model. However, almost all algorithms leverage Empirical Risk Minimization (ERM) to be the local optimizer, which is easy to make the global model fall into a sharp valley and increase a large deviation of parts of local clients. Therefore, in this paper, we revisit the solutions to the distribution shift problem in FL with a focus on local learning generality. To this end, we propose a general, effective algorithm, \texttt{FedSAM}, based on Sharpness Aware Minimization (SAM) local optimizer, and develop a momentum FL algorithm to bridge local and global models, \texttt{MoFedSAM}. Theoretically, we show the convergence analysis of these two algorithms and demonstrate the generalization bound of \texttt{FedSAM}. Empirically, our proposed algorithms substantially outperform existing FL studies and significantly decrease the learning deviation.