On Distributed Larger-Than-Memory Subset Selection With Pairwise Submodular Functions
This addresses the challenge of subset selection for large-scale machine learning training, enabling cost reduction and model quality enhancement, though it is incremental as it builds on existing submodular optimization methods.
The paper tackles the problem of selecting high-quality subsets from massive datasets when even the subset cannot fit in a single machine's memory, by proposing a distributed algorithm that combines bounding and greedy methods. It achieves marginal or no loss in quality compared to centralized methods on datasets like CIFAR-100 and ImageNet, and scales to 13 billion points.
Modern datasets span billions of samples, making training on all available data infeasible. Selecting a high quality subset helps in reducing training costs and enhancing model quality. Submodularity, a discrete analogue of convexity, is commonly used for solving such subset selection problems. However, existing algorithms for optimizing submodular functions are sequential, and the prior distributed methods require at least one central machine to fit the target subset in DRAM. At billion datapoint scale, even the subset may not fit a single machine, and the sequential algorithms are prohibitively slow. In this paper, we relax the requirement of having a central machine for the target subset by proposing a novel distributed bounding algorithm with provable approximation guarantees. The algorithm iteratively bounds the minimum and maximum utility values to select high quality points and discard the unimportant ones. When bounding does not find the complete subset, we use a multi-round, partition-based distributed greedy algorithm to identify the remaining subset. We discuss how to implement these algorithms in a distributed data processing framework and empirically analyze different configurations. We find high quality subsets on CIFAR-100 and ImageNet with marginal or no loss in quality compared to centralized methods, and scale to a dataset with 13 billion points.