Learning Set Functions with Implicit Differentiation
This work addresses a computational bottleneck for researchers and practitioners in machine learning dealing with large-scale set function optimization, though it is incremental as it builds on prior methods.
The paper tackles the computational challenge of learning set functions from optimal subset oracle data by replacing automatic differentiation with implicit differentiation, which reduces memory usage by up to 90% and speeds up training by 2-5 times on tasks like product recommendation and compound selection.
Ou et al. (2022) introduce the problem of learning set functions from data generated by a so-called optimal subset oracle. Their approach approximates the underlying utility function with an energy-based model, whose parameters are estimated via mean-field variational inference. Ou et al. (2022) show this reduces to fixed point iterations; however, as the number of iterations increases, automatic differentiation quickly becomes computationally prohibitive due to the size of the Jacobians that are stacked during backpropagation. We address this challenge with implicit differentiation and examine the convergence conditions for the fixed-point iterations. We empirically demonstrate the efficiency of our method on synthetic and real-world subset selection applications including product recommendation, set anomaly detection and compound selection tasks.