KTAN: Knowledge Transfer Adversarial Network
This work addresses the need for efficient deep learning models, offering a method to compress networks while maintaining performance, though it is incremental as it builds on existing knowledge distillation techniques.
The authors tackled the problem of reducing computation and storage costs in deep convolutional neural networks by proposing a knowledge transfer adversarial network that holistically transfers both intermediate representations and probability distributions from a teacher to a student network, significantly improving student performance on image classification and object detection tasks.
To reduce the large computation and storage cost of a deep convolutional neural network, the knowledge distillation based methods have pioneered to transfer the generalization ability of a large (teacher) deep network to a light-weight (student) network. However, these methods mostly focus on transferring the probability distribution of the softmax layer in a teacher network and thus neglect the intermediate representations. In this paper, we propose a knowledge transfer adversarial network to better train a student network. Our technique holistically considers both intermediate representations and probability distributions of a teacher network. To transfer the knowledge of intermediate representations, we set high-level teacher feature maps as a target, toward which the student feature maps are trained. Specifically, we arrange a Teacher-to-Student layer for enabling our framework suitable for various student structures. The intermediate representation helps the student network better understand the transferred generalization as compared to the probability distribution only. Furthermore, we infuse an adversarial learning process by employing a discriminator network, which can fully exploit the spatial correlation of feature maps in training a student network. The experimental results demonstrate that the proposed method can significantly improve the performance of a student network on both image classification and object detection tasks.