ALX: Large Scale Matrix Factorization on TPUs
This provides a scalable solution for researchers and practitioners working on large matrix factorization problems, though it is incremental as it adapts existing methods to TPU hardware.
The authors tackled large-scale matrix factorization by developing ALX, a JAX-based library for distributed Alternating Least Squares on TPUs, achieving training of a 365M-node dataset in about 20 minutes per epoch with 256 TPU cores.
We present ALX, an open-source library for distributed matrix factorization using Alternating Least Squares, written in JAX. Our design allows for efficient use of the TPU architecture and scales well to matrix factorization problems of O(B) rows/columns by scaling the number of available TPU cores. In order to spur future research on large scale matrix factorization methods and to illustrate the scalability properties of our own implementation, we also built a real world web link prediction dataset called WebGraph. This dataset can be easily modeled as a matrix factorization problem. We created several variants of this dataset based on locality and sparsity properties of sub-graphs. The largest variant of WebGraph has around 365M nodes and training a single epoch finishes in about 20 minutes with 256 TPU cores. We include speed and performance numbers of ALX on all variants of WebGraph. Both the framework code and the dataset is open-sourced.