MLMar 3, 2023
Diffusion Models are Minimax Optimal Distribution EstimatorsKazusato Oko, Shunta Akiyama, Taiji Suzuki
While efficient distribution learning is no doubt behind the groundbreaking success of diffusion modeling, its theoretical guarantees are quite limited. In this paper, we provide the first rigorous analysis on approximation and generalization abilities of diffusion modeling for well-known function spaces. The highlight of this paper is that when the true density function belongs to the Besov space and the empirical score matching loss is properly minimized, the generated data distribution achieves the nearly minimax optimal estimation rates in the total variation distance and in the Wasserstein distance of order one. Furthermore, we extend our theory to demonstrate how diffusion models adapt to low-dimensional data distributions. We expect these results advance theoretical understandings of diffusion modeling and its ability to generate verisimilar outputs.
MLMar 6, 2023
Primal and Dual Analysis of Entropic Fictitious Play for Finite-sum ProblemsAtsushi Nitanda, Kazusato Oko, Denny Wu et al.
The entropic fictitious play (EFP) is a recently proposed algorithm that minimizes the sum of a convex functional and entropy in the space of measures -- such an objective naturally arises in the optimization of a two-layer neural network in the mean-field regime. In this work, we provide a concise primal-dual analysis of EFP in the setting where the learning problem exhibits a finite-sum structure. We establish quantitative global convergence guarantees for both the continuous-time and discrete-time dynamics based on properties of a proximal Gibbs measure introduced in Nitanda et al. (2022). Furthermore, our primal-dual framework entails a memory-efficient particle-based implementation of the EFP update, and also suggests a connection to gradient boosting methods. We illustrate the efficiency of our novel implementation in experiments including neural network optimization and image synthesis.
LGSep 1, 2022
Versatile Single-Loop Method for Gradient Estimator: First and Second Order Optimality, and its Application to Federated LearningKazusato Oko, Shunta Akiyama, Tomoya Murata et al.
While variance reduction methods have shown great success in solving large scale optimization problems, many of them suffer from accumulated errors and, therefore, should periodically require the full gradient computation. In this paper, we present a single-loop algorithm named SLEDGE (Single-Loop mEthoD for Gradient Estimator) for finite-sum nonconvex optimization, which does not require periodic refresh of the gradient estimator but achieves nearly optimal gradient complexity. Unlike existing methods, SLEDGE has the advantage of versatility; (i) second-order optimality, (ii) exponential convergence in the PL region, and (iii) smaller complexity under less heterogeneity of data. We build an efficient federated learning algorithm by exploiting these favorable properties. We show the first and second-order optimality of the output and also provide analysis under PL conditions. When the local budget is sufficiently large and clients are less (Hessian-)~heterogeneous, the algorithm requires fewer communication rounds then existing methods such as FedAvg, SCAFFOLD, and Mime. The superiority of our method is verified in numerical experiments.
LGNov 4, 2024
Pretrained transformer efficiently learns low-dimensional target functions in-contextKazusato Oko, Yujin Song, Taiji Suzuki et al.
Transformers can efficiently learn in-context from example demonstrations. Most existing theoretical analyses studied the in-context learning (ICL) ability of transformers for linear function classes, where it is typically shown that the minimizer of the pretraining loss implements one gradient descent step on the least squares objective. However, this simplified linear setting arguably does not demonstrate the statistical efficiency of ICL, since the pretrained transformer does not outperform directly solving linear regression on the test prompt. In this paper, we study ICL of a nonlinear function class via transformer with nonlinear MLP layer: given a class of \textit{single-index} target functions $f_*(\boldsymbol{x}) = σ_*(\langle\boldsymbol{x},\boldsymbolβ\rangle)$, where the index features $\boldsymbolβ\in\mathbb{R}^d$ are drawn from a $r$-dimensional subspace, we show that a nonlinear transformer optimized by gradient descent (with a pretraining sample complexity that depends on the \textit{information exponent} of the link functions $σ_*$) learns $f_*$ in-context with a prompt length that only depends on the dimension of the distribution of target functions $r$; in contrast, any algorithm that directly learns $f_*$ on test prompt yields a statistical complexity that scales with the ambient dimension $d$. Our result highlights the adaptivity of the pretrained transformer to low-dimensional structures of the function class, which enables sample-efficient ICL that outperforms estimators that only have access to the in-context data.
LGJan 8, 2025
A Statistical Theory of Contrastive Pre-training and Multimodal Generative AIKazusato Oko, Licong Lin, Yuhang Cai et al.
Multi-modal generative AI systems, such as those combining vision and language, rely on contrastive pre-training to learn representations across different modalities. While their practical benefits are widely acknowledged, a rigorous theoretical understanding of the contrastive pre-training framework remains limited. This paper develops a theoretical framework to explain the success of contrastive pre-training in downstream tasks, such as zero-shot classification, conditional diffusion models, and vision-language models. We introduce the concept of approximate sufficient statistics, a generalization of the classical sufficient statistics, and show that near-minimizers of the contrastive pre-training loss are approximately sufficient, making them adaptable to diverse downstream tasks. We further propose the Joint Generative Hierarchical Model for the joint distribution of images and text, showing that transformers can efficiently approximate relevant functions within this model via belief propagation. Building on this framework, we derive sample complexity guarantees for multi-modal learning based on contrastive pre-trained representations. Numerical simulations validate these theoretical findings, demonstrating the strong generalization performance of contrastively pre-trained transformers in various multi-modal tasks.
LGFeb 5, 2025
Direct Distributional Optimization for Provable Alignment of Diffusion ModelsRyotaro Kawata, Kazusato Oko, Atsushi Nitanda et al.
We introduce a novel alignment method for diffusion models from distribution optimization perspectives while providing rigorous convergence guarantees. We first formulate the problem as a generic regularized loss minimization over probability distributions and directly optimize the distribution using the Dual Averaging method. Next, we enable sampling from the learned distribution by approximating its score function via Doob's $h$-transform technique. The proposed framework is supported by rigorous convergence guarantees and an end-to-end bound on the sampling error, which imply that when the original distribution's score is known accurately, the complexity of sampling from shifted distributions is independent of isoperimetric conditions. This framework is broadly applicable to general distribution optimization problems, including alignment tasks in Reinforcement Learning with Human Feedback (RLHF), Direct Preference Optimization (DPO), and Kahneman-Tversky Optimization (KTO). We empirically validate its performance on synthetic and image datasets using the DPO objective.
CLApr 24, 2025
When Does Metadata Conditioning (NOT) Work for Language Model Pre-Training? A Study with Context-Free GrammarsRei Higuchi, Ryotaro Kawata, Naoki Nishikawa et al.
The ability to acquire latent semantics is one of the key properties that determines the performance of language models. One convenient approach to invoke this ability is to prepend metadata (e.g. URLs, domains, and styles) at the beginning of texts in the pre-training data, making it easier for the model to access latent semantics before observing the entire text. Previous studies have reported that this technique actually improves the performance of trained models in downstream tasks; however, this improvement has been observed only in specific downstream tasks, without consistent enhancement in average next-token prediction loss. To understand this phenomenon, we closely investigate how prepending metadata during pre-training affects model performance by examining its behavior using artificial data. Interestingly, we found that this approach produces both positive and negative effects on the downstream tasks. We demonstrate that the effectiveness of the approach depends on whether latent semantics can be inferred from the downstream task's prompt. Specifically, through investigations using data generated by probabilistic context-free grammars, we show that training with metadata helps improve model's performance when the given context is long enough to infer the latent semantics. In contrast, the technique negatively impacts performance when the context lacks the necessary information to make an accurate posterior inference.
LGJun 17, 2024
Learning sum of diverse features: computational hardness and efficient gradient-based training for ridge combinationsKazusato Oko, Yujin Song, Taiji Suzuki et al.
We study the computational and sample complexity of learning a target function $f_*:\mathbb{R}^d\to\mathbb{R}$ with additive structure, that is, $f_*(x) = \frac{1}{\sqrt{M}}\sum_{m=1}^M f_m(\langle x, v_m\rangle)$, where $f_1,f_2,...,f_M:\mathbb{R}\to\mathbb{R}$ are nonlinear link functions of single-index models (ridge functions) with diverse and near-orthogonal index features $\{v_m\}_{m=1}^M$, and the number of additive tasks $M$ grows with the dimensionality $M\asymp d^γ$ for $γ\ge 0$. This problem setting is motivated by the classical additive model literature, the recent representation learning theory of two-layer neural network, and large-scale pretraining where the model simultaneously acquires a large number of "skills" that are often localized in distinct parts of the trained network. We prove that a large subset of polynomial $f_*$ can be efficiently learned by gradient descent training of a two-layer neural network, with a polynomial statistical and computational complexity that depends on the number of tasks $M$ and the information exponent of $f_m$, despite the unknown link function and $M$ growing with the dimensionality. We complement this learnability guarantee with computational hardness result by establishing statistical query (SQ) lower bounds for both the correlational SQ and full SQ algorithms.
LGJun 3, 2024
Neural network learns low-dimensional polynomials with SGD near the information-theoretic limitJason D. Lee, Kazusato Oko, Taiji Suzuki et al.
We study the problem of gradient descent learning of a single-index target function $f_*(\boldsymbol{x}) = \textstyleσ_*\left(\langle\boldsymbol{x},\boldsymbolθ\rangle\right)$ under isotropic Gaussian data in $\mathbb{R}^d$, where the unknown link function $σ_*:\mathbb{R}\to\mathbb{R}$ has information exponent $p$ (defined as the lowest degree in the Hermite expansion). Prior works showed that gradient-based training of neural networks can learn this target with $n\gtrsim d^{Θ(p)}$ samples, and such complexity is predicted to be necessary by the correlational statistical query lower bound. Surprisingly, we prove that a two-layer neural network optimized by an SGD-based algorithm (on the squared loss) learns $f_*$ with a complexity that is not governed by the information exponent. Specifically, for arbitrary polynomial single-index models, we establish a sample and runtime complexity of $n \simeq T = Θ(d\!\cdot\! \mathrm{polylog} d)$, where $Θ(\cdot)$ hides a constant only depending on the degree of $σ_*$; this dimension dependence matches the information theoretic limit up to polylogarithmic factors. More generally, we show that $n\gtrsim d^{(p_*-1)\vee 1}$ samples are sufficient to achieve low generalization error, where $p_* \le p$ is the \textit{generative exponent} of the link function. Core to our analysis is the reuse of minibatch in the gradient computation, which gives rise to higher-order information beyond correlational queries.