LGMLNov 1, 2023

Kronecker-Factored Approximate Curvature for Modern Neural Network Architectures

arXiv:2311.00636v245 citationsh-index: 15
Originality Incremental advance
AI Analysis

This work addresses the problem of speeding up training for researchers and practitioners using complex neural network architectures, though it is incremental as it extends an existing optimization method to new settings.

The authors tackled the challenge of applying Kronecker-Factored Approximate Curvature (K-FAC) to modern neural networks with weight-sharing layers, such as transformers and graph neural networks, by developing two variants (expand and reduce) that achieved a fixed validation metric target in 50-75% of the steps compared to first-order methods, translating to comparable wall-clock time improvements.

The core components of many modern neural network architectures, such as transformers, convolutional, or graph neural networks, can be expressed as linear layers with $\textit{weight-sharing}$. Kronecker-Factored Approximate Curvature (K-FAC), a second-order optimisation method, has shown promise to speed up neural network training and thereby reduce computational costs. However, there is currently no framework to apply it to generic architectures, specifically ones with linear weight-sharing layers. In this work, we identify two different settings of linear weight-sharing layers which motivate two flavours of K-FAC -- $\textit{expand}$ and $\textit{reduce}$. We show that they are exact for deep linear networks with weight-sharing in their respective setting. Notably, K-FAC-reduce is generally faster than K-FAC-expand, which we leverage to speed up automatic hyperparameter selection via optimising the marginal likelihood for a Wide ResNet. Finally, we observe little difference between these two K-FAC variations when using them to train both a graph neural network and a vision transformer. However, both variations are able to reach a fixed validation metric target in $50$-$75\%$ of the number of steps of a first-order reference run, which translates into a comparable improvement in wall-clock time. This highlights the potential of applying K-FAC to modern neural network architectures.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes