Suvrit Sra

LG
h-index64
103papers
7,663citations
Novelty60%
AI Score61

103 Papers

LGDec 30, 2022
Cost-Driven Representation Learning for Linear Quadratic Gaussian Control: Part I

Yi Tian, Kaiqing Zhang, Russ Tedrake et al. · mit

We study the task of learning state representations from potentially high-dimensional observations, with the goal of controlling an unknown partially observable system. We pursue a cost-driven approach, where a dynamic model in some latent state space is learned by predicting the costs without predicting the observations or actions. In particular, we focus on an intuitive cost-driven state representation learning method for solving Linear Quadratic Gaussian (LQG) control, one of the most fundamental partially observable control problems. As our main results, we establish finite-sample guarantees of finding a near-optimal state representation function and a near-optimal controller using the directly learned latent model, for finite-horizon time-varying LQG control problems. To the best of our knowledge, despite various empirical successes, finite-sample guarantees of such a cost-driven approach remain elusive. Our result underscores the value of predicting multi-step costs, an idea that is key to our theory, and notably also an idea that is known to be empirically valuable for learning state representations. A second part of this work, that is to appear as Part II, addresses the infinite-horizon linear time-invariant setting; it also extends the results to an approach that implicitly learns the latent dynamics, inspired by the recent empirical breakthrough of MuZero in model-based reinforcement learning.

LGJun 1, 2023
Transformers learn to implement preconditioned gradient descent for in-context learning

Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand et al.

Several recent works demonstrate that transformers can implement algorithms like gradient descent. By a careful construction of weights, these works show that multiple layers of transformers are expressive enough to simulate iterations of gradient descent. Going beyond the question of expressivity, we ask: Can transformers learn to implement such algorithms by training over random problem instances? To our knowledge, we make the first theoretical progress on this question via an analysis of the loss landscape for linear transformers trained over random instances of linear regression. For a single attention layer, we prove the global minimum of the training objective implements a single iteration of preconditioned gradient descent. Notably, the preconditioning matrix not only adapts to the input distribution but also to the variance induced by data inadequacy. For a transformer with $L$ attention layers, we prove certain critical points of the training objective implement $L$ iterations of preconditioned gradient descent. Our results call for future theoretical studies on learning algorithms by training transformers.

LGOct 2, 2023
Linear attention is (maybe) all you need (to understand transformer optimization)

Kwangjun Ahn, Xiang Cheng, Minhak Song et al.

Transformer training is notoriously difficult, requiring a careful design of optimizers and use of various heuristics. We make progress towards understanding the subtleties of training Transformers by carefully studying a simple yet canonical linearized shallow Transformer model. Specifically, we train linear Transformers to solve regression tasks, inspired by J.~von Oswald et al.~(ICML 2023), and K.~Ahn et al.~(NeurIPS 2023). Most importantly, we observe that our proposed linearized models can reproduce several prominent aspects of Transformer training dynamics. Consequently, the results obtained in this paper suggest that a simple linearized Transformer model could actually be a valuable, realistic abstraction for understanding Transformer optimization.

OCApr 3, 2022
Understanding the unstable convergence of gradient descent

Kwangjun Ahn, Jingzhao Zhang, Suvrit Sra

Most existing analyses of (stochastic) gradient descent rely on the condition that for $L$-smooth costs, the step size is less than $2/L$. However, many works have observed that in machine learning applications step sizes often do not fulfill this condition, yet (stochastic) gradient descent still converges, albeit in an unstable manner. We investigate this unstable convergence phenomenon from first principles, and discuss key causes behind it. We also identify its main characteristics, and how they interrelate based on both theory and experiments, offering a principled view toward understanding the phenomenon.

NADec 16, 2015
On the matrix square root via geometric optimization

Suvrit Sra

This paper is triggered by the preprint "\emph{Computing Matrix Squareroot via Non Convex Local Search}" by Jain et al. (\textit{\textcolor{blue}{arXiv:1507.05854}}), which analyzes gradient-descent for computing the square root of a positive definite matrix. Contrary to claims of~\citet{jain2015}, our experiments reveal that Newton-like methods compute matrix square roots rapidly and reliably, even for highly ill-conditioned matrices and without requiring commutativity. We observe that gradient-descent converges very slowly primarily due to tiny step-sizes and ill-conditioning. We derive an alternative first-order method based on geodesic convexity: our method admits a transparent convergence analysis ($< 1$ page), attains linear rate, and displays reliable convergence even for rank deficient problems. Though superior to gradient-descent, ultimately our method is also outperformed by a well-known scaled Newton method. Nevertheless, the primary value of our work is its conceptual value: it shows that for deriving gradient based methods for the matrix square root, \emph{the manifold geometric view of positive definite matrices can be much more advantageous than the Euclidean view}.

67.0LGApr 29
A projection-based framework for gradient-free and parallel learning

Andreas Bergmeister, Manish Krishan Lal, Stefanie Jegelka et al.

We present a feasibility-seeking approach to neural network training. This mathematical optimization framework is distinct from conventional gradient-based loss minimization and uses projection operators and iterative projection algorithms. We reformulate training as a large-scale feasibility problem: finding network parameters and states that satisfy local constraints derived from its elementary operations. Training then involves projecting onto these constraints, a local operation that can be parallelized across the network. We introduce PJAX, a JAX-based software framework that enables this paradigm. PJAX composes projection operators for elementary operations, automatically deriving the solution operators for the feasibility problems (akin to autodiff for derivatives). It inherently supports GPU/TPU acceleration, provides a familiar NumPy-like API, and is extensible. We train diverse architectures (MLPs, CNNs, RNNs) on standard benchmarks using PJAX, demonstrating its functionality and generality. Our results show that this approach is a compelling alternative to gradient-based training, with clear advantages in parallelism and the ability to handle non-differentiable operations.

LGFeb 24, 2023
On the Training Instability of Shuffling SGD with Batch Normalization

David X. Wu, Chulhee Yun, Suvrit Sra

We uncover how SGD interacts with batch normalization and can exhibit undesirable training dynamics such as divergence. More precisely, we study how Single Shuffle (SS) and Random Reshuffle (RR) -- two widely used variants of SGD -- interact surprisingly differently in the presence of batch normalization: RR leads to much more stable evolution of training loss than SS. As a concrete example, for regression using a linear network with batch normalization, we prove that SS and RR converge to distinct global optima that are "distorted" away from gradient descent. Thereafter, for classification we characterize conditions under which training divergence for SS and RR can, and cannot occur. We present explicit constructions to show how SS leads to distorted optima in regression and divergence for classification, whereas RR avoids both distortion and divergence. We validate our results by confirming them empirically in realistic settings, and conclude that the separation between SS and RR used with batch normalization is relevant in practice.

OCJul 10, 2023
Invex Programs: First Order Algorithms and Their Convergence

Adarsh Barik, Suvrit Sra, Jean Honorio

Invex programs are a special kind of non-convex problems which attain global minima at every stationary point. While classical first-order gradient descent methods can solve them, they converge very slowly. In this paper, we propose new first-order algorithms to solve the general class of invex problems. We identify sufficient conditions for convergence of our algorithms and provide rates of convergence. Furthermore, we go beyond unconstrained problems and provide a novel projected gradient method for constrained invex programs with convergence rate guarantees. We compare and contrast our results with existing first-order algorithms for a variety of unconstrained and constrained invex problems. To the best of our knowledge, our proposed algorithm is the first algorithm to solve constrained invex programs.

OCJun 22, 2022
On a class of geodesically convex optimization problems solved via Euclidean MM methods

Melanie Weber, Suvrit Sra

We study geodesically convex (g-convex) problems that can be written as a difference of Euclidean convex functions. This structure arises in several optimization problems in statistics and machine learning, e.g., for matrix scaling, M-estimators for covariances, and Brascamp-Lieb inequalities. Our work offers efficient algorithms that on the one hand exploit g-convexity to ensure global optimality along with guarantees on iteration complexity. On the other hand, the split structure permits us to develop Euclidean Majorization-Minorization algorithms that help us bypass the need to compute expensive Riemannian operations such as exponential maps and parallel transport. We illustrate our results by specializing them to a few concrete optimization problems that have been previously studied in the machine learning literature. Ultimately, we hope our work helps motivate the broader search for mixed Euclidean-Riemannian optimization algorithms

LGOct 31, 2025
Cross-fluctuation phase transitions reveal sampling dynamics in diffusion models

Sai Niranjan Ramachandran, Manish Krishan Lal, Suvrit Sra

We analyse how the sampling dynamics of distributions evolve in score-based diffusion models using cross-fluctuations, a centered-moment statistic from statistical physics. Specifically, we show that starting from an unbiased isotropic normal distribution, samples undergo sharp, discrete transitions, eventually forming distinct events of a desired distribution while progressively revealing finer structure. As this process is reversible, these transitions also occur in reverse, where intermediate states progressively merge, tracing a path back to the initial distribution. We demonstrate that these transitions can be detected as discontinuities in $n^{\text{th}}$-order cross-fluctuations. For variance-preserving SDEs, we derive a closed-form for these cross-fluctuations that is efficiently computable for the reverse trajectory. We find that detecting these transitions directly boosts sampling efficiency, accelerates class-conditional and rare-class generation, and improves two zero-shot tasks--image classification and style transfer--without expensive grid search or retraining. We also show that this viewpoint unifies classical coupling and mixing from finite Markov chains with continuous dynamics while extending to stochastic SDEs and non Markovian samplers. Our framework therefore bridges discrete Markov chain theory, phase analysis, and modern generative modeling.

NAApr 29, 2012
Explicit eigenvalues of certain scaled trigonometric matrices

Suvrit Sra

In a very recent paper "\emph{On eigenvalues and equivalent transformation of trigonometric matrices}" (D. Zhang, Z. Lin, and Y. Liu, LAA 436, 71--78 (2012)), the authors motivated and discussed a trigonometric matrix that arises in the design of finite impulse response (FIR) digital filters. The eigenvalues of this matrix shed light on the FIR filter design, so obtaining them in closed form was investigated. Zhang \emph{et al.}\ proved that their matrix had rank-4 and they conjectured closed form expressions for its eigenvalues, leaving a rigorous proof as an open problem. This paper studies trigonometric matrices significantly more general than theirs, deduces their rank, and derives closed-forms for their eigenvalues. As a corollary, it yields a short proof of the conjectures in the aforementioned paper.

LGOct 22, 2024Code
Graph Transformers Dream of Electric Flow

Xiang Cheng, Lawrence Carin, Suvrit Sra

We show theoretically and empirically that the linear Transformer, when applied to graph data, can implement algorithms that solve canonical problems such as electric flow and eigenvector decomposition. The Transformer has access to information on the input graph only via the graph's incidence matrix. We present explicit weight configurations for implementing each algorithm, and we bound the constructed Transformers' errors by the errors of the underlying algorithms. Our theoretical findings are corroborated by experiments on synthetic data. Additionally, on a real-world molecular regression task, we observe that the linear Transformer is capable of learning a more effective positional encoding than the default one based on Laplacian eigenvectors. Our work is an initial step towards elucidating the inner-workings of the Transformer for graph data. Code is available at https://github.com/chengxiang/LinearGraphTransformer

LGFeb 25, 2022Code
Sign and Basis Invariant Networks for Spectral Graph Representation Learning

Derek Lim, Joshua Robinson, Lingxiao Zhao et al.

We introduce SignNet and BasisNet -- new neural architectures that are invariant to two key symmetries displayed by eigenvectors: (i) sign flips, since if $v$ is an eigenvector then so is $-v$; and (ii) more general basis symmetries, which occur in higher dimensional eigenspaces with infinitely many choices of basis eigenvectors. We prove that under certain conditions our networks are universal, i.e., they can approximate any continuous function of eigenvectors with the desired invariances. When used with Laplacian eigenvectors, our networks are provably more expressive than existing spectral methods on graphs; for instance, they subsume all spectral graph convolutions, certain spectral graph invariants, and previously proposed graph positional encodings as special cases. Experiments show that our networks significantly outperform existing baselines on molecular graph regression, learning expressive graph representations, and learning neural fields on triangle meshes. Our code is available at https://github.com/cptq/SignNet-BasisNet .

LGJun 21, 2021Code
Can contrastive learning avoid shortcut solutions?

Joshua Robinson, Li Sun, Ke Yu et al.

The generalization of representations learned via contrastive learning depends crucially on what features of the data are extracted. However, we observe that the contrastive loss does not always sufficiently guide which features are extracted, a behavior that can negatively impact the performance on downstream tasks via "shortcuts", i.e., by inadvertently suppressing important predictive features. We find that feature extraction is influenced by the difficulty of the so-called instance discrimination task (i.e., the task of discriminating pairs of similar points from pairs of dissimilar ones). Although harder pairs improve the representation of some features, the improvement comes at the cost of suppressing previously well represented features. In response, we propose implicit feature modification (IFM), a method for altering positive and negative samples in order to guide contrastive models towards capturing a wider variety of predictive features. Empirically, we observe that IFM reduces feature suppression, and as a result improves performance on vision and medical imaging tasks. The code is available at: \url{https://github.com/joshr17/IFM}.

61.9MLMay 9
Tight Generalization Bounds for Noiseless Inverse Optimization

Pouria Fatemi, Hoomaan Maskan, Suvrit Sra et al.

Inverse optimization (IO) seeks to infer the parameters of a decision-maker's objective from observed context--action data. We study noiseless IO, where demonstrations are generated by a ground-truth objective. We provide a high-probability ${O}(\frac{d}{T})$ generalization bound for the induced action set, where $d$ is the number of unknown parameters and $T$ is the size of the training dataset. We strengthen these guarantees under additional conditions that ensure uniqueness of the chosen action, bringing our IO guarantees in line with best-arm identification results in the bandit literature. We further show that the ${O}(\frac{d}{T})$ rate is tight over all consistent estimators considered here, and extend the result to both instantaneous and cumulative regret. Notably, the resulting regret lower bound matches the corresponding upper bounds in the adversarial setting, indicating that the stochastic IO setting is effectively adversarial for the class of estimators studied here. Finally, we propose a parameter-free algorithm with lower per-iteration complexity than generic solvers. Experiments validate the predicted rates and illustrate the tightness of our bounds.

LGDec 11, 2023
Transformers Implement Functional Gradient Descent to Learn Non-Linear Functions In Context

Xiang Cheng, Yuxin Chen, Suvrit Sra

Many neural network architectures are known to be Turing Complete, and can thus, in principle implement arbitrary algorithms. However, Transformers are unique in that they can implement gradient-based learning algorithms under simple parameter configurations. This paper provides theoretical and empirical evidence that (non-linear) Transformers naturally learn to implement gradient descent in function space, which in turn enable them to learn non-linear functions in context. Our results apply to a broad class of combinations of non-linear architectures and non-linear in-context learning tasks. Additionally, we show that the optimal choice of non-linear activation depends in a natural way on the class of functions that need to be learned.

69.0LGMay 1
Trees to Flows and Back: Unifying Decision Trees and Diffusion Models

Sai Niranjan Ramachandran, Suvrit Sra

Decision trees and diffusion models are ostensibly disparate model classes, one discrete and hierarchical, the other continuous and dynamic. This work unifies the two by establishing a crisp mathematical correspondence between hierarchical decision trees and diffusion processes in appropriate limiting regimes. Our unification reveals a shared optimization principle: \emph{Global Trajectory Score Matching (GTSM)}, for which gradient boosting (in an idealized version) is asymptotically optimal. We underscore the conceptual value of our work through two key practical instantiations: \treeflow, which achieves competitive generation quality on tabular data with higher fidelity and a 2\times computational speedup, and \dsmtree, a novel distillation method that transfers hierarchical decision logic into neural networks, matching teacher performance within 2\% on many benchmarks.

STFeb 15, 2024
Efficient Sampling on Riemannian Manifolds via Langevin MCMC

Xiang Cheng, Jingzhao Zhang, Suvrit Sra

We study the task of efficiently sampling from a Gibbs distribution $d π^* = e^{-h} d {vol}_g$ over a Riemannian manifold $M$ via (geometric) Langevin MCMC; this algorithm involves computing exponential maps in random Gaussian directions and is efficiently implementable in practice. The key to our analysis of Langevin MCMC is a bound on the discretization error of the geometric Euler-Murayama scheme, assuming $\nabla h$ is Lipschitz and $M$ has bounded sectional curvature. Our error bound matches the error of Euclidean Euler-Murayama in terms of its stepsize dependence. Combined with a contraction guarantee for the geometric Langevin Diffusion under Kendall-Cranston coupling, we prove that the Langevin MCMC iterates lie within $ε$-Wasserstein distance of $π^*$ after $\tilde{O}(ε^{-2})$ steps, which matches the iteration complexity for Euclidean Langevin MCMC. Our results apply in general settings where $h$ can be nonconvex and $M$ can have negative Ricci curvature. Under additional assumptions that the Riemannian curvature tensor has bounded derivatives, and that $π^*$ satisfies a $CD(\cdot,\infty)$ condition, we analyze the stochastic gradient version of Langevin MCMC, and bound its iteration complexity by $\tilde{O}(ε^{-2})$ as well.

OCMar 11, 2025
Revisiting Frank-Wolfe for Structured Nonconvex Optimization

Hoomaan Maskan, Yikun Hou, Suvrit Sra et al.

We introduce a new projection-free (Frank-Wolfe) method for optimizing structured nonconvex functions that are expressed as a difference of two convex functions. This problem class subsumes smooth nonconvex minimization, positioning our method as a promising alternative to the classical Frank-Wolfe algorithm. DC decompositions are not unique; by carefully selecting a decomposition, we can better exploit the problem structure, improve computational efficiency, and adapt to the underlying problem geometry to find better local solutions. We prove that the proposed method achieves a first-order stationary point in $O(1/ε^2)$ iterations, matching the complexity of the standard Frank-Wolfe algorithm for smooth nonconvex minimization in general. Specific decompositions can, for instance, yield a gradient-efficient variant that requires only $O(1/ε)$ calls to the gradient oracle. Finally, we present numerical experiments demonstrating the effectiveness of the proposed method compared to the standard Frank-Wolfe algorithm.

LGMar 8
Cost-Driven Representation Learning for Linear Quadratic Gaussian Control: Part II

Yi Tian, Kaiqing Zhang, Russ Tedrake et al.

We study the problem of state representation learning for control from partial and potentially high-dimensional observations. We approach this problem via cost-driven state representation learning, in which we learn a dynamical model in a latent state space by predicting cumulative costs. In particular, we establish finite-sample guarantees on finding a near-optimal representation function and a near-optimal controller using the learned latent model for infinite-horizon time-invariant Linear Quadratic Gaussian (LQG) control. We study two approaches to cost-driven representation learning, which differ in whether the transition function of the latent state is learned explicitly or implicitly. The first approach has also been investigated in Part I of this work, for finite-horizon time-varying LQG control. The second approach closely resembles MuZero, a recent breakthrough in empirical reinforcement learning, in that it learns latent dynamics implicitly by predicting cumulative costs. A key technical contribution of this Part II is to prove persistency of excitation for a new stochastic process that arises from the analysis of quadratic regression in our approach, and may be of independent interest.

OCJul 25, 2025
Linearly Convergent Algorithms for Nonsmooth Problems with Unknown Smooth Pieces

Zhe Zhang, Suvrit Sra

We develop efficient algorithms for optimizing piecewise smooth (PWS) functions where the underlying partition of the domain into smooth pieces is \emph{unknown}. For PWS functions satisfying a quadratic growth (QG) condition, we propose a bundle-level (BL) type method that achieves global linear convergence -- to our knowledge, the first such result for any algorithm for this problem class. We extend this method to handle approximately PWS functions and to solve weakly-convex PWS problems, improving the state-of-the-art complexity to match the benchmark for smooth non-convex optimization. Furthermore, we introduce the first verifiable and accurate termination criterion for PWS optimization. Similar to the gradient norm in smooth optimization, this certificate tightly characterizes the optimality gap under the QG condition, and can moreover be evaluated without knowledge of any problem parameters. We develop a search subroutine for this certificate and embed it within a guess-and-check framework, resulting in an almost parameter-free algorithm for both the convex QG and weakly-convex settings.

LGJan 27, 2025
Implicit Bias in Matrix Factorization and its Explicit Realization in a New Architecture

Yikun Hou, Suvrit Sra, Alp Yurtsever

Gradient descent for matrix factorization exhibits an implicit bias toward approximately low-rank solutions. While existing theories often assume the boundedness of iterates, empirically the bias persists even with unbounded sequences. This reflects a dynamic where factors develop low-rank structure while their magnitudes increase, tending to align with certain directions. To capture this behavior in a stable way, we introduce a new factorization model: $X\approx UDV^\top$, where $U$ and $V$ are constrained within norm balls, while $D$ is a diagonal factor allowing the model to span the entire search space. Experiments show that this model consistently exhibits a strong implicit bias, yielding truly (rather than approximately) low-rank solutions. Extending the idea to neural networks, we introduce a new model featuring constrained layers and diagonal components that achieves competitive performance on various regression and classification tasks while producing lightweight, low-rank representations.

OCJun 18, 2024
First-Order Methods for Linearly Constrained Bilevel Optimization

Guy Kornowski, Swati Padmanabhan, Kai Wang et al.

Algorithms for bilevel optimization often encounter Hessian computations, which are prohibitive in high dimensions. While recent works offer first-order methods for unconstrained bilevel problems, the constrained setting remains relatively underexplored. We present first-order linearly constrained optimization methods with finite-time hypergradient stationarity guarantees. For linear equality constraints, we attain $ε$-stationarity in $\widetilde{O}(ε^{-2})$ gradient oracle calls, which is nearly-optimal. For linear inequality constraints, we attain $(δ,ε)$-Goldstein stationarity in $\widetilde{O}(d{δ^{-1} ε^{-3}})$ gradient oracle calls, where $d$ is the upper-level dimension. Finally, we obtain for the linear inequality setting dimension-free rates of $\widetilde{O}({δ^{-1} ε^{-4}})$ oracle complexity under the additional assumption of oracle access to the optimal dual variable. Along the way, we develop new nonsmooth nonconvex optimization methods with inexact oracles. We verify these guarantees with preliminary numerical experiments.

OCMay 22, 2024
Riemannian Bilevel Optimization

Sanchayan Dutta, Xiang Cheng, Suvrit Sra

We develop new algorithms for Riemannian bilevel optimization. We focus in particular on batch and stochastic gradient-based methods, with the explicit goal of avoiding second-order information such as Riemannian hyper-gradients. We propose and analyze $\mathrm{RF^2SA}$, a method that leverages first-order gradient information to navigate the complex geometry of Riemannian manifolds efficiently. Notably, $\mathrm{RF^2SA}$ is a single-loop algorithm, and thus easier to implement and use. Under various setups, including stochastic optimization, we provide explicit convergence rates for reaching $ε$-stationary points. We also address the challenge of optimizing over Riemannian manifolds with constraints by adjusting the multiplier in the Lagrangian, ensuring convergence to the desired solution without requiring access to second-order derivatives.

LGMay 25, 2023
How to escape sharp minima with random perturbations

Kwangjun Ahn, Ali Jadbabaie, Suvrit Sra

Modern machine learning applications have witnessed the remarkable success of optimization algorithms that are designed to find flat minima. Motivated by this design choice, we undertake a formal study that (i) formulates the notion of flat minima, and (ii) studies the complexity of finding them. Specifically, we adopt the trace of the Hessian of the cost function as a measure of flatness, and use it to formally define the notion of approximate flat minima. Under this notion, we then analyze algorithms that find approximate flat minima efficiently. For general cost functions, we discuss a gradient-based algorithm that finds an approximate flat local minimum efficiently. The main component of the algorithm is to use gradients computed from randomly perturbed iterates to estimate a direction that leads to flatter minima. For the setting where the cost function is an empirical risk over training data, we present a faster algorithm that is inspired by a recently proposed practical algorithm called sharpness-aware minimization, supporting its success in practice.

LGMay 24, 2023
The Crucial Role of Normalization in Sharpness-Aware Minimization

Yan Dai, Kwangjun Ahn, Suvrit Sra

Sharpness-Aware Minimization (SAM) is a recently proposed gradient-based optimizer (Foret et al., ICLR 2021) that greatly improves the prediction performance of deep neural networks. Consequently, there has been a surge of interest in explaining its empirical success. We focus, in particular, on understanding the role played by normalization, a key component of the SAM updates. We theoretically and empirically study the effect of normalization in SAM for both convex and non-convex functions, revealing two key roles played by normalization: i) it helps in stabilizing the algorithm; and ii) it enables the algorithm to drift along a continuum (manifold) of minima -- a property identified by recent theoretical works that is the key to better performance. We further argue that these two properties of normalization make SAM robust against the choice of hyper-parameters, supporting the practicality of SAM. Our conclusions are backed by various experiments.

OCFeb 13, 2022
Sion's Minimax Theorem in Geodesic Metric Spaces and a Riemannian Extragradient Algorithm

Peiyuan Zhang, Jingzhao Zhang, Suvrit Sra

Deciding whether saddle points exist or are approximable for nonconvex-nonconcave problems is usually intractable. This paper takes a step towards understanding a broad class of nonconvex-nonconcave minimax problems that do remain tractable. Specifically, it studies minimax problems over geodesic metric spaces, which provide a vast generalization of the usual convex-concave saddle point problems. The first main result of the paper is a geodesic metric space version of Sion's minimax theorem; we believe our proof is novel and broadly accessible as it relies on the finite intersection property alone. The second main result is a specialization to geodesically complete Riemannian manifolds: here, we devise and analyze the complexity of first-order methods for smooth minimax problems.

STDec 29, 2021
Time varying regression with hidden linear dynamics

Ali Jadbabaie, Horia Mania, Devavrat Shah et al.

We revisit a model for time-varying linear regression that assumes the unknown parameters evolve according to a linear dynamical system. Counterintuitively, we show that when the underlying dynamics are stable the parameters of this model can be estimated from data by combining just two ordinary least squares estimates. We offer a finite sample guarantee on the estimation error of our method and discuss certain advantages it has over Expectation-Maximization (EM), which is the main approach proposed by prior work.

LGDec 21, 2021
Max-Margin Contrastive Learning

Anshul Shah, Suvrit Sra, Rama Chellappa et al.

Standard contrastive learning approaches usually require a large number of negatives for effective unsupervised learning and often exhibit slow convergence. We suspect this behavior is due to the suboptimal selection of negatives used for offering contrast to the positives. We counter this difficulty by taking inspiration from support vector machines (SVMs) to present max-margin contrastive learning (MMCL). Our approach selects negatives as the sparse support vectors obtained via a quadratic optimization problem, and contrastiveness is enforced by maximizing the decision margin. As SVM optimization can be computationally demanding, especially in an end-to-end setting, we present simplifications that alleviate the computational burden. We validate our approach on standard vision benchmark datasets, demonstrating better performance in unsupervised representation learning over state-of-the-art, while having better empirical convergence properties.

OCNov 4, 2021
Understanding Riemannian Acceleration via a Proximal Extragradient Framework

Jikai Jin, Suvrit Sra

We contribute to advancing the understanding of Riemannian accelerated gradient methods. In particular, we revisit Accelerated Hybrid Proximal Extragradient(A-HPE), a powerful framework for obtaining Euclidean accelerated methods \citep{monteiro2013accelerated}. Building on A-HPE, we then propose and analyze Riemannian A-HPE. The core of our analysis consists of two key components: (i) a set of new insights into Euclidean A-HPE itself; and (ii) a careful control of metric distortion caused by Riemannian geometry. We illustrate our framework by obtaining a few existing and new Riemannian accelerated gradient methods as special cases, while characterizing their acceleration as corollaries of our main results.

LGOct 20, 2021
Minibatch vs Local SGD with Shuffling: Tight Convergence Bounds and Beyond

Chulhee Yun, Shashank Rajput, Suvrit Sra

In distributed learning, local SGD (also known as federated averaging) and its simple baseline minibatch SGD are widely studied optimization methods. Most existing analyses of these methods assume independent and unbiased gradient estimates obtained via with-replacement sampling. In contrast, we study shuffling-based variants: minibatch and local Random Reshuffling, which draw stochastic gradients without replacement and are thus closer to practice. For smooth functions satisfying the Polyak-Łojasiewicz condition, we obtain convergence bounds (in the large epoch regime) which show that these shuffling-based variants converge faster than their with-replacement counterparts. Moreover, we prove matching lower bounds showing that our convergence analysis is tight. Finally, we propose an algorithmic modification called synchronized shuffling that leads to convergence rates faster than our lower bounds in near-homogeneous settings.

LGOct 12, 2021
Neural Network Weights Do Not Converge to Stationary Points: An Invariant Measure Perspective

Jingzhao Zhang, Haochuan Li, Suvrit Sra et al.

This work examines the deep disconnect between existing theoretical analyses of gradient-based algorithms and the practice of training deep neural networks. Specifically, we provide numerical evidence that in large-scale neural network training (e.g., ImageNet + ResNet101, and WT103 + TransformerXL models), the neural network's weights do not converge to stationary points where the gradient of the loss is zero. Remarkably, however, we observe that even though the weights do not converge to stationary points, the progress in minimizing the loss function halts and training loss stabilizes. Inspired by this observation, we propose a new perspective based on ergodic theory of dynamical systems to explain it. Rather than studying the evolution of weights, we study the evolution of the distribution of weights. We prove convergence of the distribution of weights to an approximate invariant measure, thereby explaining how the training loss can stabilize without weights necessarily converging to stationary points. We further discuss how this perspective can better align optimization theory with empirical observations in machine learning practice.

LGMar 12, 2021
Can Single-Shuffle SGD be Better than Reshuffling SGD and GD?

Chulhee Yun, Suvrit Sra, Ali Jadbabaie

We propose matrix norm inequalities that extend the Recht-Ré (2012) conjecture on a noncommutative AM-GM inequality by supplementing it with another inequality that accounts for single-shuffle, which is a widely used without-replacement sampling scheme that shuffles only once in the beginning and is overlooked in the Recht-Ré conjecture. Instead of general positive semidefinite matrices, we restrict our attention to positive definite matrices with small enough condition numbers, which are more relevant to matrices that arise in the analysis of SGD. For such matrices, we conjecture that the means of matrix products corresponding to with- and without-replacement variants of SGD satisfy a series of spectral norm inequalities that can be summarized as: "single-shuffle SGD converges faster than random-reshuffle SGD, which is in turn faster than with-replacement SGD." We present theorems that support our conjecture by proving several special cases.

LGFeb 5, 2021
Provably Efficient Algorithms for Multi-Objective Competitive RL

Tiancheng Yu, Yi Tian, Jingzhao Zhang et al.

We study multi-objective reinforcement learning (RL) where an agent's reward is represented as a vector. In settings where an agent competes against opponents, its performance is measured by the distance of its average return vector to a target set. We develop statistically and computationally efficient algorithms to approach the associated target set. Our results extend Blackwell's approachability theorem (Blackwell, 1956) to tabular RL, where strategic exploration becomes essential. The algorithms presented are adaptive; their guarantees hold even without Blackwell's approachability condition. If the opponents use fixed policies, we give an improved rate of approaching the target set while also tackling the more ambitious goal of simultaneously minimizing a scalar cost function. We discuss our analysis for this special case by relating our results to previous works on constrained RL. To our knowledge, this work provides the first provably efficient algorithms for vector-valued Markov games and our theoretical guarantees are near-optimal.

LGDec 31, 2020
Why do classifier accuracies show linear trends under distribution shift?

Horia Mania, Suvrit Sra

Recent studies of generalization in deep learning have observed a puzzling trend: accuracies of models on one data distribution are approximately linear functions of the accuracies on another distribution. We explain this trend under an intuitive assumption on model similarity, which was verified empirically in prior work. More precisely, we assume the probability that two models agree in their predictions is higher than what we can infer from their accuracy levels alone. Then, we show that a linear trend must occur when evaluating models on two distributions unless the size of the distribution shift is large. This work emphasizes the value of understanding model similarity, which can have an impact on the generalization and robustness of classification models.

LGOct 28, 2020
Online Learning in Unknown Markov Games

Yi Tian, Yuanhao Wang, Tiancheng Yu et al.

We study online learning in unknown Markov games, a problem that arises in episodic multi-agent reinforcement learning where the actions of the opponents are unobservable. We show that in this challenging setting, achieving sublinear regret against the best response in hindsight is statistically hard. We then consider a weaker notion of regret by competing with the \emph{minimax value} of the game, and present an algorithm that achieves a sublinear $\tilde{\mathcal{O}}(K^{2/3})$ regret after $K$ episodes. This is the first sublinear regret bound (to our knowledge) for online learning in unknown Markov games. Importantly, our regret bound is independent of the size of the opponents' action spaces. As a result, even when the opponents' actions are fully observable, our regret bound improves upon existing analysis (e.g., (Xie et al., 2020)) by an exponential factor in the number of opponents.

LGOct 23, 2020
Coping with Label Shift via Distributionally Robust Optimisation

Jingzhao Zhang, Aditya Menon, Andreas Veit et al.

The label shift problem refers to the supervised learning setting where the train and test label distributions do not match. Existing work addressing label shift usually assumes access to an \emph{unlabelled} test sample. This sample may be used to estimate the test label distribution, and to then train a suitably re-weighted classifier. While approaches using this idea have proven effective, their scope is limited as it is not always feasible to access the target domain; further, they require repeated retraining if the model is to be deployed in \emph{multiple} test environments. Can one instead learn a \emph{single} classifier that is robust to arbitrary label shifts from a broad family? In this paper, we answer this question by proposing a model that minimises an objective based on distributionally robust optimisation (DRO). We then design and analyse a gradient descent-proximal mirror ascent algorithm tailored for large-scale problems to optimise the proposed objective. %, and establish its convergence. Finally, through experiments on CIFAR-100 and ImageNet, we show that our technique can significantly improve performance over a number of baselines in settings where label shift is present.

LGOct 9, 2020
Contrastive Learning with Hard Negative Samples

Joshua Robinson, Ching-Yao Chuang, Suvrit Sra et al.

How can you sample good negative examples for contrastive learning? We argue that, as with metric learning, contrastive learning of representations benefits from hard negative samples (i.e., points that are difficult to distinguish from an anchor point). The key challenge toward using hard negatives is that contrastive methods must remain unsupervised, making it infeasible to adopt existing negative sampling strategies that use true similarity information. In response, we develop a new family of unsupervised sampling methods for selecting hard negative samples where the user can control the hardness. A limiting case of this sampling results in a representation that tightly clusters each class, and pushes different classes as far apart as possible. The proposed method improves downstream performance across multiple modalities, requires only few additional lines of code to implement, and introduces no computational overhead.

LGJun 24, 2020
Towards Minimax Optimal Reinforcement Learning in Factored Markov Decision Processes

Yi Tian, Jian Qian, Suvrit Sra

We study minimax optimal reinforcement learning in episodic factored Markov decision processes (FMDPs), which are MDPs with conditionally independent transition components. Assuming the factorization is known, we propose two model-based algorithms. The first one achieves minimax optimal regret guarantees for a rich class of factored structures, while the second one enjoys better computational complexity with a slightly worse regret. A key new ingredient of our algorithms is the design of a bonus term to guide exploration. We complement our algorithms by presenting several structure-dependent lower bounds on regret for FMDPs that reveal the difficulty hiding in the intricacy of the structures.

OCJun 12, 2020
SGD with shuffling: optimal rates without component convexity and large epoch requirements

Kwangjun Ahn, Chulhee Yun, Suvrit Sra

We study without-replacement SGD for solving finite-sum optimization problems. Specifically, depending on how the indices of the finite-sum are shuffled, we consider the RandomShuffle (shuffle at the beginning of each epoch) and SingleShuffle (shuffle only once) algorithms. First, we establish minimax optimal convergence rates of these algorithms up to poly-log factors. Notably, our analysis is general enough to cover gradient dominated nonconvex costs, and does not rely on the convexity of individual component functions unlike existing optimal convergence results. Secondly, assuming convexity of the individual components, we further sharpen the tight convergence results for RandomShuffle by removing the drawbacks common to all prior arts: large number of epochs required for the results to hold, and extra poly-log factor gaps to the lower bound.

OCJun 8, 2020
Beyond Worst-Case Analysis in Stochastic Approximation: Moment Estimation Improves Instance Complexity

Jingzhao Zhang, Hongzhou Lin, Subhro Das et al.

We study oracle complexity of gradient based methods for stochastic approximation problems. Though in many settings optimal algorithms and tight lower bounds are known for such problems, these optimal algorithms do not achieve the best performance when used in practice. We address this theory-practice gap by focusing on instance-dependent complexity instead of worst case complexity. In particular, we first summarize known instance-dependent complexity results and categorize them into three levels. We identify the domination relation between different levels and propose a fourth instance-dependent bound that dominates existing ones. We then provide a sufficient condition according to which an adaptive algorithm with moment estimation can achieve the proposed bound without knowledge of noise levels. Our proposed algorithm and its analysis provide a theoretical justification for the success of moment estimation as it achieves improved instance complexity.

OCMay 17, 2020
Understanding Nesterov's Acceleration via Proximal Point Method

Kwangjun Ahn, Suvrit Sra

The proximal point method (PPM) is a fundamental method in optimization that is often used as a building block for designing optimization algorithms. In this work, we use the PPM method to provide conceptually simple derivations along with convergence analyses of different versions of Nesterov's accelerated gradient method (AGM). The key observation is that AGM is a simple approximation of PPM, which results in an elementary derivation of the update equations and stepsizes of AGM. This view also leads to a transparent and conceptually simple analysis of AGM's convergence by using the analysis of PPM. The derivations also naturally extend to the strongly convex case. Ultimately, the results presented in this paper are of both didactic and conceptual value; they unify and explain existing variants of AGM while motivating other accelerated methods for practically relevant settings.

OCApr 18, 2020
On Tight Convergence Rates of Without-replacement SGD

Kwangjun Ahn, Suvrit Sra

For solving finite-sum optimization problems, SGD without replacement sampling is empirically shown to outperform SGD. Denoting by $n$ the number of components in the cost and $K$ the number of epochs of the algorithm , several recent works have shown convergence rates of without-replacement SGD that have better dependency on $n$ and $K$ than the baseline rate of $O(1/(nK))$ for SGD. However, there are two main limitations shared among those works: the rates have extra poly-logarithmic factors on $nK$, and denoting by $κ$ the condition number of the problem, the rates hold after $κ^c\log(nK)$ epochs for some $c>0$. In this work, we overcome these limitations by analyzing step sizes that vary across epochs.

LGFeb 19, 2020
Strength from Weakness: Fast Learning Using Weak Supervision

Joshua Robinson, Stefanie Jegelka, Suvrit Sra

We study generalization properties of weakly supervised learning. That is, learning where only a few "strong" labels (the actual target of our prediction) are present but many more "weak" labels are available. In particular, we show that having access to weak labels can significantly accelerate the learning rate for the strong task to the fast rate of $\mathcal{O}(\nicefrac1n)$, where $n$ denotes the number of strongly labeled data points. This acceleration can happen even if by itself the strongly labeled data admits only the slower $\mathcal{O}(\nicefrac{1}{\sqrt{n}})$ rate. The actual acceleration depends continuously on the number of weak labels available, and on the relation between the two tasks. Our theoretical results are reflected empirically across a range of tasks and illustrate how weak labels speed up learning on the strong task.

OCFeb 10, 2020
Complexity of Finding Stationary Points of Nonsmooth Nonconvex Functions

Jingzhao Zhang, Hongzhou Lin, Stefanie Jegelka et al.

We provide the first non-asymptotic analysis for finding stationary points of nonsmooth, nonconvex functions. In particular, we study the class of Hadamard semi-differentiable functions, perhaps the largest class of nonsmooth functions for which the chain rule of calculus holds. This class contains examples such as ReLU neural networks and others with non-differentiable activation functions. We first show that finding an $ε$-stationary point with first-order methods is impossible in finite time. We then introduce the notion of $(δ, ε)$-stationarity, which allows for an $ε$-approximate gradient to be the convex combination of generalized gradients evaluated at points within distance $δ$ to the solution. We propose a series of randomized first-order methods and analyze their complexity of finding a $(δ, ε)$-stationary point. Furthermore, we provide a lower bound and show that our stochastic algorithm has min-max optimal dependence on $δ$. Empirically, our methods perform well for training ReLU neural networks.

OCJan 24, 2020
From Nesterov's Estimate Sequence to Riemannian Acceleration

Kwangjun Ahn, Suvrit Sra

We propose the first global accelerated gradient method for Riemannian manifolds. Toward establishing our result we revisit Nesterov's estimate sequence technique and develop an alternative analysis for it that may also be of independent interest. Then, we extend this analysis to the Riemannian setting, localizing the key difficulty due to non-Euclidean structure into a certain ``metric distortion.'' We control this distortion by developing a novel geometric inequality, which permits us to propose and analyze a Riemannian counterpart to Nesterov's accelerated gradient method.

OCDec 6, 2019
Why are Adaptive Methods Good for Attention Models?

Jingzhao Zhang, Sai Praneeth Karimireddy, Andreas Veit et al.

While stochastic gradient descent (SGD) is still the \emph{de facto} algorithm in deep learning, adaptive methods like Clipped SGD/Adam have been observed to outperform SGD across important tasks, such as attention models. The settings under which SGD performs poorly in comparison to adaptive methods are not well understood yet. In this paper, we provide empirical and theoretical evidence that a heavy-tailed distribution of the noise in stochastic gradients is one cause of SGD's poor performance. We provide the first tight upper and lower convergence bounds for adaptive gradient methods under heavy-tailed noise. Further, we demonstrate how gradient clipping plays a key role in addressing heavy-tailed gradient noise. Subsequently, we show how clipping can be applied in practice by developing an \emph{adaptive} coordinate-wise clipping algorithm (ACClip) and demonstrate its superior performance on BERT pretraining and finetuning tasks.

LGDec 3, 2019
Learning Adversarial MDPs with Bandit Feedback and Unknown Transition

Chi Jin, Tiancheng Jin, Haipeng Luo et al.

We consider the problem of learning in episodic finite-horizon Markov decision processes with an unknown transition function, bandit feedback, and adversarial losses. We propose an efficient algorithm that achieves $\mathcal{\tilde{O}}(L|X|\sqrt{|A|T})$ regret with high probability, where $L$ is the horizon, $|X|$ is the number of states, $|A|$ is the number of actions, and $T$ is the number of episodes. To the best of our knowledge, our algorithm is the first to ensure $\mathcal{\tilde{O}}(\sqrt{T})$ regret in this challenging setting; in fact it achieves the same regret bound as (Rosenberg & Mansour, 2019a) that considers an easier setting with full-information feedback. Our key technical contributions are two-fold: a tighter confidence set for the transition function, and an optimistic loss estimator that is inversely weighted by an $\textit{upper occupancy bound}$.

OCOct 9, 2019
Projection-free nonconvex stochastic optimization on Riemannian manifolds

Melanie Weber, Suvrit Sra

We study stochastic projection-free methods for constrained optimization of smooth functions on Riemannian manifolds, i.e., with additional constraints beyond the parameter domain being a manifold. Specifically, we introduce stochastic Riemannian Frank-Wolfe methods for nonconvex and geodesically convex problems. We present algorithms for both purely stochastic optimization and finite-sum problems. For the latter, we develop variance-reduced methods, including a Riemannian adaptation of the recently proposed Spider technique. For all settings, we recover convergence rates that are comparable to the best-known rates for their Euclidean counterparts. Finally, we discuss applications to two classic tasks: The computation of the Karcher mean of positive definite matrices and Wasserstein barycenters for multivariate normal distributions. For both tasks, stochastic Fw methods yield state-of-the-art empirical performance.

LGJul 22, 2019
Efficient Policy Learning for Non-Stationary MDPs under Adversarial Manipulation

Tiancheng Yu, Suvrit Sra

A Markov Decision Process (MDP) is a popular model for reinforcement learning. However, its commonly used assumption of stationary dynamics and rewards is too stringent and fails to hold in adversarial, nonstationary, or multi-agent problems. We study an episodic setting where the parameters of an MDP can differ across episodes. We learn a reliable policy of this potentially adversarial MDP by developing an Adversarial Reinforcement Learning (ARL) algorithm that reduces our MDP to a sequence of \emph{adversarial} bandit problems. ARL achieves $O(\sqrt{SATH^3})$ regret, which is optimal with respect to $S$, $A$, and $T$, and its dependence on $H$ is the best (even for the usual stationary MDP) among existing model-free methods.