From c58a82e8bda6caf517a73a84ca33fd58362f8f87 Mon Sep 17 00:00:00 2001 From: bmandracchia Date: Mon, 26 Aug 2024 18:06:22 +0200 Subject: [PATCH] update losses, metrics --- bioMONAI/_modidx.py | 5 +- bioMONAI/core.py | 2 +- bioMONAI/losses.py | 78 ++++++++--------------- bioMONAI/metrics.py | 40 +++++++++++- nbs/00_core.ipynb | 16 ++--- nbs/03_losses.ipynb | 143 +++++++++++++++---------------------------- nbs/06_metrics.ipynb | 53 +++++++++++++++- 7 files changed, 175 insertions(+), 162 deletions(-) diff --git a/bioMONAI/_modidx.py b/bioMONAI/_modidx.py index 63d8c0e..c2f734c 100644 --- a/bioMONAI/_modidx.py +++ b/bioMONAI/_modidx.py @@ -54,15 +54,14 @@ 'bioMONAI.losses.DiceLoss.__init__': ('losses.html#diceloss.__init__', 'bioMONAI/losses.py'), 'bioMONAI.losses.DiceLoss.forward': ('losses.html#diceloss.forward', 'bioMONAI/losses.py'), 'bioMONAI.losses.FRCLoss': ('losses.html#frcloss', 'bioMONAI/losses.py'), - 'bioMONAI.losses.FRCM': ('losses.html#frcm', 'bioMONAI/losses.py'), - 'bioMONAI.losses.SSIM': ('losses.html#ssim', 'bioMONAI/losses.py'), 'bioMONAI.losses.get_fourier_ring_correlations': ( 'losses.html#get_fourier_ring_correlations', 'bioMONAI/losses.py'), 'bioMONAI.losses.get_radial_masks': ('losses.html#get_radial_masks', 'bioMONAI/losses.py'), 'bioMONAI.losses.radial_mask': ('losses.html#radial_mask', 'bioMONAI/losses.py'), 'bioMONAI.losses.seventh_fourier_ring_correlation': ( 'losses.html#seventh_fourier_ring_correlation', 'bioMONAI/losses.py')}, - 'bioMONAI.metrics': {'bioMONAI.metrics.foo': ('metrics.html#foo', 'bioMONAI/metrics.py')}, + 'bioMONAI.metrics': { 'bioMONAI.metrics.FRCM': ('metrics.html#frcm', 'bioMONAI/metrics.py'), + 'bioMONAI.metrics.SSIM': ('metrics.html#ssim', 'bioMONAI/metrics.py')}, 'bioMONAI.nets': { 'bioMONAI.nets.DnCNN': ('nets.html#dncnn', 'bioMONAI/nets.py'), 'bioMONAI.nets.DnCNN.__init__': ('nets.html#dncnn.__init__', 'bioMONAI/nets.py'), 'bioMONAI.nets.DnCNN.forward': ('nets.html#dncnn.forward', 'bioMONAI/nets.py'), diff --git a/bioMONAI/core.py b/bioMONAI/core.py index a1179cf..912ea2a 100644 --- a/bioMONAI/core.py +++ b/bioMONAI/core.py @@ -18,7 +18,7 @@ from torch import from_numpy as torch_from_numpy # %% ../nbs/00_core.ipynb 6 -from fastai.vision.all import BypassNewMeta, DisplayedTransform +from fastai.vision.all import BypassNewMeta, DisplayedTransform, store_attr from fastai.data.all import delegates, hasattrs, Path, List, L, typedispatch # %% ../nbs/00_core.ipynb 8 diff --git a/bioMONAI/losses.py b/bioMONAI/losses.py index 2b3fb74..52965f8 100644 --- a/bioMONAI/losses.py +++ b/bioMONAI/losses.py @@ -1,14 +1,22 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_losses.ipynb. # %% auto 0 -__all__ = ['SSIMMetric', 'CombinedLoss', 'DiceLoss', 'radial_mask', 'get_radial_masks', 'get_fourier_ring_correlations', - 'FRCLoss', 'seventh_fourier_ring_correlation', 'SSIM', 'FRCM'] +__all__ = ['CombinedLoss', 'DiceLoss', 'radial_mask', 'get_radial_masks', 'get_fourier_ring_correlations', 'FRCLoss', + 'seventh_fourier_ring_correlation'] # %% ../nbs/03_losses.ipynb 4 -from fastai.vision.all import * +from .core import store_attr # %% ../nbs/03_losses.ipynb 5 +import numpy as np + from monai.losses import SSIMLoss +import torch.nn as nn +from torch import abs, sqrt, div, sigmoid, complex64, where, isinf, zeros_like, real, isnan +from torch.fft import fftshift +from torch.fft import fft2 + +from .core import torch_from_numpy # %% ../nbs/03_losses.ipynb 8 class CombinedLoss: @@ -58,7 +66,7 @@ def __init__(self, smooth=1): def forward(self, inputs, targets): # Make sure the inputs are probabilities - inputs = torch.sigmoid(inputs) + inputs = sigmoid(inputs) # Flatten tensors inputs = inputs.view(-1) @@ -175,23 +183,23 @@ def get_fourier_ring_correlations(image1, image2): freq_nyq = len(spatial_frequency) # Transform tensor to complex - image1 = image1.to(torch.complex64) - image2 = image2.to(torch.complex64) + image1 = image1.to(complex64) + image2 = image2.to(complex64) # Transofrm array dimensions to (freq_nyq, width. height) image1 = image1.unsqueeze(0).repeat(freq_nyq, 1, 1) image2 = image2.unsqueeze(0).repeat(freq_nyq, 1, 1) # Convert spatial frequency and radial masks to torch.tensor - spatial_frequency = torch.from_numpy(spatial_frequency) - radial_masks = torch.from_numpy(radial_masks) + spatial_frequency = torch_from_numpy(spatial_frequency) + radial_masks = torch_from_numpy(radial_masks) # Transform tensor to complex - radial_masks = radial_masks.to(torch.complex64) + radial_masks = radial_masks.to(complex64) # Compute fourier transform - fft_image1 = torch.fft.fftshift(torch.fft.fft2(image1)) - fft_image2 = torch.fft.fftshift(torch.fft.fft2(image2)) + fft_image1 = fftshift(fft2(image1)) + fft_image2 = fftshift(fft2(image2)) # Get elements only in the ring t1 = fft_image1 * radial_masks @@ -201,19 +209,19 @@ def get_fourier_ring_correlations(image1, image2): t2_conj = t2.conj() # Numerator - numerador = torch.abs(torch.real((t1 * t2_conj).sum(dim=(1,2)))) + numerador = abs(real((t1 * t2_conj).sum(dim=(1,2)))) # Denominator - denominador_1 = ((torch.abs(t1) * torch.abs(t1)).sum(dim=(1,2))) - denominador_2 = ((torch.abs(t2) * torch.abs(t2)).sum(dim=(1,2))) - denominador = torch.sqrt(denominador_1 * denominador_2) + denominador_1 = ((abs(t1) * abs(t1)).sum(dim=(1,2))) + denominador_2 = ((abs(t2) * abs(t2)).sum(dim=(1,2))) + denominador = sqrt(denominador_1 * denominador_2) # Fourier shell correlation - FRC = torch.div(numerador, denominador) + FRC = div(numerador, denominador) # Remove possible inf and NaN. - FRC = torch.where(torch.isinf(FRC), torch.zeros_like(FRC), FRC) # inf - FRC = torch.where(torch.isnan(FRC), torch.zeros_like(FRC), FRC) # NaN + FRC = where(isinf(FRC), zeros_like(FRC), FRC) # inf + FRC = where(isnan(FRC), zeros_like(FRC), FRC) # NaN return FRC , spatial_frequency @@ -271,37 +279,3 @@ def exponential_func(x, a, b, c): cutoff_frequency = (exponential_func((1/7), *params)) return cutoff_frequency - -# %% ../nbs/03_losses.ipynb 22 -def SSIM(x, y, spatial_dims=2): - return 1 - SSIMLoss(spatial_dims)(x,y) - -SSIMMetric = AvgMetric(SSIM) - -# %% ../nbs/03_losses.ipynb 23 -def FRCM(image1, image2): - - - """ - Compute the area under the Fourier Ring Correlation (FRC) curve between two images. - - #### Args: - - image1 (torch.Tensor): The first input image. - - image2 (torch.Tensor): The second input image. - - #### Returns: - - float: The area under the FRC curve. - """ - - # Calculate the Fourier Ring Correlation and spatial frequency - FRC, spatial_frequency = get_fourier_ring_correlations(image1, image2) - - # Convert to numpy - FRC = FRC.numpy() - spatial_frequency = spatial_frequency.numpy() - - # Compute the area under the curve using trapezoidal integration - area = np.trapz(FRC, spatial_frequency) - - return area - diff --git a/bioMONAI/metrics.py b/bioMONAI/metrics.py index d080a8e..e599359 100644 --- a/bioMONAI/metrics.py +++ b/bioMONAI/metrics.py @@ -1,7 +1,43 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/06_metrics.ipynb. # %% auto 0 -__all__ = ['foo'] +__all__ = ['SSIMMetric', 'SSIM', 'FRCM'] # %% ../nbs/06_metrics.ipynb 3 -def foo(): pass +from numpy import trapz +from fastai.vision.all import AvgMetric +from .losses import SSIMLoss, get_fourier_ring_correlations + +# %% ../nbs/06_metrics.ipynb 4 +def SSIM(x, y, spatial_dims=2): + return 1 - SSIMLoss(spatial_dims)(x,y) + +SSIMMetric = AvgMetric(SSIM) + +# %% ../nbs/06_metrics.ipynb 5 +def FRCM(image1, image2): + + + """ + Compute the area under the Fourier Ring Correlation (FRC) curve between two images. + + #### Args: + - image1 (torch.Tensor): The first input image. + - image2 (torch.Tensor): The second input image. + + #### Returns: + - float: The area under the FRC curve. + """ + + # Calculate the Fourier Ring Correlation and spatial frequency + FRC, spatial_frequency = get_fourier_ring_correlations(image1, image2) + + # Convert to numpy + FRC = FRC.numpy() + spatial_frequency = spatial_frequency.numpy() + + # Compute the area under the curve using trapezoidal integration + area = trapz(FRC, spatial_frequency) + + return area + diff --git a/nbs/00_core.ipynb b/nbs/00_core.ipynb index 70b3352..d6d9dd2 100755 --- a/nbs/00_core.ipynb +++ b/nbs/00_core.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -53,7 +53,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -65,12 +65,12 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "#| export\n", - "from fastai.vision.all import BypassNewMeta, DisplayedTransform\n", + "from fastai.vision.all import BypassNewMeta, DisplayedTransform, store_attr\n", "from fastai.data.all import delegates, hasattrs, Path, List, L, typedispatch" ] }, @@ -83,7 +83,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -104,7 +104,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ diff --git a/nbs/03_losses.ipynb b/nbs/03_losses.ipynb index 3408486..2439fe4 100644 --- a/nbs/03_losses.ipynb +++ b/nbs/03_losses.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -37,27 +37,35 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "#| export\n", - "from fastai.vision.all import *" + "from bioMONAI.core import store_attr" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "#| export\n", - "from monai.losses import SSIMLoss" + "import numpy as np\n", + "\n", + "from monai.losses import SSIMLoss\n", + "import torch.nn as nn\n", + "from torch import abs, sqrt, div, sigmoid, complex64, where, isinf, zeros_like, real, isnan\n", + "from torch.fft import fftshift\n", + "from torch.fft import fft2\n", + "\n", + "from bioMONAI.core import torch_from_numpy" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -105,7 +113,7 @@ " similarity.\" IEEE transactions on image processing 13.4 (2004): 600-612." ] }, - "execution_count": 8, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -123,7 +131,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -151,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -191,7 +199,7 @@ " def forward(self, inputs, targets):\n", " \n", " # Make sure the inputs are probabilities\n", - " inputs = torch.sigmoid(inputs)\n", + " inputs = sigmoid(inputs)\n", "\n", " # Flatten tensors\n", " inputs = inputs.view(-1)\n", @@ -212,22 +220,23 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Dice Loss: 0.5015131235122681\n" + "Dice Loss: 0.4994280934333801\n" ] } ], "source": [ "# inputs and targets must be equally dimensional tensors\n", + "from torch import randn, randint\n", "\n", - "inputs = torch.randn((1, 1, 256, 256)) # Input\n", - "targets = torch.randint(0, 2, (1, 1, 256, 256)).float() # Ground Truth\n", + "inputs = randn((1, 1, 256, 256)) # Input\n", + "targets = randint(0, 2, (1, 1, 256, 256)).float() # Ground Truth\n", "\n", "# Initialize\n", "dice_loss = DiceLoss()\n", @@ -254,7 +263,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -298,7 +307,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -345,7 +354,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -379,23 +388,23 @@ " freq_nyq = len(spatial_frequency)\n", " \n", " # Transform tensor to complex\n", - " image1 = image1.to(torch.complex64)\n", - " image2 = image2.to(torch.complex64)\n", + " image1 = image1.to(complex64)\n", + " image2 = image2.to(complex64)\n", "\n", " # Transofrm array dimensions to (freq_nyq, width. height)\n", " image1 = image1.unsqueeze(0).repeat(freq_nyq, 1, 1)\n", " image2 = image2.unsqueeze(0).repeat(freq_nyq, 1, 1)\n", "\n", " # Convert spatial frequency and radial masks to torch.tensor\n", - " spatial_frequency = torch.from_numpy(spatial_frequency)\n", - " radial_masks = torch.from_numpy(radial_masks)\n", + " spatial_frequency = torch_from_numpy(spatial_frequency)\n", + " radial_masks = torch_from_numpy(radial_masks)\n", "\n", " # Transform tensor to complex\n", - " radial_masks = radial_masks.to(torch.complex64)\n", + " radial_masks = radial_masks.to(complex64)\n", " \n", " # Compute fourier transform\n", - " fft_image1 = torch.fft.fftshift(torch.fft.fft2(image1))\n", - " fft_image2 = torch.fft.fftshift(torch.fft.fft2(image2))\n", + " fft_image1 = fftshift(fft2(image1))\n", + " fft_image2 = fftshift(fft2(image2))\n", "\n", " # Get elements only in the ring\n", " t1 = fft_image1 * radial_masks\n", @@ -405,26 +414,26 @@ " t2_conj = t2.conj()\n", "\n", " # Numerator\n", - " numerador = torch.abs(torch.real((t1 * t2_conj).sum(dim=(1,2))))\n", + " numerador = abs(real((t1 * t2_conj).sum(dim=(1,2))))\n", "\n", " # Denominator \n", - " denominador_1 = ((torch.abs(t1) * torch.abs(t1)).sum(dim=(1,2)))\n", - " denominador_2 = ((torch.abs(t2) * torch.abs(t2)).sum(dim=(1,2))) \n", - " denominador = torch.sqrt(denominador_1 * denominador_2)\n", + " denominador_1 = ((abs(t1) * abs(t1)).sum(dim=(1,2)))\n", + " denominador_2 = ((abs(t2) * abs(t2)).sum(dim=(1,2))) \n", + " denominador = sqrt(denominador_1 * denominador_2)\n", " \n", " # Fourier shell correlation\n", - " FRC = torch.div(numerador, denominador)\n", + " FRC = div(numerador, denominador)\n", "\n", " # Remove possible inf and NaN.\n", - " FRC = torch.where(torch.isinf(FRC), torch.zeros_like(FRC), FRC) # inf\n", - " FRC = torch.where(torch.isnan(FRC), torch.zeros_like(FRC), FRC) # NaN\n", + " FRC = where(isinf(FRC), zeros_like(FRC), FRC) # inf\n", + " FRC = where(isnan(FRC), zeros_like(FRC), FRC) # NaN\n", "\n", " return FRC , spatial_frequency" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -449,7 +458,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -495,7 +504,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -503,7 +512,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/bmandracchia/bioMONAI/blob/main/bioMONAI/losses.py#L238){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/bmandracchia/bioMONAI/blob/main/bioMONAI/losses.py#L241){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### seventh_fourier_ring_correlation\n", "\n", @@ -521,7 +530,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/bmandracchia/bioMONAI/blob/main/bioMONAI/losses.py#L238){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/bmandracchia/bioMONAI/blob/main/bioMONAI/losses.py#L241){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### seventh_fourier_ring_correlation\n", "\n", @@ -537,7 +546,7 @@ " - float: The cutoff frequency." ] }, - "execution_count": 17, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -550,74 +559,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Metrics" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "#| export\n", - "\n", - "def SSIM(x, y, spatial_dims=2):\n", - " return 1 - SSIMLoss(spatial_dims)(x,y)\n", - "\n", - "SSIMMetric = AvgMetric(SSIM)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "#| export\n", - "\n", - "def FRCM(image1, image2):\n", - "\n", - "\n", - " \"\"\"\n", - " Compute the area under the Fourier Ring Correlation (FRC) curve between two images.\n", - "\n", - " #### Args:\n", - " - image1 (torch.Tensor): The first input image.\n", - " - image2 (torch.Tensor): The second input image.\n", - "\n", - " #### Returns:\n", - " - float: The area under the FRC curve.\n", - " \"\"\"\n", - "\n", - " # Calculate the Fourier Ring Correlation and spatial frequency\n", - " FRC, spatial_frequency = get_fourier_ring_correlations(image1, image2)\n", - "\n", - " # Convert to numpy\n", - " FRC = FRC.numpy()\n", - " spatial_frequency = spatial_frequency.numpy()\n", - " \n", - " # Compute the area under the curve using trapezoidal integration\n", - " area = np.trapz(FRC, spatial_frequency)\n", - " \n", - " return area\n" + "---" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "#| hide\n", "import nbdev; nbdev.nbdev_export()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/nbs/06_metrics.ipynb b/nbs/06_metrics.ipynb index b7411a1..d2ac93c 100644 --- a/nbs/06_metrics.ipynb +++ b/nbs/06_metrics.ipynb @@ -35,7 +35,9 @@ "outputs": [], "source": [ "#| export\n", - "def foo(): pass" + "from numpy import trapz\n", + "from fastai.vision.all import AvgMetric\n", + "from bioMONAI.losses import SSIMLoss, get_fourier_ring_correlations" ] }, { @@ -43,6 +45,55 @@ "execution_count": 4, "metadata": {}, "outputs": [], + "source": [ + "#| export\n", + "\n", + "def SSIM(x, y, spatial_dims=2):\n", + " return 1 - SSIMLoss(spatial_dims)(x,y)\n", + "\n", + "SSIMMetric = AvgMetric(SSIM)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "def FRCM(image1, image2):\n", + "\n", + "\n", + " \"\"\"\n", + " Compute the area under the Fourier Ring Correlation (FRC) curve between two images.\n", + "\n", + " #### Args:\n", + " - image1 (torch.Tensor): The first input image.\n", + " - image2 (torch.Tensor): The second input image.\n", + "\n", + " #### Returns:\n", + " - float: The area under the FRC curve.\n", + " \"\"\"\n", + "\n", + " # Calculate the Fourier Ring Correlation and spatial frequency\n", + " FRC, spatial_frequency = get_fourier_ring_correlations(image1, image2)\n", + "\n", + " # Convert to numpy\n", + " FRC = FRC.numpy()\n", + " spatial_frequency = spatial_frequency.numpy()\n", + " \n", + " # Compute the area under the curve using trapezoidal integration\n", + " area = trapz(FRC, spatial_frequency)\n", + " \n", + " return area\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], "source": [ "#| hide\n", "import nbdev; nbdev.nbdev_export()"