GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent
This addresses the issue of inaccurate trees from greedy algorithms for machine learning practitioners, though it is incremental as it builds on existing gradient-based optimization techniques.
The paper tackles the problem of learning decision trees, which is non-convex and non-differentiable, by introducing a novel method that uses gradient descent with backpropagation and a straight-through operator to optimize all tree parameters jointly, resulting in outperforming existing methods on binary classification benchmarks and achieving competitive results for multi-class tasks.
Decision Trees (DTs) are commonly used for many machine learning tasks due to their high degree of interpretability. However, learning a DT from data is a difficult optimization problem, as it is non-convex and non-differentiable. Therefore, common approaches learn DTs using a greedy growth algorithm that minimizes the impurity locally at each internal node. Unfortunately, this greedy procedure can lead to inaccurate trees. In this paper, we present a novel approach for learning hard, axis-aligned DTs with gradient descent. The proposed method uses backpropagation with a straight-through operator on a dense DT representation, to jointly optimize all tree parameters. Our approach outperforms existing methods on binary classification benchmarks and achieves competitive results for multi-class tasks. The method is available under: https://github.com/s-marton/GradTree