LGAIJul 28, 2024

NAVIX: Scaling MiniGrid Environments with JAX

arXiv:2407.19396v114 citationsh-index: 9
Originality Incremental advance
AI Analysis

This work addresses a bottleneck for researchers in deep reinforcement learning by enabling faster and more scalable experiments, though it is incremental as it re-implements an existing environment.

The authors tackled the problem of slow environment simulations in deep reinforcement learning by re-implementing MiniGrid in JAX, achieving over 200,000x speed improvements in batch mode and reducing experiment times from one week to 15 minutes.

As Deep Reinforcement Learning (Deep RL) research moves towards solving large-scale worlds, efficient environment simulations become crucial for rapid experimentation. However, most existing environments struggle to scale to high throughput, setting back meaningful progress. Interactions are typically computed on the CPU, limiting training speed and throughput, due to slower computation and communication overhead when distributing the task across multiple machines. Ultimately, Deep RL training is CPU-bound, and developing batched, fast, and scalable environments has become a frontier for progress. Among the most used Reinforcement Learning (RL) environments, MiniGrid is at the foundation of several studies on exploration, curriculum learning, representation learning, diversity, meta-learning, credit assignment, and language-conditioned RL, and still suffers from the limitations described above. In this work, we introduce NAVIX, a re-implementation of MiniGrid in JAX. NAVIX achieves over 200 000x speed improvements in batch mode, supporting up to 2048 agents in parallel on a single Nvidia A100 80 GB. This reduces experiment times from one week to 15 minutes, promoting faster design iterations and more scalable RL model development.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes