Scalify: scale propagation for efficient low-precision LLM training
This work addresses the efficiency problem for ML practitioners by providing a more robust and generalizable method for low-precision training, though it appears incremental as it builds on existing tensor scaling techniques.
The authors tackled the challenge of efficiently training large language models with low-precision formats like float8 by introducing Scalify, a scale propagation paradigm that generalizes existing tensor scaling methods, resulting in out-of-the-box support for float8 matrix multiplication and gradients with float16 optimizer storage.
Low-precision formats such as float8 have been introduced in machine learning accelerated hardware to improve computational efficiency for large language models training and inference. Nevertheless, adoption by the ML community has been slowed down by the complex, and sometimes brittle, techniques required to match higher precision training accuracy. In this work, we present Scalify, a end-to-end scale propagation paradigm for computational graphs, generalizing and formalizing existing tensor scaling methods. Experiment results show that Scalify supports out-of-the-box float8 matrix multiplication and gradients representation, as well as float16 optimizer state storage. Our JAX implementation of Scalify is open-sourced at https://github.com/graphcore-research/jax-scalify