COGNATE: Acceleration of Sparse Tensor Programs on Emerging Hardware using Transfer Learning
This work addresses the problem of efficient sparse tensor program optimization for hardware designers and developers, offering a novel transfer learning approach that reduces data requirements, though it is incremental in improving existing cost model methods.
The paper tackles the challenge of optimizing sparse tensor programs for emerging hardware accelerators, where traditional ML-based cost models are ineffective due to sparse input sensitivity and expensive simulators. It introduces COGNATE, a framework that uses transfer learning from general-purpose hardware to train cost models with only 5% of the data needed by accelerator-specific models, achieving average speedups of 1.47x for SpMM and 1.39x for SDDMM.
Sparse tensor programs are essential in deep learning and graph analytics, driving the need for optimized processing. To meet this demand, specialized hardware accelerators are being developed. Optimizing these programs for accelerators is challenging for two reasons: program performance is highly sensitive to variations in sparse inputs, and early-stage accelerators rely on expensive simulators. Therefore, ML-based cost models used for optimizing such programs on general-purpose hardware are often ineffective for early-stage accelerators, as they require large datasets for proper training. To this end, we introduce COGNATE, a novel framework that leverages inexpensive data samples from general-purpose hardware (e.g., CPUs) to train cost models, followed by few-shot fine-tuning on emerging hardware. COGNATE exploits the homogeneity of input features across hardware platforms while effectively mitigating heterogeneity, enabling cost model training with just 5% of the data samples needed by accelerator-specific models to achieve comparable performance. We conduct extensive experiments to demonstrate that COGNATE outperforms existing techniques, achieving average speedups of 1.47x (up to 5.46x) for SpMM and 1.39x (up to 4.22x) for SDDMM.