Learning Gradients of Convex Functions with Monotone Gradient Networks
This work addresses a specific problem in machine learning for signal processing and optimal transport applications, offering an incremental improvement over existing data-driven methods for learning convex functions.
The paper tackled the problem of learning gradients of convex functions, which are important for applications like optimization and optimal transport, by proposing two monotone gradient neural network architectures (C-MGN and M-MGN). The result showed that these networks are easier to train, achieve more accurate monotone gradient fields, and use significantly fewer parameters compared to state-of-the-art methods.
While much effort has been devoted to deriving and analyzing effective convex formulations of signal processing problems, the gradients of convex functions also have critical applications ranging from gradient-based optimization to optimal transport. Recent works have explored data-driven methods for learning convex objective functions, but learning their monotone gradients is seldom studied. In this work, we propose C-MGN and M-MGN, two monotone gradient neural network architectures for directly learning the gradients of convex functions. We show that, compared to state of the art methods, our networks are easier to train, learn monotone gradient fields more accurately, and use significantly fewer parameters. We further demonstrate their ability to learn optimal transport mappings to augment driving image data.