Disentangling Neural Disjunctive Normal Form Models
This addresses a specific bottleneck in neuro-symbolic learning for researchers, offering an incremental improvement in interpretability and performance.
The paper tackles performance degradation in neural Disjunctive Normal Form models during symbolic translation by showing it stems from entangled knowledge representation, and proposes a disentanglement method that splits nodes to preserve performance, achieving results closer to pre-translation models in classification tasks.
Neural Disjunctive Normal Form (DNF) based models are powerful and interpretable approaches to neuro-symbolic learning and have shown promising results in classification and reinforcement learning settings without prior knowledge of the tasks. However, their performance is degraded by the thresholding of the post-training symbolic translation process. We show here that part of the performance degradation during translation is due to its failure to disentangle the learned knowledge represented in the form of the networks' weights. We address this issue by proposing a new disentanglement method; by splitting nodes that encode nested rules into smaller independent nodes, we are able to better preserve the models' performance. Through experiments on binary, multiclass, and multilabel classification tasks (including those requiring predicate invention), we demonstrate that our disentanglement method provides compact and interpretable logical representations for the neural DNF-based models, with performance closer to that of their pre-translation counterparts. Our code is available at https://github.com/kittykg/disentangling-ndnf-classification.