TOAST: Fast and scalable auto-partitioning based on principled static analysis
This addresses the challenge of efficient model partitioning for distributed AI systems, offering a fully automated solution that improves performance and scalability, though it appears incremental as it builds on existing auto-partitioning approaches.
The paper tackles the problem of partitioning large machine learning models across distributed accelerators, which is complex and prone to errors or inefficiencies in existing methods, and proposes a system that combines static compiler analysis with Monte Carlo Tree Search to outperform state-of-the-art industrial methods across diverse hardware and models.
Partitioning large machine learning models across distributed accelerator systems is a complex process, requiring a series of interdependent decisions that are further complicated by internal sharding ambiguities. Consequently, existing auto-partitioners often suffer from out-of-memory errors or are prohibitively slow when exploring the exponentially large space of possible partitionings. To mitigate this, they artificially restrict the search space, but this approach frequently yields infeasible solutions that violate device memory constraints or lead to sub-optimal performance. We propose a system that combines a novel static compiler analysis with a Monte Carlo Tree Search. Our analysis constructs an efficient decision space by identifying (i) tensor dimensions requiring identical sharding, and (ii) partitioning "conflicts" that require resolution. Our system significantly outperforms state-of-the-art industrial methods across diverse hardware platforms and model architectures, discovering previously unknown, superior solutions, and the process is fully automated even for complex and large models.