SocialJax: An Evaluation Suite for Multi-agent Reinforcement Learning in Sequential Social Dilemmas
This work addresses the problem of high computational costs for researchers in multi-agent reinforcement learning, though it is incremental as it focuses on efficiency improvements rather than new algorithmic breakthroughs.
The paper tackles the computational inefficiency of evaluating multi-agent reinforcement learning in sequential social dilemmas by introducing SocialJax, a suite implemented in JAX, which achieves at least 50x speed-up in real-time performance compared to existing baselines.
Sequential social dilemmas pose a significant challenge in the field of multi-agent reinforcement learning (MARL), requiring environments that accurately reflect the tension between individual and collective interests. Previous benchmarks and environments, such as Melting Pot, provide an evaluation protocol that measures generalization to new social partners in various test scenarios. However, running reinforcement learning algorithms in traditional environments requires substantial computational resources. In this paper, we introduce SocialJax, a suite of sequential social dilemma environments and algorithms implemented in JAX. JAX is a high-performance numerical computing library for Python that enables significant improvements in operational efficiency. Our experiments demonstrate that the SocialJax training pipeline achieves at least 50\texttimes{} speed-up in real-time performance compared to Melting Pot RLlib baselines. Additionally, we validate the effectiveness of baseline algorithms within SocialJax environments. Finally, we use Schelling diagrams to verify the social dilemma properties of these environments, ensuring that they accurately capture the dynamics of social dilemmas.