LGApr 15, 2024

Lowering PyTorch's Memory Consumption for Selective Differentiation

arXiv:2404.12406v22 citationsh-index: 8
Originality Incremental advance
AI Analysis

This addresses memory limitations for deep learning practitioners, particularly in fine-tuning tasks, but is incremental as it builds on existing PyTorch infrastructure.

The paper tackles the problem of high memory consumption in PyTorch's automatic differentiation by exploiting parameter differentiability information to discard unnecessary computation graph elements, resulting in reduced memory usage without affecting runtime.

Memory is a limiting resource for many deep learning tasks. Beside the neural network weights, one main memory consumer is the computation graph built up by automatic differentiation (AD) for backpropagation. We observe that PyTorch's current AD implementation neglects information about parameter differentiability when storing the computation graph. This information is useful though to reduce memory whenever gradients are requested for a parameter subset, as is the case in many modern fine-tuning tasks. Specifically, inputs to layers that act linearly in their parameters (dense, convolution, or normalization layers) can be discarded whenever the parameters are marked as non-differentiable. We provide a drop-in, differentiability-agnostic implementation of such layers and demonstrate its ability to reduce memory without affecting run time.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes