LGDec 5, 2023

Multitask Learning Can Improve Worst-Group Outcomes

CMU
arXiv:2312.03151v22 citationsh-index: 51Has CodeTrans. Mach. Learn. Res.
AI Analysis

This addresses fairness issues in ML for diverse user groups, offering a method that improves over existing robust optimization techniques, though it is incremental as it builds on prior multitask learning and fine-tuning frameworks.

The paper tackles the problem of improving worst-group accuracy in machine learning to ensure equitable outcomes, finding that a regularized multitask learning approach consistently outperforms a distributionally robust optimization method on both average and worst-group performance across vision and NLP datasets.

In order to create machine learning systems that serve a variety of users well, it is vital to not only achieve high average performance but also ensure equitable outcomes across diverse groups. However, most machine learning methods are designed to improve a model's average performance on a chosen end task without consideration for their impact on worst group error. Multitask learning (MTL) is one such widely used technique. In this paper, we seek not only to understand the impact of MTL on worst-group accuracy but also to explore its potential as a tool to address the challenge of group-wise fairness. We primarily consider the standard setting of fine-tuning a pre-trained model, where, following recent work \citep{gururangan2020don, dery2023aang}, we multitask the end task with the pre-training objective constructed from the end task data itself. In settings with few or no group annotations, we find that multitasking often, but not consistently, achieves better worst-group accuracy than Just-Train-Twice (JTT; \citet{pmlr-v139-liu21f}) -- a representative distributionally robust optimization (DRO) method. Leveraging insights from synthetic data experiments, we propose to modify standard MTL by regularizing the joint multitask representation space. We run a large number of fine-tuning experiments across computer vision and natural language processing datasets and find that our regularized MTL approach \emph{consistently} outperforms JTT on both average and worst-group outcomes. Our official code can be found here: \href{https://github.com/atharvajk98/MTL-group-robustness.git}{\url{https://github.com/atharvajk98/MTL-group-robustness}}.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes