DiagrammaticLearning: A Graphical Language for Compositional Training Regimes
This work offers a novel framework for simplifying and unifying complex training workflows in deep learning, potentially benefiting practitioners by enabling easier composition and manipulation of model components.
The authors introduced learning diagrams, a graphical language for representing compositional training regimes that compile to loss functions, enabling models to agree in predictions. They demonstrated its application to setups like few-shot multi-task learning and knowledge distillation, and provided a library implementation with a category theoretic foundation.
Motivated by deep learning regimes with multiple interacting yet distinct model components, we introduce learning diagrams, graphical depictions of training setups that capture parameterized learning as data rather than code. A learning diagram compiles to a unique loss function on which component models are trained. The result of training on this loss is a collection of models whose predictions ``agree" with one another. We show that a number of popular learning setups such as few-shot multi-task learning, knowledge distillation, and multi-modal learning can be depicted as learning diagrams. We further implement learning diagrams in a library that allows users to build diagrams of PyTorch and Flux.jl models. By implementing some classic machine learning use cases, we demonstrate how learning diagrams allow practitioners to build complicated models as compositions of smaller components, identify relationships between workflows, and manipulate models during or after training. Leveraging a category theoretic framework, we introduce a rigorous semantics for learning diagrams that puts such operations on a firm mathematical foundation.