FLASH-MAXSIM: IO-Aware Fused Kernels for Late-Interaction Scoring
For practitioners of late-interaction retrieval (e.g., ColBERT, ColPali), Flash-MaxSim removes the memory bottleneck that limits batch sizes in inference and training, enabling larger-scale retrieval without sacrificing accuracy.
Flash-MaxSim is an IO-aware fused GPU kernel for late-interaction retrieval that eliminates materialization of the full similarity tensor, achieving up to 3.9x speedup on A100 (4.7x on H100) and up to 16x less inference memory while preserving exact ranking.
Late-interaction retrieval (ColBERT, ColPali) scores a query against a document with the MaxSim operator: for every query token, the maximum similarity over the document tokens, summed over query tokens. The standard implementation materializes the full query-token x document-token similarity tensor in GPU memory; for visual ColPali at 10K documents this tensor alone is 21 GB in FP16, created only to be reduced to one score per document and discarded. It exhausts a 40 GB GPU and bounds the achievable batch size in both inference and training. We present Flash-MaxSim, an IO-aware fused GPU kernel that computes exactly the same scores without ever materializing the tensor, by streaming query and document tiles through on-chip SRAM and folding the row-maximum reduction into the same pass. We extend the IO-aware principle through the training backward pass, an inverse-grid CSR construction that reuses the forward argmax for an atomic-free, destination-owned gradient reduction, and through INT8xINT8 quantization and variable-length (padding-free) scoring. Flash-MaxSim is up to 3.9x faster on an A100 (4.7x on an H100) than naive PyTorch at matched precision, uses up to 16x less inference memory and ~28x less training memory, unlocks corpus and batch sizes that exhaust PyTorch entirely, preserves the exact ranking (100% top-20 agreement with an FP32 reference)