Enhanced Masked Image Modeling to Avoid Model Collapse on Multi-modal MRI Datasets
This work addresses a specific problem of model collapse in self-supervised learning for medical imaging, offering incremental improvements for multi-modal MRI analysis.
The paper tackles model collapse in masked image modeling (MIM) when applied to multi-modal MRI datasets, introducing an enhanced MIM (E-MIM) with hybrid mask patterns and a pyramid barlow twins module to prevent collapse, resulting in stable training and improved performance on segmentation and classification tasks.
Multi-modal magnetic resonance imaging (MRI) provides information of lesions for computer-aided diagnosis from different views. Deep learning algorithms are suitable for identifying specific anatomical structures, segmenting lesions, and classifying diseases. Manual labels are limited due to the high expense, which hinders further improvement of accuracy. Self-supervised learning, particularly masked image modeling (MIM), has shown promise in utilizing unlabeled data. However, we spot model collapse when applying MIM to multi-modal MRI datasets. The performance of downstream tasks does not see any improvement following the collapsed model. To solve model collapse, we analyze and address it in two types: complete collapse and dimensional collapse. We find complete collapse occurs because the collapsed loss value in multi-modal MRI datasets falls below the normally converged loss value. Based on this, the hybrid mask pattern (HMP) masking strategy is introduced to elevate the collapsed loss above the normally converged loss value and avoid complete collapse. Additionally, we reveal that dimensional collapse stems from insufficient feature uniformity in MIM. We mitigate dimensional collapse by introducing the pyramid barlow twins (PBT) module as an explicit regularization method. Overall, we construct the enhanced MIM (E-MIM) with HMP and PBT module to avoid model collapse multi-modal MRI. Experiments are conducted on three multi-modal MRI datasets to validate the effectiveness of our approach in preventing both types of model collapse. By preventing model collapse, the training of the model becomes more stable, resulting in a decent improvement in performance for segmentation and classification tasks. The code is available at https://github.com/LinxuanHan/E-MIM.