Accurate and Reliable Forecasting using Stochastic Differential Equations
This work addresses the problem of unreliable uncertainty estimates in regression tasks for practitioners needing accurate forecasts, though it appears incremental as it builds on existing heteroscedastic neural networks.
The paper tackles the challenge of uncertainty quantification in deep learning by developing SDE-HNN, a heteroscedastic neural network using stochastic differential equations to model interactions between predictive mean and variance, resulting in significant improvements in predictive performance and uncertainty calibration on challenging datasets.
It is critical yet challenging for deep learning models to properly characterize uncertainty that is pervasive in real-world environments. Although a lot of efforts have been made, such as heteroscedastic neural networks (HNNs), little work has demonstrated satisfactory practicability due to the different levels of compromise on learning efficiency, quality of uncertainty estimates, and predictive performance. Moreover, existing HNNs typically fail to construct an explicit interaction between the prediction and its associated uncertainty. This paper aims to remedy these issues by developing SDE-HNN, a new heteroscedastic neural network equipped with stochastic differential equations (SDE) to characterize the interaction between the predictive mean and variance of HNNs for accurate and reliable regression. Theoretically, we show the existence and uniqueness of the solution to the devised neural SDE. Moreover, based on the bias-variance trade-off for the optimization in SDE-HNN, we design an enhanced numerical SDE solver to improve the learning stability. Finally, to more systematically evaluate the predictive uncertainty, we present two new diagnostic uncertainty metrics. Experiments on the challenging datasets show that our method significantly outperforms the state-of-the-art baselines in terms of both predictive performance and uncertainty quantification, delivering well-calibrated and sharp prediction intervals.