Niket Kumar

2papers

2 Papers

89.5DCMay 21Code
Orbax: Distributed Checkpointing with JAX

Colin Gaffney, Shutong Li, Daniel Ng et al.

In a landscape of high-performance distributed ML systems, JAX has emerged as a framework of choice. However, JAX's modular design philosophy leaves it without a standardized checkpointing solution. In this paper, we introduce Orbax, a modular, JAX-native checkpointing library that abstracts the complexities of distributed accelerator systems while also providing flexibility for user-friendly checkpoint manipulations throughout the ML model lifecycle. We demonstrate performance exceeding comparable PyTorch competitors by up to 3.5$\times$ for saving and 2$\times$ for loading. The library is available at https://github.com/google/orbax.

54.2SEMar 28
A Multi-agent AI System for Deep Learning Model Migration from TensorFlow to JAX

Stoyan Nikolov, Bernhard Konrad, Moritz Gronbach et al.

The rapid development of AI-based products and their underlying models has led to constant innovation in deep learning frameworks. Google has been pioneering machine learning usage across dozens of products. Maintaining the multitude of model source codes in different ML frameworks and versions is a significant challenge. So far the maintenance and migration work was done largely manually by human experts. We describe an AI-based multi-agent system that we built to support automatic migration of TensorFlow-based deep learning models into JAX-based ones. We make three main contributions: First, we show how an AI planner that uses a mix of static analysis with AI instructions can create migration plans for very complex code components that are reliably followed by the combination of an orchestrator and coders, using AI-generated example-based playbooks. Second, we define quality metrics and AI-based judges that accelerate development when the code to evaluate has no tests and has to adhere to strict style and dependency requirements. Third, we demonstrate how the system accelerates code migrations in a large hyperscaler environment on commercial real-world use-cases. Our approach dramatically reduces the time (6.4x-8x speedup) for deep learning model migrations and creates a virtuous circle where effectively AI supports its own development workflow. We expect that the techniques and approaches described here can be generalized for other framework migrations and general code transformation tasks.