Skip to content

A PyTorch implementation of MedSegDiff, a diffusion probabilistic model designed for medical image segmentation.

License

Notifications You must be signed in to change notification settings

deepmancer/medseg-diffusion

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

59 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model πŸš€

PyTorch Python Jupyter Notebook

Welcome to MedSegDiff β€” This project provides a step-by-step implementation of the MedSegDiff paper from scratch using PyTorch. MedSegDiff stands out as the first Diffusion Probabilistic Model (DPM) specifically designed for general medical image segmentation tasks, setting a new standard in the identification of tumor and cancer anomalies.


πŸ“– Overview

MedSegDiff harnesses the power of Diffusion Probabilistic Models (DPM) to revolutionize medical image segmentation. By integrating dynamic conditional encoding and a novel Feature Frequency Parser (FF-Parser) that learns a Fourier-space feature space, our model significantly improves segmentation accuracy across various medical imaging modalities. This repository serves as a robust resource for understanding and implementing advanced segmentation techniques, particularly for challenging cases like tumors and cancerous lesions.

βš™οΈ Methodology

MedSegDiff Overview
An illustration of MedSegDiff. The time step encoding is omitted for clarity.

At its core, MedSegDiff utilizes a U-Net architecture for learning and segmentation tasks. The step estimation function is conditioned on the raw image prior, described by:

Equation 1

Here, (\mathbf{E_t^I}) represents the conditional feature embedding (raw image embedding), and (\mathbf{E_t^x}) is the segmentation map feature embedding at the current step. These embeddings are combined and processed through a U-Net decoder for reconstruction. The process is governed by the loss function:

Equation 2

The architecture primarily employs a modified ResUNet, integrating a ResNet encoder with a UNet decoder, offering enhanced segmentation capabilities through its innovative design.

🧠 Dynamic Encoding Process

  1. FF-Parser Input: The segmentation map undergoes initial processing through the Feature Frequency Parser (FF-Parser), which refines feature representation by reducing high-frequency noise.

    FF-Parser Illustration
    Illustration of the FF-Parser. FFT denotes Fast Fourier Transform.

  2. Attentive Fusion: The denoised feature map is then combined with prior image embeddings using an attentive-like mechanism to enhance regional attention and feature saliency.

  3. Iterative Refinement: This enriched feature map undergoes further refinement, culminating at the bottleneck phase.

  4. Bottleneck Convergence: Finally, the processed feature map is integrated with the U-Net encoder's outputs, resulting in an improved segmentation map.

⏳ Time Encoding Block

  • Sinusoidal Embedding Calculation: Sinusoidal timestep embeddings are calculated and passed through a linear layer, followed by SiLU activation, and another linear layer.

    Time Embedding Illustration

  • Integration into Residual Blocks: Time features are integrated into residual blocks, enhancing the overall model architecture.

πŸ› οΈ Encoder & Decoder Blocks

  • Initial Convolution: Separate initial convolutional layers process the mask and input image, preparing them for downstream tasks.

  • Residual Blocks: Each ResNet block, defined by two consecutive convolutional layers with SiLU activation and Group Normalization, is employed throughout the network. Removing the residual connection transforms this block into a basic convolutional network.

  • Attention Mechanism: A sub-module combining Layer Normalization, Multi-head Attention, residual connections, and a feed-forward network, all crucial for precise segmentation.

πŸ”„ Review of Diffusion Process

  • Forward Diffusion Process: Gradually transforms a segmentation label into a noisy mask sequence, converging to a Gaussian distribution as time increases.

    Forward Diffusion Process

  • Reverse Diffusion Process: Iteratively denoises the noisy data, removing the noise added at each step using the Reverse Diffusion Process.

    Reverse Diffusion Process

🎯 Results

Our method demonstrates superior performance across multiple segmentation tasks, including brain tumor segmentation, optic cup segmentation, and thyroid nodule segmentation.

Evaluation Results
Visual comparison of top general medical image segmentation methods.

SOTA Comparison
Comparison of MedSegDiff with state-of-the-art segmentation methods. The best results are highlighted in bold.

πŸš€ Installation

To get started with MedSegDiff, follow these simple steps:

git clone https://github.com/alirezaheidari-cs/DiffusionMedSeg.git
cd DiffusionMedSeg
pip install -r requirements.txt

πŸ“š Citations

If you find this work helpful, please consider citing the following papers:

@article{Wu2022MedSegDiffMI,
    title   = {MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model},
    author  = {Junde Wu and Huihui Fang and Yu Zhang and Yehui Yang and Yanwu Xu},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2211.00611}
}
@inproceedings{Hoogeboom2023simpleDE,
    title   = {simple diffusion: End-to-end diffusion for high resolution images},
    author  = {Emiel Hoogeboom and Jonathan Heek and Tim Salimans},
    year    = {2023}
}

πŸ“ License

This project is licensed under the MIT License. For detailed information, please refer to the LICENSE file.


We hope this repository aids your research and development in medical image segmentation. Feel free to contribute or raise issues. Let's advance medical technology together! πŸ’‘