Efficient, Accurate and Stable Gradients for Neural ODEs
This work addresses efficiency bottlenecks for researchers and practitioners using Neural ODEs in machine learning, representing a strong incremental improvement over existing methods.
The paper tackles the problem of high computational and memory costs in training Neural ODEs by introducing algebraically reversible ODE solvers, which reduce time and memory usage compared to recursive checkpointing while maintaining exact gradients, high-order accuracy, and numerical stability.
Training Neural ODEs requires backpropagating through an ODE solve. The state-of-the-art backpropagation method is recursive checkpointing that balances recomputation with memory cost. Here, we introduce a class of algebraically reversible ODE solvers that significantly improve upon both the time and memory cost of recursive checkpointing. The reversible solvers presented calculate exact gradients, are high-order and numerically stable -- strictly improving on previous reversible architectures.