Image classification for early detection of diabetic retinopathy in patients. Classification is perfomed on retina images of patients taken using fundus photography. This project uses a custom ResNet18 model built from scratch using PyTorch.
Disclaimer: Not intended for medical diagnosis. This project analyzes medical images for demonstration purposes only. Always consult with your doctor, or another qualified healthcare professional for diagnosis.
See the APTOS 2019 Blindness Detection competition for the full overview and data description on Kaggle.
To download the data using Kaggle API:
kaggle competitions download -c aptos2019-blindness-detection
Training and test data is by default expected in data
directory. Run python train.py -h
or python infer.py -h
for
expected parameters.
Clone repository:
git clone https://github.com/thatgeeman/retinopathy_classification_resnet18
Setup environment and install dependencies:
pip install pipenv
cd retinopathy_classification_resnet18
pipenv install --python 3.8
pipenv shell
To train the model from the data in data/train.csv
with
images located in data/train_images
python train.py 2 10 --csv data/train.csv --data data/train_images
Here, the first parameter denotes the number of epochs to train the model with frozen body parameters. The second parameter denotes the number of epochs to train the full model.
Using the saved checkpoint to run an inference cycle:
python infer.py checkpoints/model_c15.pth