Better Conditional Density Estimation for Neural Networks
This work addresses the problem of modeling full conditional distributions rather than point predictions in machine learning, offering incremental improvements for applications requiring uncertainty quantification.
The paper tackles conditional density estimation for neural networks by introducing Multiscale Nets and CDE Trend Filtering, which outperform baseline methods like multinomial classifiers and mixture density networks on simulated and real-world datasets, with MSNs excelling in high-data regimes and CDE-TF in low-data scenarios.
The vast majority of the neural network literature focuses on predicting point values for a given set of response variables, conditioned on a feature vector. In many cases we need to model the full joint conditional distribution over the response variables rather than simply making point predictions. In this paper, we present two novel approaches to such conditional density estimation (CDE): Multiscale Nets (MSNs) and CDE Trend Filtering. Multiscale nets transform the CDE regression task into a hierarchical classification task by decomposing the density into a series of half-spaces and learning boolean probabilities of each split. CDE Trend Filtering applies a k-th order graph trend filtering penalty to the unnormalized logits of a multinomial classifier network, with each edge in the graph corresponding to a neighboring point on a discretized version of the density. We compare both methods against plain multinomial classifier networks and mixture density networks (MDNs) on a simulated dataset and three real-world datasets. The results suggest the two methods are complementary: MSNs work well in a high-data-per-feature regime and CDE-TF is well suited for few-samples-per-feature scenarios where overfitting is a primary concern.