No More Sliding Window: Efficient 3D Medical Image Segmentation with Differentiable Top-k Patch Sampling
This addresses the slow inference problem for 3D medical segmentation tasks, offering a model-agnostic efficiency boost that is incremental but impactful.
The paper tackles the inefficiency of sliding window inference in 3D medical image segmentation by proposing NMSW, which uses differentiable Top-k patch sampling to reduce redundant computations, achieving competitive accuracy while cutting computational complexity by 91% and speeding up inference by up to 11.1x.
3D models surpass 2D models in CT/MRI segmentation by effectively capturing inter-slice relationships. However, the added depth dimension substantially increases memory consumption. While patch-based training alleviates memory constraints, it significantly slows down the inference speed due to the sliding window (SW) approach. We propose No-More-Sliding-Window (NMSW), a novel end-to-end trainable framework that enhances the efficiency of generic 3D segmentation backbone during an inference step by eliminating the need for SW. NMSW employs a differentiable Top-k module to selectively sample only the most relevant patches, thereby minimizing redundant computations. When patch-level predictions are insufficient, the framework intelligently leverages coarse global predictions to refine results. Evaluated across 3 tasks using 3 segmentation backbones, NMSW achieves competitive accuracy compared to SW inference while significantly reducing computational complexity by 91% (88.0 to 8.00 TMACs). Moreover, it delivers a 9.1x faster inference on the H100 GPU (99.0 to 8.3 sec) and a 11.1x faster inference on the Xeon Gold CPU (2110 to 189 sec). NMSW is model-agnostic, further boosting efficiency when integrated with any existing efficient segmentation backbones. The code is avaialble: https://github.com/Youngseok0001/open_nmsw.