h-index66
92papers
4,010citations
Novelty57%
AI Score61

92 Papers

LGMay 4, 2022Code
FedNest: Federated Bilevel, Minimax, and Compositional Optimization

Davoud Ataee Tarzanagh, Mingchen Li, Christos Thrampoulidis et al.

Standard federated optimization methods successfully apply to stochastic problems with single-level structure. However, many contemporary ML problems -- including adversarial robustness, hyperparameter tuning, and actor-critic -- fall under nested bilevel programming that subsumes minimax and compositional optimization. In this work, we propose \fedblo: A federated alternating stochastic gradient method to address general nested problems. We establish provable convergence rates for \fedblo in the presence of heterogeneous data and introduce variations for bilevel, minimax, and compositional optimization. \fedblo introduces multiple innovations including federated hypergradient computation and variance reduction to address inner-level heterogeneity. We complement our theory with experiments on hyperparameter \& hyper-representation learning and minimax optimization that demonstrate the benefits of our method in practice. Code is available at https://github.com/ucr-optml/FedNest.

LGMar 3, 2022Code
Provable and Efficient Continual Representation Learning

Yingcong Li, Mingchen Li, M. Salman Asif et al.

In continual learning (CL), the goal is to design models that can learn a sequence of tasks without catastrophic forgetting. While there is a rich set of techniques for CL, relatively little understanding exists on how representations built by previous tasks benefit new tasks that are added to the network. To address this, we study the problem of continual representation learning (CRL) where we learn an evolving representation as new tasks arrive. Focusing on zero-forgetting methods where tasks are embedded in subnetworks (e.g., PackNet), we first provide experiments demonstrating CRL can significantly boost sample efficiency when learning new tasks. To explain this, we establish theoretical guarantees for CRL by providing sample complexity and generalization error bounds for new tasks by formalizing the statistical benefits of previously-learned representations. Our analysis and experiments also highlight the importance of the order in which we learn the tasks. Specifically, we show that CL benefits if the initial tasks have large sample size and high "representation diversity". Diversity ensures that adding new tasks incurs small representation mismatch and can be learned with few samples while training only few additional nonzero weights. Finally, we ask whether one can ensure each task subnetwork to be efficient during inference time while retaining the benefits of representation learning. To this end, we propose an inference-efficient variation of PackNet called Efficient Sparse PackNet (ESPN) which employs joint channel & weight pruning. ESPN embeds tasks in channel-sparse subnets requiring up to 80% less FLOPs to compute while approximately retaining accuracy and is very competitive with a variety of baselines. In summary, this work takes a step towards data and compute-efficient CL with a representation learning perspective. GitHub page: https://github.com/ucr-optml/CtRL

OCNov 9, 2011
A Simplified Approach to Recovery Conditions for Low Rank Matrices

Samet Oymak, Karthik Mohan, Maryam Fazel et al.

Recovering sparse vectors and low-rank matrices from noisy linear measurements has been the focus of much recent research. Various reconstruction algorithms have been studied, including $\ell_1$ and nuclear norm minimization as well as $\ell_p$ minimization with $p<1$. These algorithms are known to succeed if certain conditions on the measurement map are satisfied. Proofs of robust recovery for matrices have so far been much more involved than in the vector case. In this paper, we show how several robust classes of recovery conditions can be extended from vectors to matrices in a simple and transparent way, leading to the best known restricted isometry and nullspace conditions for matrix recovery. Our results rely on the ability to "vectorize" matrices through the use of a key singular value inequality.

LGJan 17, 2023
Transformers as Algorithms: Generalization and Stability in In-context Learning

Yingcong Li, M. Emrullah Ildiz, Dimitris Papailiopoulos et al.

In-context learning (ICL) is a type of prompting where a transformer model operates on a sequence of (input, output) examples and performs inference on-the-fly. In this work, we formalize in-context learning as an algorithm learning problem where a transformer model implicitly constructs a hypothesis function at inference-time. We first explore the statistical aspects of this abstraction through the lens of multitask learning: We obtain generalization bounds for ICL when the input prompt is (1) a sequence of i.i.d. (input, label) pairs or (2) a trajectory arising from a dynamical system. The crux of our analysis is relating the excess risk to the stability of the algorithm implemented by the transformer. We characterize when transformer/attention architecture provably obeys the stability condition and also provide empirical verification. For generalization on unseen tasks, we identify an inductive bias phenomenon in which the transfer learning risk is governed by the task complexity and the number of MTL tasks in a highly predictable manner. Finally, we provide numerical evaluations that (1) demonstrate transformers can indeed implement near-optimal algorithms on classical regression problems with i.i.d. and dynamic data, (2) provide insights on stability, and (3) verify our theoretical predictions.

LGJun 6, 2023
On the Role of Attention in Prompt-tuning

Samet Oymak, Ankit Singh Rawat, Mahdi Soltanolkotabi et al.

Prompt-tuning is an emerging strategy to adapt large language models (LLM) to downstream tasks by learning a (soft-)prompt parameter from data. Despite its success in LLMs, there is limited theoretical understanding of the power of prompt-tuning and the role of the attention mechanism in prompting. In this work, we explore prompt-tuning for one-layer attention architectures and study contextual mixture-models where each input token belongs to a context-relevant or -irrelevant set. We isolate the role of prompt-tuning through a self-contained prompt-attention model. Our contributions are as follows: (1) We show that softmax-prompt-attention is provably more expressive than softmax-self-attention and linear-prompt-attention under our contextual data model. (2) We analyze the initial trajectory of gradient descent and show that it learns the prompt and prediction head with near-optimal sample complexity and demonstrate how prompt can provably attend to sparse context-relevant tokens. (3) Assuming a known prompt but an unknown prediction head, we characterize the exact finite sample performance of prompt-attention which reveals the fundamental performance limits and the precise benefit of the context information. We also provide experiments that verify our theoretical insights on real datasets and demonstrate how prompt-tuning enables the model to attend to context-relevant information.

LGAug 31, 2023
Transformers as Support Vector Machines

Davoud Ataee Tarzanagh, Yingcong Li, Christos Thrampoulidis et al.

Since its inception in "Attention Is All You Need", transformer architecture has led to revolutionary advancements in NLP. The attention layer within the transformer admits a sequence of input tokens $X$ and makes them interact through pairwise similarities computed as softmax$(XQK^\top X^\top)$, where $(K,Q)$ are the trainable key-query parameters. In this work, we establish a formal equivalence between the optimization geometry of self-attention and a hard-margin SVM problem that separates optimal input tokens from non-optimal tokens using linear constraints on the outer-products of token pairs. This formalism allows us to characterize the implicit bias of 1-layer transformers optimized with gradient descent: (1) Optimizing the attention layer with vanishing regularization, parameterized by $(K,Q)$, converges in direction to an SVM solution minimizing the nuclear norm of the combined parameter $W=KQ^\top$. Instead, directly parameterizing by $W$ minimizes a Frobenius norm objective. We characterize this convergence, highlighting that it can occur toward locally-optimal directions rather than global ones. (2) Complementing this, we prove the local/global directional convergence of gradient descent under suitable geometric conditions. Importantly, we show that over-parameterization catalyzes global convergence by ensuring the feasibility of the SVM problem and by guaranteeing a benign optimization landscape devoid of stationary points. (3) While our theory applies primarily to linear prediction heads, we propose a more general SVM equivalence that predicts the implicit bias with nonlinear heads. Our findings are applicable to arbitrary datasets and their validity is verified via experiments. We also introduce several open problems and research directions. We believe these findings inspire the interpretation of transformers as a hierarchy of SVMs that separates and selects optimal tokens.

LGJun 23, 2023
Max-Margin Token Selection in Attention Mechanism

Davoud Ataee Tarzanagh, Yingcong Li, Xuechen Zhang et al.

Attention mechanism is a central component of the transformer architecture which led to the phenomenal success of large language models. However, the theoretical principles underlying the attention mechanism are poorly understood, especially its nonconvex optimization dynamics. In this work, we explore the seminal softmax-attention model $f(\boldsymbol{X})=\langle \boldsymbol{Xv}, \texttt{softmax}(\boldsymbol{XWp})\rangle$, where $\boldsymbol{X}$ is the token sequence and $(\boldsymbol{v},\boldsymbol{W},\boldsymbol{p})$ are trainable parameters. We prove that running gradient descent on $\boldsymbol{p}$, or equivalently $\boldsymbol{W}$, converges in direction to a max-margin solution that separates $\textit{locally-optimal}$ tokens from non-optimal ones. This clearly formalizes attention as an optimal token selection mechanism. Remarkably, our results are applicable to general data and precisely characterize $\textit{optimality}$ of tokens in terms of the value embeddings $\boldsymbol{Xv}$ and problem geometry. We also provide a broader regularization path analysis that establishes the margin maximizing nature of attention even for nonlinear prediction heads. When optimizing $\boldsymbol{v}$ and $\boldsymbol{p}$ simultaneously with logistic loss, we identify conditions under which the regularization paths directionally converge to their respective hard-margin SVM solutions where $\boldsymbol{v}$ separates the input features based on their labels. Interestingly, the SVM formulation of $\boldsymbol{p}$ is influenced by the support vector geometry of $\boldsymbol{v}$. Finally, we verify our theoretical findings via numerical experiments and provide insights.

LGNov 8, 2023
Effective Restoration of Source Knowledge in Continual Test Time Adaptation

Fahim Faisal Niloy, Sk Miraj Ahmed, Dripta S. Raychaudhuri et al.

Traditional test-time adaptation (TTA) methods face significant challenges in adapting to dynamic environments characterized by continuously changing long-term target distributions. These challenges primarily stem from two factors: catastrophic forgetting of previously learned valuable source knowledge and gradual error accumulation caused by miscalibrated pseudo labels. To address these issues, this paper introduces an unsupervised domain change detection method that is capable of identifying domain shifts in dynamic environments and subsequently resets the model parameters to the original source pre-trained values. By restoring the knowledge from the source, it effectively corrects the negative consequences arising from the gradual deterioration of model parameters caused by ongoing shifts in the domain. Our method involves progressive estimation of global batch-norm statistics specific to each domain, while keeping track of changes in the statistics triggered by domain shifts. Importantly, our method is agnostic to the specific adaptation technique employed and thus, can be incorporated to existing TTA methods to enhance their performance in dynamic environments. We perform extensive experiments on benchmark datasets to demonstrate the superior performance of our method compared to state-of-the-art adaptation methods.

LGJul 10, 2023
FedYolo: Augmenting Federated Learning with Pretrained Transformers

Xuechen Zhang, Mingchen Li, Xiangyu Chang et al.

The growth and diversity of machine learning applications motivate a rethinking of learning with mobile and edge devices. How can we address diverse client goals and learn with scarce heterogeneous data? While federated learning aims to address these issues, it has challenges hindering a unified solution. Large transformer models have been shown to work across a variety of tasks achieving remarkable few-shot adaptation. This raises the question: Can clients use a single general-purpose model, rather than custom models for each task, while obeying device and network constraints? In this work, we investigate pretrained transformers (PTF) to achieve these on-device learning goals and thoroughly explore the roles of model size and modularity, where the latter refers to adaptation through modules such as prompts or adapters. Focusing on federated learning, we demonstrate that: (1) Larger scale shrinks the accuracy gaps between alternative approaches and improves heterogeneity robustness. Scale allows clients to run more local SGD epochs which can significantly reduce the number of communication rounds. At the extreme, clients can achieve respectable accuracy locally highlighting the potential of fully-local learning. (2) Modularity, by design, enables $>$100$\times$ less communication in bits. Surprisingly, it also boosts the generalization capability of local adaptation methods and the robustness of smaller PTFs. Finally, it enables clients to solve multiple unrelated tasks simultaneously using a single PTF, whereas full updates are prone to catastrophic forgetting. These insights on scale and modularity motivate a new federated learning approach we call "You Only Load Once" (FedYolo): The clients load a full PTF model once and all future updates are accomplished through communication-efficient modules with limited catastrophic-forgetting, where each task is assigned to its own module.

LGJul 13, 2024
Fine-grained Analysis of In-context Linear Estimation: Data, Architecture, and Beyond

Yingcong Li, Ankit Singh Rawat, Samet Oymak

Recent research has shown that Transformers with linear attention are capable of in-context learning (ICL) by implementing a linear estimator through gradient descent steps. However, the existing results on the optimization landscape apply under stylized settings where task and feature vectors are assumed to be IID and the attention weights are fully parameterized. In this work, we develop a stronger characterization of the optimization and generalization landscape of ICL through contributions on architectures, low-rank parameterization, and correlated designs: (1) We study the landscape of 1-layer linear attention and 1-layer H3, a state-space model. Under a suitable correlated design assumption, we prove that both implement 1-step preconditioned gradient descent. We show that thanks to its native convolution filters, H3 also has the advantage of implementing sample weighting and outperforming linear attention in suitable settings. (2) By studying correlated designs, we provide new risk bounds for retrieval augmented generation (RAG) and task-feature alignment which reveal how ICL sample complexity benefits from distributional alignment. (3) We derive the optimal risk for low-rank parameterized attention weights in terms of covariance spectrum. Through this, we also shed light on how LoRA can adapt to a new distribution by capturing the shift between task covariances. Experimental results corroborate our theoretical findings. Overall, this work explores the optimization and risk landscape of ICL in practically meaningful settings and contributes to a more thorough understanding of its mechanics.

LGAug 29, 2022
Finite Sample Identification of Bilinear Dynamical Systems

Yahya Sattar, Samet Oymak, Necmiye Ozay

Bilinear dynamical systems are ubiquitous in many different domains and they can also be used to approximate more general control-affine systems. This motivates the problem of learning bilinear systems from a single trajectory of the system's states and inputs. Under a mild marginal mean-square stability assumption, we identify how much data is needed to estimate the unknown bilinear system up to a desired accuracy with high probability. Our sample complexity and statistical error rates are optimal in terms of the trajectory length, the dimensionality of the system and the input size. Our proof technique relies on an application of martingale small-ball condition. This enables us to correctly capture the properties of the problem, specifically our error rates do not deteriorate with increasing instability. Finally, we show that numerical experiments are well-aligned with our theoretical results.

SYAug 16, 2023
Can Transformers Learn Optimal Filtering for Unknown Systems?

Haldun Balim, Zhe Du, Samet Oymak et al.

Transformer models have shown great success in natural language processing; however, their potential remains mostly unexplored for dynamical systems. In this work, we investigate the optimal output estimation problem using transformers, which generate output predictions using all the past ones. Particularly, we train the transformer using various distinct systems and then evaluate the performance on unseen systems with unknown dynamics. Empirically, the trained transformer adapts exceedingly well to different unseen systems and even matches the optimal performance given by the Kalman filter for linear systems. In more complex settings with non-i.i.d. noise, time-varying dynamics, and nonlinear dynamics like a quadrotor system with unknown parameters, transformers also demonstrate promising results. To support our experimental findings, we provide statistical guarantees that quantify the amount of training data required for the transformer to achieve a desired excess risk. Finally, we point out some limitations by identifying two classes of problems that lead to degraded performance, highlighting the need for caution when using transformers for control and estimation.

86.9LGApr 20
Learning to Correct: Calibrated Reinforcement Learning for Multi-Attempt Chain-of-Thought

Muhammed Emrullah Ildiz, Halil Alperen Gozeten, Ege Onur Taga et al.

State-of-the-art reasoning models utilize long chain-of-thought (CoT) to solve increasingly complex problems using more test-time computation. In this work, we explore a long CoT setting where the model makes up to K successive attempts at solving a problem, in which each attempt is allowed to build on earlier ones after the model receives a hard verifier feedback. This motivates RL methods that can harness per-attempt rewards by carefully weighting individual attempts. We study optimizing the Verification@K reward (the model succeeds by the K-th attempt) and show that naively weighing the attempts by their pass/fail results in biased gradients. We introduce Calibrated Attempt-Level (CAL) GRPO by devising a weighing strategy to obtain unbiased gradients while maintaining small variance. Our theory reveals how incorporating per-attempt rewards influence the training and the eventual Verification@K performance. Experiments, baselines, and ablations on synthetic and real data corroborate our theory and the benefits of CAL-GRPO over vanilla GRPO as well as naive weighting.

LGJul 8, 2024
On the Power of Convolution Augmented Transformer

Mingchen Li, Xuechen Zhang, Yixiao Huang et al.

The transformer architecture has catalyzed revolutionary advances in language modeling. However, recent architectural recipes, such as state-space models, have bridged the performance gap. Motivated by this, we examine the benefits of Convolution-Augmented Transformer (CAT) for recall, copying, and length generalization tasks. CAT incorporates convolutional filters in the K/Q/V embeddings of an attention layer. Through CAT, we show that the locality of the convolution synergizes with the global view of the attention. Unlike comparable architectures, such as Mamba or transformer, CAT can provably solve the associative recall (AR) and copying tasks using a single layer while also enjoying guaranteed length generalization. We also establish computational tradeoffs between convolution and attention by characterizing how convolution can mitigate the need for full attention by summarizing the context window and creating salient summary tokens to attend. Evaluations on real datasets corroborate our findings and demonstrate that CAT and its variations indeed enhance the language modeling performance.

91.3LGApr 18
In-Context Learning Under Regime Change

Carson Dudley, Yutong Bi, Xiaofeng Liu et al.

Non-stationary sequences arise naturally in control, forecasting, and decision-making. The data-generating process shifts at unknown times, and models must detect the change, discard or downweight obsolete evidence, and adapt to new dynamics on the fly. Transformer-based foundation models increasingly rely on in-context learning for time series forecasting, tabular prediction, and continuous control. As these models are deployed in non-stationary environments, understanding their ability to detect and adapt to regime shifts is important. We formalize this as an in-context change-point detection problem and formally establish the existence of transformer models that solve this problem. Our construction demonstrates that model complexity, in layers and parameters, depends on the level of information available about the change-point location, from no knowledge to knowing exact timing. We validate our results with experiments on synthetic linear regression and linear dynamical systems, where trained transformers match the performance of optimal baselines across information levels. We also show that encoding and incorporating changepoint knowledge indeed improves the real-world performance of a pretrained foundation models on infectious disease forecasting and on financial volatility forecasting around Federal Open Market Committee (FOMC) announcements without retraining, demonstrating practical applicability to real-world regime changes.

LGMar 8, 2023
Provable Pathways: Learning Multiple Tasks over Multiple Paths

Yingcong Li, Samet Oymak

Constructing useful representations across a large number of tasks is a key requirement for sample-efficient intelligent systems. A traditional idea in multitask learning (MTL) is building a shared representation across tasks which can then be adapted to new tasks by tuning last layers. A desirable refinement of using a shared one-fits-all representation is to construct task-specific representations. To this end, recent PathNet/muNet architectures represent individual tasks as pathways within a larger supernet. The subnetworks induced by pathways can be viewed as task-specific representations that are composition of modules within supernet's computation graph. This work explores the pathways proposal from the lens of statistical learning: We first develop novel generalization bounds for empirical risk minimization problems learning multiple tasks over multiple paths (Multipath MTL). In conjunction, we formalize the benefits of resulting multipath representation when adapting to new downstream tasks. Our bounds are expressed in terms of Gaussian complexity, lead to tangible guarantees for the class of linear representations, and provide novel insights into the quality and benefits of a multipath representation. When computation graph is a tree, Multipath MTL hierarchically clusters the tasks and builds cluster-specific representations. We provide further discussion and experiments for hierarchical MTL and rigorously identify the conditions under which Multipath MTL is provably superior to traditional MTL approaches with shallow supernets.

LGFeb 2, 2023
Stochastic Contextual Bandits with Long Horizon Rewards

Yuzhen Qin, Yingcong Li, Fabio Pasqualetti et al.

The growing interest in complex decision-making and language modeling problems highlights the importance of sample-efficient learning over very long horizons. This work takes a step in this direction by investigating contextual linear bandits where the current reward depends on at most $s$ prior actions and contexts (not necessarily consecutive), up to a time horizon of $h$. In order to avoid polynomial dependence on $h$, we propose new algorithms that leverage sparsity to discover the dependence pattern and arm parameters jointly. We consider both the data-poor ($T<h$) and data-rich ($T\ge h$) regimes, and derive respective regret upper bounds $\tilde O(d\sqrt{sT} +\min\{ q, T\})$ and $\tilde O(\sqrt{sdT})$, with sparsity $s$, feature dimension $d$, total time horizon $T$, and $q$ that is adaptive to the reward dependence pattern. Complementing upper bounds, we also show that learning over a single trajectory brings inherent challenges: While the dependence pattern and arm parameters form a rank-1 matrix, circulant matrices are not isometric over rank-1 manifolds and sample complexity indeed benefits from the sparse reward dependence structure. Our results necessitate a new analysis to address long-range temporal dependencies across data and avoid polynomial dependence on the reward horizon $h$. Specifically, we utilize connections to the restricted isometry property of circulant matrices formed by dependent sub-Gaussian vectors and establish new guarantees that are also of independent interest.

LGMay 12, 2022
Representation Learning for Context-Dependent Decision-Making

Yuzhen Qin, Tommaso Menara, Samet Oymak et al.

Humans are capable of adjusting to changing environments flexibly and quickly. Empirical evidence has revealed that representation learning plays a crucial role in endowing humans with such a capability. Inspired by this observation, we study representation learning in the sequential decision-making scenario with contextual changes. We propose an online algorithm that is able to learn and transfer context-dependent representations and show that it significantly outperforms the existing ones that do not learn representations adaptively. As a case study, we apply our algorithm to the Wisconsin Card Sorting Task, a well-established test for the mental flexibility of humans in sequential decision-making. By comparing our algorithm with the standard Q-learning and Deep-Q learning algorithms, we demonstrate the benefits of adaptive representation learning.

50.9CLApr 8
Enabling Intrinsic Reasoning over Dense Geospatial Embeddings with DFR-Gemma

Xuechen Zhang, Aviv Slobodkin, Joydeep Paul et al.

Representation learning for geospatial and spatio-temporal data plays a critical role in enabling general-purpose geospatial intelligence. Recent geospatial foundation models, such as the Population Dynamics Foundation Model (PDFM), encode complex population and mobility dynamics into compact embeddings. However, their integration with Large Language Models (LLMs) remains limited. Existing approaches to LLM integration treat these embeddings as retrieval indices or convert them into textual descriptions for reasoning, introducing redundancy, token inefficiency, and numerical inaccuracies. We propose Direct Feature Reasoning-Gemma (DFR-Gemma), a novel framework that enables LLMs to reason directly over dense geospatial embeddings. DFR aligns high-dimensional embeddings with the latent space of an LLM via a lightweight projector, allowing embeddings to be injected as semantic tokens alongside natural language instructions. This design eliminates the need for intermediate textual representations and enables intrinsic reasoning over spatial features. To evaluate this paradigm, we introduce a multi-task geospatial benchmark that pairs embeddings with diverse question-answer tasks, including feature querying, comparison, and semantic description. Experimental results show that DFR allows LLMs to decode latent spatial patterns and perform accurate zero-shot reasoning across tasks, while significantly improving efficiency compared to text-based baselines. Our results demonstrate that treating embeddings as primary data inputs, provides a more direct, efficient, and scalable approach to multimodal geospatial intelligence.

88.6MEMar 20
Learning to Bet for Horizon-Aware Anytime-Valid Testing

Ege Onur Taga, Samet Oymak, Shubhanshu Shekhar

We develop horizon-aware anytime-valid tests and confidence sequences for bounded means under a strict deadline $N$. Using the betting/e-process framework, we cast horizon-aware betting as a finite-horizon optimal control problem with state space $(t, \log W_t)$, where $t$ is the time and $W_t$ is the test martingale value. We first show that in certain interior regions of the state space, policies that deviate significantly from Kelly betting are provably suboptimal, while Kelly betting reaches the threshold with high probability. We then identify sufficient conditions showing that outside this region, more aggressive betting than Kelly can be better if the bettor is behind schedule, and less aggressive can be better if the bettor is ahead. Taken together these results suggest a simple phase diagram in the $(t, \log W_t)$ plane, delineating regions where Kelly, fractional Kelly, and aggressive betting may be preferable. Guided by this phase diagram, we introduce a Deep Reinforcement Learning approach based on a universal Deep Q-Network (DQN) agent that learns a single policy from synthetic experience and maps simple statistics of past observations to bets across horizons and null values. In limited-horizon experiments, the learned DQN policy yields state-of-the-art results.

LGFeb 16
Covariance-Aware Transformers for Quadratic Programming and Decision Making

Kutay Tire, Yufan Zhang, Ege Onur Taga et al.

We explore the use of transformers for solving quadratic programs and how this capability benefits decision-making problems that involve covariance matrices. We first show that the linear attention mechanism can provably solve unconstrained QPs by tokenizing the matrix variables (e.g.~$A$ of the objective $\frac{1}{2}x^\top Ax+b^\top x$) row-by-row and emulating gradient descent iterations. Furthermore, by incorporating MLPs, a transformer block can solve (i) $\ell_1$-penalized QPs by emulating iterative soft-thresholding and (ii) $\ell_1$-constrained QPs when equipped with an additional feedback loop. Our theory motivates us to introduce Time2Decide: a generic method that enhances a time series foundation model (TSFM) by explicitly feeding the covariance matrix between the variates. We empirically find that Time2Decide uniformly outperforms the base TSFM model for the classical portfolio optimization problem that admits an $\ell_1$-constrained QP formulation. Remarkably, Time2Decide also outperforms the classical "Predict-then-Optimize (PtO)" procedure, where we first forecast the returns and then explicitly solve a constrained QP, in suitable settings. Our results demonstrate that transformers benefit from explicit use of second-order statistics, and this can enable them to effectively solve complex decision-making problems, like portfolio construction, in one forward pass.

LGJun 2, 2023
Federated Multi-Sequence Stochastic Approximation with Local Hypergradient Estimation

Davoud Ataee Tarzanagh, Mingchen Li, Pranay Sharma et al.

Stochastic approximation with multiple coupled sequences (MSA) has found broad applications in machine learning as it encompasses a rich class of problems including bilevel optimization (BLO), multi-level compositional optimization (MCO), and reinforcement learning (specifically, actor-critic methods). However, designing provably-efficient federated algorithms for MSA has been an elusive question even for the special case of double sequence approximation (DSA). Towards this goal, we develop FedMSA which is the first federated algorithm for MSA, and establish its near-optimal communication complexity. As core novelties, (i) FedMSA enables the provable estimation of hypergradients in BLO and MCO via local client updates, which has been a notable bottleneck in prior theory, and (ii) our convergence guarantees are sensitive to the heterogeneity-level of the problem. We also incorporate momentum and variance reduction techniques to achieve further acceleration leading to near-optimal rates. Finally, we provide experiments that support our theory and demonstrate the empirical benefits of FedMSA. As an example, FedMSA enables order-of-magnitude savings in communication rounds compared to prior federated BLO schemes.

75.9LGMay 3Code
Stochastic Sparse Attention for Memory-Bound Inference

Kyle Lee, Corentin Delacour, Kevin Callahan-Coray et al.

Autoregressive decoding becomes bandwidth-limited at long contexts, as generating each token requires reading all $n_k$ key and value vectors from KV cache. We present Stochastic Additive No-mulT Attention (SANTA), a method that sparsifies value-cache access by sampling $S \ll n_k$ indices from the post-softmax distribution and aggregates only those value rows. This yields an unbiased estimator of the post-softmax value aggregation while replacing value-stage multiply-accumulates with gather-and-add. We introduce stratified sampling to design variance-reduced, GPU-friendly variants, demonstrating $1.5\times$ decode-step attention kernel speedup over FlashInfer and FlashDecoding on an NVIDIA RTX 6000 Ada while matching baseline accuracy at 32k-token contexts. Finally, we propose Bernoulli $qK^\mathsf{T}$ sampling as a complementary technique to sparsify the score stage, reducing key-feature access through stochastic ternary queries. Both methods are orthogonal to upstream techniques such as ternary quantization, low-rank projections, and KV-cache compression. Together, they point toward sparse, multiplier-free, and energy-efficient inference. We open-source our kernels at: https://github.com/OPUSLab/SANTA.git

60.4LGMay 21
Evolutionary Multi-Task Optimization for LLM-Guided Program Discovery

Halil Alperen Gozeten, Xuechen Zhang, Emrullah Ildiz et al.

Recent LLM-guided evolutionary search methods have shown that iterative program mutation can discover strong algorithms, but they typically optimize each task independently, even when related tasks share reusable structure. We introduce Evolutionary Multi-Task Optimization (EMO) for LLM-guided program discovery, and propose EMO-STA (Shared-Then-Adapt), a two-stage framework that first evolves a shared archive of executable programs across a task family and then adapts selected shared candidates to each target task. Within EMO-STA, we explore multiple adaptation strategies, including warm-starting from the shared archive, adapting the best average shared program, and adapting the shared program that performs best on each target task. Across eight task families spanning continuous optimization, geometric construction, modeling, and algorithmic optimization, EMO-STA improves over matched-compute single-task evolution in most settings, with STA Best-Local providing the strongest in-distribution adaptation and STA Best-Shared yielding robust transfer to unseen tasks. Compute-allocation experiments show that allocating a substantial fraction of the family-level budget to shared evolution is consistently beneficial, with roughly balanced shared and adaptation budgets often being optimal. Beyond compute efficiency, we show that shared evolution can mitigate overfitting in low-evidence settings (e.g. few training data), including ARC tasks and time-series feature engineering, by favoring programs that generalize across all tasks rather than exploiting task-specific brittle artifacts.

95.5LGMay 15
VSPO: Vector-Steered Policy Optimization for Behavioral Control

Xuechen Zhang, Zijian Huang, Kai Yang et al.

Modern language models often need to optimize a primary accuracy objective while also accommodating secondary behavioral preferences, such as verbosity, agreeableness, or the level of technical expertise in its response. In practice, a base model may exhibit a desired behavior very rarely or not at all. Thus, endowing the model with a target behavior creates a sparse behavioral reward bottleneck. To address such multi-objective problems, we introduce Vector-Steered Policy Optimization (VSPO) which employs a steering vector associated with the target behavior to control the behavior intensity of the generated rollouts. VSPO is obtained by modifying GRPO to sample rollouts with varying steering intensities. This process can be interpreted as an on-policy latent self-distillation procedure where the model internalizes its steering vector. By varying steering intensities, VSPO upsamples rare behaviors and enriches rollout diversity, which alleviates the sparse reward issue and provably accelerates the policy optimization. Through comprehensive theory and experiments, we establish that VSPO has favorable properties compared to vanilla reward shaping and other alternative approaches. Specifically, under a bandit abstraction, VSPO provably achieves better iteration complexity than reward-shaped GRPO when the steering-induced distributions are sufficiently aligned with the target behavior. We evaluate VSPO across multiple reasoning benchmarks, including MATH and MMLU-Pro, for four target behaviors: explanation expertise, confidence expression, robustness to misleading context, and response verbosity. Our results show that VSPO consistently improves the control along target behavior while maintaining or improving task accuracy compared with reward shaping, teacher-trace distillation, and guidance-based baselines.

71.6LGMay 11
Latent Chain-of-Thought Improves Structured-Data Transformers

Carson Dudley, Samet Oymak

Chain-of-thought and more broadly test-time compute are known to augment the expressive capabilities of language models and have led to major innovations in reasoning. Motivated by this success, this paper explores latent chain-of-thought as well as the impact of depth and looping for time-series and tabular data. We propose a recurrent scheme in which a structured-data transformer, after an initial forward pass, compresses its query-position hidden states into feedback tokens that are appended to the input and processed again, allowing multiple rounds of latent computation before prediction. We compare CoT models against a same-depth no-CoT baseline, a deeper baseline matched to the CoT model in effective depth, and a looped transformer with weight-tied recurrence but no additional chain-of-thought tokens. Across 36 datasets in time-series forecasting and tabular prediction, latent chain-of-thought improves over the baseline on 8/9 time-series datasets (+10.99\% average gain) and 22/27 tabular datasets (+5.31\% average gain). Across both settings, the CoT models perform the best on average. These results demonstrate that chain-of-thought is a useful axis for scaling test-time compute for structured data.

LGFeb 6, 2024
Can Mamba Learn How to Learn? A Comparative Study on In-Context Learning Tasks

Jongho Park, Jaeseung Park, Zheyang Xiong et al.

State-space models (SSMs), such as Mamba (Gu & Dao, 2023), have been proposed as alternatives to Transformer networks in language modeling, by incorporating gating, convolutions, and input-dependent token selection to mitigate the quadratic cost of multi-head attention. Although SSMs exhibit competitive performance, their in-context learning (ICL) capabilities, a remarkable emergent property of modern language models that enables task execution without parameter optimization, remain underexplored compared to Transformers. In this study, we evaluate the ICL performance of SSMs, focusing on Mamba, against Transformer models across various tasks. Our results show that SSMs perform comparably to Transformers in standard regression ICL tasks, while outperforming them in tasks like sparse parity learning. However, SSMs fall short in tasks involving non-standard retrieval functionality. To address these limitations, we introduce a hybrid model, MambaFormer, that combines Mamba with attention blocks, surpassing individual models in tasks where they struggle independently. Our findings suggest that hybrid architectures offer promising avenues for enhancing ICL in language models.

LGJan 4, 2022Code
AutoBalance: Optimized Loss Functions for Imbalanced Data

Mingchen Li, Xuechen Zhang, Christos Thrampoulidis et al.

Imbalanced datasets are commonplace in modern machine learning problems. The presence of under-represented classes or groups with sensitive attributes results in concerns about generalization and fairness. Such concerns are further exacerbated by the fact that large capacity deep nets can perfectly fit the training data and appear to achieve perfect accuracy and fairness during training, but perform poorly during test. To address these challenges, we propose AutoBalance, a bi-level optimization framework that automatically designs a training loss function to optimize a blend of accuracy and fairness-seeking objectives. Specifically, a lower-level problem trains the model weights, and an upper-level problem tunes the loss function by monitoring and optimizing the desired objective over the validation data. Our loss design enables personalized treatment for classes/groups by employing a parametric cross-entropy loss and individualized data augmentation schemes. We evaluate the benefits and performance of our approach for the application scenarios of imbalanced and group-sensitive classification. Extensive empirical evaluations demonstrate the benefits of AutoBalance over state-of-the-art approaches. Our experimental findings are complemented with theoretical insights on loss function design and the benefits of train-validation split. All code is available open-source.

LGMar 12, 2024
Mechanics of Next Token Prediction with Self-Attention

Yingcong Li, Yixiao Huang, M. Emrullah Ildiz et al.

Transformer-based language models are trained on large datasets to predict the next token given an input sequence. Despite this simple training objective, they have led to revolutionary advances in natural language processing. Underlying this success is the self-attention mechanism. In this work, we ask: $\textit{What}$ $\textit{does}$ $\textit{a}$ $\textit{single}$ $\textit{self-attention}$ $\textit{layer}$ $\textit{learn}$ $\textit{from}$ $\textit{next-token}$ $\textit{prediction?}$ We show that training self-attention with gradient descent learns an automaton which generates the next token in two distinct steps: $\textbf{(1)}$ $\textbf{Hard}$ $\textbf{retrieval:}$ Given input sequence, self-attention precisely selects the $\textit{high-priority}$ $\textit{input}$ $\textit{tokens}$ associated with the last input token. $\textbf{(2)}$ $\textbf{Soft}$ $\textbf{composition:}$ It then creates a convex combination of the high-priority tokens from which the next token can be sampled. Under suitable conditions, we rigorously characterize these mechanics through a directed graph over tokens extracted from the training data. We prove that gradient descent implicitly discovers the strongly-connected components (SCC) of this graph and self-attention learns to retrieve the tokens that belong to the highest-priority SCC available in the context window. Our theory relies on decomposing the model weights into a directional component and a finite component that correspond to hard retrieval and soft composition steps respectively. This also formalizes a related implicit bias formula conjectured in [Tarzanagh et al. 2023]. We hope that these findings shed light on how self-attention processes sequential data and pave the path toward demystifying more complex architectures.

LGFeb 21, 2024
From Self-Attention to Markov Models: Unveiling the Dynamics of Generative Transformers

M. Emrullah Ildiz, Yixiao Huang, Yingcong Li et al.

Modern language models rely on the transformer architecture and attention mechanism to perform language understanding and text generation. In this work, we study learning a 1-layer self-attention model from a set of prompts and associated output data sampled from the model. We first establish a precise mapping between the self-attention mechanism and Markov models: Inputting a prompt to the model samples the output token according to a context-conditioned Markov chain (CCMC) which weights the transition matrix of a base Markov chain. Additionally, incorporating positional encoding results in position-dependent scaling of the transition probabilities. Building on this formalism, we develop identifiability/coverage conditions for the prompt distribution that guarantee consistent estimation and establish sample complexity guarantees under IID samples. Finally, we study the problem of learning from a single output trajectory generated from an initial prompt. We characterize an intriguing winner-takes-all phenomenon where the generative process implemented by self-attention collapses into sampling a limited subset of tokens due to its non-mixing nature. This provides a mathematical explanation to the tendency of modern LLMs to generate repetitive text. In summary, the equivalence to CCMC provides a simple but powerful framework to study self-attention and its properties.

CLApr 17, 2024
Efficient Contextual LLM Cascades through Budget-Constrained Policy Learning

Xuechen Zhang, Zijian Huang, Ege Onur Taga et al.

Recent successes in natural language processing have led to the proliferation of large language models (LLMs) by multiple providers. Each LLM offering has different inference accuracy, monetary cost, and latency, and their accuracy further depends on the exact wording of the question (i.e., the specific prompt). At the same time, users often have a limit on monetary budget and latency to answer all their questions, and they do not know which LLMs to choose for each question to meet their accuracy and long term budget requirements. To navigate this rich design space, we propose TREACLE ($\underline{T}$hrifty $\underline{Rea}$soning via $\underline{C}$ontext-Aware $\underline{L}$LM and Prompt S$\underline{e}$lection), a reinforcement learning policy that jointly selects the model and prompting scheme while respecting the user's monetary cost and latency constraints. TREACLE uses the problem context, including question text embeddings (reflecting the type or difficulty of a query) and the response history (reflecting the consistency of previous responses) to make smart decisions. Our evaluations on standard reasoning datasets (GSM8K, CSQA, and LLC) with various LLMs and prompts show that TREACLE enables cost savings of up to 85% compared to baselines, while maintaining high accuracy. Importantly, it provides the user with the ability to gracefully trade off accuracy for cost.

MLOct 24, 2024
High-dimensional Analysis of Knowledge Distillation: Weak-to-Strong Generalization and Scaling Laws

M. Emrullah Ildiz, Halil Alperen Gozeten, Ege Onur Taga et al.

A growing number of machine learning scenarios rely on knowledge distillation where one uses the output of a surrogate model as labels to supervise the training of a target model. In this work, we provide a sharp characterization of this process for ridgeless, high-dimensional regression, under two settings: (i) model shift, where the surrogate model is arbitrary, and (ii) distribution shift, where the surrogate model is the solution of empirical risk minimization with out-of-distribution data. In both cases, we characterize the precise risk of the target model through non-asymptotic bounds in terms of sample size and data distribution under mild conditions. As a consequence, we identify the form of the optimal surrogate model, which reveals the benefits and limitations of discarding weak features in a data-dependent fashion. In the context of weak-to-strong (W2S) generalization, this has the interpretation that (i) W2S training, with the surrogate as the weak model, can provably outperform training with strong labels under the same data budget, but (ii) it is unable to improve the data scaling law. We validate our results on numerical experiments both on ridgeless regression and on neural network architectures.

LGFeb 22, 2025
TimePFN: Effective Multivariate Time Series Forecasting with Synthetic Data

Ege Onur Taga, M. Emrullah Ildiz, Samet Oymak

The diversity of time series applications and scarcity of domain-specific data highlight the need for time-series models with strong few-shot learning capabilities. In this work, we propose a novel training scheme and a transformer-based architecture, collectively referred to as TimePFN, for multivariate time-series (MTS) forecasting. TimePFN is based on the concept of Prior-data Fitted Networks (PFN), which aims to approximate Bayesian inference. Our approach consists of (1) generating synthetic MTS data through diverse Gaussian process kernels and the linear coregionalization method, and (2) a novel MTS architecture capable of utilizing both temporal and cross-channel dependencies across all input patches. We evaluate TimePFN on several benchmark datasets and demonstrate that it outperforms the existing state-of-the-art models for MTS forecasting in both zero-shot and few-shot settings. Notably, fine-tuning TimePFN with as few as 500 data points nearly matches full dataset training error, and even 50 data points yield competitive results. We also find that TimePFN exhibits strong univariate forecasting performance, attesting to its generalization ability. Overall, this work unlocks the power of synthetic data priors for MTS forecasting and facilitates strong zero- and few-shot forecasting performance.

LGNov 12, 2024
Retrieval Augmented Time Series Forecasting

Kutay Tire, Ege Onur Taga, Muhammed Emrullah Ildiz et al.

Retrieval-augmented generation (RAG) is a central component of modern LLM systems, particularly in scenarios where up-to-date information is crucial for accurately responding to user queries or when queries exceed the scope of the training data. The advent of time-series foundation models (TSFM), such as Chronos, and the need for effective zero-shot forecasting performance across various time-series domains motivates the question: Do benefits of RAG similarly carry over to time series forecasting? In this paper, we advocate that the dynamic and event-driven nature of time-series data makes RAG a crucial component of TSFMs and introduce a principled RAG framework for time-series forecasting, called Retrieval Augmented Forecasting (RAF). Within RAF, we develop efficient strategies for retrieving related time-series examples and incorporating them into forecast. Through experiments and mechanistic studies, we demonstrate that RAF indeed improves the forecasting accuracy across diverse time series domains and the improvement is more significant for larger TSFM sizes.

LGNov 19, 2024
Selective Attention: Enhancing Transformer through Principled Context Control

Xuechen Zhang, Xiangyu Chang, Mingchen Li et al.

The attention mechanism within the transformer architecture enables the model to weigh and combine tokens based on their relevance to the query. While self-attention has enjoyed major success, it notably treats all queries $q$ in the same way by applying the mapping $V^\top\text{softmax}(Kq)$, where $V,K$ are the value and key embeddings respectively. In this work, we argue that this uniform treatment hinders the ability to control contextual sparsity and relevance. As a solution, we introduce the $\textit{Selective Self-Attention}$ (SSA) layer that augments the softmax nonlinearity with a principled temperature scaling strategy. By controlling temperature, SSA adapts the contextual sparsity of the attention map to the query embedding and its position in the context window. Through theory and experiments, we demonstrate that this alleviates attention dilution, aids the optimization process, and enhances the model's ability to control softmax spikiness of individual queries. We also incorporate temperature scaling for value embeddings and show that it boosts the model's ability to suppress irrelevant/noisy tokens. Notably, SSA is a lightweight method which introduces less than 0.5% new parameters through a weight-sharing strategy and can be fine-tuned on existing LLMs. Extensive empirical evaluations demonstrate that SSA-equipped models achieve a noticeable and consistent accuracy improvement on language modeling benchmarks.

80.8SYApr 26
On the Generalization Properties of Selective State-Space Models for Filtering Tasks for Unknown Systems

Alex Tang, M. Emrullah Ildiz, Batin Kurt et al.

Selective State-Space Models (SSMs) such as Mamba have emerged as an alternative architecture to self-attention based transformers in sequence modeling tasks. Recent works have demonstrated the use of transformers in some filtering and output prediction tasks via in-context learning. In this paper, we analyze whether structured SSMs can work equally well for filtering of unknown systems. In particular, we train the SSM on trajectory samples from a set of systems. At run-time, the SSM is given the outputs of an unknown system from the same set and is expected to predict the next output online. Theoretically, under appropriate assumptions, we derive generalization bounds as to why SSMs succeed in such tasks. Empirically, we demonstrate the performance via several numerical examples. We also discuss the advantages and disadvantages of SSMs versus transformers for this task.

LGJun 20, 2025
BREAD: Branched Rollouts from Expert Anchors Bridge SFT & RL for Reasoning

Xuechen Zhang, Zijian Huang, Yingcong Li et al.

Small language models (SLMs) struggle to learn complex reasoning behaviors, especially when high-quality traces are scarce or difficult to learn from. The standard training approach combines a supervised fine-tuning (SFT) stage, often to distill capabilities of a larger model, followed by a reinforcement learning (RL)stage such as Group Relative Policy Optimization (GRPO). In this paper, we investigate the fundamental limitations of this SFT + RL paradigm and propose methods to overcome them. Under a suitable theoretical model, we demonstrate that the SFT + RL strategy can fail completely when (1) the expert's traces are too difficult for the small model to express, or (2) the small model's initialization has exponentially small likelihood of success. To address these, we introduce BREAD: a GRPO variant that unifies the SFT and RL stages via partial expert guidance and branched rollouts. When self-generated traces fail, BREAD adaptively inserts short expert prefixes/hints, allowing the small model to complete the rest of the reasoning path, and ensuring that each update includes at least one successful trace. This mechanism both densifies the reward signal and induces a natural learning curriculum. BREAD requires fewer than 40% of ground-truth traces, consistently outperforming standard GRPO while speeding up the training by about 3 times. Importantly, we demonstrate that BREAD helps the model solve problems that are otherwise unsolvable by the SFT + RL strategy, highlighting how branched rollouts and expert guidance can substantially boost SLM reasoning.

LGMay 29, 2025
Continuous Chain of Thought Enables Parallel Exploration and Reasoning

Halil Alperen Gozeten, M. Emrullah Ildiz, Xuechen Zhang et al.

Modern language models generate chain-of-thought traces by autoregressively sampling tokens from a finite vocabulary. While this discrete sampling has achieved remarkable success, conducting chain-of-thought with continuously-valued tokens (CoT2) offers a richer and more expressive alternative. Our work provides new theoretical guarantees and algorithms for CoT2, motivated by logical reasoning tasks that inherently require search capabilities. Theoretically, we establish how CoT2 facilitates the model to track multiple discrete traces in parallel; and quantify the level of achievable parallelism and its benefits for inference efficiency. We also provide a CoT2-based one-layer transformer construction that solves the combinatorial "subset sum problem" given a sufficient embedding dimension. These insights arise from a novel and effective supervision strategy where we match the language model outputs to the empirical token distributions of a set of target traces. Complementing this, we introduce sampling strategies that unlock policy optimization methods for CoT2. Our primary strategy samples and composes $K$ discrete tokens at each decoding step to control the level of parallelism. Experiments confirm that (i) the optimal level of parallelism is governed by the embedding dimension, (ii) our continuous supervision strategy can outperform alternative methods, and (iii) policy optimization with CoT2 indeed improves the performance of the model beyond its initial discrete or continuous supervision.

LGMay 12, 2025
Making Small Language Models Efficient Reasoners: Intervention, Supervision, Reinforcement

Xuechen Zhang, Zijian Huang, Chenshun Ni et al.

Recent research enhances language model reasoning by scaling test-time compute via longer chain-of-thought traces. This often improves accuracy but also introduces redundancy and high computational cost, especially for small language models distilled with supervised fine-tuning (SFT). In this work, we propose new algorithms to improve token-efficient reasoning with small-scale models by effectively trading off accuracy and computation. We first show that the post-SFT model fails to determine the optimal stopping point of the reasoning process, resulting in verbose and repetitive outputs. Verbosity also significantly varies across wrong vs correct responses. To address these issues, we propose two solutions: (1) Temperature scaling (TS) to control the stopping point for the thinking phase and thereby trace length, and (2) TLDR: a length-regularized reinforcement learning method based on GRPO that facilitates multi-level trace length control (e.g. short, medium, long reasoning). Experiments on four reasoning benchmarks, MATH500, AMC, AIME24 and OlympiadBench, demonstrate that TS is highly effective compared to s1's budget forcing approach and TLDR significantly improves token efficiency by about 50% with minimal to no accuracy loss over the SFT baseline. Moreover, TLDR also facilitates flexible control over the response length, offering a practical and effective solution for token-efficient reasoning in small models. Ultimately, our work reveals the importance of stopping time control, highlights shortcomings of pure SFT, and provides effective algorithmic recipes.

LGMar 14, 2025
Test-Time Training Provably Improves Transformers as In-context Learners

Halil Alperen Gozeten, M. Emrullah Ildiz, Xuechen Zhang et al.

Test-time training (TTT) methods explicitly update the weights of a model to adapt to the specific test instance, and they have found success in a variety of settings, including most recently language modeling and reasoning. To demystify this success, we investigate a gradient-based TTT algorithm for in-context learning, where we train a transformer model on the in-context demonstrations provided in the test prompt. Specifically, we provide a comprehensive theoretical characterization of linear transformers when the update rule is a single gradient step. Our theory (i) delineates the role of alignment between pretraining distribution and target task, (ii) demystifies how TTT can alleviate distribution shift, and (iii) quantifies the sample complexity of TTT including how it can significantly reduce the eventual sample size required for in-context learning. As our empirical contribution, we study the benefits of TTT for TabPFN, a tabular foundation model. In line with our theory, we demonstrate that TTT significantly reduces the required sample size for tabular classification (3 to 5 times fewer) unlocking substantial inference efficiency with a negligible training cost.

CLJun 10, 2025
Extrapolation by Association: Length Generalization Transfer in Transformers

Ziyang Cai, Nayoung Lee, Avi Schwarzschild et al.

Transformer language models have demonstrated impressive generalization capabilities in natural language domains, yet we lack a fine-grained understanding of how such generalization arises. In this paper, we investigate length generalization--the ability to extrapolate from shorter to longer inputs--through the lens of \textit{task association}. We find that length generalization can be \textit{transferred} across related tasks. That is, training a model with a longer and related auxiliary task can lead it to generalize to unseen and longer inputs from some other target task. We demonstrate this length generalization transfer across diverse algorithmic tasks, including arithmetic operations, string transformations, and maze navigation. Our results show that transformer models can inherit generalization capabilities from similar tasks when trained jointly. Moreover, we observe similar transfer effects in pretrained language models, suggesting that pretraining equips models with reusable computational scaffolding that facilitates extrapolation in downstream settings. Finally, we provide initial mechanistic evidence that length generalization transfer correlates with the re-use of the same attention heads between the tasks. Together, our findings deepen our understanding of how transformers generalize to out-of-distribution inputs and highlight the compositional reuse of inductive structure across tasks.

LGApr 6, 2025
Gating is Weighting: Understanding Gated Linear Attention through In-context Learning

Yingcong Li, Davoud Ataee Tarzanagh, Ankit Singh Rawat et al.

Linear attention methods offer a compelling alternative to softmax attention due to their efficiency in recurrent decoding. Recent research has focused on enhancing standard linear attention by incorporating gating while retaining its computational benefits. Such Gated Linear Attention (GLA) architectures include competitive models such as Mamba and RWKV. In this work, we investigate the in-context learning capabilities of the GLA model and make the following contributions. We show that a multilayer GLA can implement a general class of Weighted Preconditioned Gradient Descent (WPGD) algorithms with data-dependent weights. These weights are induced by the gating mechanism and the input, enabling the model to control the contribution of individual tokens to prediction. To further understand the mechanics of this weighting, we introduce a novel data model with multitask prompts and characterize the optimization landscape of learning a WPGD algorithm. Under mild conditions, we establish the existence and uniqueness (up to scaling) of a global minimum, corresponding to a unique WPGD solution. Finally, we translate these findings to explore the optimization landscape of GLA and shed light on how gating facilitates context-aware learning and when it is provably better than vanilla linear attention.

LGFeb 13, 2024
FLASH: Federated Learning Across Simultaneous Heterogeneities

Xiangyu Chang, Sk Miraj Ahmed, Srikanth V. Krishnamurthy et al.

The key premise of federated learning (FL) is to train ML models across a diverse set of data-owners (clients), without exchanging local data. An overarching challenge to this date is client heterogeneity, which may arise not only from variations in data distribution, but also in data quality, as well as compute/communication latency. An integrated view of these diverse and concurrent sources of heterogeneity is critical; for instance, low-latency clients may have poor data quality, and vice versa. In this work, we propose FLASH(Federated Learning Across Simultaneous Heterogeneities), a lightweight and flexible client selection algorithm that outperforms state-of-the-art FL frameworks under extensive sources of heterogeneity, by trading-off the statistical information associated with the client's data quality, data distribution, and latency. FLASH is the first method, to our knowledge, for handling all these heterogeneities in a unified manner. To do so, FLASH models the learning dynamics through contextual multi-armed bandits (CMAB) and dynamically selects the most promising clients. Through extensive experiments, we demonstrate that FLASH achieves substantial and consistent improvements over state-of-the-art baselines -- as much as 10% in absolute accuracy -- thanks to its unified approach. Importantly, FLASH also outperforms federated aggregation methods that are designed to handle highly heterogeneous settings and even enjoys a performance boost when integrated with them.

LGJan 4, 2024
CONTRAST: Continual Multi-source Adaptation to Dynamic Distributions

Sk Miraj Ahmed, Fahim Faisal Niloy, Xiangyu Chang et al.

Adapting to dynamic data distributions is a practical yet challenging task. One effective strategy is to use a model ensemble, which leverages the diverse expertise of different models to transfer knowledge to evolving data distributions. However, this approach faces difficulties when the dynamic test distribution is available only in small batches and without access to the original source data. To address the challenge of adapting to dynamic distributions in such practical settings, we propose Continual Multi-source Adaptation to Dynamic Distributions (CONTRAST), a novel method that optimally combines multiple source models to adapt to the dynamic test data. CONTRAST has two distinguishing features. First, it efficiently computes the optimal combination weights to combine the source models to adapt to the test data distribution continuously as a function of time. Second, it identifies which of the source model parameters to update so that only the model which is most correlated to the target data is adapted, leaving the less correlated ones untouched; this mitigates the issue of ``forgetting" the source model parameters by focusing only on the source model that exhibits the strongest correlation with the test batch distribution. Through theoretical analysis we show that the proposed method is able to optimally combine the source models and prioritize updates to the model least prone to forgetting. Experimental analysis on diverse datasets demonstrates that the combination of multiple source models does at least as well as the best source (with hindsight knowledge), and performance does not degrade as the test data distribution changes over time (robust to forgetting).

LGJun 18, 2025
When and How Unlabeled Data Provably Improve In-Context Learning

Yingcong Li, Xiangyu Chang, Muti Kara et al.

Recent research shows that in-context learning (ICL) can be effective even when demonstrations have missing or incorrect labels. To shed light on this capability, we examine a canonical setting where the demonstrations are drawn according to a binary Gaussian mixture model (GMM) and a certain fraction of the demonstrations have missing labels. We provide a comprehensive theoretical study to show that: (1) The loss landscape of one-layer linear attention models recover the optimal fully-supervised estimator but completely fail to exploit unlabeled data; (2) In contrast, multilayer or looped transformers can effectively leverage unlabeled data by implicitly constructing estimators of the form $\sum_{i\ge 0} a_i (X^\top X)^iX^\top y$ with $X$ and $y$ denoting features and partially-observed labels (with missing entries set to zero). We characterize the class of polynomials that can be expressed as a function of depth and draw connections to Expectation Maximization, an iterative pseudo-labeling algorithm commonly used in semi-supervised learning. Importantly, the leading polynomial power is exponential in depth, so mild amount of depth/looping suffices. As an application of theory, we propose looping off-the-shelf tabular foundation models to enhance their semi-supervision capabilities. Extensive evaluations on real-world datasets show that our method significantly improves the semisupervised tabular learning performance over the standard single pass inference.

IRDec 17, 2025
SmartChunk Retrieval: Query-Aware Chunk Compression with Planning for Efficient Document RAG

Xuechen Zhang, Koustava Goswami, Samet Oymak et al.

Retrieval-augmented generation (RAG) has strong potential for producing accurate and factual outputs by combining language models (LMs) with evidence retrieved from large text corpora. However, current pipelines are limited by static chunking and flat retrieval: documents are split into short, predetermined, fixed-size chunks, embeddings are retrieved uniformly, and generation relies on whatever chunks are returned. This design brings challenges, as retrieval quality is highly sensitive to chunk size, often introduces noise from irrelevant or misleading chunks, and scales poorly to large corpora. We present SmartChunk retrieval, a query-adaptive framework for efficient and robust long-document question answering (QA). SmartChunk uses (i) a planner that predicts the optimal chunk abstraction level for each query, and (ii) a lightweight compression module that produces high-level chunk embeddings without repeated summarization. By adapting retrieval granularity on the fly, SmartChunk balances accuracy with efficiency and avoids the drawbacks of fixed strategies. Notably, our planner can reason about chunk abstractions through a novel reinforcement learning scheme, STITCH, which boosts accuracy and generalization. To reflect real-world applications, where users face diverse document types and query styles, we evaluate SmartChunk on five QA benchmarks plus one out-of-domain dataset. Across these evaluations, SmartChunk outperforms state-of-the-art RAG baselines, while reducing cost. Further analysis demonstrates strong scalability with larger corpora and consistent gains on out-of-domain datasets, highlighting its effectiveness as a general framework for adaptive retrieval.

LGNov 24, 2025
Mitigating Participation Imbalance Bias in Asynchronous Federated Learning

Xiangyu Chang, Manyi Yao, Srikanth V. Krishnamurthy et al.

In Asynchronous Federated Learning (AFL), the central server immediately updates the global model with each arriving client's contribution. As a result, clients perform their local training on different model versions, causing information staleness (delay). In federated environments with non-IID local data distributions, this asynchronous pattern amplifies the adverse effect of client heterogeneity (due to different data distribution, local objectives, etc.), as faster clients contribute more frequent updates, biasing the global model. We term this phenomenon heterogeneity amplification. Our work provides a theoretical analysis that maps AFL design choices to their resulting error sources when heterogeneity amplification occurs. Guided by our analysis, we propose ACE (All-Client Engagement AFL), which mitigates participation imbalance through immediate, non-buffered updates that use the latest information available from all clients. We also introduce a delay-aware variant, ACED, to balance client diversity against update staleness. Experiments on different models for different tasks across diverse heterogeneity and delay settings validate our analysis and demonstrate the robust performance of our approaches.

LGMay 22, 2025
Attention with Trained Embeddings Provably Selects Important Tokens

Diyuan Wu, Aleksandr Shevchenko, Samet Oymak et al.

Token embeddings play a crucial role in language modeling but, despite this practical relevance, their theoretical understanding remains limited. Our paper addresses the gap by characterizing the structure of embeddings obtained via gradient descent. Specifically, we consider a one-layer softmax attention model with a linear head for binary classification, i.e., $\texttt{Softmax}( p^\top E_X^\top ) E_X v = \frac{ \sum_{i=1}^T \exp(p^\top E_{x_i}) E_{x_i}^\top v}{\sum_{j=1}^T \exp(p^\top E_{x_{j}}) }$, where $E_X = [ E_{x_1} , \dots, E_{x_T} ]^\top$ contains the embeddings of the input sequence, $p$ is the embedding of the $\mathrm{\langle cls \rangle}$ token and $v$ the output vector. First, we show that, already after a single step of gradient training with the logistic loss, the embeddings $E_X$ capture the importance of tokens in the dataset by aligning with the output vector $v$ proportionally to the frequency with which the corresponding tokens appear in the dataset. Then, after training $p$ via gradient flow until convergence, the softmax selects the important tokens in the sentence (i.e., those that are predictive of the label), and the resulting $\mathrm{\langle cls \rangle}$ embedding maximizes the margin for such a selection. Experiments on real-world datasets (IMDB, Yelp) exhibit a phenomenology close to that unveiled by our theory.

LGMar 13, 2025
Identifying Trustworthiness Challenges in Deep Learning Models for Continental-Scale Water Quality Prediction

Xiaobo Xia, Xiaofeng Liu, Jiale Liu et al.

Water quality is foundational to environmental sustainability, ecosystem resilience, and public health. Deep learning offers transformative potential for large-scale water quality prediction and scientific insights generation. However, their widespread adoption in high-stakes operational decision-making, such as pollution mitigation and equitable resource allocation, is prevented by unresolved trustworthiness challenges, including performance disparity, robustness, uncertainty, interpretability, generalizability, and reproducibility. In this work, we present a multi-dimensional, quantitative evaluation of trustworthiness benchmarking three state-of-the-art deep learning architectures: recurrent (LSTM), operator-learning (DeepONet), and transformer-based (Informer), trained on 37 years of data from 482 U.S. basins to predict 20 water quality variables. Our investigation reveals systematic performance disparities tied to process complexity, data availability, and basin heterogeneity. Management-critical variables remain the least predictable and most uncertain. Robustness tests reveal pronounced sensitivity to outliers and corrupted targets; notably, the architecture with the strongest baseline performance (LSTM) proves most vulnerable under data corruption. Attribution analyses align for simple variables but diverge for nutrients, underscoring the need for multi-method interpretability. Spatial generalization to ungauged basins remains poor across all models. This work serves as a timely call to action for advancing trustworthy data-driven methods for water resources management and provides a pathway to offering critical insights for researchers, decision-makers, and practitioners seeking to leverage artificial intelligence (AI) responsibly in environmental management.

CLMar 3, 2025
Provable Benefits of Task-Specific Prompts for In-context Learning

Xiangyu Chang, Yingcong Li, Muti Kara et al.

The in-context learning capabilities of modern language models have motivated a deeper mathematical understanding of sequence models. A line of recent work has shown that linear attention models can emulate projected gradient descent iterations to implicitly learn the task vector from the data provided in the context window. In this work, we consider a novel setting where the global task distribution can be partitioned into a union of conditional task distributions. We then examine the use of task-specific prompts and prediction heads for learning the prior information associated with the conditional task distribution using a one-layer attention model. Our results on loss landscape show that task-specific prompts facilitate a covariance-mean decoupling where prompt-tuning explains the conditional mean of the distribution whereas the variance is learned/explained through in-context learning. Incorporating task-specific head further aids this process by entirely decoupling estimation of mean and variance components. This covariance-mean perspective similarly explains how jointly training prompt and attention weights can provably help over fine-tuning after pretraining.