Fast Differentiable Clipping-Aware Normalization and Rescaling
This addresses a bottleneck in training neural networks with normalized perturbations, offering a practical improvement for researchers and practitioners in machine learning.
The paper tackles the problem of efficiently rescaling perturbations to a desired norm while ensuring they remain within a data domain after clipping, by developing a fast and differentiable analytical algorithm to replace slow iterative methods. The result is a method that works for any p-norm and is implemented in popular machine learning frameworks.
Rescaling a vector $\vecδ \in \mathbb{R}^n$ to a desired length is a common operation in many areas such as data science and machine learning. When the rescaled perturbation $η\vecδ$ is added to a starting point $\vec{x} \in D$ (where $D$ is the data domain, e.g. $D = [0, 1]^n$), the resulting vector $\vec{v} = \vec{x} + η\vecδ$ will in general not be in $D$. To enforce that the perturbed vector $v$ is in $D$, the values of $\vec{v}$ can be clipped to $D$. This subsequent element-wise clipping to the data domain does however reduce the effective perturbation size and thus interferes with the rescaling of $\vecδ$. The optimal rescaling $η$ to obtain a perturbation with the desired norm after the clipping can be iteratively approximated using a binary search. However, such an iterative approach is slow and non-differentiable. Here we show that the optimal rescaling can be found analytically using a fast and differentiable algorithm. Our algorithm works for any p-norm and can be used to train neural networks on inputs with normalized perturbations. We provide native implementations for PyTorch, TensorFlow, JAX, and NumPy based on EagerPy.