Cong Fang

LG
h-index19
41papers
1,609citations
Novelty62%
AI Score61

41 Papers

STMar 6, 2023
Environment Invariant Linear Least Squares

Jianqing Fan, Cong Fang, Yihong Gu et al. · princeton

This paper considers a multi-environment linear regression model in which data from multiple experimental settings are collected. The joint distribution of the response variable and covariates may vary across different environments, yet the conditional expectations of $y$ given the unknown set of important variables are invariant. Such a statistical model is related to the problem of endogeneity, causal inference, and transfer learning. The motivation behind it is illustrated by how the goals of prediction and attribution are inherent in estimating the true parameter and the important variable set. We construct a novel environment invariant linear least squares (EILLS) objective function, a multi-environment version of linear least-squares regression that leverages the above conditional expectation invariance structure and heterogeneity among different environments to determine the true parameter. Our proposed method is applicable without any additional structural knowledge and can identify the true parameter under a near-minimal identification condition. We establish non-asymptotic $\ell_2$ error bounds on the estimation error for the EILLS estimator in the presence of spurious variables. Moreover, we further show that the $\ell_0$ penalized EILLS estimator can achieve variable selection consistency in high-dimensional regimes. These non-asymptotic results demonstrate the sample efficiency of the EILLS estimator and its capability to circumvent the curse of endogeneity in an algorithmic manner without any prior structural knowledge. To the best of our knowledge, this paper is the first to realize statistically efficient invariance learning in the general linear model.

LGMar 26, 2022
A Roadmap for Big Model

Sha Yuan, Hanyu Zhao, Shuai Zhao et al. · bytedance, pku

With the rapid development of deep learning, training Big Models (BMs) for multiple downstream tasks becomes a popular paradigm. Researchers have achieved various outcomes in the construction of BMs and the BM application in many fields. At present, there is a lack of research work that sorts out the overall progress of BMs and guides the follow-up research. In this paper, we cover not only the BM technologies themselves but also the prerequisites for BM training and applications with BMs, dividing the BM review into four parts: Resource, Models, Key Technologies and Application. We introduce 16 specific BM-related topics in those four parts, they are Data, Knowledge, Computing System, Parallel Training System, Language Model, Vision Model, Multi-modal Model, Theory&Interpretability, Commonsense Reasoning, Reliability&Security, Governance, Evaluation, Machine Translation, Text Generation, Dialogue and Protein Research. In each topic, we summarize clearly the current studies and propose some future research directions. At the end of this paper, we conclude the further development of BMs in a more general view.

NANov 29, 2017
Convergence Rates Analysis of The Quadratic Penalty Method and Its Applications to Decentralized Distributed Optimization

Huan Li, Cong Fang, Zhouchen Lin

In this paper, we study a variant of the quadratic penalty method for linearly constrained convex problems, which has already been widely used but actually lacks theoretical justification. Namely, the penalty parameter steadily increases and the penalized objective function is minimized inexactly rather than exactly, e.g., with only one step of the proximal gradient descent. For such a variant of the quadratic penalty method, we give counterexamples to show that it may not give a solution to the original constrained problem. By choosing special penalty parameters, we ensure the convergence and further establish the convergence rates of $O\left(\frac{1}{\sqrt{K}}\right)$ for the generally convex problems and $O\left(\frac{1}{K}\right)$ for strongly convex ones, where $K$ is the number of iterations. Furthermore, by adopting Nesterov's extrapolation we show that the convergence rates can be improved to $O\left(\frac{1}{K}\right)$ for the generally convex problems and $O\left(\frac{1}{K^2}\right)$ for strongly convex ones. When applied to the decentralized distributed optimization, the penalty methods studied in this paper become the widely used distributed gradient method and the fast distributed gradient method. However, due to the totally different analysis framework, we can improve their $O\left(\frac{\log K}{\sqrt{K}}\right)$ and $O\left(\frac{\log K}{K}\right)$ convergence rates to $O\left(\frac{1}{\sqrt{K}}\right)$ and $O\left(\frac{1}{K}\right)$ with fewer assumptions on the network topology for general convex problems. Using our analysis framework, we also extend the fast distributed gradient method to a communication efficient version, i.e., finding an $\varepsilon$ solution in $O\left(\frac{1}{\varepsilon}\right)$ communications and $O\left(\frac{1}{\varepsilon^{2+δ}}\right)$ computations for the non-smooth problems, where $δ$ is a small constant.

LGMay 31
What Makes a Strong Model? A Unified Spectral Analysis of Knowledge Transfer over High-dimensional Linear Regression

Wendao Wu, Fangqing Zhang, Haihan Zhang et al.

Teacher-Student Knowledge Transfer (KT) is ubiquitous in modern machine learning, ranging from classical model compression via Knowledge Distillation (KD) to the emergent phenomenon of Weak-to-Strong (W2S) generalization. While existing studies offer isolated insights, a unified theoretical framework explaining the efficacy of KT across these disparate regimes remains lacking. In this work, we establish a unified spectral analysis of SGD dynamics in high-dimensional linear regression, elucidating the efficiency of KT across seemingly disparate regimes. We characterize KT efficiency through two distinct mechanisms: \emph{Spectral Horizon Expansion} in KD, which enables the capture of statistically inaccessible high-frequency signals, and \emph{Spectral Denoising} in W2S, where the student acts as a filter for optimization noise. Our framework unifies these phenomena, revealing that the efficacy of transfer is governed by the interplay between implicit regularization and heterogeneous spectral learning speeds over the spectrum.

CVJun 21, 2023
Task-Robust Pre-Training for Worst-Case Downstream Adaptation

Jianghui Wang, Yang Chen, Xingyu Xie et al.

Pre-training has achieved remarkable success when transferred to downstream tasks. In machine learning, we care about not only the good performance of a model but also its behavior under reasonable shifts of condition. The same philosophy holds when pre-training a foundation model. However, the foundation model may not uniformly behave well for a series of related downstream tasks. This happens, for example, when conducting mask recovery regression where the recovery ability or the training instances diverge like pattern features are extracted dominantly on pre-training, but semantic features are also required on a downstream task. This paper considers pre-training a model that guarantees a uniformly good performance over the downstream tasks. We call this goal as $\textit{downstream-task robustness}$. Our method first separates the upstream task into several representative ones and applies a simple minimax loss for pre-training. We then design an efficient algorithm to solve the minimax loss and prove its convergence in the convex setting. In the experiments, we show both on large-scale natural language processing and computer vision datasets our method increases the metrics on worse-case downstream tasks. Additionally, some theoretical explanations for why our loss is beneficial are provided. Specifically, we show fewer samples are inherently required for the most challenging downstream task in some cases.

LGSep 23, 2023
CORE: Common Random Reconstruction for Distributed Optimization with Provable Low Communication Complexity

Pengyun Yue, Hanzhen Zhao, Cong Fang et al.

With distributed machine learning being a prominent technique for large-scale machine learning tasks, communication complexity has become a major bottleneck for speeding up training and scaling up machine numbers. In this paper, we propose a new technique named Common randOm REconstruction(CORE), which can be used to compress the information transmitted between machines in order to reduce communication complexity without other strict conditions. Especially, our technique CORE projects the vector-valued information to a low-dimensional one through common random vectors and reconstructs the information with the same random noises after communication. We apply CORE to two distributed tasks, respectively convex optimization on linear models and generic non-convex optimization, and design new distributed algorithms, which achieve provably lower communication complexities. For example, we show for linear models CORE-based algorithm can encode the gradient vector to $\mathcal{O}(1)$-bits (against $\mathcal{O}(d)$), with the convergence rate not worse, preceding the existing results.

OCMar 2, 2023
PAPAL: A Provable PArticle-based Primal-Dual ALgorithm for Mixed Nash Equilibrium

Shihong Ding, Hanze Dong, Cong Fang et al.

We consider the non-convex non-concave objective function in two-player zero-sum continuous games. The existence of pure Nash equilibrium requires stringent conditions, posing a major challenge for this problem. To circumvent this difficulty, we examine the problem of identifying a mixed Nash equilibrium, where strategies are randomized and characterized by probability distributions over continuous domains. To this end, we propose PArticle-based Primal-dual ALgorithm (PAPAL) tailored for a weakly entropy-regularized min-max optimization over probability distributions. This algorithm employs the stochastic movements of particles to represent the updates of random strategies for the $ε$-mixed Nash equilibrium. We offer a comprehensive convergence analysis of the proposed algorithm, demonstrating its effectiveness. In contrast to prior research that attempted to update particle importance without movements, PAPAL is the first implementable particle-based algorithm accompanied by non-asymptotic quantitative convergence results, running time, and sample complexity guarantees. Our framework contributes novel insights into the particle-based algorithms for continuous min-max optimization in the general non-convex non-concave setting.

LGJan 29
Partial Feedback Online Learning

Shihao Shao, Cong Fang, Zhouchen Lin et al.

We study a new learning protocol, termed partial-feedback online learning, where each instance admits a set of acceptable labels, but the learner observes only one acceptable label per round. We highlight that, while classical version space is widely used for online learnability, it does not directly extend to this setting. We address this obstacle by introducing a collection version space, which maintains sets of hypotheses rather than individual hypotheses. Using this tool, we obtain a tight characterization of learnability in the set-realizable regime. In particular, we define the Partial-Feedback Littlestone dimension (PFLdim) and the Partial-Feedback Measure Shattering dimension (PMSdim), and show that they tightly characterize the minimax regret for deterministic and randomized learners, respectively. We further identify a nested inclusion condition under which deterministic and randomized learnability coincide, resolving an open question of Raman et al. (2024b). Finally, given a hypothesis space H, we show that beyond set realizability, the minimax regret can be linear even when |H|=2, highlighting a barrier beyond set realizability.

LGSep 15, 2024
The Optimality of (Accelerated) SGD for High-Dimensional Quadratic Optimization

Haihan Zhang, Yuanshi Liu, Qianwen Chen et al.

Stochastic gradient descent (SGD) is a widely used algorithm in machine learning, particularly for neural network training. Recent studies on SGD for canonical quadratic optimization or linear regression show it attains well generalization under suitable high-dimensional settings. However, a fundamental question -- for what kinds of high-dimensional learning problems SGD and its accelerated variants can achieve optimality has yet to be well studied. This paper investigates SGD with two essential components in practice: exponentially decaying step size schedule and momentum. We establish the convergence upper bound for momentum accelerated SGD (ASGD) and propose concrete classes of learning problems under which SGD or ASGD achieves min-max optimal convergence rates. The characterization of the target function is based on standard power-law decays in (functional) linear regression. Our results unveil new insights for understanding the learning bias of SGD: (i) SGD is efficient in learning ``dense'' features where the corresponding weights are subject to an infinity norm constraint; (ii) SGD is efficient for easy problem without suffering from the saturation effect; (iii) momentum can accelerate the convergence rate by order when the learning problem is relatively hard. To our knowledge, this is the first work to clearly identify the optimal boundary of SGD versus ASGD for the problem under mild settings.

LGSep 29, 2025Code
Conda: Column-Normalized Adam for Training Large Language Models Faster

Junjie Wang, Pan Zhou, Yiming Dong et al.

Large language models (LLMs) have demonstrated impressive generalization and emergent capabilities, yet their pre-training remains computationally expensive and sensitive to optimization dynamics. While Adam-based optimizers offer fast convergence by adapting learning rates coordinate-wise, recent studies reveal that their updates often suffer from poor spectral conditioning and low-rank structures, hindering efficiency. Muon addresses this issue via global spectral normalization but lacks the per-coordinate adaptivity of Adam. In this work, we propose Column-Normalized Adam (Conda), a novel optimizer that bridges the strengths of both approaches. Conda projects updates into an orthogonal subspace and applies column-wise second moment normalization based on the projected gradients, thereby achieving both improved spectral conditioning and maintaining coordinate-wise adaptivity. This design alleviates the spectral pathologies of Adam while preserving its fast convergence behavior. Extensive experiments on the LLaMA and GPT-2 series show that Conda consistently outperforms AdamW, Muon, and other baselines in pre-training. Remarkably, on the LLaMA series, Conda achieves 2-2.5 the convergence speed of AdamW, measured in both training steps and training time. Further ablations demonstrate its robustness under diverse training setups. These results collectively highlight Conda as an effective and broadly applicable optimizer for large-scale LLM training. The code is released on https://github.com/jie040109/Conda

LGApr 11
Mild Over-Parameterization Benefits Asymmetric Tensor PCA

Shihong Ding, Weicheng Lin, Cong Fang

Asymmetric Tensor PCA (ATPCA) is a prototypical model for studying the trade-offs between sample complexity, computation, and memory. Existing algorithms for this problem typically require at least $d^{\left\lceil\overline{k}/2\right\rceil}$ state memory cost to recover the signal, where $d$ is the vector dimension and $\overline{k}$ is the tensor order. We focus on the setting where $\overline{k} \geq 4$ is even and consider (stochastic) gradient descent-based algorithms under a limited memory budget, which permits only mild over-parameterization of the model. We propose a matrix-parameterized method (in $d^{2}$ state memory cost) using a novel three-phase alternating-update algorithm to address the problem and demonstrate how mild over-parameterization facilitates learning in two key aspects: (i) it improves sample efficiency, allowing our method to achieve \emph{near-optimal} $d^{\overline{k}-2}$ sample complexity in our limited memory setting; and (ii) it enhances adaptivity to problem structure, a previously unrecognized phenomenon, where the required sample size naturally decreases as consecutive vectors become more aligned, and in the symmetric limit attains $d^{\overline{k}/2}$, matching the \emph{best} known polynomial-time complexity. To our knowledge, this is the \emph{first} tractable algorithm for ATPCA with $d^{\overline{k}}$-independent memory costs.

LGFeb 26
PRAC: Principal-Random Subspace for LLM Activation Compression and Memory-Efficient Training

Yanyi Li, Yimu Zhang, Cong Fang

Activations have become the primary memory bottleneck in large-batch LLM training. However, existing compression methods fail to exploit the spectral structure of activations, resulting in slow convergence or limited compression. To address this, we bridge the relationship between the algorithm's fast convergence and the requirements for subspace projection, and show that an effective compression should yield an unbiased estimate of the original activation with low variance. We propose Principal-Random Subspace for LLM Activation Compression (PRAC), which novelly decomposes activations into two components: a principal subspace captured via SVD to retain dominant information, and a random subspace sampled from the orthogonal complement to approximate the tail. By introducing a precise scaling factor, we prove that PRAC yields an unbiased gradient estimator with minimum variance under certain conditions. Extensive experiments on pre-training and fine-tuning tasks demonstrate that PRAC achieves up to 36% total memory reduction with negligible performance degradation and minimal computational cost.

LGMay 6
SPHERE: Mitigating the Loss of Spectral Plasticity in Mixture-of-Experts for Deep Reinforcement Learning

Lirui Luo, Guoxi Zhang, Hongming Xu et al.

In deep reinforcement learning (DRL), an agent is trained from a stream of experience. In a continual learning setting, such agents can suffer from plasticity loss: their ability to learn new skills from new experiences diminishes over training. Recently, Mixture-of-Experts (MoE) networks have been reported to enable scaling laws and facilitate the learning of diverse skills. However, in continual reinforcement learning settings, their performance can degenerate as learning proceeds, indicating a loss of plasticity. To address this, building on Neural Tangent Kernel (NTK) theory, we formalize the plasticity loss in MoE policies as a loss of spectral plasticity. We then derive a tractable proxy for spectral plasticity, one expressible in terms of individual expert feature matrices. Leveraging this proxy, we introduce SPHERE, a practical Parseval penalty tailored for MoE-based policies that alleviates the loss of spectral plasticity. On MetaWorld and HumanoidBench, SPHERE improves average success under continual RL by 133% and 50% over an unregularized MoE baseline, while maintaining higher spectral plasticity throughout training.

LGMay 1
Near-optimal and Efficient First-Order Algorithm for Multi-Task Learning with Shared Linear Representation

Shihong Ding, Fangyu Du, Cong Fang

Multi-task learning (MTL) has emerged as a pivotal paradigm in machine learning by leveraging shared structures across multiple related tasks. Despite its empirical success, the development of likelihood-based efficiently solvable algorithms--even for shared linear representations--remains largely underdeveloped, primarily due to the non-convex structure intrinsic to matrix factorization. This paper introduces a first-order algorithm that jointly learns a shared representation and task-specific parameters, with guaranteed efficiency. Notably, it converges in $\widetilde{\mathcal{O}}(1)$ iterations and attains a \emph{near-optimal} estimation error of $\widetilde{\mathcal{O}}(dk/(TN))$, \emph{improving} over existing likelihood-based methods by a factor of $k$, where $d$, $k$, $T$, $N$ denote input dimension, representation dimension, task count, and samples per task, respectively. Our results justify that likelihood-based first-order methods can efficiently solve the MTL problem.

CVMar 2
MVR: Multi-view Video Reward Shaping for Reinforcement Learning

Lirui Luo, Guoxi Zhang, Hongming Xu et al.

Reward design is of great importance for solving complex tasks with reinforcement learning. Recent studies have explored using image-text similarity produced by vision-language models (VLMs) to augment rewards of a task with visual feedback. A common practice linearly adds VLM scores to task or success rewards without explicit shaping, potentially altering the optimal policy. Moreover, such approaches, often relying on single static images, struggle with tasks whose desired behavior involves complex, dynamic motions spanning multiple visually different states. Furthermore, single viewpoints can occlude critical aspects of an agent's behavior. To address these issues, this paper presents Multi-View Video Reward Shaping (MVR), a framework that models the relevance of states regarding the target task using videos captured from multiple viewpoints. MVR leverages video-text similarity from a frozen pre-trained VLM to learn a state relevance function that mitigates the bias towards specific static poses inherent in image-based methods. Additionally, we introduce a state-dependent reward shaping formulation that integrates task-specific rewards and VLM-based guidance, automatically reducing the influence of VLM guidance once the desired motion pattern is achieved. We confirm the efficacy of the proposed framework with extensive experiments on challenging humanoid locomotion tasks from HumanoidBench and manipulation tasks from MetaWorld, verifying the design choices through ablation studies.

LGMar 2
Accelerating Single-Pass SGD for Generalized Linear Prediction

Qian Chen, Shihong Ding, Cong Fang

We study generalized linear prediction under a streaming setting, where each iteration uses only one fresh data point for a gradient-level update. While momentum is well-established in deterministic optimization, a fundamental open question is whether it can accelerate such single-pass non-quadratic stochastic optimization. We propose the first algorithm that successfully incorporates momentum via a novel data-dependent proximal method, achieving dual-momentum acceleration. Our derived excess risk bound decomposes into three components: an improved optimization error, a minimax optimal statistical error, and a higher-order model-misspecification error. The proof handles mis-specification via a fine-grained stationary analysis of inner updates, while localizing statistical error through a two-phase outer-loop analysis. As a result, we resolve the open problem posed by Jain et al. [2018a] and demonstrate that momentum acceleration is more effective than variance reduction for generalized linear prediction in the streaming setting.

STMay 7, 2024
Causality Pursuit from Heterogeneous Environments via Neural Adversarial Invariance Learning

Yihong Gu, Cong Fang, Peter Bühlmann et al. · princeton

Pursuing causality from data is a fundamental problem in scientific discovery, treatment intervention, and transfer learning. This paper introduces a novel algorithmic method for addressing nonparametric invariance and causality learning in regression models across multiple environments, where the joint distribution of response variables and covariates varies, but the conditional expectations of outcome given an unknown set of quasi-causal variables are invariant. The challenge of finding such an unknown set of quasi-causal or invariant variables is compounded by the presence of endogenous variables that have heterogeneous effects across different environments. The proposed Focused Adversarial Invariant Regularization (FAIR) framework utilizes an innovative minimax optimization approach that drives regression models toward prediction-invariant solutions through adversarial testing. Leveraging the representation power of neural networks, FAIR neural networks (FAIR-NN) are introduced for causality pursuit. It is shown that FAIR-NN can find the invariant variables and quasi-causal variables under a minimal identification condition and that the resulting procedure is adaptive to low-dimensional composition structures in a non-asymptotic analysis. Under a structural causal model, variables identified by FAIR-NN represent pragmatic causality and provably align with exact causal mechanisms under conditions of sufficient heterogeneity. Computationally, FAIR-NN employs a novel Gumbel approximation with decreased temperature and a stochastic gradient descent ascent algorithm. The procedures are demonstrated using simulated and real-data examples.

MLFeb 13, 2025
Optimal Algorithms in Linear Regression under Covariate Shift: On the Importance of Precondition

Yuanshi Liu, Haihan Zhang, Qian Chen et al.

A common pursuit in modern statistical learning is to attain satisfactory generalization out of the source data distribution (OOD). In theory, the challenge remains unsolved even under the canonical setting of covariate shift for the linear model. This paper studies the foundational (high-dimensional) linear regression where the ground truth variables are confined to an ellipse-shape constraint and addresses two fundamental questions in this regime: (i) given the target covariate matrix, what is the min-max \emph{optimal} algorithm under covariate shift? (ii) for what kinds of target classes, the commonly-used SGD-type algorithms achieve optimality? Our analysis starts with establishing a tight lower generalization bound via a Bayesian Cramer-Rao inequality. For (i), we prove that the optimal estimator can be simply a certain linear transformation of the best estimator for the source distribution. Given the source and target matrices, we show that the transformation can be efficiently computed via a convex program. The min-max optimal analysis for SGD leverages the idea that we recognize both the accumulated updates of the applied algorithms and the ideal transformation as preconditions on the learning variables. We provide sufficient conditions when SGD with its acceleration variants attain optimality.

STJan 29, 2025
Fundamental Computational Limits in Pursuing Invariant Causal Prediction and Invariance-Guided Regularization

Yihong Gu, Cong Fang, Yang Xu et al. · princeton

Pursuing invariant prediction from heterogeneous environments opens the door to learning causality in a purely data-driven way and has several applications in causal discovery and robust transfer learning. However, existing methods such as ICP [Peters et al., 2016] and EILLS [Fan et al., 2024] that can attain sample-efficient estimation are based on exponential time algorithms. In this paper, we show that such a problem is intrinsically hard in computation: the decision problem, testing whether a non-trivial prediction-invariant solution exists across two environments, is NP-hard even for the linear causal relationship. In the world where P$\neq$NP, our results imply that the estimation error rate can be arbitrarily slow using any computationally efficient algorithm. This suggests that pursuing causality is fundamentally harder than detecting associations when no prior assumption is pre-offered. Given there is almost no hope of computational improvement under the worst case, this paper proposes a method capable of attaining both computationally and statistically efficient estimation under additional conditions. Furthermore, our estimator is a distributionally robust estimator with an ellipse-shaped uncertain set where more uncertainty is placed on spurious directions than invariant directions, resulting in a smooth interpolation between the most predictive solution and the causal solution by varying the invariance hyper-parameter. Non-asymptotic results and empirical applications support the claim.

LGDec 6, 2023
Accelerated Gradient Algorithms with Adaptive Subspace Search for Instance-Faster Optimization

Yuanshi Liu, Hanzhen Zhao, Yang Xu et al.

Gradient-based minimax optimal algorithms have greatly promoted the development of continuous optimization and machine learning. One seminal work due to Yurii Nesterov [Nes83a] established $\tilde{\mathcal{O}}(\sqrt{L/μ})$ gradient complexity for minimizing an $L$-smooth $μ$-strongly convex objective. However, an ideal algorithm would adapt to the explicit complexity of a particular objective function and incur faster rates for simpler problems, triggering our reconsideration of two defeats of existing optimization modeling and analysis. (i) The worst-case optimality is neither the instance optimality nor such one in reality. (ii) Traditional $L$-smoothness condition may not be the primary abstraction/characterization for modern practical problems. In this paper, we open up a new way to design and analyze gradient-based algorithms with direct applications in machine learning, including linear regression and beyond. We introduce two factors $(α, τ_α)$ to refine the description of the degenerated condition of the optimization problems based on the observation that the singular values of Hessian often drop sharply. We design adaptive algorithms that solve simpler problems without pre-known knowledge with reduced gradient or analogous oracle accesses. The algorithms also improve the state-of-art complexities for several problems in machine learning, thereby solving the open problem of how to design faster algorithms in light of the known complexity lower bounds. Specially, with the $\mathcal{O}(1)$-nuclear norm bounded, we achieve an optimal $\tilde{\mathcal{O}}(μ^{-1/3})$ (v.s. $\tilde{\mathcal{O}}(μ^{-1/2})$) gradient complexity for linear regression. We hope this work could invoke the rethinking for understanding the difficulty of modern problems in optimization.

LGFeb 13, 2025
Scaling Law for Stochastic Gradient Descent in Quadratically Parameterized Linear Regression

Shihong Ding, Haihan Zhang, Hanzhen Zhao et al.

In machine learning, the scaling law describes how the model performance improves with the model and data size scaling up. From a learning theory perspective, this class of results establishes upper and lower generalization bounds for a specific learning algorithm. Here, the exact algorithm running using a specific model parameterization often offers a crucial implicit regularization effect, leading to good generalization. To characterize the scaling law, previous theoretical studies mainly focus on linear models, whereas, feature learning, a notable process that contributes to the remarkable empirical success of neural networks, is regretfully vacant. This paper studies the scaling law over a linear regression with the model being quadratically parameterized. We consider infinitely dimensional data and slope ground truth, both signals exhibiting certain power-law decay rates. We study convergence rates for Stochastic Gradient Descent and demonstrate the learning rates for variables will automatically adapt to the ground truth. As a result, in the canonical linear regression, we provide explicit separations for generalization curves between SGD with and without feature learning, and the information-theoretical lower bound that is agnostic to parametrization method and the algorithm. Our analysis for decaying ground truth provides a new characterization for the learning dynamic of the model.

LGOct 19, 2025
Improving Model Representation and Reducing KV Cache via Skip Connections with First Value Heads

Zhoutong Wu, Yuan Zhang, Yiming Dong et al.

Transformer models have driven breakthroughs across various language tasks by their strong capability to learn rich contextual representations. Scaling them to improve representation, however, often demands substantial memory and compute costs, such as the Key-Value (KV) cache used during auto-regressive decoding. Skip connections offer a promising way to improve representation without bloating resource usage, yet most prior works either improve expressivity while leaving KV costs unchanged, or reduce memory at the cost of weaker representation. In this work, we propose SkipV1Former, a Transformer variant that uses skip connections from the first layer's Value heads to strengthen model representation and reduce KV cache. Specifically, from the second block onward, each layer reuses half of its Value heads from the very first layer, while computing the other half as usual-cutting Value projections and V cache by nearly 50 \%. Theoretically, we show that routing uncompressed first-layer Values into deeper layers restores information lost to compression and accelerates the model's implicit mesa-optimization-a key pattern of Transformer in auto-regressive tasks. Empirically, across different model scales, SkipV1Former delivers consistent reductions of approximately 25 \% in KV cache while improving perplexity relative to standard Multi-Head Attention (MHA) Transformers and some advanced variants. Moreover, we propose a recipe for uptraining existing MHA Transformer checkpoints to SkipV1Former with only 10-15\% additional compute. Finally, SkipV1Former can seamlessly combine advanced methods like Group-Query Attention and Multi-Latent Attention to achieve further KV cache savings and performance improvement. When combined with YOCO, it cuts KV cache size by nearly 50 \% while still improving performance.

LGOct 10, 2025
AdaPM: a Partial Momentum Algorithm for LLM Training

Yimu Zhang, Yuanshi Liu, Cong Fang

In the training of large language models, momentum is widely used and often demonstrated to achieve significant acceleration. However, storing momentum typically presents memory challenges. In this paper, we propose AdaPM, an adaptive training strategy that leverages partial momentum to implement a memory-efficient optimizer. To this end, AdaPM utilizes a non-uniform momentum design: for most blocks, full momentum is not necessary to preserve the performance of the optimization. In the momentum design of AdaPM, to mitigate the bias and performance loss caused by partial momentum, we enhance the partial momentum by a bias correction technique. Empirically, we verify that our approach reduces memory by over $90\%$ in momentum while maintaining both efficiency and performance for pretraining various language models ranging from 60M to 1.5B, as well as for supervised fine-tuning and RLHF. AdaPM can further reduce memory by up to $95\%$ in optimizer states by combining the memory-efficient technique on the second-order statistic, saving over $30\%$ GPU hours for pretraining GPT-2 1.5B.

MLMay 28, 2025
Learning Curves of Stochastic Gradient Descent in Kernel Regression

Haihan Zhang, Weicheng Lin, Yuanshi Liu et al.

This paper considers a canonical problem in kernel regression: how good are the model performances when it is trained by the popular online first-order algorithms, compared to the offline ones, such as ridge and ridgeless regression? In this paper, we analyze the foundational single-pass Stochastic Gradient Descent (SGD) in kernel regression under source condition where the optimal predictor can even not belong to the RKHS, i.e. the model is misspecified. Specifically, we focus on the inner product kernel over the sphere and characterize the exact orders of the excess risk curves under different scales of sample sizes $n$ concerning the input dimension $d$. Surprisingly, we show that SGD achieves min-max optimal rates up to constants among all the scales, without suffering the saturation, a prevalent phenomenon observed in (ridge) regression, except when the model is highly misspecified and the learning is in a final stage where $n\gg d^γ$ with any constant $γ>0$. The main reason for SGD to overcome the curse of saturation is the exponentially decaying step size schedule, a common practice in deep neural network training. As a byproduct, we provide the \emph{first} provable advantage of the scheme over the iterative averaging method in the common setting.

AIJun 17, 2024
Relational Learning in Pre-Trained Models: A Theory from Hypergraph Recovery Perspective

Yang Chen, Cong Fang, Zhouchen Lin et al.

Foundation Models (FMs) have demonstrated remarkable insights into the relational dynamics of the world, leading to the crucial question: how do these models acquire an understanding of world hybrid relations? Traditional statistical learning, particularly for prediction problems, may overlook the rich and inherently structured information from the data, especially regarding the relationships between objects. We introduce a mathematical model that formalizes relational learning as hypergraph recovery to study pre-training of FMs. In our framework, the world is represented as a hypergraph, with data abstracted as random samples from hyperedges. We theoretically examine the feasibility of a Pre-Trained Model (PTM) to recover this hypergraph and analyze the data efficiency in a minimax near-optimal style. By integrating rich graph theories into the realm of PTMs, our mathematical framework offers powerful tools for an in-depth understanding of pre-training from a unique perspective and can be used under various scenarios. As an example, we extend the framework to entity alignment in multimodal learning.

QUANT-PHJun 5, 2024
Quantum Algorithms and Lower Bounds for Finite-Sum Optimization

Yexin Zhang, Chenyi Zhang, Cong Fang et al.

Finite-sum optimization has wide applications in machine learning, covering important problems such as support vector machines, regression, etc. In this paper, we initiate the study of solving finite-sum optimization problems by quantum computing. Specifically, let $f_1,\ldots,f_n\colon\mathbb{R}^d\to\mathbb{R}$ be $\ell$-smooth convex functions and $ψ\colon\mathbb{R}^d\to\mathbb{R}$ be a $μ$-strongly convex proximal function. The goal is to find an $ε$-optimal point for $F(\mathbf{x})=\frac{1}{n}\sum_{i=1}^n f_i(\mathbf{x})+ψ(\mathbf{x})$. We give a quantum algorithm with complexity $\tilde{O}\big(n+\sqrt{d}+\sqrt{\ell/μ}\big(n^{1/3}d^{1/3}+n^{-2/3}d^{5/6}\big)\big)$, improving the classical tight bound $\tildeΘ\big(n+\sqrt{n\ell/μ}\big)$. We also prove a quantum lower bound $\tildeΩ(n+n^{3/4}(\ell/μ)^{1/4})$ when $d$ is large enough. Both our quantum upper and lower bounds can extend to the cases where $ψ$ is not necessarily strongly convex, or each $f_i$ is Lipschitz but not necessarily smooth. In addition, when $F$ is nonconvex, our quantum algorithm can find an $ε$-critial point using $\tilde{O}(n+\ell(d^{1/3}n^{1/3}+\sqrt{d})/ε^2)$ queries.

AIMar 19, 2024
End-to-End Neuro-Symbolic Reinforcement Learning with Textual Explanations

Lirui Luo, Guoxi Zhang, Hongming Xu et al.

Neuro-symbolic reinforcement learning (NS-RL) has emerged as a promising paradigm for explainable decision-making, characterized by the interpretability of symbolic policies. NS-RL entails structured state representations for tasks with visual observations, but previous methods cannot refine the structured states with rewards due to a lack of efficiency. Accessibility also remains an issue, as extensive domain knowledge is required to interpret symbolic policies. In this paper, we present a neuro-symbolic framework for jointly learning structured states and symbolic policies, whose key idea is to distill the vision foundation model into an efficient perception module and refine it during policy learning. Moreover, we design a pipeline to prompt GPT-4 to generate textual explanations for the learned policies and decisions, significantly reducing users' cognitive load to understand the symbolic policies. We verify the efficacy of our approach on nine Atari tasks and present GPT-generated explanations for policies and decisions.

LGMar 3, 2024
The Implicit Bias of Heterogeneity towards Invariance: A Study of Multi-Environment Matrix Sensing

Yang Xu, Yihong Gu, Cong Fang · princeton

Models are expected to engage in invariance learning, which involves distinguishing the core relations that remain consistent across varying environments to ensure the predictions are safe, robust and fair. While existing works consider specific algorithms to realize invariance learning, we show that model has the potential to learn invariance through standard training procedures. In other words, this paper studies the implicit bias of Stochastic Gradient Descent (SGD) over heterogeneous data and shows that the implicit bias drives the model learning towards an invariant solution. We call the phenomenon the implicit invariance learning. Specifically, we theoretically investigate the multi-environment low-rank matrix sensing problem where in each environment, the signal comprises (i) a lower-rank invariant part shared across all environments; and (ii) a significantly varying environment-dependent spurious component. The key insight is, through simply employing the large step size large-batch SGD sequentially in each environment without any explicit regularization, the oscillation caused by heterogeneity can provably prevent model learning spurious signals. The model reaches the invariant solution after certain iterations. In contrast, model learned using pooled SGD over all data would simultaneously learn both the invariant and spurious signals. Overall, we unveil another implicit bias that is a result of the symbiosis between the heterogeneity of data and modern algorithms, which is, to the best of our knowledge, first in the literature.

LGMay 22, 2023
Policy Representation via Diffusion Probability Model for Reinforcement Learning

Long Yang, Zhixiong Huang, Fenghao Lei et al.

Popular reinforcement learning (RL) algorithms tend to produce a unimodal policy distribution, which weakens the expressiveness of complicated policy and decays the ability of exploration. The diffusion probability model is powerful to learn complicated multimodal distributions, which has shown promising and potential applications to RL. In this paper, we formally build a theoretical foundation of policy representation via the diffusion probability model and provide practical implementations of diffusion policy for online model-free RL. Concretely, we character diffusion policy as a stochastic process, which is a new approach to representing a policy. Then we present a convergence guarantee for diffusion policy, which provides a theory to understand the multimodality of diffusion policy. Furthermore, we propose the DIPO which is an implementation for model-free online RL with DIffusion POlicy. To the best of our knowledge, DIPO is the first algorithm to solve model-free online RL problems with the diffusion model. Finally, extensive empirical results show the effectiveness and superiority of DIPO on the standard continuous control Mujoco benchmark.

LGJan 29, 2021
Exploring Deep Neural Networks via Layer-Peeled Model: Minority Collapse in Imbalanced Training

Cong Fang, Hangfeng He, Qi Long et al.

In this paper, we introduce the \textit{Layer-Peeled Model}, a nonconvex yet analytically tractable optimization program, in a quest to better understand deep neural networks that are trained for a sufficiently long time. As the name suggests, this new model is derived by isolating the topmost layer from the remainder of the neural network, followed by imposing certain constraints separately on the two parts of the network. We demonstrate that the Layer-Peeled Model, albeit simple, inherits many characteristics of well-trained neural networks, thereby offering an effective tool for explaining and predicting common empirical patterns of deep learning training. First, when working on class-balanced datasets, we prove that any solution to this model forms a simplex equiangular tight frame, which in part explains the recently discovered phenomenon of neural collapse \cite{papyan2020prevalence}. More importantly, when moving to the imbalanced case, our analysis of the Layer-Peeled Model reveals a hitherto unknown phenomenon that we term \textit{Minority Collapse}, which fundamentally limits the performance of deep learning models on the minority classes. In addition, we use the Layer-Peeled Model to gain insights into how to mitigate Minority Collapse. Interestingly, this phenomenon is first predicted by the Layer-Peeled Model before being confirmed by our computational experiments.

LGDec 27, 2020
Mathematical Models of Overparameterized Neural Networks

Cong Fang, Hanze Dong, Tong Zhang

Deep learning has received considerable empirical successes in recent years. However, while many ad hoc tricks have been discovered by practitioners, until recently, there has been a lack of theoretical understanding for tricks invented in the deep learning literature. Known by practitioners that overparameterized neural networks are easy to learn, in the past few years there have been important theoretical developments in the analysis of overparameterized neural networks. In particular, it was shown that such systems behave like convex systems under various restricted settings, such as for two-layer NNs, and when learning is restricted locally in the so-called neural tangent kernel space around specialized initializations. This paper discusses some of these recent progresses leading to significant better understanding of neural networks. We will focus on the analysis of two-layer neural networks, and explain the key mathematical models, with their algorithmic implications. We will then discuss challenges in understanding deep neural networks and some current research directions.

LGOct 5, 2020
Improved Analysis of Clipping Algorithms for Non-convex Optimization

Bohang Zhang, Jikai Jin, Cong Fang et al.

Gradient clipping is commonly used in training deep neural networks partly due to its practicability in relieving the exploding gradient problem. Recently, \citet{zhang2019gradient} show that clipped (stochastic) Gradient Descent (GD) converges faster than vanilla GD/SGD via introducing a new assumption called $(L_0, L_1)$-smoothness, which characterizes the violent fluctuation of gradients typically encountered in deep neural networks. However, their iteration complexities on the problem-dependent parameters are rather pessimistic, and theoretical justification of clipping combined with other crucial techniques, e.g. momentum acceleration, are still lacking. In this paper, we bridge the gap by presenting a general framework to study the clipping algorithms, which also takes momentum methods into consideration. We provide convergence analysis of the framework in both deterministic and stochastic setting, and demonstrate the tightness of our results by comparing them with existing lower bounds. Our results imply that the efficiency of clipping methods will not degenerate even in highly non-smooth regions of the landscape. Experiments confirm the superiority of clipping-based methods in deep learning tasks.

MLJul 3, 2020
Modeling from Features: a Mean-field Framework for Over-parameterized Deep Neural Networks

Cong Fang, Jason D. Lee, Pengkun Yang et al.

This paper proposes a new mean-field framework for over-parameterized deep neural networks (DNNs), which can be used to analyze neural network training. In this framework, a DNN is represented by probability measures and functions over its features (that is, the function values of the hidden units over the training data) in the continuous limit, instead of the neural network parameters as most existing studies have done. This new representation overcomes the degenerate situation where all the hidden units essentially have only one meaningful hidden unit in each middle layer, and further leads to a simpler representation of DNNs, for which the training objective can be reformulated as a convex optimization problem via suitable re-parameterization. Moreover, we construct a non-linear dynamics called neural feature flow, which captures the evolution of an over-parameterized DNN trained by Gradient Descent. We illustrate the framework via the standard DNN and the Residual Network (Res-Net) architectures. Furthermore, we show, for Res-Net, when the neural feature flow process converges, it reaches a global minimal solution under suitable conditions. Our analysis leads to the first global convergence proof for over-parameterized neural network training with more than $3$ layers in the mean-field regime.

LGNov 18, 2019
Convex Formulation of Overparameterized Deep Neural Networks

Cong Fang, Yihong Gu, Weizhong Zhang et al.

Analysis of over-parameterized neural networks has drawn significant attention in recentyears. It was shown that such systems behave like convex systems under various restrictedsettings, such as for two-level neural networks, and when learning is only restricted locally inthe so-called neural tangent kernel space around specialized initializations. However, there areno theoretical techniques that can analyze fully trained deep neural networks encountered inpractice. This paper solves this fundamental problem by investigating such overparameterizeddeep neural networks when fully trained. We generalize a new technique called neural feature repopulation, originally introduced in (Fang et al., 2019a) for two-level neural networks, to analyze deep neural networks. It is shown that under suitable representations, overparameterized deep neural networks are inherently convex, and when optimized, the system can learn effective features suitable for the underlying learning task under mild conditions. This new analysis is consistent with empirical observations that deep neural networks are capable of learning efficient feature representations. Therefore, the highly unexpected result of this paper can satisfactorily explain the practical success of deep neural networks. Empirical studies confirm that predictions of our theory are consistent with results observed in practice.

LGOct 25, 2019
Over Parameterized Two-level Neural Networks Can Learn Near Optimal Feature Representations

Cong Fang, Hanze Dong, Tong Zhang

Recently, over-parameterized neural networks have been extensively analyzed in the literature. However, the previous studies cannot satisfactorily explain why fully trained neural networks are successful in practice. In this paper, we present a new theoretical framework for analyzing over-parameterized neural networks which we call neural feature repopulation. Our analysis can satisfactorily explain the empirical success of two level neural networks that are trained by standard learning algorithms. Our key theoretical result is that in the limit of infinite number of hidden neurons, over-parameterized two-level neural networks trained via the standard (noisy) gradient descent learns a well-defined feature distribution (population), and the limiting feature distribution is nearly optimal for the underlying learning task under certain conditions. Empirical studies confirm that predictions of our theory are consistent with the results observed in real practice.

OCMar 4, 2019
A Stochastic Trust Region Method for Non-convex Minimization

Zebang Shen, Pan Zhou, Cong Fang et al.

We target the problem of finding a local minimum in non-convex finite-sum minimization. Towards this goal, we first prove that the trust region method with inexact gradient and Hessian estimation can achieve a convergence rate of order $\mathcal{O}(1/{k^{2/3}})$ as long as those differential estimations are sufficiently accurate. Combining such result with a novel Hessian estimator, we propose the sample-efficient stochastic trust region (STR) algorithm which finds an $(ε, \sqrtε)$-approximate local minimum within $\mathcal{O}({\sqrt{n}}/{ε^{1.5}})$ stochastic Hessian oracle queries. This improves state-of-the-art result by $\mathcal{O}(n^{1/6})$. Experiments verify theoretical conclusions and the efficiency of STR.

OCFeb 1, 2019
Sharp Analysis for Nonconvex SGD Escaping from Saddle Points

Cong Fang, Zhouchen Lin, Tong Zhang

In this paper, we give a sharp analysis for Stochastic Gradient Descent (SGD) and prove that SGD is able to efficiently escape from saddle points and find an $(ε, O(ε^{0.5}))$-approximate second-order stationary point in $\tilde{O}(ε^{-3.5})$ stochastic gradient computations for generic nonconvex optimization problems, when the objective function satisfies gradient-Lipschitz, Hessian-Lipschitz, and dispersive noise assumptions. This result subverts the classical belief that SGD requires at least $O(ε^{-4})$ stochastic gradient computations for obtaining an $(ε,O(ε^{0.5}))$-approximate second-order stationary point. Such SGD rate matches, up to a polylogarithmic factor of problem-dependent parameters, the rate of most accelerated nonconvex stochastic optimization algorithms that adopt additional techniques, such as Nesterov's momentum acceleration, negative curvature search, as well as quadratic and cubic regularization tricks. Our novel analysis gives new insights into nonconvex SGD and can be potentially generalized to a broad class of stochastic optimization algorithms.

LGDec 29, 2018
Hessian-Aware Zeroth-Order Optimization for Black-Box Adversarial Attack

Haishan Ye, Zhichao Huang, Cong Fang et al.

Zeroth-order optimization is an important research topic in machine learning. In recent years, it has become a key tool in black-box adversarial attack to neural network based image classifiers. However, existing zeroth-order optimization algorithms rarely extract second-order information of the model function. In this paper, we utilize the second-order information of the objective function and propose a novel \textit{Hessian-aware zeroth-order algorithm} called \texttt{ZO-HessAware}. Our theoretical result shows that \texttt{ZO-HessAware} has an improved zeroth-order convergence rate and query complexity under structured Hessian approximation, where we propose a few approximation methods for estimating Hessian. Our empirical studies on the black-box adversarial attack problem validate that our algorithm can achieve improved success rates with a lower query complexity.

LGNov 5, 2018
Lifted Proximal Operator Machines

Jia Li, Cong Fang, Zhouchen Lin

We propose a new optimization method for training feed-forward neural networks. By rewriting the activation function as an equivalent proximal operator, we approximate a feed-forward neural network by adding the proximal operators to the objective function as penalties, hence we call the lifted proximal operator machine (LPOM). LPOM is block multi-convex in all layer-wise weights and activations. This allows us to use block coordinate descent to update the layer-wise weights and activations in parallel. Most notably, we only use the mapping of the activation function itself, rather than its derivatives, thus avoiding the gradient vanishing or blow-up issues in gradient based training methods. So our method is applicable to various non-decreasing Lipschitz continuous activation functions, which can be saturating and non-differentiable. LPOM does not require more auxiliary variables than the layer-wise activations, thus using roughly the same amount of memory as stochastic gradient descent (SGD) does. We further prove the convergence of updating the layer-wise weights and activations. Experiments on MNIST and CIFAR-10 datasets testify to the advantages of LPOM.

OCJul 4, 2018
SPIDER: Near-Optimal Non-Convex Optimization via Stochastic Path Integrated Differential Estimator

Cong Fang, Chris Junchi Li, Zhouchen Lin et al.

In this paper, we propose a new technique named \textit{Stochastic Path-Integrated Differential EstimatoR} (SPIDER), which can be used to track many deterministic quantities of interest with significantly reduced computational cost. We apply SPIDER to two tasks, namely the stochastic first-order and zeroth-order methods. For stochastic first-order method, combining SPIDER with normalized gradient descent, we propose two new algorithms, namely SPIDER-SFO and SPIDER-SFO\textsuperscript{+}, that solve non-convex stochastic optimization problems using stochastic gradients only. We provide sharp error-bound results on their convergence rates. In special, we prove that the SPIDER-SFO and SPIDER-SFO\textsuperscript{+} algorithms achieve a record-breaking gradient computation cost of $\mathcal{O}\left( \min( n^{1/2} ε^{-2}, ε^{-3} ) \right)$ for finding an $ε$-approximate first-order and $\tilde{\mathcal{O}}\left( \min( n^{1/2} ε^{-2}+ε^{-2.5}, ε^{-3} ) \right)$ for finding an $(ε, \mathcal{O}(ε^{0.5}))$-approximate second-order stationary point, respectively. In addition, we prove that SPIDER-SFO nearly matches the algorithmic lower bound for finding approximate first-order stationary points under the gradient Lipschitz assumption in the finite-sum setting. For stochastic zeroth-order method, we prove a cost of $\mathcal{O}( d \min( n^{1/2} ε^{-2}, ε^{-3}) )$ which outperforms all existing results.

OCFeb 27, 2018
Accelerating Asynchronous Algorithms for Convex Optimization by Momentum Compensation

Cong Fang, Yameng Huang, Zhouchen Lin

Asynchronous algorithms have attracted much attention recently due to the crucial demands on solving large-scale optimization problems. However, the accelerated versions of asynchronous algorithms are rarely studied. In this paper, we propose the "momentum compensation" technique to accelerate asynchronous algorithms for convex problems. Specifically, we first accelerate the plain Asynchronous Gradient Descent, which achieves a faster $O(1/\sqrtε)$ (v.s. $O(1/ε)$) convergence rate for non-strongly convex functions, and $O(\sqrtκ\log(1/ε))$ (v.s. $O(κ\log(1/ε))$) for strongly convex functions to reach an $ε$- approximate minimizer with the condition number $κ$. We further apply the technique to accelerate modern stochastic asynchronous algorithms such as Asynchronous Stochastic Coordinate Descent and Asynchronous Stochastic Gradient Descent. Both of the resultant practical algorithms are faster than existing ones by order. To the best of our knowledge, we are the first to consider accelerated algorithms that allow updating by delayed gradients and are the first to propose truly accelerated asynchronous algorithms. Finally, the experimental results on a shared memory system show that acceleration can lead to significant performance gains on ill-conditioned problems.