Efficient and Scalable Implementation of Differentially Private Deep Learning without Shortcuts
This work addresses computational bottlenecks for researchers and practitioners implementing differentially private deep learning, offering incremental improvements to existing methods.
The paper tackles the computational inefficiency of implementing differentially private stochastic gradient descent (DP-SGD) with correct Poisson subsampling, finding that naive implementations in PyTorch reduce throughput by 2.6 to 8 times compared to SGD, but efficient methods like Ghost Clipping can halve this cost, and a proposed JAX implementation scales well up to 80 GPUs.
Differentially private stochastic gradient descent (DP-SGD) is the standard algorithm for training machine learning models under differential privacy (DP). The most common DP-SGD privacy accountants rely on Poisson subsampling to ensure the theoretical DP guarantees. Implementing computationally efficient DP-SGD with Poisson subsampling is not trivial, which leads many implementations to taking a shortcut by using computationally faster subsampling. We quantify the computational cost of training deep learning models under DP by implementing and benchmarking efficient methods with the correct Poisson subsampling. We find that using the naive implementation of DP-SGD with Opacus in PyTorch has a throughput between 2.6 and 8 times lower than that of SGD. However, efficient gradient clipping implementations like Ghost Clipping can roughly halve this cost. We propose an alternative computationally efficient implementation of DP-SGD with JAX that uses Poisson subsampling and performs comparably with efficient clipping optimizations based on PyTorch. We study the scaling behavior using up to 80 GPUs and find that DP-SGD scales better than SGD. We share our library at https://github.com/DPBayes/Towards-Efficient-Scalable-Training-DP-DL.