Kolmogorov-Arnold Attention: Is Learnable Attention Better For Vision Transformers?
This work addresses the challenge of enhancing attention mechanisms in vision Transformers for computer vision tasks, though it is incremental as it builds on existing Kolmogorov-Arnold networks and does not scale well to larger models.
The paper tackles the problem of improving token interactions in vision Transformers by designing Kolmogorov-Arnold Attention (KArAt), a learnable attention mechanism that uses learnable activation functions, and finds that it outperforms or matches traditional softmax attention on datasets like CIFAR-10, CIFAR-100, and ImageNet-1K in some cases.
Kolmogorov-Arnold networks (KANs) are a remarkable innovation that consists of learnable activation functions, with the potential to capture more complex relationships from data. Presently, KANs are deployed by replacing multilayer perceptrons (MLPs) in deep networks, including advanced architectures such as vision Transformers (ViTs). This work asks whether KAN could learn token interactions. In this paper, we design the first learnable attention called Kolmogorov-Arnold Attention (KArAt) for ViTs that can operate on any basis, ranging from Fourier, Wavelets, Splines, to Rational Functions. However, learnable activations in the attention cause a memory explosion. To remedy this, we propose a modular version of KArAt that uses a low-rank approximation. By adopting the Fourier basis, Fourier-KArAt and its variants, in some cases, outperform their traditional softmax counterparts, or show comparable performance on CIFAR-10, CIFAR-100, and ImageNet-1K. We also deploy Fourier KArAt to ConViT and Swin-Transformer, and use it in detection and segmentation with ViT-Det. We dissect the performance of these architectures by analyzing their loss landscapes, weight distributions, optimizer paths, attention visualizations, and transferability to other datasets. KArAt's learnable activation yields a better attention score across all ViTs, indicating improved token-to-token interactions and contributing to enhanced inference. Still, its generalizability does not scale with larger ViTs. However, many factors, including the present computing interface, affect the relative performance of parameter- and memory-heavy KArAts. We note that the goal of this paper is not to produce efficient attention or challenge the traditional activations; by designing KArAt, we are the first to show that attention can be learned and encourage researchers to explore KArAt in conjunction with more advanced architectures.