Deep learning pipeline for automated segmentation of Multiple Sclerosis (MS) lesions from 3D brain MRI scans using volumetric convolutional neural networks.
Four 3D CNN architectures are implemented:
- HighRes3DNet — Residual network with dilated convolutions (dilation rates 1, 2, 4) for multi-scale feature extraction. Based on Li et al., 2017.
- SmallHighRes3DNet — Compact variant with fewer residual blocks, suitable for limited GPU memory.
- BigHighRes3DNet — Extended variant with additional residual blocks for higher capacity.
- V-Net — Encoder-decoder architecture with skip connections and transposed convolutions, adapted from Milletari et al., 2016.
All models operate on 3D patches and output per-voxel binary segmentation maps via softmax.
├── run.py # Training script (CLI)
├── predict.py # Inference and evaluation (CLI)
├── data_gen.py # 3D patch extraction (random and strided)
├── losses.py # Dice loss
├── metrics.py # Dice coefficient, precision, recall, F1
├── models/
│ ├── HighRes3DNet.py
│ ├── SmallHighRes3DNet.py
│ ├── BigHighRes3DNet.py
│ └── VNet.py
└── help_functions/
└── data_folder_structure.py # Data organization utilities
pip install -r requirements.txt
Dependencies: TensorFlow, Keras, TensorLayer, nibabel, NumPy, Click.
Input data should be organized as:
raw_data/
├── train/
│ ├── images/ # 3D NIfTI brain scans (.nii)
│ └── masks/ # Binary lesion masks (.nii), named <subject>_mask.nii
└── test/
├── images/
└── masks/
Extract random 3D patches from NIfTI volumes:
python data_gen.pypython run.py <train_images.npy> <train_masks.npy> <test_images.npy> <test_masks.npy> <output_dir>The best model weights (by validation loss) are saved to <output_dir>/weights.h5.
python predict.py <test_folder> <weights_path>Prints Dice coefficient, recall, and F1 score on the test set.
- Dice Coefficient (Sorensen-Dice) — overlap between predicted and ground truth masks
- Precision — fraction of predicted lesion voxels that are correct
- Recall — fraction of true lesion voxels that are detected
- F1 Score — harmonic mean of precision and recall