DrJAX: Scalable and Differentiable MapReduce Primitives in JAX
This library addresses the need for efficient and flexible parallel algorithm development in machine learning, particularly for researchers and engineers working with distributed systems, though it is incremental as it builds on existing JAX capabilities.
The authors tackled the challenge of enabling large-scale distributed and parallel machine learning with MapReduce-style operations by developing DrJAX, a JAX-based library that provides scalable, differentiable primitives, resulting in a framework that integrates with platforms like TPUs and Apache Beam.
We present DrJAX, a JAX-based library designed to support large-scale distributed and parallel machine learning algorithms that use MapReduce-style operations. DrJAX leverages JAX's sharding mechanisms to enable native targeting of TPUs and state-of-the-art JAX runtimes, including Pathways. DrJAX embeds building blocks for MapReduce computations as primitives in JAX. This enables three key benefits. First, DrJAX computations can be translated directly to XLA HLO, enabling flexible integration with a wide array of ML training platforms. Second, DrJAX computations are fully differentiable. Last, DrJAX computations can be interpreted out to existing batch-processing compute systems, including traditional MapReduce systems like Apache Beam and cross-device compute systems like those powering federated learning applications. We show that DrJAX provides an easily programmable, performant, and scalable framework for parallelized algorithm development. DrJAX is available at \url{https://github.com/google-research/google-research/tree/master/drjax}.