LGAIFeb 21, 2021

Constrained Optimization to Train Neural Networks on Critical and Under-Represented Classes

arXiv:2102.12894v434 citations
Originality Incremental advance
AI Analysis

This addresses the issue of high false positive rates in imbalanced binary and multi-class classification for clinical applications, where misclassifying critical cases has severe consequences, though it is incremental as it builds on existing loss functions and optimization methods.

The paper tackles the problem of class imbalance in deep neural networks, particularly for critical classes like cancer, by formulating training as a constrained optimization problem to reduce false positive rates at high true positive rates, resulting in improved accuracy on critical classes and reduced misclassification for non-critical classes in experiments on medical imaging, CIFAR10, and CIFAR100 datasets.

Deep neural networks (DNNs) are notorious for making more mistakes for the classes that have substantially fewer samples than the others during training. Such class imbalance is ubiquitous in clinical applications and very crucial to handle because the classes with fewer samples most often correspond to critical cases (e.g., cancer) where misclassifications can have severe consequences. Not to miss such cases, binary classifiers need to be operated at high True Positive Rates (TPRs) by setting a higher threshold, but this comes at the cost of very high False Positive Rates (FPRs) for problems with class imbalance. Existing methods for learning under class imbalance most often do not take this into account. We argue that prediction accuracy should be improved by emphasizing reducing FPRs at high TPRs for problems where misclassification of the positive, i.e. critical, class samples are associated with higher cost. To this end, we pose the training of a DNN for binary classification as a constrained optimization problem and introduce a novel constraint that can be used with existing loss functions to enforce maximal area under the ROC curve (AUC) through prioritizing FPR reduction at high TPR. We solve the resulting constrained optimization problem using an Augmented Lagrangian method (ALM). Going beyond binary, we also propose two possible extensions of the proposed constraint for multi-class classification problems. We present experimental results for image-based binary and multi-class classification applications using an in-house medical imaging dataset, CIFAR10, and CIFAR100. Our results demonstrate that the proposed method improves the baselines in majority of the cases by attaining higher accuracy on critical classes while reducing the misclassification rate for the non-critical class samples.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes