Repository contains whole training pipeline using own implementation of unet model on Brain MRI segmentation dataset. Main difference between original paper model and this implementation is droput replacement with batch normalization. Purpose: create segmentation model for anomalous brain parts detection -> helping doctors with expertise. Simple api included.
Files structure:
- src (all scripts)
- data_preparation.py - loading whole data/specific patient and spliting into test train
- evaluate_model.py - evaluate model with default model metric (binary iou)
- predict.py - make prediction on data and save images in output folder
- unet_model_recipe.py - whole unet model architecture
- unet_training.py - model training pipeline
- api (fastapi, one prediction endpoint)
- api.py - simple API for making predictions on brain images, outputs segmentation mask (without thresholding)
- api_test.py - test of API, making call with image from dataset
- notebooks (notebooks and analysis)
- model_predictions_analysis.ipynb - check model predictions, specific patient output
- models (pretrained models)
- unet_brain_segmentation.h5 - pretrained unet model for brain segmentation
- data (raw data)
Tested with python 3.10.4
Libraries used for training: - tensorflow - scikit-learn - numpy - matplotlib
Libraries used for api: - fastapi - uvicorn
You can install all using pip.
pip install -r requirements.txt
Download dataset from Kaggle https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation
Default data structure looks like this:
data
-> kaggle_3m
--->patient_1
--->patient_2
--->...
Model training:
python src/unet_training.py
Model evaluation:
python src/evaluate_model.py
Make prediction on whole dataset:
python src/predict.py
Run API server:
python src/api/api.py
Test API endpoint (make prediction):
python src/api/api_test.py
Simple dockerfile for api server running.
Build
docker build -t brain_segmentation .
Run (in background)
docker run -d -it --name brain_segmentation-run -p 8000:8000 brain_segmentation
Train set: 0.9047
Test set: 0.8806