Skip to content

Commit

Permalink
update losses, metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
bmandracchia committed Aug 26, 2024
1 parent c3a4c61 commit c58a82e
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 162 deletions.
5 changes: 2 additions & 3 deletions bioMONAI/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,14 @@
'bioMONAI.losses.DiceLoss.__init__': ('losses.html#diceloss.__init__', 'bioMONAI/losses.py'),
'bioMONAI.losses.DiceLoss.forward': ('losses.html#diceloss.forward', 'bioMONAI/losses.py'),
'bioMONAI.losses.FRCLoss': ('losses.html#frcloss', 'bioMONAI/losses.py'),
'bioMONAI.losses.FRCM': ('losses.html#frcm', 'bioMONAI/losses.py'),
'bioMONAI.losses.SSIM': ('losses.html#ssim', 'bioMONAI/losses.py'),
'bioMONAI.losses.get_fourier_ring_correlations': ( 'losses.html#get_fourier_ring_correlations',
'bioMONAI/losses.py'),
'bioMONAI.losses.get_radial_masks': ('losses.html#get_radial_masks', 'bioMONAI/losses.py'),
'bioMONAI.losses.radial_mask': ('losses.html#radial_mask', 'bioMONAI/losses.py'),
'bioMONAI.losses.seventh_fourier_ring_correlation': ( 'losses.html#seventh_fourier_ring_correlation',
'bioMONAI/losses.py')},
'bioMONAI.metrics': {'bioMONAI.metrics.foo': ('metrics.html#foo', 'bioMONAI/metrics.py')},
'bioMONAI.metrics': { 'bioMONAI.metrics.FRCM': ('metrics.html#frcm', 'bioMONAI/metrics.py'),
'bioMONAI.metrics.SSIM': ('metrics.html#ssim', 'bioMONAI/metrics.py')},
'bioMONAI.nets': { 'bioMONAI.nets.DnCNN': ('nets.html#dncnn', 'bioMONAI/nets.py'),
'bioMONAI.nets.DnCNN.__init__': ('nets.html#dncnn.__init__', 'bioMONAI/nets.py'),
'bioMONAI.nets.DnCNN.forward': ('nets.html#dncnn.forward', 'bioMONAI/nets.py'),
Expand Down
2 changes: 1 addition & 1 deletion bioMONAI/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch import from_numpy as torch_from_numpy

# %% ../nbs/00_core.ipynb 6
from fastai.vision.all import BypassNewMeta, DisplayedTransform
from fastai.vision.all import BypassNewMeta, DisplayedTransform, store_attr
from fastai.data.all import delegates, hasattrs, Path, List, L, typedispatch

# %% ../nbs/00_core.ipynb 8
Expand Down
78 changes: 26 additions & 52 deletions bioMONAI/losses.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_losses.ipynb.

# %% auto 0
__all__ = ['SSIMMetric', 'CombinedLoss', 'DiceLoss', 'radial_mask', 'get_radial_masks', 'get_fourier_ring_correlations',
'FRCLoss', 'seventh_fourier_ring_correlation', 'SSIM', 'FRCM']
__all__ = ['CombinedLoss', 'DiceLoss', 'radial_mask', 'get_radial_masks', 'get_fourier_ring_correlations', 'FRCLoss',
'seventh_fourier_ring_correlation']

# %% ../nbs/03_losses.ipynb 4
from fastai.vision.all import *
from .core import store_attr

# %% ../nbs/03_losses.ipynb 5
import numpy as np

from monai.losses import SSIMLoss
import torch.nn as nn
from torch import abs, sqrt, div, sigmoid, complex64, where, isinf, zeros_like, real, isnan
from torch.fft import fftshift
from torch.fft import fft2

from .core import torch_from_numpy

# %% ../nbs/03_losses.ipynb 8
class CombinedLoss:
Expand Down Expand Up @@ -58,7 +66,7 @@ def __init__(self, smooth=1):
def forward(self, inputs, targets):

# Make sure the inputs are probabilities
inputs = torch.sigmoid(inputs)
inputs = sigmoid(inputs)

# Flatten tensors
inputs = inputs.view(-1)
Expand Down Expand Up @@ -175,23 +183,23 @@ def get_fourier_ring_correlations(image1, image2):
freq_nyq = len(spatial_frequency)

# Transform tensor to complex
image1 = image1.to(torch.complex64)
image2 = image2.to(torch.complex64)
image1 = image1.to(complex64)
image2 = image2.to(complex64)

# Transofrm array dimensions to (freq_nyq, width. height)
image1 = image1.unsqueeze(0).repeat(freq_nyq, 1, 1)
image2 = image2.unsqueeze(0).repeat(freq_nyq, 1, 1)

# Convert spatial frequency and radial masks to torch.tensor
spatial_frequency = torch.from_numpy(spatial_frequency)
radial_masks = torch.from_numpy(radial_masks)
spatial_frequency = torch_from_numpy(spatial_frequency)
radial_masks = torch_from_numpy(radial_masks)

# Transform tensor to complex
radial_masks = radial_masks.to(torch.complex64)
radial_masks = radial_masks.to(complex64)

# Compute fourier transform
fft_image1 = torch.fft.fftshift(torch.fft.fft2(image1))
fft_image2 = torch.fft.fftshift(torch.fft.fft2(image2))
fft_image1 = fftshift(fft2(image1))
fft_image2 = fftshift(fft2(image2))

# Get elements only in the ring
t1 = fft_image1 * radial_masks
Expand All @@ -201,19 +209,19 @@ def get_fourier_ring_correlations(image1, image2):
t2_conj = t2.conj()

# Numerator
numerador = torch.abs(torch.real((t1 * t2_conj).sum(dim=(1,2))))
numerador = abs(real((t1 * t2_conj).sum(dim=(1,2))))

# Denominator
denominador_1 = ((torch.abs(t1) * torch.abs(t1)).sum(dim=(1,2)))
denominador_2 = ((torch.abs(t2) * torch.abs(t2)).sum(dim=(1,2)))
denominador = torch.sqrt(denominador_1 * denominador_2)
denominador_1 = ((abs(t1) * abs(t1)).sum(dim=(1,2)))
denominador_2 = ((abs(t2) * abs(t2)).sum(dim=(1,2)))
denominador = sqrt(denominador_1 * denominador_2)

# Fourier shell correlation
FRC = torch.div(numerador, denominador)
FRC = div(numerador, denominador)

# Remove possible inf and NaN.
FRC = torch.where(torch.isinf(FRC), torch.zeros_like(FRC), FRC) # inf
FRC = torch.where(torch.isnan(FRC), torch.zeros_like(FRC), FRC) # NaN
FRC = where(isinf(FRC), zeros_like(FRC), FRC) # inf
FRC = where(isnan(FRC), zeros_like(FRC), FRC) # NaN

return FRC , spatial_frequency

Expand Down Expand Up @@ -271,37 +279,3 @@ def exponential_func(x, a, b, c):
cutoff_frequency = (exponential_func((1/7), *params))

return cutoff_frequency

# %% ../nbs/03_losses.ipynb 22
def SSIM(x, y, spatial_dims=2):
return 1 - SSIMLoss(spatial_dims)(x,y)

SSIMMetric = AvgMetric(SSIM)

# %% ../nbs/03_losses.ipynb 23
def FRCM(image1, image2):


"""
Compute the area under the Fourier Ring Correlation (FRC) curve between two images.
#### Args:
- image1 (torch.Tensor): The first input image.
- image2 (torch.Tensor): The second input image.
#### Returns:
- float: The area under the FRC curve.
"""

# Calculate the Fourier Ring Correlation and spatial frequency
FRC, spatial_frequency = get_fourier_ring_correlations(image1, image2)

# Convert to numpy
FRC = FRC.numpy()
spatial_frequency = spatial_frequency.numpy()

# Compute the area under the curve using trapezoidal integration
area = np.trapz(FRC, spatial_frequency)

return area

40 changes: 38 additions & 2 deletions bioMONAI/metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,43 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/06_metrics.ipynb.

# %% auto 0
__all__ = ['foo']
__all__ = ['SSIMMetric', 'SSIM', 'FRCM']

# %% ../nbs/06_metrics.ipynb 3
def foo(): pass
from numpy import trapz
from fastai.vision.all import AvgMetric
from .losses import SSIMLoss, get_fourier_ring_correlations

# %% ../nbs/06_metrics.ipynb 4
def SSIM(x, y, spatial_dims=2):
return 1 - SSIMLoss(spatial_dims)(x,y)

SSIMMetric = AvgMetric(SSIM)

# %% ../nbs/06_metrics.ipynb 5
def FRCM(image1, image2):


"""
Compute the area under the Fourier Ring Correlation (FRC) curve between two images.
#### Args:
- image1 (torch.Tensor): The first input image.
- image2 (torch.Tensor): The second input image.
#### Returns:
- float: The area under the FRC curve.
"""

# Calculate the Fourier Ring Correlation and spatial frequency
FRC, spatial_frequency = get_fourier_ring_correlations(image1, image2)

# Convert to numpy
FRC = FRC.numpy()
spatial_frequency = spatial_frequency.numpy()

# Compute the area under the curve using trapezoidal integration
area = trapz(FRC, spatial_frequency)

return area

16 changes: 8 additions & 8 deletions nbs/00_core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -37,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -53,7 +53,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -65,12 +65,12 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"from fastai.vision.all import BypassNewMeta, DisplayedTransform\n",
"from fastai.vision.all import BypassNewMeta, DisplayedTransform, store_attr\n",
"from fastai.data.all import delegates, hasattrs, Path, List, L, typedispatch"
]
},
Expand All @@ -83,7 +83,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -104,7 +104,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand Down
Loading

0 comments on commit c58a82e

Please sign in to comment.