Stein-Rule Shrinkage for Stochastic Gradient Estimation in High Dimensions
This work addresses the problem of inefficient gradient estimation in high-dimensional deep learning for practitioners, offering a principled, incremental improvement with practical integration into optimizers like Adam.
The paper tackled the suboptimality of standard stochastic gradients in high-dimensional settings by introducing a Stein-rule shrinkage estimator that adaptively contracts noisy mini-batch gradients, showing it uniformly dominates standard gradients under squared error loss and yields consistent improvements over Adam on CIFAR10 and CIFAR100 with label noise.
Stochastic gradient methods are central to large-scale learning, yet their analysis typically treats mini-batch gradients as unbiased estimators of the population gradient. In high-dimensional settings, however, classical results from statistical decision theory show that unbiased estimators are generally inadmissible under quadratic loss, suggesting that standard stochastic gradients may be suboptimal from a risk perspective. In this work, we formulate stochastic gradient computation as a high-dimensional estimation problem and introduce a decision-theoretic framework based on Stein-rule shrinkage. We construct a shrinkage gradient estimator that adaptively contracts noisy mini-batch gradients toward a stable restricted estimator derived from historical momentum. The shrinkage intensity is determined in a data-driven manner using an online estimate of gradient noise variance, leveraging second-moment statistics commonly maintained by adaptive optimization methods. Under a Gaussian noise model and for dimension p>=3, we show that the proposed estimator uniformly dominates the standard stochastic gradient under squared error loss and is minimax-optimal in the classical decision-theoretic sense. We further demonstrate how this estimator can be incorporated into the Adam optimizer, yielding a practical algorithm with negligible additional computational cost. Empirical evaluations on CIFAR10 and CIFAR100, across multiple levels of label noise, show consistent improvements over Adam in the large-batch regime. Ablation studies indicate that the gains arise primarily from selectively applying shrinkage to high-dimensional convolutional layers, while indiscriminate shrinkage across all parameters degrades performance. These results illustrate that classical shrinkage principles provide a principled and effective approach to improving stochastic gradient estimation in modern deep learning.