Improving Robustness by Enhancing Weak Subnets
This work addresses robustness issues in deep learning models, which is a critical problem for AI safety and reliability, though it appears incremental as it builds on existing training methods.
The paper tackles the problem of deep networks being susceptible to perturbations by identifying and enhancing weak internal sub-networks (subnets) that correlate with poor robustness. The proposed EWS training procedure improves robustness against corrupted images and accuracy on clean data, and it enhances performance when combined with data augmentation and adversarial training methods.
Despite their success, deep networks have been shown to be highly susceptible to perturbations, often causing significant drops in accuracy. In this paper, we investigate model robustness on perturbed inputs by studying the performance of internal sub-networks (subnets). Interestingly, we observe that most subnets show particularly poor robustness against perturbations. More importantly, these weak subnets are correlated with the overall lack of robustness. Tackling this phenomenon, we propose a new training procedure that identifies and enhances weak subnets (EWS) to improve robustness. Specifically, we develop a search algorithm to find particularly weak subnets and explicitly strengthen them via knowledge distillation from the full network. We show that EWS greatly improves both robustness against corrupted images as well as accuracy on clean data. Being complementary to popular data augmentation methods, EWS consistently improves robustness when combined with these approaches. To highlight the flexibility of our approach, we combine EWS also with popular adversarial training methods resulting in improved adversarial robustness.