Over-the-Air Computation Aided Federated Learning with the Aggregation of Normalized Gradient
This work addresses communication efficiency in federated learning for mobile devices, but it is incremental as it builds on existing over-the-air computation methods by refining gradient handling.
The paper tackles the problem of inefficient gradient amplification in over-the-air federated learning by normalizing local gradients before transmission, proving convergence to stationary points at sub-linear rates for smooth loss functions and achieving minimal training loss at linear rates for smooth and strongly convex cases, with experimental results showing improved convergence performance over benchmarks.
Over-the-air computation is a communication-efficient solution for federated learning (FL). In such a system, iterative procedure is performed: Local gradient of private loss function is updated, amplified and then transmitted by every mobile device; the server receives the aggregated gradient all-at-once, generates and then broadcasts updated model parameters to every mobile device. In terms of amplification factor selection, most related works suppose the local gradient's maximal norm always happens although it actually fluctuates over iterations, which may degrade convergence performance. To circumvent this problem, we propose to turn local gradient to be normalized one before amplifying it. Under our proposed method, when the loss function is smooth, we prove our proposed method can converge to stationary point at sub-linear rate. In case of smooth and strongly convex loss function, we prove our proposed method can achieve minimal training loss at linear rate with any small positive tolerance. Moreover, a tradeoff between convergence rate and the tolerance is discovered. To speedup convergence, problems optimizing system parameters are also formulated for above two cases. Although being non-convex, optimal solution with polynomial complexity of the formulated problems are derived. Experimental results show our proposed method can outperform benchmark methods on convergence performance.