Gradient-Weight Alignment as a Train-Time Proxy for Generalization in Classification Tasks
This provides a validation-set-free approach for model analysis, addressing a practical need for robust training metrics in supervised classification tasks, though it is incremental as it builds on existing gradient-based methods.
The paper tackles the problem of monitoring generalization during training in deep learning by introducing Gradient-Weight Alignment (GWA), a metric that quantifies coherence between per-sample gradients and model weights, and shows it accurately predicts early stopping and identifies influential samples.
Robust validation metrics remain essential in contemporary deep learning, not only to detect overfitting and poor generalization, but also to monitor training dynamics. In the supervised classification setting, we investigate whether interactions between training data and model weights can yield such a metric that both tracks generalization during training and attributes performance to individual training samples. We introduce Gradient-Weight Alignment (GWA), quantifying the coherence between per-sample gradients and model weights. We show that effective learning corresponds to coherent alignment, while misalignment indicates deteriorating generalization. GWA is efficiently computable during training and reflects both sample-specific contributions and dataset-wide learning dynamics. Extensive experiments show that GWA accurately predicts optimal early stopping, enables principled model comparisons, and identifies influential training samples, providing a validation-set-free approach for model analysis directly from the training data.