veScale-FSDP: Flexible and High-Performance FSDP at Scale
This addresses performance and scalability bottlenecks for large-scale model training, particularly for cutting-edge models like Gemini and Kimi K2, though it is incremental as it builds on existing FSDP frameworks.
The paper tackles the limitations of current Fully Sharded Data Parallel (FSDP) systems, which struggle with structure-aware training methods and non-element-wise optimizers, by introducing veScale-FSDP, a redesigned system that achieves 5-66% higher throughput and 16-30% lower memory usage while scaling to tens of thousands of GPUs.
Fully Sharded Data Parallel (FSDP), also known as ZeRO, is widely used for training large-scale models, featuring its flexibility and minimal intrusion on model code. However, current FSDP systems struggle with structure-aware training methods (e.g., block-wise quantized training) and with non-element-wise optimizers (e.g., Shampoo and Muon) used in cutting-edge models (e.g., Gemini, Kimi K2). FSDP's fixed element- or row-wise sharding formats conflict with the block-structured computations. In addition, today's implementations fall short in communication and memory efficiency, limiting scaling to tens of thousands of GPUs. We introduce veScale-FSDP, a redesigned FSDP system that couples a flexible sharding format, RaggedShard, with a structure-aware planning algorithm to deliver both flexibility and performance at scale. veScale-FSDP natively supports efficient data placement required by FSDP, empowering block-wise quantization and non-element-wise optimizers. As a result, veScale-FSDP achieves 5~66% higher throughput and 16~30% lower memory usage than existing FSDP systems, while scaling efficiently to tens of thousands of GPUs.