Clustering Aware Classification for Risk Prediction and Subtyping in Clinical Data
This addresses the challenge of risk prediction and subtyping in clinical data, where heterogeneous subpopulations can degrade classifier performance, but the approach is incremental as it builds on existing clustering and classification ideas.
The paper tackles the problem of improving classification performance in heterogeneous data by incorporating cluster structure, proposing a method called DeepCAC that learns embeddings and clusters to train classifiers for subpopulations, achieving better results than previous methods on synthetic and real datasets.
In data containing heterogeneous subpopulations, classification performance benefits from incorporating the knowledge of cluster structure in the classifier. Previous methods for such combined clustering and classification either 1) are classifier-specific and not generic, or 2) independently perform clustering and classifier training, which may not form clusters that can potentially benefit classifier performance. The question of how to perform clustering to improve the performance of classifiers trained on the clusters has received scant attention in previous literature, despite its importance in several real-world applications. In this paper, first, we theoretically analyze the generalization performance of classifiers trained on clustered data and find conditions under which clustering can potentially aid classification. This motivates the design of a simple k-means-based classification algorithm called Clustering Aware Classification (CAC) and its neural variant {DeepCAC}. DeepCAC effectively leverages deep representation learning to learn latent embeddings and finds clusters in a manner that make the clustered data suitable for training classifiers for each underlying subpopulation. Our experiments on synthetic and real benchmark datasets demonstrate the efficacy of DeepCAC over previous methods for combined clustering and classification.