LGOct 11, 2023
MatFormer: Nested Transformer for Elastic InferenceDevvrit, Sneha Kudugunta, Aditya Kusupati et al. · uw
Foundation models are applied in a broad spectrum of settings with different inference constraints, from massive multi-accelerator clusters to resource-constrained standalone mobile devices. However, the substantial costs associated with training these models often limit the number of unique model sizes that can be offered. Consequently, practitioners are compelled to select a model that may not be optimally aligned with their specific latency and cost requirements. We present MatFormer, a novel Transformer architecture designed to provide elastic inference across diverse deployment constraints. MatFormer achieves this by incorporating a nested Feed Forward Network (FFN) block structure within a standard Transformer model. During training, we optimize the parameters of multiple nested FFN blocks with varying sizes, enabling the extraction of hundreds of accurate smaller models without incurring additional computational costs. We empirically validate the efficacy of MatFormer across different model classes (decoders and encoders) and modalities (language and vision), demonstrating its potential for real-world deployment. We show that a 850M decoder-only MatFormer language model (MatLM) allows us to extract multiple smaller models spanning from 582M to 850M parameters, each exhibiting better validation loss and one-shot downstream evaluations than independently trained counterparts. Furthermore, we observe that smaller encoders extracted from a universal MatFormer-based ViT (MatViT) encoder preserve the metric-space structure for adaptive large-scale retrieval. Finally, we showcase that speculative decoding with the accurate and consistent submodels extracted from MatFormer can lead to significant reduction in inference latency. Project website: https://devvrit.github.io/matformer/
LGMar 20, 2023
Greedy Pruning with Group Lasso Provably Generalizes for Matrix SensingNived Rajaraman, Devvrit, Aryan Mokhtari et al.
Pruning schemes have been widely used in practice to reduce the complexity of trained models with a massive number of parameters. In fact, several practical studies have shown that if a pruned model is fine-tuned with some gradient-based updates it generalizes well to new samples. Although the above pipeline, which we refer to as pruning + fine-tuning, has been extremely successful in lowering the complexity of trained models, there is very little known about the theory behind this success. In this paper, we address this issue by investigating the pruning + fine-tuning framework on the overparameterized matrix sensing problem with the ground truth $U_\star \in \mathbb{R}^{d \times r}$ and the overparameterized model $U \in \mathbb{R}^{d \times k}$ with $k \gg r$. We study the approximate local minima of the mean square error, augmented with a smooth version of a group Lasso regularizer, $\sum_{i=1}^k \| U e_i \|_2$. In particular, we provably show that pruning all the columns below a certain explicit $\ell_2$-norm threshold results in a solution $U_{\text{prune}}$ which has the minimum number of columns $r$, yet close to the ground truth in training loss. Moreover, in the subsequent fine-tuning phase, gradient descent initialized at $U_{\text{prune}}$ converges at a linear rate to its limit. While our analysis provides insights into the role of regularization in pruning, we also show that running gradient descent in the absence of regularization results in models which {are not suitable for greedy pruning}, i.e., many columns could have their $\ell_2$ norm comparable to that of the maximum. To the best of our knowledge, our results provide the first rigorous insights on why greedy pruning + fine-tuning leads to smaller models which also generalize well.
LGNov 28, 2020
Voting based ensemble improves robustness of defensive modelsDevvrit, Minhao Cheng, Cho-Jui Hsieh et al.
Developing robust models against adversarial perturbations has been an active area of research and many algorithms have been proposed to train individual robust models. Taking these pretrained robust models, we aim to study whether it is possible to create an ensemble to further improve robustness. Several previous attempts tackled this problem by ensembling the soft-label prediction and have been proved vulnerable based on the latest attack methods. In this paper, we show that if the robust training loss is diverse enough, a simple hard-label based voting ensemble can boost the robust error over each individual model. Furthermore, given a pool of robust models, we develop a principled way to select which models to ensemble. Finally, to verify the improved robustness, we conduct extensive experiments to study how to attack a voting-based ensemble and develop several new white-box attacks. On CIFAR-10 dataset, by ensembling several state-of-the-art pre-trained defense models, our method can achieve a 59.8% robust accuracy, outperforming all the existing defensive models without using additional data.