GraB-sampler: Optimal Permutation-based SGD Data Sampler for PyTorch
This provides an accessible tool for machine learning practitioners to use an optimal data sampling method, though it is incremental as it focuses on implementation rather than new theory.
The authors tackled the lack of an efficient implementation for the Gradient Balancing (GraB) algorithm, which is theoretically optimal for SGD data sampling, by developing GraB-sampler, a Python library that achieves comparable training loss and test accuracy with only 8.7% training time overhead and 0.85% peak GPU memory usage overhead.
The online Gradient Balancing (GraB) algorithm greedily choosing the examples ordering by solving the herding problem using per-sample gradients is proved to be the theoretically optimal solution that guarantees to outperform Random Reshuffling. However, there is currently no efficient implementation of GraB for the community to easily use it. This work presents an efficient Python library, $\textit{GraB-sampler}$, that allows the community to easily use GraB algorithms and proposes 5 variants of the GraB algorithm. The best performance result of the GraB-sampler reproduces the training loss and test accuracy results while only in the cost of 8.7% training time overhead and 0.85% peak GPU memory usage overhead.