JaxDecompiler: Redefining Gradient-Informed Software Design
This tool simplifies reverse engineering and customization for users of JAX, particularly in deep learning, but is incremental as it builds on existing JAX capabilities.
The paper tackles the problem of editing JAX functions by introducing JaxDecompiler, a tool that transforms any JAX function into editable Python code, with results showing decompiled code performance similar to the original.
Among numerical libraries capable of computing gradient descent optimization, JAX stands out by offering more features, accelerated by an intermediate representation known as Jaxpr language. However, editing the Jaxpr code is not directly possible. This article introduces JaxDecompiler, a tool that transforms any JAX function into an editable Python code, especially useful for editing the JAX function generated by the gradient function. JaxDecompiler simplifies the processes of reverse engineering, understanding, customizing, and interoperability of software developed by JAX. We highlight its capabilities, emphasize its practical applications especially in deep learning and more generally gradient-informed software, and demonstrate that the decompiled code speed performance is similar to the original.