JaxWildfire: A GPU-Accelerated Wildfire Simulator for Reinforcement Learning
This work addresses a bottleneck for researchers using reinforcement learning in wildfire management, though it is incremental as it focuses on simulation speed rather than a new RL method.
The authors tackled the slow speed of existing wildfire simulators, which limits training reinforcement learning agents, by introducing JaxWildfire, a GPU-accelerated simulator that achieves 6-35x speedup and enables gradient-based optimization.
Artificial intelligence methods are increasingly being explored for managing wildfires and other natural hazards. In particular, reinforcement learning (RL) is a promising path towards improving outcomes in such uncertain decision-making scenarios and moving beyond reactive strategies. However, training RL agents requires many environment interactions, and the speed of existing wildfire simulators is a severely limiting factor. We introduce $\texttt{JaxWildfire}$, a simulator underpinned by a principled probabilistic fire spread model based on cellular automata. It is implemented in JAX and enables vectorized simulations using $\texttt{vmap}$, allowing high throughput of simulations on GPUs. We demonstrate that $\texttt{JaxWildfire}$ achieves 6-35x speedup over existing software and enables gradient-based optimization of simulator parameters. Furthermore, we show that $\texttt{JaxWildfire}$ can be used to train RL agents to learn wildfire suppression policies. Our work is an important step towards enabling the advancement of RL techniques for managing natural hazards.