CVApr 19, 2023Code
A Robust and Interpretable Deep Learning Framework for Multi-modal Registration via KeypointsAlan Q. Wang, Evan M. Yu, Adrian V. Dalca et al.
We present KeyMorph, a deep learning-based image registration framework that relies on automatically detecting corresponding keypoints. State-of-the-art deep learning methods for registration often are not robust to large misalignments, are not interpretable, and do not incorporate the symmetries of the problem. In addition, most models produce only a single prediction at test-time. Our core insight which addresses these shortcomings is that corresponding keypoints between images can be used to obtain the optimal transformation via a differentiable closed-form expression. We use this observation to drive the end-to-end learning of keypoints tailored for the registration task, and without knowledge of ground-truth keypoints. This framework not only leads to substantially more robust registration but also yields better interpretability, since the keypoints reveal which parts of the image are driving the final alignment. Moreover, KeyMorph can be designed to be equivariant under image translations and/or symmetric with respect to the input image ordering. Finally, we show how multiple deformation fields can be computed efficiently and in closed-form at test time corresponding to different transformation variants. We demonstrate the proposed framework in solving 3D affine and spline-based registration of multi-modal brain MRI scans. In particular, we show registration accuracy that surpasses current state-of-the-art methods, especially in the context of large displacements. Our code is available at https://github.com/alanqrwang/keymorph.
CVMar 26Code
Diffusion MRI Transformer with a Diffusion Space Rotary Positional Embedding (D-RoPE)Gustavo Chau Loo Kung, Mohammad Abbasi, Camila Blank et al.
Diffusion Magnetic Resonance Imaging (dMRI) plays a critical role in studying microstructural changes in the brain. It is, therefore, widely used in clinical practice; yet progress in learning general-purpose representations from dMRI has been limited. A key challenge is that existing deep learning approaches are not well-suited to capture the unique properties of diffusion signals. Brain dMRI is normally composed of several brain volumes, each with different attenuation characteristics dependent on the direction and strength of the diffusion-sensitized gradients. Thus, there is a need to jointly model spatial, diffusion-weighting, and directional dependencies in dMRI. Furthermore, varying acquisition protocols (e.g., differing numbers of directions) further limit traditional models. To address these gaps, we introduce a diffusion space rotatory positional embedding (D-RoPE) plugged into our dMRI transformer to capture both the spatial structure and directional characteristics of diffusion data, enabling robust and transferable representations across diverse acquisition settings and an arbitrary number of diffusion directions. After self-supervised masked autoencoding pretraining, tests on several downstream tasks show that the learned representations and the pretrained model can provide competitive or superior performance compared to several baselines in these downstream tasks (even compared to a fully trained baseline); the finetuned features from our pretrained encoder resulted in a 6% higher accuracy in classifying mild cognitive impairment and a 0.05 increase in the correlation coefficient when predicting cognitive scores. Code is available at: github.com/gustavochau/D-RoPE.
LGOct 2, 2023
A Framework for Interpretability in Machine Learning for Medical ImagingAlan Q. Wang, Batuhan K. Karaman, Heejong Kim et al.
Interpretability for machine learning models in medical imaging (MLMI) is an important direction of research. However, there is a general sense of murkiness in what interpretability means. Why does the need for interpretability in MLMI arise? What goals does one actually seek to address when interpretability is needed? To answer these questions, we identify a need to formalize the goals and elements of interpretability in MLMI. By reasoning about real-world tasks and goals common in both medical image analysis and its intersection with machine learning, we identify five core elements of interpretability: localization, visual recognizability, physical attribution, model transparency, and actionability. From this, we arrive at a framework for interpretability in MLMI, which serves as a step-by-step guide to approaching interpretability in this context. Overall, this paper formalizes interpretability needs in the context of medical imaging, and our applied perspective clarifies concrete MLMI-specific goals and considerations in order to guide method design and improve real-world usage. Our goal is to provide practical and didactic information for model designers and practitioners, inspire developers of models in the medical imaging field to reason more deeply about what interpretability is achieving, and suggest future directions of interpretability research.
CVDec 7, 2022
A Flexible Nadaraya-Watson Head Can Offer Explainable and Calibrated ClassificationAlan Q. Wang, Mert R. Sabuncu
In this paper, we empirically analyze a simple, non-learnable, and nonparametric Nadaraya-Watson (NW) prediction head that can be used with any neural network architecture. In the NW head, the prediction is a weighted average of labels from a support set. The weights are computed from distances between the query feature and support features. This is in contrast to the dominant approach of using a learnable classification head (e.g., a fully-connected layer) on the features, which can be challenging to interpret and can yield poorly calibrated predictions. Our empirical results on an array of computer vision tasks demonstrate that the NW head can yield better calibration with comparable accuracy compared to its parametric counterpart, particularly in data-limited settings. To further increase inference-time efficiency, we propose a simple approach that involves a clustering step run on the training set to create a relatively small distilled support set. Furthermore, we explore two means of interpretability/explainability that fall naturally from the NW head. The first is the label weights, and the second is our novel concept of the ``support influence function,'' which is an easy-to-compute metric that quantifies the influence of a support element on the prediction for a given query. As we demonstrate in our experiments, the influence function can allow the user to debug a trained model. We believe that the NW head is a flexible, interpretable, and highly useful building block that can be used in a range of applications.
LGSep 9, 2024
Adapting to Shifting Correlations with Unlabeled Data CalibrationMinh Nguyen, Alan Q. Wang, Heejong Kim et al.
Distribution shifts between sites can seriously degrade model performance since models are prone to exploiting unstable correlations. Thus, many methods try to find features that are stable across sites and discard unstable features. However, unstable features might have complementary information that, if used appropriately, could increase accuracy. More recent methods try to adapt to unstable features at the new sites to achieve higher accuracy. However, they make unrealistic assumptions or fail to scale to multiple confounding features. We propose Generalized Prevalence Adjustment (GPA for short), a flexible method that adjusts model predictions to the shifting correlations between prediction target and confounders to safely exploit unstable features. GPA can infer the interaction between target and confounders in new sites using unlabeled samples from those sites. We evaluate GPA on several real and synthetic datasets, and show that it outperforms competitive baselines.
LGOct 24, 2023
Robust Learning via Conditional Prevalence AdjustmentMinh Nguyen, Alan Q. Wang, Heejong Kim et al.
Healthcare data often come from multiple sites in which the correlations between confounding variables can vary widely. If deep learning models exploit these unstable correlations, they might fail catastrophically in unseen sites. Although many methods have been proposed to tackle unstable correlations, each has its limitations. For example, adversarial training forces models to completely ignore unstable correlations, but doing so may lead to poor predictive performance. Other methods (e.g. Invariant risk minimization [4]) try to learn domain-invariant representations that rely only on stable associations by assuming a causal data-generating process (input X causes class label Y ). Thus, they may be ineffective for anti-causal tasks (Y causes X), which are common in computer vision. We propose a method called CoPA (Conditional Prevalence-Adjustment) for anti-causal tasks. CoPA assumes that (1) generation mechanism is stable, i.e. label Y and confounding variable(s) Z generate X, and (2) the unstable conditional prevalence in each site E fully accounts for the unstable correlations between X and Y . Our crucial observation is that confounding variables are routinely recorded in healthcare settings and the prevalence can be readily estimated, for example, from a set of (Y, Z) samples (no need for corresponding samples of X). CoPA can work even if there is a single training site, a scenario which is often overlooked by existing methods. Our experiments on synthetic and real data show CoPA beating competitive baselines.
LGSep 23, 2023
Learning Invariant Representations with a Nonparametric Nadaraya-Watson HeadAlan Q. Wang, Minh Nguyen, Mert R. Sabuncu
Machine learning models will often fail when deployed in an environment with a data distribution that is different than the training distribution. When multiple environments are available during training, many methods exist that learn representations which are invariant across the different distributions, with the hope that these representations will be transportable to unseen domains. In this work, we present a nonparametric strategy for learning invariant representations based on the recently-proposed Nadaraya-Watson (NW) head. The NW head makes a prediction by comparing the learned representations of the query to the elements of a support set that consists of labeled data. We demonstrate that by manipulating the support set, one can encode different causal assumptions. In particular, restricting the support set to a single environment encourages the model to learn invariant features that do not depend on the environment. We present a causally-motivated setup for our modeling and training strategy and validate on three challenging real-world domain generalization tasks in computer vision.
CVDec 24, 2025Code
A Tool Bottleneck Framework for Clinically-Informed and Interpretable Medical Image UnderstandingChristina Liu, Alan Q. Wang, Joy Hsu et al.
Recent tool-use frameworks powered by vision-language models (VLMs) improve image understanding by grounding model predictions with specialized tools. Broadly, these frameworks leverage VLMs and a pre-specified toolbox to decompose the prediction task into multiple tool calls (often deep learning models) which are composed to make a prediction. The dominant approach to composing tools is using text, via function calls embedded in VLM-generated code or natural language. However, these methods often perform poorly on medical image understanding, where salient information is encoded as spatially-localized features that are difficult to compose or fuse via text alone. To address this, we propose a tool-use framework for medical image understanding called the Tool Bottleneck Framework (TBF), which composes VLM-selected tools using a learned Tool Bottleneck Model (TBM). For a given image and task, TBF leverages an off-the-shelf medical VLM to select tools from a toolbox that each extract clinically-relevant features. Instead of text-based composition, these tools are composed by the TBM, which computes and fuses the tool outputs using a neural network before outputting the final prediction. We propose a simple and effective strategy for TBMs to make predictions with any arbitrary VLM tool selection. Overall, our framework not only improves tool-use in medical imaging contexts, but also yields more interpretable, clinically-grounded predictors. We evaluate TBF on tasks in histopathology and dermatology and find that these advantages enable our framework to perform on par with or better than deep learning-based classifiers, VLMs, and state-of-the-art tool-use frameworks, with particular gains in data-limited regimes. Our code is available at https://github.com/christinaliu2020/tool-bottleneck-framework.
CVMay 22, 2024Code
BrainMorph: A Foundational Keypoint Model for Robust and Flexible Brain MRI RegistrationAlan Q. Wang, Rachit Saluja, Heejong Kim et al.
We present a keypoint-based foundation model for general purpose brain MRI registration, based on the recently-proposed KeyMorph framework. Our model, called BrainMorph, serves as a tool that supports multi-modal, pairwise, and scalable groupwise registration. BrainMorph is trained on a massive dataset of over 100,000 3D volumes, skull-stripped and non-skull-stripped, from nearly 16,000 unique healthy and diseased subjects. BrainMorph is robust to large misalignments, interpretable via interrogating automatically-extracted keypoints, and enables rapid and controllable generation of many plausible transformations with different alignment types and different degrees of nonlinearity at test-time. We demonstrate the superiority of BrainMorph in solving 3D rigid, affine, and nonlinear registration on a variety of multi-modal brain MRI scans of healthy and diseased subjects, in both the pairwise and groupwise setting. In particular, we show registration accuracy and speeds that surpass many classical and learning-based methods, especially in the context of large initial misalignments and large group settings. All code and models are available at https://github.com/alanqrwang/brainmorph.
CVFeb 22, 2022Code
Computing Multiple Image Reconstructions with a Single HypernetworkAlan Q. Wang, Adrian V. Dalca, Mert R. Sabuncu
Deep learning based techniques achieve state-of-the-art results in a wide range of image reconstruction tasks like compressed sensing. These methods almost always have hyperparameters, such as the weight coefficients that balance the different terms in the optimized loss function. The typical approach is to train the model for a hyperparameter setting determined with some empirical or theoretical justification. Thus, at inference time, the model can only compute reconstructions corresponding to the pre-determined hyperparameter values. In this work, we present a hypernetwork-based approach, called HyperRecon, to train reconstruction models that are agnostic to hyperparameter settings. At inference time, HyperRecon can efficiently produce diverse reconstructions, which would each correspond to different hyperparameter values. In this framework, the user is empowered to select the most useful output(s) based on their own judgement. We demonstrate our method in compressed sensing, super-resolution and denoising tasks, using two large-scale and publicly-available MRI datasets. Our code is available at https://github.com/alanqrwang/hyperrecon.
IVFeb 6, 2022Code
Hyper-Convolutions via Implicit Kernels for Medical ImagingTianyu Ma, Alan Q. Wang, Adrian V. Dalca et al.
The convolutional neural network (CNN) is one of the most commonly used architectures for computer vision tasks. The key building block of a CNN is the convolutional kernel that aggregates information from the pixel neighborhood and shares weights across all pixels. A standard CNN's capacity, and thus its performance, is directly related to the number of learnable kernel weights, which is determined by the number of channels and the kernel size (support). In this paper, we present the \textit{hyper-convolution}, a novel building block that implicitly encodes the convolutional kernel using spatial coordinates. Hyper-convolutions decouple kernel size from the total number of learnable parameters, enabling a more flexible architecture design. We demonstrate in our experiments that replacing regular convolutions with hyper-convolutions can improve performance with less parameters, and increase robustness against noise. We provide our code here: \emph{https://github.com/tym002/Hyper-Convolution}
IVJan 6, 2021Code
Regularization-Agnostic Compressed Sensing MRI Reconstruction with HypernetworksAlan Q. Wang, Adrian V. Dalca, Mert R. Sabuncu
Reconstructing under-sampled k-space measurements in Compressed Sensing MRI (CS-MRI) is classically solved with regularized least-squares. Recently, deep learning has been used to amortize this optimization by training reconstruction networks on a dataset of under-sampled measurements. Here, a crucial design choice is the regularization function(s) and corresponding weight(s). In this paper, we explore a novel strategy of using a hypernetwork to generate the parameters of a separate reconstruction network as a function of the regularization weight(s), resulting in a regularization-agnostic reconstruction model. At test time, for a given under-sampled image, our model can rapidly compute reconstructions with different amounts of regularization. We analyze the variability of these reconstructions, especially in situations when the overall quality is similar. Finally, we propose and empirically demonstrate an efficient and data-driven way of maximizing reconstruction performance given limited hypernetwork capacity. Our code is publicly available at https://github.com/alanqrwang/RegAgnosticCSMRI.
IVJul 29, 2020Code
Neural Network-based Reconstruction in Compressed Sensing MRI Without Fully-sampled Training DataAlan Q. Wang, Adrian V. Dalca, Mert R. Sabuncu
Compressed Sensing MRI (CS-MRI) has shown promise in reconstructing under-sampled MR images, offering the potential to reduce scan times. Classical techniques minimize a regularized least-squares cost function using an expensive iterative optimization procedure. Recently, deep learning models have been developed that model the iterative nature of classical techniques by unrolling iterations in a neural network. While exhibiting superior performance, these methods require large quantities of ground-truth images and have shown to be non-robust to unseen data. In this paper, we explore a novel strategy to train an unrolled reconstruction network in an unsupervised fashion by adopting a loss function widely-used in classical optimization schemes. We demonstrate that this strategy achieves lower loss and is computationally cheap compared to classical optimization solvers while also exhibiting superior robustness compared to supervised models. Code is available at https://github.com/alanqrwang/HQSNet.
IVJul 26, 2019Code
Deep-learning-based Optimization of the Under-sampling Pattern in MRICagla D. Bahadir, Alan Q. Wang, Adrian V. Dalca et al.
In compressed sensing MRI (CS-MRI), k-space measurements are under-sampled to achieve accelerated scan times. CS-MRI presents two fundamental problems: (1) where to sample and (2) how to reconstruct an under-sampled scan. In this paper, we tackle both problems simultaneously for the specific case of 2D Cartesian sampling, using a novel end-to-end learning framework that we call LOUPE (Learning-based Optimization of the Under-sampling PattErn). Our method trains a neural network model on a set of full-resolution MRI scans, which are retrospectively under-sampled on a 2D Cartesian grid and forwarded to an anti-aliasing (a.k.a. reconstruction) model that computes a reconstruction, which is in turn compared with the input. This formulation enables a data-driven optimized under-sampling pattern at a given sparsity level. In our experiments, we demonstrate that LOUPE-optimized under-sampling masks are data-dependent, varying significantly with the imaged anatomy, and perform well with different reconstruction methods. We present empirical results obtained with a large-scale, publicly available knee MRI dataset, where LOUPE offered superior reconstruction quality across different conditions. Even with an aggressive 8-fold acceleration rate, LOUPE's reconstructions contained much of the anatomical detail that was missed by alternative masks and reconstruction methods. Our experiments also show how LOUPE yielded optimal under-sampling patterns that were significantly different for brain vs knee MRI scans. Our code is made freely available at https://github.com/cagladbahadir/LOUPE/.
CVOct 25, 2025
Discovering Latent Graphs with GFlowNets for Diverse Conditional Image GenerationBailey Trang, Parham Saremi, Alan Q. Wang et al.
Capturing diversity is crucial in conditional and prompt-based image generation, particularly when conditions contain uncertainty that can lead to multiple plausible outputs. To generate diverse images reflecting this diversity, traditional methods often modify random seeds, making it difficult to discern meaningful differences between samples, or diversify the input prompt, which is limited in verbally interpretable diversity. We propose Rainbow, a novel conditional image generation framework, applicable to any pretrained conditional generative model, that addresses inherent condition/prompt uncertainty and generates diverse plausible images. Rainbow is based on a simple yet effective idea: decomposing the input condition into diverse latent representations, each capturing an aspect of the uncertainty and generating a distinct image. First, we integrate a latent graph, parameterized by Generative Flow Networks (GFlowNets), into the prompt representation computation. Second, leveraging GFlowNets' advanced graph sampling capabilities to capture uncertainty and output diverse trajectories over the graph, we produce multiple trajectories that collectively represent the input condition, leading to diverse condition representations and corresponding output images. Evaluations on natural image and medical image datasets demonstrate Rainbow's improvement in both diversity and fidelity across image synthesis, image generation, and counterfactual generation tasks.
CVJun 12, 2025
RealKeyMorph: Keypoints in Real-world Coordinates for Resolution-agnostic Image RegistrationMina C. Moghadam, Alan Q. Wang, Omer Taub et al.
Many real-world settings require registration of a pair of medical images that differ in spatial resolution, which may arise from differences in image acquisition parameters like pixel spacing, slice thickness, and field-of-view. However, all previous machine learning-based registration techniques resample images onto a fixed resolution. This is suboptimal because resampling can introduce artifacts due to interpolation. To address this, we present RealKeyMorph (RKM), a resolution-agnostic method for image registration. RKM is an extension of KeyMorph, a registration framework which works by training a network to learn corresponding keypoints for a given pair of images, after which a closed-form keypoint matching step is used to derive the transformation that aligns them. To avoid resampling and enable operating on the raw data, RKM outputs keypoints in real-world coordinates of the scanner. To do this, we leverage the affine matrix produced by the scanner (e.g., MRI machine) that encodes the mapping from voxel coordinates to real world coordinates. By transforming keypoints into real-world space and integrating this into the training process, RKM effectively enables the extracted keypoints to be resolution-agnostic. In our experiments, we demonstrate the advantages of RKM on the registration task for orthogonal 2D stacks of abdominal MRIs, as well as 3D volumes with varying resolutions in brain datasets.
IVMay 17, 2021
Joint Optimization of Hadamard Sensing and Reconstruction in Compressed Sensing Fluorescence MicroscopyAlan Q. Wang, Aaron K. LaViolette, Leo Moon et al.
Compressed sensing fluorescence microscopy (CS-FM) proposes a scheme whereby less measurements are collected during sensing and reconstruction is performed to recover the image. Much work has gone into optimizing the sensing and reconstruction portions separately. We propose a method of jointly optimizing both sensing and reconstruction end-to-end under a total measurement constraint, enabling learning of the optimal sensing scheme concurrently with the parameters of a neural network-based reconstruction network. We train our model on a rich dataset of confocal, two-photon, and wide-field microscopy images comprising of a variety of biological samples. We show that our method outperforms several baseline sensing schemes and a regularized regression reconstruction algorithm.