Scalable Spatiotemporal Inference with Biased Scan Attention Transformer Neural Processes
This addresses scalability issues for applications in geology, epidemiology, climate, and robotics, though it is incremental as it builds on existing Neural Process architectures.
The paper tackles the scalability-accuracy trade-off in Neural Processes for spatiotemporal modeling by proposing the Biased Scan Attention Transformer Neural Process (BSA-TNP), which matches or exceeds the accuracy of best models while training faster and scaling to over 1M test points in under a minute on a single GPU.
Neural Processes (NPs) are a rapidly evolving class of models designed to directly model the posterior predictive distribution of stochastic processes. While early architectures were developed primarily as a scalable alternative to Gaussian Processes (GPs), modern NPs tackle far more complex and data hungry applications spanning geology, epidemiology, climate, and robotics. These applications have placed increasing pressure on the scalability of these models, with many architectures compromising accuracy for scalability. In this paper, we demonstrate that this tradeoff is often unnecessary, particularly when modeling fully or partially translation invariant processes. We propose a versatile new architecture, the Biased Scan Attention Transformer Neural Process (BSA-TNP), which introduces Kernel Regression Blocks (KRBlocks), group-invariant attention biases, and memory-efficient Biased Scan Attention (BSA). BSA-TNP is able to: (1) match or exceed the accuracy of the best models while often training in a fraction of the time, (2) exhibit translation invariance, enabling learning at multiple resolutions simultaneously, (3) transparently model processes that evolve in both space and time, (4) support high dimensional fixed effects, and (5) scale gracefully -- running inference with over 1M test points with 100K context points in under a minute on a single 24GB GPU.