MALI: A memory efficient and reverse accurate integrator for Neural ODEs
This work addresses the critical problem of memory efficiency and gradient accuracy in Neural ODEs, which previously hindered their application in large-scale tasks like ImageNet, for researchers and practitioners using continuous depth models.
Neural ODEs suffer from inaccurate gradient estimation or high memory cost. This paper proposes MALI, a new integrator that achieves constant memory cost and accurate reverse-time trajectory. MALI enables feasible training of Neural ODEs on ImageNet, outperforming ResNet, and achieves state-of-the-art performance in time series modeling and continuous generative models.
Neural ordinary differential equations (Neural ODEs) are a new family of deep-learning models with continuous depth. However, the numerical estimation of the gradient in the continuous case is not well solved: existing implementations of the adjoint method suffer from inaccuracy in reverse-time trajectory, while the naive method and the adaptive checkpoint adjoint method (ACA) have a memory cost that grows with integration time. In this project, based on the asynchronous leapfrog (ALF) solver, we propose the Memory-efficient ALF Integrator (MALI), which has a constant memory cost \textit{w.r.t} number of solver steps in integration similar to the adjoint method, and guarantees accuracy in reverse-time trajectory (hence accuracy in gradient estimation). We validate MALI in various tasks: on image recognition tasks, to our knowledge, MALI is the first to enable feasible training of a Neural ODE on ImageNet and outperform a well-tuned ResNet, while existing methods fail due to either heavy memory burden or inaccuracy; for time series modeling, MALI significantly outperforms the adjoint method; and for continuous generative models, MALI achieves new state-of-the-art performance. We provide a pypi package at \url{https://jzkay12.github.io/TorchDiffEqPack/}