LAST: Scalable Lattice-Based Speech Modelling in JAX
This work addresses performance and scalability issues in speech recognition for researchers and practitioners, though it is incremental as it builds on existing WFSA algorithms.
The authors tackled the challenge of implementing scalable lattice-based speech modeling in JAX by developing LAST, a library that uses differentiable WFSA algorithms, and demonstrated its effectiveness with benchmarks on TPUv3 and V100 GPU.
We introduce LAST, a LAttice-based Speech Transducer library in JAX. With an emphasis on flexibility, ease-of-use, and scalability, LAST implements differentiable weighted finite state automaton (WFSA) algorithms needed for training \& inference that scale to a large WFSA such as a recognition lattice over the entire utterance. Despite these WFSA algorithms being well-known in the literature, new challenges arise from performance characteristics of modern architectures, and from nuances in automatic differentiation. We describe a suite of generally applicable techniques employed in LAST to address these challenges, and demonstrate their effectiveness with benchmarks on TPUv3 and V100 GPU.