LGPFPLSep 2, 2025

DaCe AD: Unifying High-Performance Automatic Differentiation for Machine Learning and Scientific Computing

arXiv:2509.02197v11 citationsh-index: 13CLUSTER
Originality Highly original
AI Analysis

This work addresses the problem of inefficient gradient computation for domain scientists in machine learning and scientific computing, offering a significant performance improvement.

The paper tackles the limitations of existing automatic differentiation frameworks, such as limited language support and performance issues in scientific computing, by introducing DaCe AD, which outperforms JAX by over 92 times on average in HPC benchmarks without code modifications.

Automatic differentiation (AD) is a set of techniques that systematically applies the chain rule to compute the gradients of functions without requiring human intervention. Although the fundamentals of this technology were established decades ago, it is experiencing a renaissance as it plays a key role in efficiently computing gradients for backpropagation in machine learning algorithms. AD is also crucial for many applications in scientific computing domains, particularly emerging techniques that integrate machine learning models within scientific simulations and schemes. Existing AD frameworks have four main limitations: limited support of programming languages, requiring code modifications for AD compatibility, limited performance on scientific computing codes, and a naive store-all solution for forward-pass data required for gradient calculations. These limitations force domain scientists to manually compute the gradients for large problems. This work presents DaCe AD, a general, efficient automatic differentiation engine that requires no code modifications. DaCe AD uses a novel ILP-based algorithm to optimize the trade-off between storing and recomputing to achieve maximum performance within a given memory constraint. We showcase the generality of our method by applying it to NPBench, a suite of HPC benchmarks with diverse scientific computing patterns, where we outperform JAX, a Python framework with state-of-the-art general AD capabilities, by more than 92 times on average without requiring any code changes.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes