MPAX: Mathematical Programming in JAX
This provides a versatile and efficient solution for researchers and practitioners needing to solve LPs in machine learning contexts, though it is incremental as it builds on existing methods.
The paper tackles the problem of integrating linear programming into machine learning workflows by introducing MPAX, a toolbox that implements state-of-the-art first-order methods in JAX, resulting in advantages like hardware acceleration and batch solving as demonstrated in experiments.
This paper presents MPAX (Mathematical Programming in JAX), a versatile and efficient toolbox for integrating linear programming (LP) into machine learning workflows. MPAX implemented the state-of-the-art first-order methods, restarted average primal-dual hybrid gradient and reflected restarted Halpern primal-dual hybrid gradient, to solve LPs in JAX. This provides native support for hardware accelerations along with features like batch solving, auto-differentiation, and device parallelism. Extensive numerical experiments demonstrate the advantages of MPAX over existing solvers. The solver is available at https://github.com/MIT-Lu-Lab/MPAX.