LGAIOct 16, 2024

Flash Inference: Near Linear Time Inference for Long Convolution Sequence Models and Beyond

arXiv:2410.12982v23 citationsh-index: 96ICLR
Originality Highly original
AI Analysis

This addresses computational bottlenecks for users of long sequence models, offering significant speedups for inference tasks.

The paper tackles the quadratic inference cost of long convolution sequence models (LCSMs) like Hyena, proposing a method that speeds up exact inference to quasilinear O(L log^2 L) time, achieving up to 7.8x end-to-end improvement and 110x within the position-mixing part.

While transformers have been at the core of most recent advancements in sequence generative models, their computational cost remains quadratic in sequence length. Several subquadratic architectures have been proposed to address this computational issue. Some of them, including long convolution sequence models (LCSMs), such as Hyena, address this issue at training time but remain quadratic during inference. We propose a method for speeding up LCSMs' exact inference to quasilinear $O(L\log^2L)$ time, identify the key properties that make this possible, and propose a general framework that exploits these. Our approach, inspired by previous work on relaxed polynomial interpolation, is based on a tiling which helps decrease memory movement and share computation. It has the added benefit of allowing for almost complete parallelization across layers of the position-mixing part of the architecture. Empirically, we provide a proof of concept implementation for Hyena, which gets up to $7.8\times$ end-to-end improvement over standard inference by improving $110\times$ within the position-mixing part.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes