LGAug 9, 2022
Training Overparametrized Neural Networks in Sublinear TimeYichuan Deng, Hang Hu, Zhao Song et al.
The success of deep learning comes at a tremendous computational and energy cost, and the scalability of training massively overparametrized neural networks is becoming a real barrier to the progress of artificial intelligence (AI). Despite the popularity and low cost-per-iteration of traditional backpropagation via gradient decent, stochastic gradient descent (SGD) has prohibitive convergence rate in non-convex settings, both in theory and practice. To mitigate this cost, recent works have proposed to employ alternative (Newton-type) training methods with much faster convergence rate, albeit with higher cost-per-iteration. For a typical neural network with $m=\mathrm{poly}(n)$ parameters and input batch of $n$ datapoints in $\mathbb{R}^d$, the previous work of [Brand, Peng, Song, and Weinstein, ITCS'2021] requires $\sim mnd + n^3$ time per iteration. In this paper, we present a novel training method that requires only $m^{1-α} n d + n^3$ amortized time in the same overparametrized regime, where $α\in (0.01,1)$ is some fixed constant. This method relies on a new and alternative view of neural networks, as a set of binary search trees, where each iteration corresponds to modifying a small subset of the nodes in the tree. We believe this view would have further applications in the design and analysis of deep neural networks (DNNs).
LGMar 15, 2025
Changing Base Without Losing Pace: A GPU-Efficient Alternative to MatMul in DNNsNir Ailon, Akhiad Bercovich, Yahel Uffenheimer et al.
Modern AI relies on huge matrix multiplications (MatMuls), whose computation poses a scalability problem for inference and training. We propose an alternative, GPU native bilinear operator to MatMuls in neural networks, which offers a three-way tradeoff between: speed, accuracy and parameter count. In particular, this operator requires substantially fewer FLOPs to evaluate ($\ll n^3$), yet increases the parameter count compared to MatMul ($\gg n^2$). We call this operator Strassen-Tile (STL). The key idea behind STL is a local learnable change-of-basis, applied on tiles of the weight and activation matrices, followed by an element-wise product between the tiles, implemented simultaneously via MatMul. The key technical question we study is how to optimize the change-of-basis of a given layer, which is a highly non-convex problem. We show that theory-backed initializations (inspired by fast matrix and polynomial multiplication) lead to substantially better accuracy than random SGD initialization. This phenomenon motivates further algorithmic study of STL optimization in DNNs. Our experiments demonstrate that STL can approximate 4x4 MatMul of tiles while reducing FLOPs by a factor of 2.66, and can improve Imagenet-1K accuracy of SoTA T2T-ViT-7 (4.3M parameters) while lowering FLOPs. Even with non-CUDA optimized PyTorch code, STL achieves wall-clock speedups in the compute-bound regime. These results, together with its theoretical grounds, suggest STL as a promising building block for scalable and cost-efficient AI.
DSJan 1, 2022
The Complexity of Dynamic Least-Squares RegressionShunhua Jiang, Binghui Peng, Omri Weinstein
We settle the complexity of dynamic least-squares regression (LSR), where rows and labels $(\mathbf{A}^{(t)}, \mathbf{b}^{(t)})$ can be adaptively inserted and/or deleted, and the goal is to efficiently maintain an $ε$-approximate solution to $\min_{\mathbf{x}^{(t)}} \| \mathbf{A}^{(t)} \mathbf{x}^{(t)} - \mathbf{b}^{(t)} \|_2$ for all $t\in [T]$. We prove sharp separations ($d^{2-o(1)}$ vs. $\sim d$) between the amortized update time of: (i) Fully vs. Partially dynamic $0.01$-LSR; (ii) High vs. low-accuracy LSR in the partially-dynamic (insertion-only) setting. Our lower bounds follow from a gap-amplification reduction -- reminiscent of iterative refinement -- rom the exact version of the Online Matrix Vector Conjecture (OMv) [HKNS15], to constant approximate OMv over the reals, where the $i$-th online product $\mathbf{H}\mathbf{v}^{(i)}$ only needs to be computed to $0.1$-relative error. All previous fine-grained reductions from OMv to its approximate versions only show hardness for inverse polynomial approximation $ε= n^{-ω(1)}$ (additive or multiplicative) . This result is of independent interest in fine-grained complexity and for the investigation of the OMv Conjecture, which is still widely open.
LGJun 20, 2020
Training (Overparametrized) Neural Networks in Near-Linear TimeJan van den Brand, Binghui Peng, Zhao Song et al.
The slow convergence rate and pathological curvature issues of first-order gradient methods for training deep neural networks, initiated an ongoing effort for developing faster $\mathit{second}$-$\mathit{order}$ optimization algorithms beyond SGD, without compromising the generalization error. Despite their remarkable convergence rate ($\mathit{independent}$ of the training batch size $n$), second-order algorithms incur a daunting slowdown in the $\mathit{cost}$ $\mathit{per}$ $\mathit{iteration}$ (inverting the Hessian matrix of the loss function), which renders them impractical. Very recently, this computational overhead was mitigated by the works of [ZMG19,CGH+19}, yielding an $O(mn^2)$-time second-order algorithm for training two-layer overparametrized neural networks of polynomial width $m$. We show how to speed up the algorithm of [CGH+19], achieving an $\tilde{O}(mn)$-time backpropagation algorithm for training (mildly overparametrized) ReLU networks, which is near-linear in the dimension ($mn$) of the full gradient (Jacobian) matrix. The centerpiece of our algorithm is to reformulate the Gauss-Newton iteration as an $\ell_2$-regression problem, and then use a Fast-JL type dimension reduction to $\mathit{precondition}$ the underlying Gram matrix in time independent of $M$, allowing to find a sufficiently good approximate solution via $\mathit{first}$-$\mathit{order}$ conjugate gradient. Our result provides a proof-of-concept that advanced machinery from randomized linear algebra -- which led to recent breakthroughs in $\mathit{convex}$ $\mathit{optimization}$ (ERM, LPs, Regression) -- can be carried over to the realm of deep learning as well.
DSApr 9, 2019
Lower Bounds for Oblivious Near-Neighbor SearchKasper Green Larsen, Tal Malkin, Omri Weinstein et al.
We prove an $Ω(d \lg n/ (\lg\lg n)^2)$ lower bound on the dynamic cell-probe complexity of statistically $\mathit{oblivious}$ approximate-near-neighbor search ($\mathsf{ANN}$) over the $d$-dimensional Hamming cube. For the natural setting of $d = Θ(\log n)$, our result implies an $\tildeΩ(\lg^2 n)$ lower bound, which is a quadratic improvement over the highest (non-oblivious) cell-probe lower bound for $\mathsf{ANN}$. This is the first super-logarithmic $\mathit{unconditional}$ lower bound for $\mathsf{ANN}$ against general (non black-box) data structures. We also show that any oblivious $\mathit{static}$ data structure for decomposable search problems (like $\mathsf{ANN}$) can be obliviously dynamized with $O(\log n)$ overhead in update and query time, strengthening a classic result of Bentley and Saxe (Algorithmica, 1980).
DSAug 12, 2018
Local Decodability of the Burrows-Wheeler TransformSandip Sinha, Omri Weinstein
The Burrows-Wheeler Transform (BWT) is among the most influential discoveries in text compression and DNA storage. It is a reversible preprocessing step that rearranges an $n$-letter string into runs of identical characters (by exploiting context regularities), resulting in highly compressible strings, and is the basis of the \texttt{bzip} compression program. Alas, the decoding process of BWT is inherently sequential and requires $Ω(n)$ time even to retrieve a \emph{single} character. We study the succinct data structure problem of locally decoding short substrings of a given text under its \emph{compressed} BWT, i.e., with small additive redundancy $r$ over the \emph{Move-To-Front} (\texttt{bzip}) compression. The celebrated BWT-based FM-index (FOCS '00), as well as other related literature, yield a trade-off of $r=\tilde{O}(n/\sqrt{t})$ bits, when a single character is to be decoded in $O(t)$ time. We give a near-quadratic improvement $r=\tilde{O}(n\lg(t)/t)$. As a by-product, we obtain an \emph{exponential} (in $t$) improvement on the redundancy of the FM-index for counting pattern-matches on compressed text. In the interesting regime where the text compresses to $n^{1-o(1)}$ bits, these results provide an $\exp(t)$ \emph{overall} space reduction. For the local decoding problem of BWT, we also prove an $Ω(n/t^2)$ cell-probe lower bound for "symmetric" data structures. We achieve our main result by designing a compressed partial-sums (Rank) data structure over BWT. The key component is a \emph{locally-decodable} Move-to-Front (MTF) code: with only $O(1)$ extra bits per block of length $n^{Ω(1)}$, the decoding time of a single character can be decreased from $Ω(n)$ to $O(\lg n)$. This result is of independent interest in algorithmic information theory.