DCLGMLApr 28, 2020

Automatic Cross-Replica Sharding of Weight Update in Data-Parallel Training

arXiv:2004.13336v152 citations
AI Analysis

This addresses scalability and performance issues in large-scale deep learning training, particularly for models with large weights or small batch sizes, though it is incremental as it builds on existing data-parallel methods.

The paper tackles the bottleneck of redundant weight update computation in data-parallel synchronous training by automatically sharding it across replicas, achieving substantial speedups on typical models without code changes and helping set state-of-the-art training performance in MLPerf 0.6.

In data-parallel synchronous training of deep neural networks, different devices (replicas) run the same program with different partitions of the training batch, but weight update computation is repeated on all replicas, because the weights do not have a batch dimension to partition. This can be a bottleneck for performance and scalability in typical language models with large weights, and models with small per-replica batch size which is typical in large-scale training. This paper presents an approach to automatically shard the weight update computation across replicas with efficient communication primitives and data formatting, using static analysis and transformations on the training computation graph. We show this technique achieves substantial speedups on typical image and language models on Cloud TPUs, requiring no change to model code. This technique helps close the gap between traditionally expensive (ADAM) and cheap (SGD) optimizers, as they will only take a small part of training step time and have similar peak memory usage. It helped us to achieve state-of-the-art training performance in Google's MLPerf 0.6 submission.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes