COIMLGMay 10, 2023

CosmoPower-JAX: high-dimensional Bayesian inference with differentiable cosmological emulators

arXiv:2305.06347v230 citations
Originality Highly original
AI Analysis

This work addresses the computational bottleneck in cosmological parameter estimation for next-generation surveys, offering a significant speed-up for the cosmological community.

The paper tackles the challenge of accelerating high-dimensional Bayesian inference in cosmology by introducing CosmoPower-JAX, a JAX-based framework that builds neural emulators of cosmological power spectra, achieving speed-ups of up to 1000x for 37 parameters and reducing inference time from an estimated 6 years to 3 days for 157 parameters.

We present CosmoPower-JAX, a JAX-based implementation of the CosmoPower framework, which accelerates cosmological inference by building neural emulators of cosmological power spectra. We show how, using the automatic differentiation, batch evaluation and just-in-time compilation features of JAX, and running the inference pipeline on graphics processing units (GPUs), parameter estimation can be accelerated by orders of magnitude with advanced gradient-based sampling techniques. These can be used to efficiently explore high-dimensional parameter spaces, such as those needed for the analysis of next-generation cosmological surveys. We showcase the accuracy and computational efficiency of CosmoPower-JAX on two simulated Stage IV configurations. We first consider a single survey performing a cosmic shear analysis totalling 37 model parameters. We validate the contours derived with CosmoPower-JAX and a Hamiltonian Monte Carlo sampler against those derived with a nested sampler and without emulators, obtaining a speed-up factor of $\mathcal{O}(10^3)$. We then consider a combination of three Stage IV surveys, each performing a joint cosmic shear and galaxy clustering (3x2pt) analysis, for a total of 157 model parameters. Even with such a high-dimensional parameter space, CosmoPower-JAX provides converged posterior contours in 3 days, as opposed to the estimated 6 years required by standard methods. CosmoPower-JAX is fully written in Python, and we make it publicly available to help the cosmological community meet the accuracy requirements set by next-generation surveys.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes