SOLAR: Switchable Output Layer for Accuracy and Robustness in Once-for-All Training
This addresses the challenge of maintaining high accuracy and robustness in flexible neural network deployment for machine learning practitioners, though it is incremental as it builds on existing OFA frameworks.
The paper tackled the problem of degraded calibration and performance in Once-for-All (OFA) training due to excessive parameter sharing by proposing SOLAR, a switchable output layer that assigns separate classification heads to each sub-net, resulting in accuracy improvements of up to 4.71% and robustness gains of up to 9.01% across multiple datasets and backbones.
Once-for-All (OFA) training enables a single super-net to generate multiple sub-nets tailored to diverse deployment scenarios, supporting flexible trade-offs among accuracy, robustness, and model-size without retraining. However, as the number of supported sub-nets increases, excessive parameter sharing in the backbone limits representational capacity, leading to degraded calibration and reduced overall performance. To address this, we propose SOLAR (Switchable Output Layer for Accuracy and Robustness in Once-for-All Training), a simple yet effective technique that assigns each sub-net a separate classification head. By decoupling the logit learning process across sub-nets, the Switchable Output Layer (SOL) reduces representational interference and improves optimization, without altering the shared backbone. We evaluate SOLAR on five datasets (SVHN, CIFAR-10, STL-10, CIFAR-100, and TinyImageNet) using four super-net backbones (ResNet-34, WideResNet-16-8, WideResNet-40-2, and MobileNetV2) for two OFA training frameworks (OATS and SNNs). Experiments show that SOLAR outperforms the baseline methods: compared to OATS, it improves accuracy of sub-nets up to 1.26 %, 4.71 %, 1.67 %, and 1.76 %, and robustness up to 9.01 %, 7.71 %, 2.72 %, and 1.26 % on SVHN, CIFAR-10, STL-10, and CIFAR-100, respectively. Compared to SNNs, it improves TinyImageNet accuracy by up to 2.93 %, 2.34 %, and 1.35 % using ResNet-34, WideResNet-16-8, and MobileNetV2 backbones (with 8 sub-nets), respectively.