Jigsaw: Training Multi-Billion-Parameter AI Weather Models with Optimized Model Parallelism
This work addresses the computational bottlenecks in training multi-billion-parameter AI weather models, enabling more efficient high-resolution forecasting for meteorology and climate science.
The paper tackles the challenge of training large AI weather models by introducing Jigsaw, a model parallelization scheme that combines domain and tensor parallelism to eliminate memory redundancy, achieving up to 72% scaling efficiency and peak performances of 9-11 PFLOPs on 256 GPUs.
AI-based methods have revolutionized atmospheric forecasting, with recent successes in medium-range forecasting spurring the development of climate foundation models. Accurate modeling of complex atmospheric dynamics at high spatial resolutions and longer lead times requires large neural networks and gigabyte-sized data samples, making accelerator memory and I/O-bandwidth the bottlenecks for model training. We introduce WeatherMixer, a multi-layer-perceptron-based architecture whose workload scales linearly with input size, allowing the model to learn global weather phenomena at accuracies similar to numerical weather prediction. To cope with the computational demand, we propose Jigsaw, a novel model parallelization scheme that employs both domain and tensor parallelism, eliminating memory redundancy. Jigsaw exceeds state-of-the-art performance in strong scaling in compute-communication-limited systems and achieves superscalar weak scaling in I/O-bandwidth-limited systems. We scale training to 256 GPUs, reaching peak performances of 9 and 11 PFLOPs, 23% and 28% of theoretical peaks, achieving 68% and 72% scaling efficiency versus 51% without model parallelism.