LGOCMLDec 30, 2024

Adaptive Batch Size Schedules for Distributed Training of Language Models with Data and Model Parallelism

arXiv:2412.21124v24 citationsh-index: 8CPAL
Originality Incremental advance
AI Analysis

This addresses the efficiency-generalization trade-off in large-scale language model training, offering a practical solution for researchers and practitioners using data and model parallelism, though it is incremental as it builds on existing parallelism strategies.

The paper tackles the challenge of designing adaptive batch size schedules for distributed training of language models, proposing a theoretically principled approach that outperforms constant and heuristic warmup schedules in pretraining models like Llama 2, with up to 3 billion parameters, achieving improved performance.

An appropriate choice of batch sizes in large-scale model training is crucial, yet it involves an intrinsic yet inevitable dilemma: large-batch training improves training efficiency in terms of memory utilization, while generalization performance often deteriorates due to small amounts of gradient noise. Despite this dilemma, the common practice of choosing batch sizes in language model training often prioritizes training efficiency -- employing either constant large sizes with data parallelism or implementing batch size warmup schedules. However, such batch size schedule designs remain heuristic and often fail to adapt to training dynamics, presenting the challenge of designing adaptive batch size schedules. Given the abundance of available datasets and the data-hungry nature of language models, data parallelism has become an indispensable distributed training paradigm, enabling the use of larger batch sizes for gradient computation. However, vanilla data parallelism requires replicas of model parameters, gradients, and optimizer states at each worker, which prohibits training larger models with billions of parameters. To optimize memory usage, more advanced parallelism strategies must be employed. In this work, we propose general-purpose and theoretically principled adaptive batch size schedules compatible with data parallelism and model parallelism. We develop a practical implementation with PyTorch Fully Sharded Data Parallel, facilitating the pretraining of language models of different sizes. We empirically demonstrate that our proposed approaches outperform constant batch sizes and heuristic batch size warmup schedules in the pretraining of models in the Llama 2 family, with particular focus on smaller models with up to 3 billion parameters. We also establish theoretical convergence guarantees for such adaptive batch size schedules with Adam for general smooth nonconvex objectives.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes