Keith Rush

LG
h-index117
18papers
6,279citations
Novelty54%
AI Score57

18 Papers

LGNov 12, 2022
Multi-Epoch Matrix Factorization Mechanisms for Private Machine Learning

Christopher A. Choquette-Choo, H. Brendan McMahan, Keith Rush et al. · deepmind

We introduce new differentially private (DP) mechanisms for gradient-based machine learning (ML) with multiple passes (epochs) over a dataset, substantially improving the achievable privacy-utility-computation tradeoffs. We formalize the problem of DP mechanisms for adaptive streams with multiple participations and introduce a non-trivial extension of online matrix factorization DP mechanisms to our setting. This includes establishing the necessary theory for sensitivity calculations and efficient computation of optimal matrices. For some applications like $>\!\! 10,000$ SGD steps, applying these optimal techniques becomes computationally expensive. We thus design an efficient Fourier-transform-based mechanism with only a minor utility loss. Extensive empirical evaluation on both example-level DP for image classification and user-level DP for language modeling demonstrate substantial improvements over all previous methods, including the widely-used DP-SGD . Though our primary application is to ML, our main DP results are applicable to arbitrary linear queries and hence may have much broader applicability.

LGJun 13, 2023
(Amplified) Banded Matrix Factorization: A unified approach to private training

Christopher A. Choquette-Choo, Arun Ganesh, Ryan McKenna et al. · deepmind

Matrix factorization (MF) mechanisms for differential privacy (DP) have substantially improved the state-of-the-art in privacy-utility-computation tradeoffs for ML applications in a variety of scenarios, but in both the centralized and federated settings there remain instances where either MF cannot be easily applied, or other algorithms provide better tradeoffs (typically, as $ε$ becomes small). In this work, we show how MF can subsume prior state-of-the-art algorithms in both federated and centralized training settings, across all privacy budgets. The key technique throughout is the construction of MF mechanisms with banded matrices (lower-triangular matrices with at most $\hat{b}$ nonzero bands including the main diagonal). For cross-device federated learning (FL), this enables multiple-participations with a relaxed device participation schema compatible with practical FL infrastructure (as demonstrated by a production deployment). In the centralized setting, we prove that banded matrices enjoy the same privacy amplification results as the ubiquitous DP-SGD algorithm, but can provide strictly better performance in most scenarios -- this lets us always at least match DP-SGD, and often outperform it.

LGJul 10, 2024
Fine-Tuning Large Language Models with User-Level Differential Privacy

Zachary Charles, Arun Ganesh, Ryan McKenna et al.

We investigate practical and scalable algorithms for training large language models (LLMs) with user-level differential privacy (DP) in order to provably safeguard all the examples contributed by each user. We study two variants of DP-SGD with: (1) example-level sampling (ELS) and per-example gradient clipping, and (2) user-level sampling (ULS) and per-user gradient clipping. We derive a novel user-level DP accountant that allows us to compute provably tight privacy guarantees for ELS. Using this, we show that while ELS can outperform ULS in specific settings, ULS generally yields better results when each user has a diverse collection of examples. We validate our findings through experiments in synthetic mean estimation and LLM fine-tuning tasks under fixed compute budgets. We find that ULS is significantly better in settings where either (1) strong privacy guarantees are required, or (2) the compute budget is large. Notably, our focus on LLM-compatible training algorithms allows us to scale to models with hundreds of millions of parameters and datasets with hundreds of thousands of users.

CLApr 23
Decoupled DiLoCo for Resilient Distributed Pre-training

Arthur Douillard, Keith Rush, Yani Donchev et al.

Modern large-scale language model pre-training relies heavily on the single program multiple data (SPMD) paradigm, which requires tight coupling across accelerators. Due to this coupling, transient slowdowns, hardware failures, and synchronization overhead stall the entire computation, wasting significant compute time at scale. While recent distributed methods like DiLoCo reduced communication bandwidth, they remained fundamentally synchronous and vulnerable to these system stalls. To address this, we introduce Decoupled DiLoCo, an evolution of the DiLoCo framework designed to break the lock-step synchronization barrier and go beyond SPMD to maximize training goodput. Decoupled DiLoCo partitions compute across multiple independent ``learners'' that execute local inner optimization steps. These learners asynchronously communicate parameter fragments to a central synchronizer, which circumvents failed or straggling learners by aggregating updates using a minimum quorum, an adaptive grace window, and dynamic token-weighted merging. Inspired by ``chaos engineering'', we achieve significantly improved training efficiency in failure-prone environments with millions of simulated chips with strictly zero global downtime, while maintaining competitive model performance across text and vision tasks, for both dense and mixture-of-expert architectures.

LGFeb 2, 2023
Gradient Descent with Linearly Correlated Noise: Theory and Applications to Differential Privacy

Anastasia Koloskova, Ryan McKenna, Zachary Charles et al.

We study gradient descent under linearly correlated noise. Our work is motivated by recent practical methods for optimization with differential privacy (DP), such as DP-FTRL, which achieve strong performance in settings where privacy amplification techniques are infeasible (such as in federated learning). These methods inject privacy noise through a matrix factorization mechanism, making the noise linearly correlated over iterations. We propose a simplified setting that distills key facets of these methods and isolates the impact of linearly correlated noise. We analyze the behavior of gradient descent in this setting, for both convex and non-convex functions. Our analysis is demonstrably tighter than prior work and recovers multiple important special cases exactly (including anticorrelated perturbed gradient descent). We use our results to develop new, effective matrix factorizations for differentially private optimization, and highlight the benefits of these factorizations theoretically and empirically.

LGJan 18, 2023
Federated Automatic Differentiation

Keith Rush, Zachary Charles, Zachary Garrett

Federated learning (FL) is a general framework for learning across an axis of group partitioned data (heterogeneous clients) while preserving data privacy, under the orchestration of a central server. FL methods often compute gradients of loss functions purely locally (ie. entirely at each client, or entirely at the server), typically using automatic differentiation (AD) techniques. We propose a federated automatic differentiation (FAD) framework that 1) enables computing derivatives of functions involving client and server computation as well as communication between them and 2) operates in a manner compatible with existing federated technology. In other words, FAD computes derivatives across communication boundaries. We show, in analogy with traditional AD, that FAD may be implemented using various accumulation modes, which introduce distinct computation-communication trade-offs and systems requirements. Further, we show that a broad class of federated computations is closed under these various modes of FAD, implying in particular that if the original computation can be implemented using privacy-preserving primitives, its derivative may be computed using only these same primitives. We then show how FAD can be used to create algorithms that dynamically learn components of the algorithm itself. In particular, we show that FedAvg-style algorithms can exhibit significantly improved performance by using FAD to adjust the server optimization step automatically, or by using FAD to learn weighting schemes for computing weighted averages across clients.

CLJul 7, 2025
Gemini 2.5: Pushing the Frontier with Advanced Reasoning, Multimodality, Long Context, and Next Generation Agentic Capabilities

Gheorghe Comanici, Eric Bieber, Mike Schaekermann et al. · amazon-science, baidu

In this report, we introduce the Gemini 2.X model family: Gemini 2.5 Pro and Gemini 2.5 Flash, as well as our earlier Gemini 2.0 Flash and Flash-Lite models. Gemini 2.5 Pro is our most capable model yet, achieving SoTA performance on frontier coding and reasoning benchmarks. In addition to its incredible coding and reasoning skills, Gemini 2.5 Pro is a thinking model that excels at multimodal understanding and it is now able to process up to 3 hours of video content. Its unique combination of long context, multimodal and reasoning capabilities can be combined to unlock new agentic workflows. Gemini 2.5 Flash provides excellent reasoning abilities at a fraction of the compute and latency requirements and Gemini 2.0 Flash and Flash-Lite provide high performance at low latency and cost. Taken together, the Gemini 2.X model generation spans the full Pareto frontier of model capability vs cost, allowing users to explore the boundaries of what is possible with complex agentic problem solving.

DCMar 11, 2024Code
DrJAX: Scalable and Differentiable MapReduce Primitives in JAX

Keith Rush, Zachary Charles, Zachary Garrett et al.

We present DrJAX, a JAX-based library designed to support large-scale distributed and parallel machine learning algorithms that use MapReduce-style operations. DrJAX leverages JAX's sharding mechanisms to enable native targeting of TPUs and state-of-the-art JAX runtimes, including Pathways. DrJAX embeds building blocks for MapReduce computations as primitives in JAX. This enables three key benefits. First, DrJAX computations can be translated directly to XLA HLO, enabling flexible integration with a wide array of ML training platforms. Second, DrJAX computations are fully differentiable. Last, DrJAX computations can be interpreted out to existing batch-processing compute systems, including traditional MapReduce systems like Apache Beam and cross-device compute systems like those powering federated learning applications. We show that DrJAX provides an easily programmable, performant, and scalable framework for parallelized algorithm development. DrJAX is available at \url{https://github.com/google-research/google-research/tree/master/drjax}.

LGFeb 5, 2021Code
Federated Reconstruction: Partially Local Federated Learning

Karan Singhal, Hakim Sidahmed, Zachary Garrett et al.

Personalization methods in federated learning aim to balance the benefits of federated and local training for data availability, communication cost, and robustness to client heterogeneity. Approaches that require clients to communicate all model parameters can be undesirable due to privacy and communication constraints. Other approaches require always-available or stateful clients, impractical in large-scale cross-device settings. We introduce Federated Reconstruction, the first model-agnostic framework for partially local federated learning suitable for training and inference at scale. We motivate the framework via a connection to model-agnostic meta learning, empirically demonstrate its performance over existing approaches for collaborative filtering and next word prediction, and release an open-source library for evaluating approaches in this setting. We also describe the successful deployment of this approach at scale for federated collaborative filtering in a mobile keyboard application.

LGMar 19
Global Convergence of Multiplicative Updates for the Matrix Mechanism: A Collaborative Proof with Gemini 3

Keith Rush

We analyze a fixed-point iteration $v \leftarrow ϕ(v)$ arising in the optimization of a regularized nuclear norm objective involving the Hadamard product structure, posed in~\cite{denisov} in the context of an optimization problem over the space of algorithms in private machine learning. We prove that the iteration $v^{(k+1)} = \text{diag}((D_{v^{(k)}}^{1/2} M D_{v^{(k)}}^{1/2})^{1/2})$ converges monotonically to the unique global optimizer of the potential function $J(v) = 2 \text{Tr}((D_v^{1/2} M D_v^{1/2})^{1/2}) - \sum v_i$, closing a problem left open there. The bulk of this proof was provided by Gemini 3, subject to some corrections and interventions. Gemini 3 also sketched the initial version of this note. Thus, it represents as much a commentary on the practical use of AI in mathematics as it represents the closure of a small gap in the literature. As such, we include a small narrative description of the prompting process, and some resulting principles for working with AI to prove mathematics.

CLJan 30, 2025
Streaming DiLoCo with overlapping communication: Towards a Distributed Free Lunch

Arthur Douillard, Yanislav Donchev, Keith Rush et al.

Training of large language models (LLMs) is typically distributed across a large number of accelerators to reduce training time. Since internal states and parameter gradients need to be exchanged at each and every single gradient step, all devices need to be co-located using low-latency high-bandwidth communication links to support the required high volume of exchanged bits. Recently, distributed algorithms like DiLoCo have relaxed such co-location constraint: accelerators can be grouped into ``workers'', where synchronizations between workers only occur infrequently. This in turn means that workers can afford being connected by lower bandwidth communication links without affecting learning quality. However, in these methods, communication across workers still requires the same peak bandwidth as before, as the synchronizations require all parameters to be exchanged across all workers. In this paper, we improve DiLoCo in three ways. First, we synchronize only subsets of parameters in sequence, rather than all at once, which greatly reduces peak bandwidth. Second, we allow workers to continue training while synchronizing, which decreases wall clock time. Third, we quantize the data exchanged by workers, which further reduces bandwidth across workers. By properly combining these modifications, we show experimentally that we can distribute training of billion-scale parameters and reach similar quality as before, but reducing required bandwidth by two orders of magnitude.

LGMar 12, 2025
Communication-Efficient Language Model Training Scales Reliably and Robustly: Scaling Laws for DiLoCo

Zachary Charles, Gabriel Teston, Lucio Dery et al.

As we scale to more massive machine learning models, the frequent synchronization demands inherent in data-parallel approaches create significant slowdowns, posing a critical challenge to further scaling. Recent work develops an approach (DiLoCo) that relaxes synchronization demands without compromising model quality. However, these works do not carefully analyze how DiLoCo's behavior changes with model size. In this work, we study the scaling law behavior of DiLoCo when training LLMs under a fixed compute budget. We focus on how algorithmic factors, including number of model replicas, hyperparameters, and token budget affect training in ways that can be accurately predicted via scaling laws. We find that DiLoCo scales both predictably and robustly with model size. When well-tuned, DiLoCo scales better than data-parallel training with model size, and can outperform data-parallel training even at small model sizes. Our results showcase a more general set of benefits of DiLoCo than previously documented, including increased optimal batch sizes, improved downstream generalization with scale, and improved evaluation loss for a fixed token budget.

LGJun 9, 2025
Correlated Noise Mechanisms for Differentially Private Learning

Krishna Pillutla, Jalaj Upadhyay, Christopher A. Choquette-Choo et al.

This monograph explores the design and analysis of correlated noise mechanisms for differential privacy (DP), focusing on their application to private training of AI and machine learning models via the core primitive of estimation of weighted prefix sums. While typical DP mechanisms inject independent noise into each step of a stochastic gradient (SGD) learning algorithm in order to protect the privacy of the training data, a growing body of recent research demonstrates that introducing (anti-)correlations in the noise can significantly improve privacy-utility trade-offs by carefully canceling out some of the noise added on earlier steps in subsequent steps. Such correlated noise mechanisms, known variously as matrix mechanisms, factorization mechanisms, and DP-Follow-the-Regularized-Leader (DP-FTRL) when applied to learning algorithms, have also been influential in practice, with industrial deployment at a global scale.

LGFeb 16, 2022
Improved Differential Privacy for SGD via Optimal Private Linear Operators on Adaptive Streams

Sergey Denisov, Brendan McMahan, Keith Rush et al.

Motivated by recent applications requiring differential privacy over adaptive streams, we investigate the question of optimal instantiations of the matrix mechanism in this setting. We prove fundamental theoretical results on the applicability of matrix factorizations to adaptive streams, and provide a parameter-free fixed-point algorithm for computing optimal factorizations. We instantiate this framework with respect to concrete matrices which arise naturally in machine learning, and train user-level differentially private models with the resulting optimal mechanisms, yielding significant improvements in a notable problem in federated learning with user-level differential privacy.

OCSep 8, 2021
Iterated Vector Fields and Conservatism, with Applications to Federated Learning

Zachary Charles, Keith Rush

We study whether iterated vector fields (vector fields composed with themselves) are conservative. We give explicit examples of vector fields for which this self-composition preserves conservatism. Notably, this includes gradient vector fields of loss functions associated with some generalized linear models. As we show, characterizing the set of vector fields satisfying this condition leads to non-trivial geometric questions. In the context of federated learning, we show that when clients have loss functions whose gradients satisfy this condition, federated averaging is equivalent to gradient descent on a surrogate loss function. We leverage this to derive novel convergence results for federated learning. By contrast, we demonstrate that when the client losses violate this property, federated averaging can yield behavior which is fundamentally distinct from centralized optimization. Finally, we discuss theoretical and practical questions our analytical framework raises for federated learning.

LGAug 14, 2020
Fast Dimension Independent Private AdaGrad on Publicly Estimated Subspaces

Peter Kairouz, Mónica Ribero, Keith Rush et al.

We revisit the problem of empirical risk minimziation (ERM) with differential privacy. We show that noisy AdaGrad, given appropriate knowledge and conditions on the subspace from which gradients can be drawn, achieves a regret comparable to traditional AdaGrad plus a well-controlled term due to noise. We show a convergence rate of $O(\text{Tr}(G_T)/T)$, where $G_T$ captures the geometry of the gradient subspace. Since $\text{Tr}(G_T)=O(\sqrt{T})$ we can obtain faster rates for convex and Lipschitz functions, compared to the $O(1/\sqrt{T})$ rate achieved by known versions of noisy (stochastic) gradient descent with comparable noise variance. In particular, we show that if the gradients lie in a known constant rank subspace, and assuming algorithmic access to an envelope which bounds decaying sensitivity, one can achieve faster convergence to an excess empirical risk of $\tilde O(1/εn)$, where $ε$ is the privacy budget and $n$ the number of samples. Letting $p$ be the problem dimension, this result implies that, by running noisy Adagrad, we can bypass the DP-SGD bound $\tilde O(\sqrt{p}/εn)$ in $T=(εn)^{2/(1+2α)}$ iterations, where $α\geq 0$ is a parameter controlling gradient norm decay, instead of the rate achieved by SGD of $T=ε^2n^2$. Our results operate with general convex functions in both constrained and unconstrained minimization. Along the way, we do a perturbation analysis of noisy AdaGrad of independent interest. Our utility guarantee for the private ERM problem follows as a corollary to the regret guarantee of noisy AdaGrad.

LGFeb 29, 2020
Adaptive Federated Optimization

Sashank Reddi, Zachary Charles, Manzil Zaheer et al.

Federated learning is a distributed machine learning paradigm in which a large number of clients coordinate with a central server to learn a model without sharing their own training data. Standard federated optimization methods such as Federated Averaging (FedAvg) are often difficult to tune and exhibit unfavorable convergence behavior. In non-federated settings, adaptive optimization methods have had notable success in combating such issues. In this work, we propose federated versions of adaptive optimizers, including Adagrad, Adam, and Yogi, and analyze their convergence in the presence of heterogeneous data for general non-convex settings. Our results highlight the interplay between client heterogeneity and communication efficiency. We also perform extensive experiments on these methods and show that the use of adaptive optimizers can significantly improve the performance of federated learning.

LGSep 27, 2019
Improving Federated Learning Personalization via Model Agnostic Meta Learning

Yihan Jiang, Jakub Konečný, Keith Rush et al.

Federated Learning (FL) refers to learning a high quality global model based on decentralized data storage, without ever copying the raw data. A natural scenario arises with data created on mobile phones by the activity of their users. Given the typical data heterogeneity in such situations, it is natural to ask how can the global model be personalized for every such device, individually. In this work, we point out that the setting of Model Agnostic Meta Learning (MAML), where one optimizes for a fast, gradient-based, few-shot adaptation to a heterogeneous distribution of tasks, has a number of similarities with the objective of personalization for FL. We present FL as a natural source of practical applications for MAML algorithms, and make the following observations. 1) The popular FL algorithm, Federated Averaging, can be interpreted as a meta learning algorithm. 2) Careful fine-tuning can yield a global model with higher accuracy, which is at the same time easier to personalize. However, solely optimizing for the global model accuracy yields a weaker personalization result. 3) A model trained using a standard datacenter optimization method is much harder to personalize, compared to one trained using Federated Averaging, supporting the first claim. These results raise new questions for FL, MAML, and broader ML research.