Robust Conditional Conformal Prediction via Branched Normalizing Flow
For practitioners using conformal prediction under distribution shift, this work addresses the critical issue of unreliable conditional coverage at individual test inputs.
The paper bounds the conditional coverage error of conformal prediction under distribution shift using Wasserstein distance and proposes Branched Normalizing Flow (BNF) to mitigate it. BNF consistently improves conditional coverage robustness across nine datasets.
Conformal prediction (CP) constructs prediction sets with marginal coverage guarantees under the assumption that the calibration and test distributions are identical. However, under distribution shift, existing approaches primarily align marginal conformal score distributions, which is sufficient to preserve marginal coverage but does not control the conditional coverage error at individual test inputs. As a consequence, CP can remain unreliable in regions where the conditional score distributions are mismatched. In this work, we bound the conditional invalidity of CP under distribution shift in terms of the Wasserstein distance between the calibration and test distributions. This result highlights the role of invertible transport in mitigating conditional coverage degradation. Motivated by this insight, we introduce Branched Normalizing Flow (BNF), a two-branch architecture that normalizes a test input to the calibration distribution and transforms the prediction set of the normalized input back to the test distribution while preserving conditional guarantees. Empirically, BNF consistently improves conditional coverage robustness on nine datasets across a wide range of confidence levels.