Taiji Suzuki

LG
h-index34
105papers
3,008citations
Novelty55%
AI Score62

105 Papers

15.4AIJun 3Code
LeanMarathon: Toward Reliable AI Co-Mathematicians through Long-Horizon Lean Autoformalization

Yuanhe Zhang, Yuekai Sun, Taiji Suzuki et al.

Long-horizon autoformalization of research mathematics fails not only at hard lemmas, but at scale: statements drift, dependencies tangle, context decays, and local repairs corrupt distant work. We present LeanMarathon, a multi-agent harness for reliable research-level Lean autoformalization. Its core abstraction is an evolving blueprint: a Lean file that serves simultaneously as formal proof skeleton, natural-language proof graph, and shared system of record. Four contract-scoped agents construct, audit, prove, and repair this blueprint. These agents are coordinated by a two-stage orchestrator that first stabilizes target fidelity through adversarial review and then discharges the proof directed acyclic graph (DAG) from its dynamic leaves upward in parallel CI-gated rounds. LeanMarathon turns one brittle multi-hour run into many local, recoverable, parallel transactions. We evaluate LeanMarathon on two recent research papers spanning four Erdős problems (#1051, #1196, #164, #1217). Across three autonomous runs, it formalizes all seven target theorems with no sorry, proving 258 lemmas and theorems. These results show that reliable AI co-mathematics requires not only stronger provers, but durable harnesses that preserve target fidelity across long mathematical developments. The code can be found at https://github.com/YuanheZ/LeanMarathon.

39.0MLMar 3, 2023
Diffusion Models are Minimax Optimal Distribution Estimators

Kazusato Oko, Shunta Akiyama, Taiji Suzuki

While efficient distribution learning is no doubt behind the groundbreaking success of diffusion modeling, its theoretical guarantees are quite limited. In this paper, we provide the first rigorous analysis on approximation and generalization abilities of diffusion modeling for well-known function spaces. The highlight of this paper is that when the true density function belongs to the Besov space and the empirical score matching loss is properly minimized, the generated data distribution achieves the nearly minimax optimal estimation rates in the total variation distance and in the Wasserstein distance of order one. Furthermore, we extend our theory to demonstrate how diffusion models adapt to low-dimensional data distributions. We expect these results advance theoretical understandings of diffusion modeling and its ability to generate verisimilar outputs.

13.0LGJun 24, 2023
Quantifying the Optimization and Generalization Advantages of Graph Neural Networks Over Multilayer Perceptrons

Wei Huang, Yuan Cao, Haonan Wang et al.

Graph neural networks (GNNs) have demonstrated remarkable capabilities in learning from graph-structured data, often outperforming traditional Multilayer Perceptrons (MLPs) in numerous graph-based tasks. Although existing works have demonstrated the benefits of graph convolution through Laplacian smoothing, expressivity or separability, there remains a lack of quantitative analysis comparing GNNs and MLPs from an optimization and generalization perspective. This study aims to address this gap by examining the role of graph convolution through feature learning theory. Using a signal-noise data model, we conduct a comparative analysis of the optimization and generalization between two-layer graph convolutional networks (GCNs) and their MLP counterparts. Our approach tracks the trajectory of signal learning and noise memorization in GNNs, characterizing their post-training generalization. We reveal that GNNs significantly prioritize signal learning, thus enhancing the regime of {low test error} over MLPs by $D^{q-2}$ times, where $D$ denotes a node's expected degree and $q$ is the power of ReLU activation function with $q>2$. This finding highlights a substantial and quantitative discrepancy between GNNs and MLPs in terms of optimization and generalization, a conclusion further supported by our empirical simulations on both synthetic and real-world datasets.

37.5MLMay 3, 2022
High-dimensional Asymptotics of Feature Learning: How One Gradient Step Improves the Representation

Jimmy Ba, Murat A. Erdogdu, Taiji Suzuki et al.

We study the first gradient descent step on the first-layer parameters $\boldsymbol{W}$ in a two-layer neural network: $f(\boldsymbol{x}) = \frac{1}{\sqrt{N}}\boldsymbol{a}^\topσ(\boldsymbol{W}^\top\boldsymbol{x})$, where $\boldsymbol{W}\in\mathbb{R}^{d\times N}, \boldsymbol{a}\in\mathbb{R}^{N}$ are randomly initialized, and the training objective is the empirical MSE loss: $\frac{1}{n}\sum_{i=1}^n (f(\boldsymbol{x}_i)-y_i)^2$. In the proportional asymptotic limit where $n,d,N\to\infty$ at the same rate, and an idealized student-teacher setting, we show that the first gradient update contains a rank-1 "spike", which results in an alignment between the first-layer weights and the linear component of the teacher model $f^*$. To characterize the impact of this alignment, we compute the prediction risk of ridge regression on the conjugate kernel after one gradient step on $\boldsymbol{W}$ with learning rate $η$, when $f^*$ is a single-index model. We consider two scalings of the first step learning rate $η$. For small $η$, we establish a Gaussian equivalence property for the trained feature map, and prove that the learned kernel improves upon the initial random features model, but cannot defeat the best linear model on the input. Whereas for sufficiently large $η$, we prove that for certain $f^*$, the same ridge estimator on trained features can go beyond this "linear regime" and outperform a wide range of random features and rotationally invariant kernels. Our results demonstrate that even one gradient step can lead to a considerable advantage over random features, and highlight the role of learning rate scaling in the initial phase of training.

14.9LGJun 12, 2023
Convergence of mean-field Langevin dynamics: Time and space discretization, stochastic gradient, and variance reduction

Taiji Suzuki, Denny Wu, Atsushi Nitanda

The mean-field Langevin dynamics (MFLD) is a nonlinear generalization of the Langevin dynamics that incorporates a distribution-dependent drift, and it naturally arises from the optimization of two-layer neural networks via (noisy) gradient descent. Recent works have shown that MFLD globally minimizes an entropy-regularized convex functional in the space of measures. However, all prior analyses assumed the infinite-particle or continuous-time limit, and cannot handle stochastic gradient updates. We provide an general framework to prove a uniform-in-time propagation of chaos for MFLD that takes into account the errors due to finite-particle approximation, time-discretization, and stochastic gradient approximation. To demonstrate the wide applicability of this framework, we establish quantitative convergence rate guarantees to the regularized global optimal solution under (i) a wide range of learning problems such as neural network in the mean-field regime and MMD minimization, and (ii) different gradient estimators including SGD and SVRG. Despite the generality of our results, we achieve an improved convergence rate in both the SGD and SVRG settings when specialized to the standard Langevin dynamics.

12.3LGFeb 8, 2023
DIFF2: Differential Private Optimization via Gradient Differences for Nonconvex Distributed Learning

Tomoya Murata, Taiji Suzuki

Differential private optimization for nonconvex smooth objective is considered. In the previous work, the best known utility bound is $\widetilde O(\sqrt{d}/(n\varepsilon_\mathrm{DP}))$ in terms of the squared full gradient norm, which is achieved by Differential Private Gradient Descent (DP-GD) as an instance, where $n$ is the sample size, $d$ is the problem dimensionality and $\varepsilon_\mathrm{DP}$ is the differential privacy parameter. To improve the best known utility bound, we propose a new differential private optimization framework called \emph{DIFF2 (DIFFerential private optimization via gradient DIFFerences)} that constructs a differential private global gradient estimator with possibly quite small variance based on communicated \emph{gradient differences} rather than gradients themselves. It is shown that DIFF2 with a gradient descent subroutine achieves the utility of $\widetilde O(d^{2/3}/(n\varepsilon_\mathrm{DP})^{4/3})$, which can be significantly better than the previous one in terms of the dependence on the sample size $n$. To the best of our knowledge, this is the first fundamental result to improve the standard utility $\widetilde O(\sqrt{d}/(n\varepsilon_\mathrm{DP}))$ for nonconvex objectives. Additionally, a more computational and communication efficient subroutine is combined with DIFF2 and its theoretical analysis is also given. Numerical experiments are conducted to validate the superiority of DIFF2 framework.

10.8MLMay 30, 2022
Excess Risk of Two-Layer ReLU Neural Networks in Teacher-Student Settings and its Superiority to Kernel Methods

Shunta Akiyama, Taiji Suzuki

While deep learning has outperformed other methods for various tasks, theoretical frameworks that explain its reason have not been fully established. To address this issue, we investigate the excess risk of two-layer ReLU neural networks in a teacher-student regression model, in which a student network learns an unknown teacher network through its outputs. Especially, we consider the student network that has the same width as the teacher network and is trained in two phases: first by noisy gradient descent and then by the vanilla gradient descent. Our result shows that the student network provably reaches a near-global optimal solution and outperforms any kernel methods estimator (more generally, linear estimators), including neural tangent kernel approach, random feature model, and other kernel methods, in a sense of the minimax optimal rate. The key concept inducing this superiority is the non-convexity of the neural network models. Even though the loss landscape is highly non-convex, the student network adaptively learns the teacher neurons.

10.7LGFeb 12, 2023
Koopman-based generalization bound: New aspect for full-rank weights

Yuka Hashimoto, Sho Sonoda, Isao Ishikawa et al.

We propose a new bound for generalization of neural networks using Koopman operators. Whereas most of existing works focus on low-rank weight matrices, we focus on full-rank weight matrices. Our bound is tighter than existing norm-based bounds when the condition numbers of weight matrices are small. Especially, it is completely independent of the width of the network if the weight matrices are orthogonal. Our bound does not contradict to the existing bounds but is a complement to the existing bounds. As supported by several existing empirical results, low-rankness is not the only reason for generalization. Furthermore, our bound can be combined with the existing bounds to obtain a tighter bound. Our result sheds new light on understanding generalization of neural networks with full-rank weight matrices, and it provides a connection between operator-theoretic analysis and generalization of neural networks.

8.6MLMar 6, 2023
Primal and Dual Analysis of Entropic Fictitious Play for Finite-sum Problems

Atsushi Nitanda, Kazusato Oko, Denny Wu et al.

The entropic fictitious play (EFP) is a recently proposed algorithm that minimizes the sum of a convex functional and entropy in the space of measures -- such an objective naturally arises in the optimization of a two-layer neural network in the mean-field regime. In this work, we provide a concise primal-dual analysis of EFP in the setting where the learning problem exhibits a finite-sum structure. We establish quantitative global convergence guarantees for both the continuous-time and discrete-time dynamics based on properties of a proximal Gibbs measure introduced in Nitanda et al. (2022). Furthermore, our primal-dual framework entails a memory-efficient particle-based implementation of the EFP update, and also suggests a connection to gradient boosting methods. We illustrate the efficiency of our novel implementation in experiments including neural network optimization and image synthesis.

3.8LGAug 1, 2023Code
Learning Green's Function Efficiently Using Low-Rank Approximations

Kishan Wimalawarne, Taiji Suzuki, Sophie Langer

Learning the Green's function using deep learning models enables to solve different classes of partial differential equations. A practical limitation of using deep learning for the Green's function is the repeated computationally expensive Monte-Carlo integral approximations. We propose to learn the Green's function by low-rank decomposition, which results in a novel architecture to remove redundant computations by separate learning with domain data for evaluation and Monte-Carlo samples for integral approximation. Using experiments we show that the proposed method improves computational time compared to MOD-Net while achieving comparable accuracy compared to both PINNs and MOD-Net.

27.3MLAug 22, 2024
Transformers are Minimax Optimal Nonparametric In-Context Learners

Juno Kim, Tai Nakamaki, Taiji Suzuki

In-context learning (ICL) of large language models has proven to be a surprisingly effective method of learning a new task from only a few demonstrative examples. In this paper, we study the efficacy of ICL from the viewpoint of statistical learning theory. We develop approximation and generalization error bounds for a transformer composed of a deep neural network and one linear attention layer, pretrained on nonparametric regression tasks sampled from general function spaces including the Besov space and piecewise $γ$-smooth class. We show that sufficiently trained transformers can achieve -- and even improve upon -- the minimax optimal estimation risk in context by encoding the most relevant basis representations during pretraining. Our analysis extends to high-dimensional or sequential data and distinguishes the \emph{pretraining} and \emph{in-context} generalization gaps. Furthermore, we establish information-theoretic lower bounds for meta-learners w.r.t. both the number of tasks and in-context examples. These findings shed light on the roles of task diversity and representation learning for ICL.

3.3LGSep 1, 2022
Versatile Single-Loop Method for Gradient Estimator: First and Second Order Optimality, and its Application to Federated Learning

Kazusato Oko, Shunta Akiyama, Tomoya Murata et al.

While variance reduction methods have shown great success in solving large scale optimization problems, many of them suffer from accumulated errors and, therefore, should periodically require the full gradient computation. In this paper, we present a single-loop algorithm named SLEDGE (Single-Loop mEthoD for Gradient Estimator) for finite-sum nonconvex optimization, which does not require periodic refresh of the gradient estimator but achieves nearly optimal gradient complexity. Unlike existing methods, SLEDGE has the advantage of versatility; (i) second-order optimality, (ii) exponential convergence in the PL region, and (iii) smaller complexity under less heterogeneity of data. We build an efficient federated learning algorithm by exploiting these favorable properties. We show the first and second-order optimality of the output and also provide analysis under PL conditions. When the local budget is sufficiently large and clients are less (Hessian-)~heterogeneous, the algorithm requires fewer communication rounds then existing methods such as FedAvg, SCAFFOLD, and Mime. The superiority of our method is verified in numerical experiments.

20.7LGSep 28, 2024
Unveil Benign Overfitting for Transformer in Vision: Training Dynamics, Convergence, and Generalization

Jiarui Jiang, Wei Huang, Miao Zhang et al.

Transformers have demonstrated great power in the recent development of large foundational models. In particular, the Vision Transformer (ViT) has brought revolutionary changes to the field of vision, achieving significant accomplishments on the experimental side. However, their theoretical capabilities, particularly in terms of generalization when trained to overfit training data, are still not fully understood. To address this gap, this work delves deeply into the benign overfitting perspective of transformers in vision. To this end, we study the optimization of a Transformer composed of a self-attention layer with softmax followed by a fully connected layer under gradient descent on a certain data distribution model. By developing techniques that address the challenges posed by softmax and the interdependent nature of multiple weights in transformer optimization, we successfully characterized the training dynamics and achieved generalization in post-training. Our results establish a sharp condition that can distinguish between the small test error phase and the large test error regime, based on the signal-to-noise ratio in the data model. The theoretical results are further verified by experimental simulation. To the best of our knowledge, this is the first work to characterize benign overfitting for Transformers.

18.8LGMay 10, 2024Code
State-Free Inference of State-Space Models: The Transfer Function Approach

Rom N. Parnichkun, Stefano Massaroli, Alessandro Moro et al.

We approach designing a state-space model for deep learning applications through its dual representation, the transfer function, and uncover a highly efficient sequence parallel inference algorithm that is state-free: unlike other proposed algorithms, state-free inference does not incur any significant memory or computational cost with an increase in state size. We achieve this using properties of the proposed frequency domain transfer function parametrization, which enables direct computation of its corresponding convolutional kernel's spectrum via a single Fast Fourier Transform. Our experimental results across multiple sequence lengths and state sizes illustrates, on average, a 35% training speed improvement over S4 layers -- parametrized in time-domain -- on the Long Range Arena benchmark, while delivering state-of-the-art downstream performances over other attention-free approaches. Moreover, we report improved perplexity in language modeling over a long convolutional Hyena baseline, by simply introducing our transfer function parametrization. Our code is available at https://github.com/ruke1ire/RTF.

11.4LGNov 10, 2025
Provable Benefit of Curriculum in Transformer Tree-Reasoning Post-Training

Dake Bu, Wei Huang, Andi Han et al.

Recent curriculum techniques in the post-training stage of LLMs have been widely observed to outperform non-curriculum approaches in enhancing reasoning performance, yet a principled understanding of why and to what extent they work remains elusive. To address this gap, we develop a theoretical framework grounded in the intuition that progressively learning through manageable steps is more efficient than directly tackling a hard reasoning task, provided each stage stays within the model's effective competence. Under mild complexity conditions linking consecutive curriculum stages, we show that curriculum post-training avoids the exponential complexity bottleneck. To substantiate this result, drawing insights from the Chain-of-Thoughts (CoTs) solving mathematical problems such as Countdown and parity, we model CoT generation as a states-conditioned autoregressive reasoning tree, define a uniform-branching base model to capture pretrained behavior, and formalize curriculum stages as either depth-increasing (longer reasoning chains) or hint-decreasing (shorter prefixes) subtasks. Our analysis shows that, under outcome-only reward signals, reinforcement learning finetuning achieves high accuracy with polynomial sample complexity, whereas direct learning suffers from an exponential bottleneck. We further establish analogous guarantees for test-time scaling, where curriculum-aware querying reduces both reward oracle calls and sampling cost from exponential to polynomial order.

1.8LGSep 12, 2022
Graph Polynomial Convolution Models for Node Classification of Non-Homophilous Graphs

Kishan Wimalawarne, Taiji Suzuki

We investigate efficient learning from higher-order graph convolution and learning directly from adjacency matrices for node classification. We revisit the scaled graph residual network and remove ReLU activation from residual layers and apply a single weight matrix at each residual layer. We show that the resulting model lead to new graph convolution models as a polynomial of the normalized adjacency matrix, the residual weight matrix, and the residual scaling parameter. Additionally, we propose adaptive learning between directly graph polynomial convolution models and learning directly from the adjacency matrix. Furthermore, we propose fully adaptive models to learn scaling parameters at each residual layer. We show that generalization bounds of proposed methods are bounded as a polynomial of eigenvalue spectrum, scaling parameters, and upper bounds of residual weights. By theoretical analysis, we argue that the proposed models can obtain improved generalization bounds by limiting the higher-orders of convolutions and direct learning from the adjacency matrix. Using a wide set of real-data, we demonstrate that the proposed methods obtain improved accuracy for node-classification of non-homophilous graphs.

2.0LGNov 15, 2023
Scalable Federated Learning for Clients with Different Input Image Sizes and Numbers of Output Categories

Shuhei Nitta, Taiji Suzuki, Albert Rodríguez Mulet et al.

Federated learning is a privacy-preserving training method which consists of training from a plurality of clients but without sharing their confidential data. However, previous work on federated learning do not explore suitable neural network architectures for clients with different input images sizes and different numbers of output categories. In this paper, we propose an effective federated learning method named ScalableFL, where the depths and widths of the local models for each client are adjusted according to the clients' input image size and the numbers of output categories. In addition, we provide a new bound for the generalization gap of federated learning. In particular, this bound helps to explain the effectiveness of our scalable neural network approach. We demonstrate the effectiveness of ScalableFL in several heterogeneous client settings for both image classification and object detection tasks.

2.3CVSep 19, 2020
MSR-DARTS: Minimum Stable Rank of Differentiable Architecture Search

Kengo Machida, Kuniaki Uto, Koichi Shinoda et al.

In neural architecture search (NAS), differentiable architecture search (DARTS) has recently attracted much attention due to its high efficiency. It defines an over-parameterized network with mixed edges, each of which represents all operator candidates, and jointly optimizes the weights of the network and its architecture in an alternating manner. However, this method finds a model with the weights converging faster than the others, and such a model with fastest convergence often leads to overfitting. Accordingly, the resulting model cannot always be well-generalized. To overcome this problem, we propose a method called minimum stable rank DARTS (MSR-DARTS), for finding a model with the best generalization error by replacing architecture optimization with the selection process using the minimum stable rank criterion. Specifically, a convolution operator is represented by a matrix, and MSR-DARTS selects the one with the smallest stable rank. We evaluated MSR-DARTS on CIFAR-10 and ImageNet datasets. It achieves an error rate of 2.54% with 4.0M parameters within 0.3 GPU-days on CIFAR-10, and a top-1 error rate of 23.9% on ImageNet. The official code is available at https://github.com/mtaecchhi/msrdarts.git.

16.2LGJun 15, 2020Code
Optimization and Generalization Analysis of Transduction through Gradient Boosting and Application to Multi-scale Graph Neural Networks

Kenta Oono, Taiji Suzuki

It is known that the current graph neural networks (GNNs) are difficult to make themselves deep due to the problem known as over-smoothing. Multi-scale GNNs are a promising approach for mitigating the over-smoothing problem. However, there is little explanation of why it works empirically from the viewpoint of learning theory. In this study, we derive the optimization and generalization guarantees of transductive learning algorithms that include multi-scale GNNs. Using the boosting theory, we prove the convergence of the training error under weak learning-type conditions. By combining it with generalization gap bounds in terms of transductive Rademacher complexity, we show that a test error bound of a specific type of multi-scale GNNs that decreases corresponding to the number of node aggregations under some conditions. Our results offer theoretical explanations for the effectiveness of the multi-scale structure against the over-smoothing problem. We apply boosting algorithms to the training of multi-scale GNNs for real-world node prediction tasks. We confirm that its performance is comparable to existing GNNs, and the practical behaviors are consistent with theoretical observations. Code is available at https://github.com/delta2323/GB-GNN.

42.0LGMay 27, 2019Code
Graph Neural Networks Exponentially Lose Expressive Power for Node Classification

Kenta Oono, Taiji Suzuki

Graph Neural Networks (graph NNs) are a promising deep learning approach for analyzing graph-structured data. However, it is known that they do not improve (or sometimes worsen) their predictive performance as we pile up many layers and add non-lineality. To tackle this problem, we investigate the expressive power of graph NNs via their asymptotic behaviors as the layer size tends to infinity. Our strategy is to generalize the forward propagation of a Graph Convolutional Network (GCN), which is a popular graph NN variant, as a specific dynamical system. In the case of a GCN, we show that when its weights satisfy the conditions determined by the spectra of the (augmented) normalized Laplacian, its output exponentially approaches the set of signals that carry information of the connected components and node degrees only for distinguishing nodes. Our theory enables us to relate the expressive power of GCNs with the topological information of the underlying graphs inherent in the graph spectra. To demonstrate this, we characterize the asymptotic behavior of GCNs on the Erdős -- Rényi graph. We show that when the Erdős -- Rényi graph is sufficiently dense and large, a broad range of GCNs on it suffers from the "information loss" in the limit of infinite layers with high probability. Based on the theory, we provide a principled guideline for weight normalization of graph NNs. We experimentally confirm that the proposed weight scaling enhances the predictive performance of GCNs in real data. Code is available at https://github.com/delta2323/gnn-asymptotics.

31.2LGMar 26, 2024Code
Mechanistic Design and Scaling of Hybrid Architectures

Michael Poli, Armin W Thomas, Eric Nguyen et al.

The development of deep learning architectures is a resource-demanding process, due to a vast design space, long prototyping times, and high compute costs associated with at-scale model training and evaluation. We set out to simplify this process by grounding it in an end-to-end mechanistic architecture design (MAD) pipeline, encompassing small-scale capability unit tests predictive of scaling laws. Through a suite of synthetic token manipulation tasks such as compression and recall, designed to probe capabilities, we identify and test new hybrid architectures constructed from a variety of computational primitives. We experimentally validate the resulting architectures via an extensive compute-optimal and a new state-optimal scaling law analysis, training over 500 language models between 70M to 7B parameters. Surprisingly, we find MAD synthetics to correlate with compute-optimal perplexity, enabling accurate evaluation of new architectures via isolated proxy tasks. The new architectures found via MAD, based on simple ideas such as hybridization and sparsity, outperform state-of-the-art Transformer, convolutional, and recurrent architectures (Transformer++, Hyena, Mamba) in scaling, both at compute-optimal budgets and in overtrained regimes. Overall, these results provide evidence that performance on curated synthetic tasks can be predictive of scaling laws, and that an optimal architecture should leverage specialized layers via a hybrid topology.

30.0MLFeb 2, 2024
Transformers Learn Nonlinear Features In Context: Nonconvex Mean-field Dynamics on the Attention Landscape

Juno Kim, Taiji Suzuki

Large language models based on the Transformer architecture have demonstrated impressive capabilities to learn in context. However, existing theoretical studies on how this phenomenon arises are limited to the dynamics of a single layer of attention trained on linear regression tasks. In this paper, we study the optimization of a Transformer consisting of a fully connected layer followed by a linear attention layer. The MLP acts as a common nonlinear representation or feature map, greatly enhancing the power of in-context learning. We prove in the mean-field and two-timescale limit that the infinite-dimensional loss landscape for the distribution of parameters, while highly nonconvex, becomes quite benign. We also analyze the second-order stability of mean-field dynamics and show that Wasserstein gradient flow almost always avoids saddle points. Furthermore, we establish novel methods for obtaining concrete improvement rates both away from and near critical points. This represents the first saddle point analysis of mean-field dynamics in general and the techniques are of independent interest.

29.5LGOct 11, 2024
Transformers Provably Solve Parity Efficiently with Chain of Thought

Juno Kim, Taiji Suzuki

This work provides the first theoretical analysis of training transformers to solve complex problems by recursively generating intermediate states, analogous to fine-tuning for chain-of-thought (CoT) reasoning. We consider training a one-layer transformer to solve the fundamental $k$-parity problem, extending the work on RNNs by Wies et al. (2023). We establish three key results: (1) any finite-precision gradient-based algorithm, without intermediate supervision, requires substantial iterations to solve parity with finite samples. (2) In contrast, when intermediate parities are incorporated into the loss function, our model can learn parity in one gradient update when aided by \emph{teacher forcing}, where ground-truth labels of the reasoning chain are provided at each generation step. (3) Even without teacher forcing, where the model must generate CoT chains end-to-end, parity can be learned efficiently if augmented data is employed to internally verify the soundness of intermediate steps. Our findings, supported by numerical experiments, show that task decomposition and stepwise reasoning naturally arise from optimizing transformers with CoT; moreover, self-consistency checking can improve multi-step reasoning ability, aligning with empirical studies of CoT.

23.8LGNov 4, 2024
Pretrained transformer efficiently learns low-dimensional target functions in-context

Kazusato Oko, Yujin Song, Taiji Suzuki et al.

Transformers can efficiently learn in-context from example demonstrations. Most existing theoretical analyses studied the in-context learning (ICL) ability of transformers for linear function classes, where it is typically shown that the minimizer of the pretraining loss implements one gradient descent step on the least squares objective. However, this simplified linear setting arguably does not demonstrate the statistical efficiency of ICL, since the pretrained transformer does not outperform directly solving linear regression on the test prompt. In this paper, we study ICL of a nonlinear function class via transformer with nonlinear MLP layer: given a class of \textit{single-index} target functions $f_*(\boldsymbol{x}) = σ_*(\langle\boldsymbol{x},\boldsymbolβ\rangle)$, where the index features $\boldsymbolβ\in\mathbb{R}^d$ are drawn from a $r$-dimensional subspace, we show that a nonlinear transformer optimized by gradient descent (with a pretraining sample complexity that depends on the \textit{information exponent} of the link functions $σ_*$) learns $f_*$ in-context with a prompt length that only depends on the dimension of the distribution of target functions $r$; in contrast, any algorithm that directly learns $f_*$ on test prompt yields a statistical complexity that scales with the ambient dimension $d$. Our result highlights the adaptivity of the pretrained transformer to low-dimensional structures of the function class, which enables sample-efficient ICL that outperforms estimators that only have access to the in-context data.

19.6MLFeb 8, 2024
How do Transformers perform In-Context Autoregressive Learning?

Michael E. Sander, Raja Giryes, Taiji Suzuki et al.

Transformers have achieved state-of-the-art performance in language modeling tasks. However, the reasons behind their tremendous success are still unclear. In this paper, towards a better understanding, we train a Transformer model on a simple next token prediction task, where sequences are generated as a first-order autoregressive process $s_{t+1} = W s_t$. We show how a trained Transformer predicts the next token by first learning $W$ in-context, then applying a prediction mapping. We call the resulting procedure in-context autoregressive learning. More precisely, focusing on commuting orthogonal matrices $W$, we first show that a trained one-layer linear Transformer implements one step of gradient descent for the minimization of an inner objective function, when considering augmented tokens. When the tokens are not augmented, we characterize the global minima of a one-layer diagonal linear multi-head Transformer. Importantly, we exhibit orthogonality between heads and show that positional encoding captures trigonometric relations in the data. On the experimental side, we consider the general case of non-commuting orthogonal matrices and generalize our theoretical findings.

18.2LGNov 5, 2024
On the Comparison between Multi-modal and Single-modal Contrastive Learning

Wei Huang, Andi Han, Yongqiang Chen et al.

Multi-modal contrastive learning with language supervision has presented a paradigm shift in modern machine learning. By pre-training on a web-scale dataset, multi-modal contrastive learning can learn high-quality representations that exhibit impressive robustness and transferability. Despite its empirical success, the theoretical understanding is still in its infancy, especially regarding its comparison with single-modal contrastive learning. In this work, we introduce a feature learning theory framework that provides a theoretical foundation for understanding the differences between multi-modal and single-modal contrastive learning. Based on a data generation model consisting of signal and noise, our analysis is performed on a ReLU network trained with the InfoMax objective function. Through a trajectory-based optimization analysis and generalization characterization on downstream tasks, we identify the critical factor, which is the signal-to-noise ratio (SNR), that impacts the generalizability in downstream tasks of both multi-modal and single-modal contrastive learning. Through the cooperation between the two modalities, multi-modal learning can achieve better feature learning, leading to improvements in performance in downstream tasks compared to single-modal learning. Our analysis provides a unified framework that can characterize the optimization and generalization of both single-modal and multi-modal contrastive learning. Empirical experiments on both synthetic and real-world datasets further consolidate our theoretical findings.

22.9AIFeb 2, 2025
Metastable Dynamics of Chain-of-Thought Reasoning: Provable Benefits of Search, RL and Distillation

Juno Kim, Denny Wu, Jason Lee et al.

A key paradigm to improve the reasoning capabilities of large language models (LLMs) is to allocate more inference-time compute to search against a verifier or reward model. This process can then be utilized to refine the pretrained model or distill its reasoning patterns into more efficient models. In this paper, we study inference-time compute by viewing chain-of-thought (CoT) generation as a metastable Markov process: easy reasoning steps (e.g., algebraic manipulations) form densely connected clusters, while hard reasoning steps (e.g., applying a relevant theorem) create sparse, low-probability edges between clusters, leading to phase transitions at longer timescales. Under this framework, we prove that implementing a search protocol that rewards sparse edges improves CoT by decreasing the expected number of steps to reach different clusters. In contrast, we establish a limit on reasoning capability when the model is restricted to local information of the pretrained graph. We also show that the information gained by search can be utilized to obtain a better reasoning model: (1) the pretrained model can be directly finetuned to favor sparse edges via policy gradient methods, and moreover (2) a compressed metastable representation of the reasoning dynamics can be distilled into a smaller, more efficient model.

15.7LGMar 22, 2024
Mean-field Analysis on Two-layer Neural Networks from a Kernel Perspective

Shokichi Takakura, Taiji Suzuki

In this paper, we study the feature learning ability of two-layer neural networks in the mean-field regime through the lens of kernel methods. To focus on the dynamics of the kernel induced by the first layer, we utilize a two-timescale limit, where the second layer moves much faster than the first layer. In this limit, the learning problem is reduced to the minimization problem over the intrinsic kernel. Then, we show the global convergence of the mean-field Langevin dynamics and derive time and particle discretization error. We also demonstrate that two-layer neural networks can learn a union of multiple reproducing kernel Hilbert spaces more efficiently than any kernel methods, and neural networks acquire data-dependent kernel which aligns with the target function. In addition, we develop a label noise procedure, which converges to the global optimum and show that the degrees of freedom appears as an implicit regularization.

12.5LGNov 4, 2024
Provably Transformers Harness Multi-Concept Word Semantics for Efficient In-Context Learning

Dake Bu, Wei Huang, Andi Han et al.

Transformer-based large language models (LLMs) have displayed remarkable creative prowess and emergence capabilities. Existing empirical studies have revealed a strong connection between these LLMs' impressive emergence abilities and their in-context learning (ICL) capacity, allowing them to solve new tasks using only task-specific prompts without further fine-tuning. On the other hand, existing empirical and theoretical studies also show that there is a linear regularity of the multi-concept encoded semantic representation behind transformer-based LLMs. However, existing theoretical work fail to build up an understanding of the connection between this regularity and the innovative power of ICL. Additionally, prior work often focuses on simplified, unrealistic scenarios involving linear transformers or unrealistic loss functions, and they achieve only linear or sub-linear convergence rates. In contrast, this work provides a fine-grained mathematical analysis to show how transformers leverage the multi-concept semantics of words to enable powerful ICL and excellent out-of-distribution ICL abilities, offering insights into how transformers innovate solutions for certain unseen tasks encoded with multiple cross-concept semantics. Inspired by empirical studies on the linear latent geometry of LLMs, the analysis is based on a concept-based low-noise sparse coding prompt model. Leveraging advanced techniques, this work showcases the exponential 0-1 loss convergence over the highly non-convex training dynamics, which pioneeringly incorporates the challenges of softmax self-attention, ReLU-activated MLPs, and cross-entropy loss. Empirical simulations corroborate the theoretical findings.

16.9LGFeb 5, 2025
Direct Distributional Optimization for Provable Alignment of Diffusion Models

Ryotaro Kawata, Kazusato Oko, Atsushi Nitanda et al.

We introduce a novel alignment method for diffusion models from distribution optimization perspectives while providing rigorous convergence guarantees. We first formulate the problem as a generic regularized loss minimization over probability distributions and directly optimize the distribution using the Dual Averaging method. Next, we enable sampling from the learned distribution by approximating its score function via Doob's $h$-transform technique. The proposed framework is supported by rigorous convergence guarantees and an end-to-end bound on the sampling error, which imply that when the original distribution's score is known accurately, the complexity of sampling from shifted distributions is independent of isoperimetric conditions. This framework is broadly applicable to general distribution optimization problems, including alignment tasks in Reinforcement Learning with Human Feedback (RLHF), Direct Preference Optimization (DPO), and Kahneman-Tversky Optimization (KTO). We empirically validate its performance on synthetic and image datasets using the DPO objective.

14.0MLMay 25, 2025Code
On the Role of Label Noise in the Feature Learning Process

Andi Han, Wei Huang, Zhanpeng Zhou et al.

Deep learning with noisy labels presents significant challenges. In this work, we theoretically characterize the role of label noise from a feature learning perspective. Specifically, we consider a signal-noise data distribution, where each sample comprises a label-dependent signal and label-independent noise, and rigorously analyze the training dynamics of a two-layer convolutional neural network under this data setup, along with the presence of label noise. Our analysis identifies two key stages. In Stage I, the model perfectly fits all the clean samples (i.e., samples without label noise) while ignoring the noisy ones (i.e., samples with noisy labels). During this stage, the model learns the signal from the clean samples, which generalizes well on unseen data. In Stage II, as the training loss converges, the gradient in the direction of noise surpasses that of the signal, leading to overfitting on noisy samples. Eventually, the model memorizes the noise present in the noisy samples and degrades its generalization ability. Furthermore, our analysis provides a theoretical basis for two widely used techniques for tackling label noise: early stopping and sample selection. Experiments on both synthetic and real-world setups validate our theory.

14.4LGMay 12, 2025
Direct Density Ratio Optimization: A Statistically Consistent Approach to Aligning Large Language Models

Rei Higuchi, Taiji Suzuki

Aligning large language models (LLMs) with human preferences is crucial for safe deployment, yet existing methods assume specific preference models like Bradley-Terry model. This assumption leads to statistical inconsistency, where more data doesn't guarantee convergence to true human preferences. To address this critical gap, we introduce a novel alignment method Direct Density Ratio Optimization (DDRO). DDRO directly estimates the density ratio between preferred and unpreferred output distributions, circumventing the need for explicit human preference modeling. We theoretically prove that DDRO is statistically consistent, ensuring convergence to the true preferred distribution as the data size grows, regardless of the underlying preference structure. Experiments demonstrate that DDRO achieves superior performance compared to existing methods on many major benchmarks. DDRO unlocks the potential for truly data-driven alignment, paving the way for more reliable and human-aligned LLMs.

12.3MLFeb 9, 2025
Propagation of Chaos for Mean-Field Langevin Dynamics and its Application to Model Ensemble

Atsushi Nitanda, Anzelle Lee, Damian Tan Xing Kai et al.

Mean-field Langevin dynamics (MFLD) is an optimization method derived by taking the mean-field limit of noisy gradient descent for two-layer neural networks in the mean-field regime. Recently, the propagation of chaos (PoC) for MFLD has gained attention as it provides a quantitative characterization of the optimization complexity in terms of the number of particles and iterations. A remarkable progress by Chen et al. (2022) showed that the approximation error due to finite particles remains uniform in time and diminishes as the number of particles increases. In this paper, by refining the defective log-Sobolev inequality -- a key result from that earlier work -- under the neural network training setting, we establish an improved PoC result for MFLD, which removes the exponential dependence on the regularization coefficient from the particle approximation term of the optimization complexity. As an application, we propose a PoC-based model ensemble strategy with theoretical guarantees.

7.1LGSep 28, 2025
Trained Mamba Emulates Online Gradient Descent in In-Context Linear Regression

Jiarui Jiang, Wei Huang, Miao Zhang et al.

State-space models (SSMs), particularly Mamba, emerge as an efficient Transformer alternative with linear complexity for long-sequence modeling. Recent empirical works demonstrate Mamba's in-context learning (ICL) capabilities competitive with Transformers, a critical capacity for large foundation models. However, theoretical understanding of Mamba's ICL remains limited, restricting deeper insights into its underlying mechanisms. Even fundamental tasks such as linear regression ICL, widely studied as a standard theoretical benchmark for Transformers, have not been thoroughly analyzed in the context of Mamba. To address this gap, we study the training dynamics of Mamba on the linear regression ICL task. By developing novel techniques tackling non-convex optimization with gradient descent related to Mamba's structure, we establish an exponential convergence rate to ICL solution, and derive a loss bound that is comparable to Transformer's. Importantly, our results reveal that Mamba can perform a variant of \textit{online gradient descent} to learn the latent function in context. This mechanism is different from that of Transformer, which is typically understood to achieve ICL through gradient descent emulation. The theoretical results are verified by experimental simulation.

14.4LGJun 2, 2025
Mixture of Experts Provably Detect and Learn the Latent Cluster Structure in Gradient-Based Learning

Ryotaro Kawata, Kohsei Matsutani, Yuri Kinoshita et al.

Mixture of Experts (MoE), an ensemble of specialized models equipped with a router that dynamically distributes each input to appropriate experts, has achieved successful results in the field of machine learning. However, theoretical understanding of this architecture is falling behind due to its inherent complexity. In this paper, we theoretically study the sample and runtime complexity of MoE following the stochastic gradient descent (SGD) when learning a regression task with an underlying cluster structure of single index models. On the one hand, we prove that a vanilla neural network fails in detecting such a latent organization as it can only process the problem as a whole. This is intrinsically related to the concept of information exponent which is low for each cluster, but increases when we consider the entire task. On the other hand, we show that a MoE succeeds in dividing this problem into easier subproblems by leveraging the ability of each expert to weakly recover the simpler function corresponding to an individual cluster. To the best of our knowledge, this work is among the first to explore the benefits of the MoE framework by examining its SGD dynamics in the context of nonlinear regression.

13.0LGApr 28, 2025
Quantifying Memory Utilization with Effective State-Size

Rom N. Parnichkun, Neehal Tumma, Armin W. Thomas et al.

The need to develop a general framework for architecture analysis is becoming increasingly important, given the expanding design space of sequence models. To this end, we draw insights from classical signal processing and control theory, to develop a quantitative measure of \textit{memory utilization}: the internal mechanisms through which a model stores past information to produce future outputs. This metric, which we call \textbf{\textit{effective state-size}} (ESS), is tailored to the fundamental class of systems with \textit{input-invariant} and \textit{input-varying linear operators}, encompassing a variety of computational units such as variants of attention, convolutions, and recurrences. Unlike prior work on memory utilization, which either relies on raw operator visualizations (e.g. attention maps), or simply the total \textit{memory capacity} (i.e. cache size) of a model, our metrics provide highly interpretable and actionable measurements. In particular, we show how ESS can be leveraged to improve initialization strategies, inform novel regularizers and advance the performance-efficiency frontier through model distillation. Furthermore, we demonstrate that the effect of context delimiters (such as end-of-speech tokens) on ESS highlights cross-architectural differences in how large language models utilize their available memory to recall information. Overall, we find that ESS provides valuable insights into the dynamics that dictate memory utilization, enabling the design of more efficient and effective sequence models.

6.7CLApr 24, 2025
When Does Metadata Conditioning (NOT) Work for Language Model Pre-Training? A Study with Context-Free Grammars

Rei Higuchi, Ryotaro Kawata, Naoki Nishikawa et al.

The ability to acquire latent semantics is one of the key properties that determines the performance of language models. One convenient approach to invoke this ability is to prepend metadata (e.g. URLs, domains, and styles) at the beginning of texts in the pre-training data, making it easier for the model to access latent semantics before observing the entire text. Previous studies have reported that this technique actually improves the performance of trained models in downstream tasks; however, this improvement has been observed only in specific downstream tasks, without consistent enhancement in average next-token prediction loss. To understand this phenomenon, we closely investigate how prepending metadata during pre-training affects model performance by examining its behavior using artificial data. Interestingly, we found that this approach produces both positive and negative effects on the downstream tasks. We demonstrate that the effectiveness of the approach depends on whether latent semantics can be inferred from the downstream task's prompt. Specifically, through investigations using data generated by probabilistic context-free grammars, we show that training with metadata helps improve model's performance when the given context is long enough to infer the latent semantics. In contrast, the technique negatively impacts performance when the context lacks the necessary information to make an accurate posterior inference.

13.4LGApr 30, 2024
Weighted Point Set Embedding for Multimodal Contrastive Learning Toward Optimal Similarity Metric

Toshimitsu Uesaka, Taiji Suzuki, Yuhta Takida et al.

In typical multimodal contrastive learning, such as CLIP, encoders produce one point in the latent representation space for each input. However, one-point representation has difficulty in capturing the relationship and the similarity structure of a huge amount of instances in the real world. For richer classes of the similarity, we propose the use of weighted point sets, namely, sets of pairs of weight and vector, as representations of instances. In this work, we theoretically show the benefit of our proposed method through a new understanding of the contrastive loss of CLIP, which we call symmetric InfoNCE. We clarify that the optimal similarity that minimizes symmetric InfoNCE is the pointwise mutual information, and show an upper bound of excess risk on downstream classification tasks of representations that achieve the optimal similarity. In addition, we show that our proposed similarity based on weighted point sets consistently achieves the optimal similarity. To verify the effectiveness of our proposed method, we demonstrate pretraining of text-image representation models and classification tasks on common benchmarks.

11.4LGOct 20, 2025
How Does Label Noise Gradient Descent Improve Generalization in the Low SNR Regime?

Wei Huang, Andi Han, Yujin Song et al.

The capacity of deep learning models is often large enough to both learn the underlying statistical signal and overfit to noise in the training set. This noise memorization can be harmful especially for data with a low signal-to-noise ratio (SNR), leading to poor generalization. Inspired by prior observations that label noise provides implicit regularization that improves generalization, in this work, we investigate whether introducing label noise to the gradient updates can enhance the test performance of neural network (NN) in the low SNR regime. Specifically, we consider training a two-layer NN with a simple label noise gradient descent (GD) algorithm, in an idealized signal-noise data setting. We prove that adding label noise during training suppresses noise memorization, preventing it from dominating the learning process; consequently, label noise GD enjoys rapid signal growth while the overfitting remains controlled, thereby achieving good generalization despite the low SNR. In contrast, we also show that NN trained with standard GD tends to overfit to noise in the same low SNR setting and establish a non-vanishing lower bound on its test error, thus demonstrating the benefit of introducing label noise in gradient-based training.

4.1LGOct 14, 2025
Mamba Can Learn Low-Dimensional Targets In-Context via Test-Time Feature Learning

Junsoo Oh, Wei Huang, Taiji Suzuki

Mamba, a recently proposed linear-time sequence model, has attracted significant attention for its computational efficiency and strong empirical performance. However, a rigorous theoretical understanding of its underlying mechanisms remains limited. In this work, we provide a theoretical analysis of Mamba's in-context learning (ICL) capability by focusing on tasks defined by low-dimensional nonlinear target functions. Specifically, we study in-context learning of a single-index model $y \approx g_*(\langle \boldsymbolβ, \boldsymbol{x} \rangle)$, which depends on only a single relevant direction $\boldsymbolβ$, referred to as feature. We prove that Mamba, pretrained by gradient-based methods, can achieve efficient ICL via test-time feature learning, extracting the relevant direction directly from context examples. Consequently, we establish a test-time sample complexity that improves upon linear Transformers -- analyzed to behave like kernel methods -- and is comparable to nonlinear Transformers, which have been shown to surpass the Correlational Statistical Query (CSQ) lower bound and achieve near information-theoretically optimal rate in previous works. Our analysis reveals the crucial role of the nonlinear gating mechanism in Mamba for feature extraction, highlighting it as the fundamental driver behind Mamba's ability to achieve both computational efficiency and high performance.

4.1LGJun 12, 2025
Generalization Bound of Gradient Flow through Training Trajectory and Data-dependent Kernel

Yilan Chen, Zhichao Wang, Wei Huang et al.

Gradient-based optimization methods have shown remarkable empirical success, yet their theoretical generalization properties remain only partially understood. In this paper, we establish a generalization bound for gradient flow that aligns with the classical Rademacher complexity bounds for kernel methods-specifically those based on the RKHS norm and kernel trace-through a data-dependent kernel called the loss path kernel (LPK). Unlike static kernels such as NTK, the LPK captures the entire training trajectory, adapting to both data and optimization dynamics, leading to tighter and more informative generalization guarantees. Moreover, the bound highlights how the norm of the training loss gradients along the optimization trajectory influences the final generalization performance. The key technical ingredients in our proof combine stability analysis of gradient flow with uniform convergence via Rademacher complexity. Our bound recovers existing kernel regression bounds for overparameterized neural networks and shows the feature learning capability of neural networks compared to kernel methods. Numerical experiments on real-world datasets validate that our bounds correlate well with the true generalization gap.

2.6LGOct 29, 2024
Dimensionality-induced information loss of outliers in deep neural networks

Kazuki Uematsu, Kosuke Haruki, Taiji Suzuki et al.

Out-of-distribution (OOD) detection is a critical issue for the stable and reliable operation of systems using a deep neural network (DNN). Although many OOD detection methods have been proposed, it remains unclear how the differences between in-distribution (ID) and OOD samples are generated by each processing step inside DNNs. We experimentally clarify this issue by investigating the layer dependence of feature representations from multiple perspectives. We find that intrinsic low dimensionalization of DNNs is essential for understanding how OOD samples become more distinct from ID samples as features propagate to deeper layers. Based on these observations, we provide a simple picture that consistently explains various properties of OOD samples. Specifically, low-dimensional weights eliminate most information from OOD samples, resulting in misclassifications due to excessive attention to dataset bias. In addition, we demonstrate the utility of dimensionality by proposing a dimensionality-aware OOD detection method based on alignment of features and weights, which consistently achieves high performance for various datasets with lower computational cost.

19.3LGJun 17, 2024
Learning sum of diverse features: computational hardness and efficient gradient-based training for ridge combinations

Kazusato Oko, Yujin Song, Taiji Suzuki et al.

We study the computational and sample complexity of learning a target function $f_*:\mathbb{R}^d\to\mathbb{R}$ with additive structure, that is, $f_*(x) = \frac{1}{\sqrt{M}}\sum_{m=1}^M f_m(\langle x, v_m\rangle)$, where $f_1,f_2,...,f_M:\mathbb{R}\to\mathbb{R}$ are nonlinear link functions of single-index models (ridge functions) with diverse and near-orthogonal index features $\{v_m\}_{m=1}^M$, and the number of additive tasks $M$ grows with the dimensionality $M\asymp d^γ$ for $γ\ge 0$. This problem setting is motivated by the classical additive model literature, the recent representation learning theory of two-layer neural network, and large-scale pretraining where the model simultaneously acquires a large number of "skills" that are often localized in distinct parts of the trained network. We prove that a large subset of polynomial $f_*$ can be efficiently learned by gradient descent training of a two-layer neural network, with a polynomial statistical and computational complexity that depends on the number of tasks $M$ and the information exponent of $f_m$, despite the unknown link function and $M$ growing with the dimensionality. We complement this learnability guarantee with computational hardness result by establishing statistical query (SQ) lower bounds for both the correlational SQ and full SQ algorithms.

11.5LGJun 6, 2024
Provably Neural Active Learning Succeeds via Prioritizing Perplexing Samples

Dake Bu, Wei Huang, Taiji Suzuki et al.

Neural Network-based active learning (NAL) is a cost-effective data selection technique that utilizes neural networks to select and train on a small subset of samples. While existing work successfully develops various effective or theory-justified NAL algorithms, the understanding of the two commonly used query criteria of NAL: uncertainty-based and diversity-based, remains in its infancy. In this work, we try to move one step forward by offering a unified explanation for the success of both query criteria-based NAL from a feature learning view. Specifically, we consider a feature-noise data model comprising easy-to-learn or hard-to-learn features disrupted by noise, and conduct analysis over 2-layer NN-based NALs in the pool-based scenario. We provably show that both uncertainty-based and diversity-based NAL are inherently amenable to one and the same principle, i.e., striving to prioritize samples that contain yet-to-be-learned features. We further prove that this shared principle is the key to their success-achieve small test error within a small labeled set. Contrastingly, the strategy-free passive learning exhibits a large test error due to the inadequate learning of yet-to-be-learned features, necessitating resort to a significantly larger label complexity for a sufficient test error reduction. Experimental results validate our findings.

28.5LGJun 3, 2024
Neural network learns low-dimensional polynomials with SGD near the information-theoretic limit

Jason D. Lee, Kazusato Oko, Taiji Suzuki et al.

We study the problem of gradient descent learning of a single-index target function $f_*(\boldsymbol{x}) = \textstyleσ_*\left(\langle\boldsymbol{x},\boldsymbolθ\rangle\right)$ under isotropic Gaussian data in $\mathbb{R}^d$, where the unknown link function $σ_*:\mathbb{R}\to\mathbb{R}$ has information exponent $p$ (defined as the lowest degree in the Hermite expansion). Prior works showed that gradient-based training of neural networks can learn this target with $n\gtrsim d^{Θ(p)}$ samples, and such complexity is predicted to be necessary by the correlational statistical query lower bound. Surprisingly, we prove that a two-layer neural network optimized by an SGD-based algorithm (on the squared loss) learns $f_*$ with a complexity that is not governed by the information exponent. Specifically, for arbitrary polynomial single-index models, we establish a sample and runtime complexity of $n \simeq T = Θ(d\!\cdot\! \mathrm{polylog} d)$, where $Θ(\cdot)$ hides a constant only depending on the degree of $σ_*$; this dimension dependence matches the information theoretic limit up to polylogarithmic factors. More generally, we show that $n\gtrsim d^{(p_*-1)\vee 1}$ samples are sufficient to achieve low generalization error, where $p_* \le p$ is the \textit{generative exponent} of the link function. Core to our analysis is the reuse of minibatch in the gradient computation, which gives rise to higher-order information beyond correlational queries.

20.7LGMay 30, 2023
Approximation and Estimation Ability of Transformers for Sequence-to-Sequence Functions with Infinite Dimensional Input

Shokichi Takakura, Taiji Suzuki

Despite the great success of Transformer networks in various applications such as natural language processing and computer vision, their theoretical aspects are not well understood. In this paper, we study the approximation and estimation ability of Transformers as sequence-to-sequence functions with infinite dimensional inputs. Although inputs and outputs are both infinite dimensional, we show that when the target function has anisotropic smoothness, Transformers can avoid the curse of dimensionality due to their feature extraction ability and parameter sharing property. In addition, we show that even if the smoothness changes depending on each input, Transformers can estimate the importance of features for each input and extract important features dynamically. Then, we proved that Transformers achieve similar convergence rate as in the case of the fixed smoothness. Our theoretical results support the practical success of Transformers for high dimensional data.

4.3MLMay 13, 2023
Tight and fast generalization error bound of graph embedding in metric space

Atsushi Suzuki, Atsushi Nitanda, Taiji Suzuki et al.

Recent studies have experimentally shown that we can achieve in non-Euclidean metric space effective and efficient graph embedding, which aims to obtain the vertices' representations reflecting the graph's structure in the metric space. Specifically, graph embedding in hyperbolic space has experimentally succeeded in embedding graphs with hierarchical-tree structure, e.g., data in natural languages, social networks, and knowledge bases. However, recent theoretical analyses have shown a much higher upper bound on non-Euclidean graph embedding's generalization error than Euclidean one's, where a high generalization error indicates that the incompleteness and noise in the data can significantly damage learning performance. It implies that the existing bound cannot guarantee the success of graph embedding in non-Euclidean metric space in a practical training data size, which can prevent non-Euclidean graph embedding's application in real problems. This paper provides a novel upper bound of graph embedding's generalization error by evaluating the local Rademacher complexity of the model as a function set of the distances of representation couples. Our bound clarifies that the performance of graph embedding in non-Euclidean metric space, including hyperbolic space, is better than the existing upper bounds suggest. Specifically, our new upper bound is polynomial in the metric space's geometric radius $R$ and can be $O(\frac{1}{S})$ at the fastest, where $S$ is the training data size. Our bound is significantly tighter and faster than the existing one, which can be exponential to $R$ and $O(\frac{1}{\sqrt{S}})$ at the fastest. Specific calculations on example cases show that graph embedding in non-Euclidean metric space can outperform that in Euclidean space with much smaller training data than the existing bound has suggested.

15.1LGMar 30, 2022Code
Improved Convergence Rate of Stochastic Gradient Langevin Dynamics with Variance Reduction and its Application to Optimization

Yuri Kinoshita, Taiji Suzuki

The stochastic gradient Langevin Dynamics is one of the most fundamental algorithms to solve sampling problems and non-convex optimization appearing in several machine learning applications. Especially, its variance reduced versions have nowadays gained particular attention. In this paper, we study two variants of this kind, namely, the Stochastic Variance Reduced Gradient Langevin Dynamics and the Stochastic Recursive Gradient Langevin Dynamics. We prove their convergence to the objective distribution in terms of KL-divergence under the sole assumptions of smoothness and Log-Sobolev inequality which are weaker conditions than those used in prior works for these algorithms. With the batch size and the inner loop length set to $\sqrt{n}$, the gradient complexity to achieve an $ε$-precision is $\tilde{O}((n+dn^{1/2}ε^{-1})γ^2 L^2α^{-2})$, which is an improvement from any previous analyses. We also show some essential applications of our result to non-convex optimization.

5.8LGFeb 12, 2022
Escaping Saddle Points with Bias-Variance Reduced Local Perturbed SGD for Communication Efficient Nonconvex Distributed Learning

Tomoya Murata, Taiji Suzuki

In recent centralized nonconvex distributed learning and federated learning, local methods are one of the promising approaches to reduce communication time. However, existing work has mainly focused on studying first-order optimality guarantees. On the other side, second-order optimality guaranteed algorithms, i.e., algorithms escaping saddle points, have been extensively studied in the non-distributed optimization literature. In this paper, we study a new local algorithm called Bias-Variance Reduced Local Perturbed SGD (BVR-L-PSGD), that combines the existing bias-variance reduced gradient estimator with parameter perturbation to find second-order optimal points in centralized nonconvex distributed optimization. BVR-L-PSGD enjoys second-order optimality with nearly the same communication complexity as the best known one of BVR-L-SGD to find first-order optimality. Particularly, the communication complexity is better than non-local methods when the local datasets heterogeneity is smaller than the smoothness of the local loss. In an extreme case, the communication complexity approaches to $\widetilde Θ(1)$ when the local datasets heterogeneity goes to zero. Numerical results validate our theoretical findings.

23.4MLJan 25, 2022
Convex Analysis of the Mean Field Langevin Dynamics

Atsushi Nitanda, Denny Wu, Taiji Suzuki

As an example of the nonlinear Fokker-Planck equation, the mean field Langevin dynamics recently attracts attention due to its connection to (noisy) gradient descent on infinitely wide neural networks in the mean field regime, and hence the convergence property of the dynamics is of great theoretical interest. In this work, we give a concise and self-contained convergence rate analysis of the mean field Langevin dynamics with respect to the (regularized) objective function in both continuous and discrete time settings. The key ingredient of our proof is a proximal Gibbs distribution $p_q$ associated with the dynamics, which, in combination with techniques in [Vempala and Wibisono (2019)], allows us to develop a simple convergence theory parallel to classical results in convex optimization. Furthermore, we reveal that $p_q$ connects to the duality gap in the empirical risk minimization setting, which enables efficient empirical evaluation of the algorithm convergence.