MLJun 27, 2022
AutoInit: Automatic Initialization via Jacobian TuningTianyu He, Darshil Doshi, Andrey Gromov
Good initialization is essential for training Deep Neural Networks (DNNs). Oftentimes such initialization is found through a trial and error approach, which has to be applied anew every time an architecture is substantially modified, or inherited from smaller size networks leading to sub-optimal initialization. In this work we introduce a new and cheap algorithm, that allows one to find a good initialization automatically, for general feed-forward DNNs. The algorithm utilizes the Jacobian between adjacent network blocks to tune the network hyperparameters to criticality. We solve the dynamics of the algorithm for fully connected networks with ReLU and derive conditions for its convergence. We then extend the discussion to more general architectures with BatchNorm and residual connections. Finally, we apply our method to ResMLP and VGG architectures, where the automatic one-shot initialization found by our method shows good performance on vision tasks.
LGOct 19, 2023
To grok or not to grok: Disentangling generalization and memorization on corrupted algorithmic datasetsDarshil Doshi, Aritra Das, Tianyu He et al.
Robust generalization is a major challenge in deep learning, particularly when the number of trainable parameters is very large. In general, it is very difficult to know if the network has memorized a particular set of examples or understood the underlying rule (or both). Motivated by this challenge, we study an interpretable model where generalizing representations are understood analytically, and are easily distinguishable from the memorizing ones. Namely, we consider multi-layer perceptron (MLP) and Transformer architectures trained on modular arithmetic tasks, where ($ξ\cdot 100\%$) of labels are corrupted (\emph{i.e.} some results of the modular operations in the training set are incorrect). We show that (i) it is possible for the network to memorize the corrupted labels \emph{and} achieve $100\%$ generalization at the same time; (ii) the memorizing neurons can be identified and pruned, lowering the accuracy on corrupted data and improving the accuracy on uncorrupted data; (iii) regularization methods such as weight decay, dropout and BatchNorm force the network to ignore the corrupted data during optimization, and achieve $100\%$ accuracy on the uncorrupted dataset; and (iv) the effect of these regularization methods is (``mechanistically'') interpretable: weight decay and dropout force all the neurons to learn generalizing representations, while BatchNorm de-amplifies the output of memorizing neurons and amplifies the output of the generalizing ones. Finally, we show that in the presence of regularization, the training dynamics involves two consecutive stages: first, the network undergoes \emph{grokking} dynamics reaching high train \emph{and} test accuracy; second, it unlearns the memorizing representations, where the train accuracy suddenly jumps from $100\%$ to $100 (1-ξ)\%$.
LGFeb 14, 2025
(How) Can Transformers Predict Pseudo-Random Numbers?Tao Tao, Darshil Doshi, Dayal Singh Kalra et al.
Transformers excel at discovering patterns in sequential data, yet their fundamental limitations and learning mechanisms remain crucial topics of investigation. In this paper, we study the ability of Transformers to learn pseudo-random number sequences from linear congruential generators (LCGs), defined by the recurrence relation $x_{t+1} = a x_t + c \;\mathrm{mod}\; m$. We find that with sufficient architectural capacity and training data variety, Transformers can perform in-context prediction of LCG sequences with unseen moduli ($m$) and parameters ($a,c$). By analyzing the embedding layers and attention patterns, we uncover how Transformers develop algorithmic structures to learn these sequences in two scenarios of increasing complexity. First, we investigate how Transformers learn LCG sequences with unseen ($a, c$) but fixed modulus; and demonstrate successful learning up to $m = 2^{32}$. We find that models learn to factorize $m$ and utilize digit-wise number representations to make sequential predictions. In the second, more challenging scenario of unseen moduli, we show that Transformers can generalize to unseen moduli up to $m_{\text{test}} = 2^{16}$. In this case, the model employs a two-step strategy: first estimating the unknown modulus from the context, then utilizing prime factorizations to generate predictions. For this task, we observe a sharp transition in the accuracy at a critical depth $d= 3$. We also find that the number of in-context sequence elements needed to reach high accuracy scales sublinearly with the modulus.
LGJun 5, 2024
Grokking Modular PolynomialsDarshil Doshi, Tianyu He, Aritra Das et al.
Neural networks readily learn a subset of the modular arithmetic tasks, while failing to generalize on the rest. This limitation remains unmoved by the choice of architecture and training strategies. On the other hand, an analytical solution for the weights of Multi-layer Perceptron (MLP) networks that generalize on the modular addition task is known in the literature. In this work, we (i) extend the class of analytical solutions to include modular multiplication as well as modular addition with many terms. Additionally, we show that real networks trained on these datasets learn similar solutions upon generalization (grokking). (ii) We combine these "expert" solutions to construct networks that generalize on arbitrary modular polynomials. (iii) We hypothesize a classification of modular polynomials into learnable and non-learnable via neural networks training; and provide experimental evidence supporting our claims.
LGJun 4, 2024
Learning to grok: Emergence of in-context learning and skill composition in modular arithmetic tasksTianyu He, Darshil Doshi, Aritra Das et al.
Large language models can solve tasks that were not present in the training set. This capability is believed to be due to in-context learning and skill composition. In this work, we study the emergence of in-context learning and skill composition in a collection of modular arithmetic tasks. Specifically, we consider a finite collection of linear modular functions $z = a \, x + b \, y \;\mathrm{mod}\; p$ labeled by the vector $(a, b) \in \mathbb{Z}_p^2$. We use some of these tasks for pre-training and the rest for out-of-distribution testing. We empirically show that a GPT-style transformer exhibits a transition from in-distribution to out-of-distribution generalization as the number of pre-training tasks increases. We find that the smallest model capable of out-of-distribution generalization requires two transformer blocks, while for deeper models, the out-of-distribution generalization phase is \emph{transient}, necessitating early stopping. Finally, we perform an interpretability study of the pre-trained models, revealing highly structured representations in both attention heads and MLPs; and discuss the learned algorithms. Notably, we find an algorithmic shift in deeper models, as we go from few to many in-context examples.
LGNov 23, 2021
Critical Initialization of Wide and Deep Neural Networks through Partial Jacobians: General Theory and ApplicationsDarshil Doshi, Tianyu He, Andrey Gromov
Deep neural networks are notorious for defying theoretical treatment. However, when the number of parameters in each layer tends to infinity, the network function is a Gaussian process (GP) and quantitatively predictive description is possible. Gaussian approximation allows one to formulate criteria for selecting hyperparameters, such as variances of weights and biases, as well as the learning rate. These criteria rely on the notion of criticality defined for deep neural networks. In this work we describe a new practical way to diagnose criticality. We introduce \emph{partial Jacobians} of a network, defined as derivatives of preactivations in layer $l$ with respect to preactivations in layer $l_0\leq l$. We derive recurrence relations for the norms of partial Jacobians and utilize these relations to analyze criticality of deep fully connected neural networks with LayerNorm and/or residual connections. We derive and implement a simple and cheap numerical test that allows one to select optimal initialization for a broad class of deep neural networks; containing fully connected, convolutional and normalization layers. Using these tools we show quantitatively that proper stacking of the LayerNorm (applied to preactivations) and residual connections leads to an architecture that is critical for any initialization. Finally, we apply our methods to analyze ResNet and MLP-Mixer architectures; demonstrating the everywhere-critical regime.