LGSEApr 27, 2023

JaxPruner: A concise library for sparsity research

MILA
arXiv:2304.14082v319 citationsh-index: 35Has Code
Originality Synthesis-oriented
AI Analysis

This is an incremental contribution that provides a useful tool for researchers working on sparse neural networks.

The paper introduces JaxPruner, a JAX-based library for pruning and sparse training in machine learning, designed to accelerate research by providing concise implementations with minimal overhead and easy integration with existing tools.

This paper introduces JaxPruner, an open-source JAX-based pruning and sparse training library for machine learning research. JaxPruner aims to accelerate research on sparse neural networks by providing concise implementations of popular pruning and sparse training algorithms with minimal memory and latency overhead. Algorithms implemented in JaxPruner use a common API and work seamlessly with the popular optimization library Optax, which, in turn, enables easy integration with existing JAX based libraries. We demonstrate this ease of integration by providing examples in four different codebases: Scenic, t5x, Dopamine and FedJAX and provide baseline experiments on popular benchmarks.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes