Scaling Deep Learning Training with MPMD Pipeline Parallelism
This addresses scaling challenges for researchers and engineers training large deep learning models, representing an incremental improvement in pipeline parallelism systems.
The authors tackled the problem of scaling deep learning training by developing JaxPP, a system using MPMD pipeline parallelism that improves hardware utilization by up to 1.11× compared to SPMD configurations.
We present JaxPP, a system for efficiently scaling the training of large deep learning models with flexible pipeline parallelism. We introduce a seamless programming model that allows implementing user-defined pipeline schedules for gradient accumulation. JaxPP automatically distributes tasks, corresponding to pipeline stages, over a cluster of nodes and automatically infers the communication among them. We implement a MPMD runtime for asynchronous execution of SPMD tasks. The pipeline parallelism implementation of JaxPP improves hardware utilization by up to $1.11\times$ with respect to the best performing SPMD configuration.