LGAIOCMLDec 8, 2023

TaskMet: Task-Driven Metric Learning for Model Learning

arXiv:2312.05250v217 citationsh-index: 37Has CodeNIPS
Originality Incremental advance
AI Analysis

This addresses the challenge for practitioners deploying models in real-world applications where task-specific performance is critical, though it is incremental as it builds on existing metric learning and task-aware training approaches.

The paper tackles the problem of deep learning models trained for accurate predictions but performing poorly on downstream tasks due to conflicting objectives, by proposing TaskMet, a method that learns a metric in the prediction space to align model learning with task requirements, achieving improved performance in decision-focused and reinforcement learning settings.

Deep learning models are often deployed in downstream tasks that the training procedure may not be aware of. For example, models solely trained to achieve accurate predictions may struggle to perform well on downstream tasks because seemingly small prediction errors may incur drastic task errors. The standard end-to-end learning approach is to make the task loss differentiable or to introduce a differentiable surrogate that the model can be trained on. In these settings, the task loss needs to be carefully balanced with the prediction loss because they may have conflicting objectives. We propose take the task loss signal one level deeper than the parameters of the model and use it to learn the parameters of the loss function the model is trained on, which can be done by learning a metric in the prediction space. This approach does not alter the optimal prediction model itself, but rather changes the model learning to emphasize the information important for the downstream task. This enables us to achieve the best of both worlds: a prediction model trained in the original prediction space while also being valuable for the desired downstream task. We validate our approach through experiments conducted in two main settings: 1) decision-focused model learning scenarios involving portfolio optimization and budget allocation, and 2) reinforcement learning in noisy environments with distracting states. The source code to reproduce our experiments is available at https://github.com/facebookresearch/taskmet

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes