SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile
This addresses the problem of high engineering effort and performance bottlenecks in distributed training for AI researchers and practitioners, though it is incremental as it builds on existing FSDP and compiler techniques.
The paper tackles the complexity and resource demands of distributed training for large models by introducing SimpleFSDP, a PyTorch-native compiler-based FSDP framework, which achieves up to 28.54% memory reduction and 68.67% throughput improvement compared to FSDP2.
Distributed training of large models consumes enormous computation resources and requires substantial engineering efforts to compose various training techniques. This paper presents SimpleFSDP, a PyTorch-native compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations. SimpleFSDP's novelty lies in its unique $torch.compile$-friendly implementation of collective communications using existing PyTorch primitives, namely parametrizations, selective activation checkpointing, and DTensor. It also features the first-of-its-kind intermediate representation (IR) nodes bucketing and reordering in the TorchInductor backend for effective computation-communication overlapping. As a result, users can employ the aforementioned optimizations to automatically or manually wrap model components for minimal communication exposure. Extensive evaluations of SimpleFSDP on Llama 3 models (including the ultra-large 405B) using TorchTitan demonstrate up to 28.54% memory reduction and 68.67% throughput improvement compared to the most widely adopted FSDP2 eager framework, when composed with other distributed training techniques.