Efficient Training of Neural SDEs Using Stochastic Optimal Control
This work addresses efficiency issues in training neural SDEs for uncertainty-aware time-series analysis, representing an incremental improvement in method optimization.
The paper tackles the computational challenge of variational inference for neural stochastic differential equations (SDEs) by proposing a hierarchical method that decomposes the control term into linear and non-linear components, resulting in faster convergence and lower initialization cost.
We present a hierarchical, control theory inspired method for variational inference (VI) for neural stochastic differential equations (SDEs). While VI for neural SDEs is a promising avenue for uncertainty-aware reasoning in time-series, it is computationally challenging due to the iterative nature of maximizing the ELBO. In this work, we propose to decompose the control term into linear and residual non-linear components and derive an optimal control term for linear SDEs, using stochastic optimal control. Modeling the non-linear component by a neural network, we show how to efficiently train neural SDEs without sacrificing their expressive power. Since the linear part of the control term is optimal and does not need to be learned, the training is initialized at a lower cost and we observe faster convergence.