Learning Deep Nearest Neighbor Representations Using Differentiable Boundary Trees
This work addresses the need for improved representations in kNN-based methods, offering a novel approach that is incremental in building upon existing boundary tree algorithms.
The paper tackles the problem of learning effective representations for nearest neighbor methods by introducing differentiable boundary trees, which enable deep neural networks to learn representations that result in efficient and interpretable trees.
Nearest neighbor (kNN) methods have been gaining popularity in recent years in light of advances in hardware and efficiency of algorithms. There is a plethora of methods to choose from today, each with their own advantages and disadvantages. One requirement shared between all kNN based methods is the need for a good representation and distance measure between samples. We introduce a new method called differentiable boundary tree which allows for learning deep kNN representations. We build on the recently proposed boundary tree algorithm which allows for efficient nearest neighbor classification, regression and retrieval. By modelling traversals in the tree as stochastic events, we are able to form a differentiable cost function which is associated with the tree's predictions. Using a deep neural network to transform the data and back-propagating through the tree allows us to learn good representations for kNN methods. We demonstrate that our method is able to learn suitable representations allowing for very efficient trees with a clearly interpretable structure.