10 Papers

MLJul 20, 2023
Cluster-aware Semi-supervised Learning: Relational Knowledge Distillation Provably Learns Clustering

Yijun Dong, Kevin Miller, Qi Lei et al.

Despite the empirical success and practical significance of (relational) knowledge distillation that matches (the relations of) features between teacher and student models, the corresponding theoretical interpretations remain limited for various knowledge distillation paradigms. In this work, we take an initial step toward a theoretical understanding of relational knowledge distillation (RKD), with a focus on semi-supervised classification problems. We start by casting RKD as spectral clustering on a population-induced graph unveiled by a teacher model. Via a notion of clustering error that quantifies the discrepancy between the predicted and ground truth clusterings, we illustrate that RKD over the population provably leads to low clustering error. Moreover, we provide a sample complexity bound for RKD with limited unlabeled samples. For semi-supervised learning, we further demonstrate the label efficiency of RKD through a general framework of cluster-aware semi-supervised learning that assumes low clustering errors. Finally, by unifying data augmentation consistency regularization into this cluster-aware framework, we show that despite the common effect of learning accurate clusterings, RKD facilitates a "global" perspective through spectral clustering, whereas consistency regularization focuses on a "local" perspective via expansion.

LGJul 8, 2024
Sketchy Moment Matching: Toward Fast and Provable Data Selection for Finetuning

Yijun Dong, Hoang Phan, Xiang Pan et al.

We revisit data selection in a modern context of finetuning from a fundamental perspective. Extending the classical wisdom of variance minimization in low dimensions to high-dimensional finetuning, our generalization analysis unveils the importance of additionally reducing bias induced by low-rank approximation. Inspired by the variance-bias tradeoff in high dimensions from the theory, we introduce Sketchy Moment Matching (SkMM), a scalable data selection scheme with two stages. (i) First, the bias is controlled using gradient sketching that explores the finetuning parameter space for an informative low-dimensional subspace $\mathcal{S}$; (ii) then the variance is reduced over $\mathcal{S}$ via moment matching between the original and selected datasets. Theoretically, we show that gradient sketching is fast and provably accurate: selecting $n$ samples by reducing variance over $\mathcal{S}$ preserves the fast-rate generalization $O(\dim(\mathcal{S})/n)$, independent of the parameter dimension. Empirically, we concretize the variance-bias balance via synthetic experiments and demonstrate the effectiveness of SkMM for finetuning in real vision tasks.

AIJul 26, 2024
Greedy Output Approximation: Towards Efficient Structured Pruning for LLMs Without Retraining

Jianwei Li, Yijun Dong, Qi Lei

To remove redundant components of large language models (LLMs) without incurring significant computational costs, this work focuses on single-shot pruning without a retraining phase. We simplify the pruning process for Transformer-based LLMs by identifying a depth-2 pruning structure that functions independently. Additionally, we propose two inference-aware pruning criteria derived from the optimization perspective of output approximation, which outperforms traditional training-aware metrics such as gradient and Hessian. We also introduce a two-step reconstruction technique to mitigate pruning errors without model retraining. Experimental results demonstrate that our approach significantly reduces computational costs and hardware requirements while maintaining superior performance across various datasets and models.

NAApr 2
Attention Mechanisms Through the Lens of Numerical Methods: Approximation Methods and Alternative Formulations

Michel Fabrice Serret, Alice Cortinovis, Yijun Dong et al.

The attention mechanism is the computational core of modern Transformer architectures, but its quadratic complexity in the input sequence length is the bottleneck for large-scale inference. This has motivated a rapidly growing body of work aimed at accelerating attention through approximation and reformulation. In this survey, we revisit attention mechanisms through the lens of numerical analysis, with a particular emphasis on tools and perspectives from numerical linear algebra. Our goal is twofold: first, we aim to systematically review and classify fast approximation methods according to the numerical principles they exploit. These include sparsity and clustering approaches, low-rank and subspace projection techniques, randomized sketching methods, and tensor-based decompositions. We also discuss kernel-inspired reformulations of attention and recent architectural variants, such as Latent Attention, that modify the standard softmax formulation to improve efficiency. Second, by presenting these developments within a unified mathematical framework, we aim to bridge the gap between disciplines and highlight opportunities for further contributions from computational mathematics, particularly numerical linear algebra, to the design of scalable attention mechanisms.

CVOct 4, 2022
Adaptively Weighted Data Augmentation Consistency Regularization for Robust Optimization under Concept Shift

Yijun Dong, Yuege Xie, Rachel Ward

Concept shift is a prevailing problem in natural tasks like medical image segmentation where samples usually come from different subpopulations with variant correlations between features and labels. One common type of concept shift in medical image segmentation is the "information imbalance" between label-sparse samples with few (if any) segmentation labels and label-dense samples with plentiful labeled pixels. Existing distributionally robust algorithms have focused on adaptively truncating/down-weighting the "less informative" (i.e., label-sparse in our context) samples. To exploit data features of label-sparse samples more efficiently, we propose an adaptively weighted online optimization algorithm -- AdaWAC -- to incorporate data augmentation consistency regularization in sample reweighting. Our method introduces a set of trainable weights to balance the supervised loss and unsupervised consistency regularization of each sample separately. At the saddle point of the underlying objective, the weights assign label-dense samples to the supervised loss and label-sparse samples to the unsupervised consistency regularization. We provide a convergence guarantee by recasting the optimization as online mirror descent on a saddle point problem. Our empirical results demonstrate that AdaWAC not only enhances the segmentation performance and sample efficiency but also improves the robustness to concept shift on various medical image segmentation tasks with different UNet-style backbones.

LGFeb 10
A Task-Centric Theory for Iterative Self-Improvement with Easy-to-Hard Curricula

Chenruo Liu, Yijun Dong, Yiqiu Shen et al.

Iterative self-improvement fine-tunes an autoregressive large language model (LLM) on reward-verified outputs generated by the LLM itself. In contrast to the empirical success of self-improvement, the theoretical foundation of this generative, iterative procedure in a practical, finite-sample setting remains limited. We make progress toward this goal by modeling each round of self-improvement as maximum-likelihood fine-tuning on a reward-filtered distribution and deriving finite-sample guarantees for the expected reward. Our analysis reveals an explicit feedback loop where better models accept more data per iteration, supporting sustained self-improvement while explaining eventual saturation of such improvement. Adopting a task-centric view by considering reasoning tasks with multiple difficulty levels, we further prove quantifiable conditions on model initialization, task difficulty, and sample budget where easy-to-hard curricula provably achieve better guarantees than training on fixed mixtures of tasks. Our analyses are validated via Monte-Carlo simulations and controlled experiments on graph-based reasoning tasks.

LGOct 3, 2023
Randomized Dimension Reduction with Statistical Guarantees

Yijun Dong

Large models and enormous data are essential driving forces of the unprecedented successes achieved by modern algorithms, especially in scientific computing and machine learning. Nevertheless, the growing dimensionality and model complexity, as well as the non-negligible workload of data pre-processing, also bring formidable costs to such successes in both computation and data aggregation. As the deceleration of Moore's Law slackens the cost reduction of computation from the hardware level, fast heuristics for expensive classical routines and efficient algorithms for exploiting limited data are increasingly indispensable for pushing the limit of algorithm potency. This thesis explores some of such algorithms for fast execution and efficient data utilization. From the computational efficiency perspective, we design and analyze fast randomized low-rank decomposition algorithms for large matrices based on "matrix sketching", which can be regarded as a dimension reduction strategy in the data space. These include the randomized pivoting-based interpolative and CUR decomposition discussed in Chapter 2 and the randomized subspace approximations discussed in Chapter 3. From the sample efficiency perspective, we focus on learning algorithms with various incorporations of data augmentation that improve generalization and distributional robustness provably. Specifically, Chapter 4 presents a sample complexity analysis for data augmentation consistency regularization where we view sample efficiency from the lens of dimension reduction in the function space. Then in Chapter 5, we introduce an adaptively weighted data augmentation consistency regularization algorithm for distributionally robust optimization with applications in medical image segmentation.

LGFeb 7, 2025
Discrepancies are Virtue: Weak-to-Strong Generalization through Lens of Intrinsic Dimension

Yijun Dong, Yicheng Li, Yunai Li et al.

Weak-to-strong (W2S) generalization is a type of finetuning (FT) where a strong (large) student model is trained on pseudo-labels generated by a weak teacher. Surprisingly, W2S FT often outperforms the weak teacher. We seek to understand this phenomenon through the observation that FT often occurs in intrinsically low-dimensional spaces. Leveraging the low intrinsic dimensionality of FT, we analyze W2S in the ridgeless regression setting from a variance reduction perspective. For a strong student-weak teacher pair with sufficiently expressive low-dimensional feature subspaces $\mathcal{V}_s, \mathcal{V}_w$, we provide an exact characterization of the variance that dominates the generalization error of W2S. This unveils a virtue of discrepancy between the strong and weak models in W2S: the variance of the weak teacher is inherited by the strong student in $\mathcal{V}_s \cap \mathcal{V}_w$, while reduced by a factor of $\mathrm{dim}(\mathcal{V}_s)/N$ in the subspace of discrepancy $\mathcal{V}_w \setminus \mathcal{V}_s$ with $N$ pseudo-labels for W2S. Our analysis further casts light on the sample complexities and the scaling of performance gap recovery in W2S. The analysis is supported by experiments on synthetic regression problems, as well as real vision and NLP tasks.

LGSep 28, 2025
Does Weak-to-strong Generalization Happen under Spurious Correlations?

Chenruo Liu, Yijun Dong, Qi Lei

We initiate a unified theoretical and algorithmic study of a key problem in weak-to-strong (W2S) generalization: when fine-tuning a strong pre-trained student with pseudolabels from a weaker teacher on a downstream task with spurious correlations, does W2S happen, and how to improve it upon failures? We consider two sources of spurious correlations caused by group imbalance: (i) a weak teacher fine-tuned on group-imbalanced labeled data with a minority group of fraction $η_\ell$, and (ii) a group-imbalanced unlabeled set pseudolabeled by the teacher with a minority group of fraction $η_u$. Theoretically, a precise characterization of W2S gain at the proportional asymptotic limit shows that W2S always happens with sufficient pseudolabels when $η_u = η_\ell$ but may fail when $η_u \ne η_\ell$, where W2S gain diminishes as $(η_u - η_\ell)^2$ increases. Our theory is corroborated by extensive experiments on various spurious correlation benchmarks and teacher-student pairs. To boost W2S performance upon failures, we further propose a simple, effective algorithmic remedy that retrains the strong student on its high-confidence data subset after W2S fine-tuning. Our algorithm is group-label-free and achieves consistent, substantial improvements over vanilla W2S fine-tuning.

LGFeb 24, 2022
Sample Efficiency of Data Augmentation Consistency Regularization

Shuo Yang, Yijun Dong, Rachel Ward et al.

Data augmentation is popular in the training of large neural networks; currently, however, there is no clear theoretical comparison between different algorithmic choices on how to use augmented data. In this paper, we take a step in this direction - we first present a simple and novel analysis for linear regression with label invariant augmentations, demonstrating that data augmentation consistency (DAC) is intrinsically more efficient than empirical risk minimization on augmented data (DA-ERM). The analysis is then extended to misspecified augmentations (i.e., augmentations that change the labels), which again demonstrates the merit of DAC over DA-ERM. Further, we extend our analysis to non-linear models (e.g., neural networks) and present generalization bounds. Finally, we perform experiments that make a clean and apples-to-apples comparison (i.e., with no extra modeling or data tweaks) between DAC and DA-ERM using CIFAR-100 and WideResNet; these together demonstrate the superior efficacy of DAC.