Skip to content

Commit

Permalink
update - losses and metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
bmandracchia committed Sep 1, 2024
1 parent c44659b commit c3a3502
Show file tree
Hide file tree
Showing 5 changed files with 397 additions and 360 deletions.
12 changes: 6 additions & 6 deletions bioMONAI/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@
'bioMONAI.losses.DiceLoss.__init__': ('losses.html#diceloss.__init__', 'bioMONAI/losses.py'),
'bioMONAI.losses.DiceLoss.forward': ('losses.html#diceloss.forward', 'bioMONAI/losses.py'),
'bioMONAI.losses.FRCLoss': ('losses.html#frcloss', 'bioMONAI/losses.py'),
'bioMONAI.losses.get_fourier_ring_correlations': ( 'losses.html#get_fourier_ring_correlations',
'bioMONAI/losses.py'),
'bioMONAI.losses.get_radial_masks': ('losses.html#get_radial_masks', 'bioMONAI/losses.py'),
'bioMONAI.losses.radial_mask': ('losses.html#radial_mask', 'bioMONAI/losses.py'),
'bioMONAI.losses.seventh_fourier_ring_correlation': ( 'losses.html#seventh_fourier_ring_correlation',
'bioMONAI/losses.py')},
'bioMONAI.metrics': { 'bioMONAI.metrics.FRCM': ('metrics.html#frcm', 'bioMONAI/metrics.py'),
'bioMONAI.metrics.SSIM': ('metrics.html#ssim', 'bioMONAI/metrics.py')},
'bioMONAI.metrics': { 'bioMONAI.metrics.FRCMetric': ('metrics.html#frcmetric', 'bioMONAI/metrics.py'),
'bioMONAI.metrics.SSIMMetric': ('metrics.html#ssimmetric', 'bioMONAI/metrics.py'),
'bioMONAI.metrics.get_fourier_ring_correlations': ( 'metrics.html#get_fourier_ring_correlations',
'bioMONAI/metrics.py'),
'bioMONAI.metrics.get_radial_masks': ('metrics.html#get_radial_masks', 'bioMONAI/metrics.py'),
'bioMONAI.metrics.radial_mask': ('metrics.html#radial_mask', 'bioMONAI/metrics.py')},
'bioMONAI.nets': { 'bioMONAI.nets.ASPP_module': ('nets.html#aspp_module', 'bioMONAI/nets.py'),
'bioMONAI.nets.ASPP_module.__init__': ('nets.html#aspp_module.__init__', 'bioMONAI/nets.py'),
'bioMONAI.nets.ASPP_module.forward': ('nets.html#aspp_module.forward', 'bioMONAI/nets.py'),
Expand Down
164 changes: 10 additions & 154 deletions bioMONAI/losses.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_losses.ipynb.

# %% auto 0
__all__ = ['CombinedLoss', 'DiceLoss', 'radial_mask', 'get_radial_masks', 'get_fourier_ring_correlations', 'FRCLoss',
'seventh_fourier_ring_correlation']
__all__ = ['CombinedLoss', 'DiceLoss', 'FRCLoss', 'seventh_fourier_ring_correlation']

# %% ../nbs/03_losses.ipynb 4
from .core import store_attr
Expand All @@ -12,11 +11,12 @@

from monai.losses import SSIMLoss
import torch.nn as nn
from torch import abs, sqrt, div, sigmoid, complex64, where, isinf, zeros_like, real, isnan
from torch.fft import fftshift
from torch.fft import fft2
from torch import sigmoid

from scipy.optimize import curve_fit

from .metrics import FRCMetric, get_fourier_ring_correlations

from .core import torch_from_numpy

# %% ../nbs/03_losses.ipynb 8
class CombinedLoss:
Expand Down Expand Up @@ -84,148 +84,7 @@ def forward(self, inputs, targets):
return loss


# %% ../nbs/03_losses.ipynb 14
def radial_mask(r, # Radius of the radial mask
cx=128, # X coordinate mask center
cy=128, # Y coordinate maske center
sx=np.arange(0, 256),
sy=np.arange(0, 256),
delta=1,
):

"""
Generate a radial mask.
#### Parameters:
- r (int or float): Radius of the circular mask.
- cx (int, optional): X-coordinate of the center of the circular mask. Defaults to 128.
- cy (int, optional): Y-coordinate of the center of the circular mask. Defaults to 128.
- sx (numpy.ndarray, optional): Array of x-coordinates forming a grid. Defaults to np.arange(0, 256).
- sy (numpy.ndarray, optional): Array of y-coordinates forming a grid. Defaults to np.arange(0, 256).
- delta (int or float, optional): Thickness adjustment for the circular mask. Defaults to 1.
#### Returns:
- numpy.ndarray: Radial mask.
"""

# Calculate squared distances from each point in the grid to the center
ind = (sx[np.newaxis, :] - cx) ** 2 + (sy[:, np.newaxis] - cy) ** 2

# Define inner boundary of the circular mask
ind1 = ind <= ((r[0] + delta) ** 2)

# Define outer boundary of the circular mask
ind2 = ind > (r[0] ** 2)

# Create the radial mask by combining inner and outer boundaries
return ind1 * ind2


# %% ../nbs/03_losses.ipynb 15
def get_radial_masks(width, height):

"""
Generates a set of radial masks and corresponding to spatial frequencies.
#### Parameters:
- width (int): Width of the image.
- height (int): Height of the image.
#### Returns:
tuple: A tuple containing:
- numpy.ndarray: Array of radial masks.
- numpy.ndarray: Array of spatial frequencies corresponding to the masks.
"""

# Calculate Nyquist frequency
freq_nyq = int(np.floor(int(min(width, height)) / 2.0))

# Generate radii from 0 to Nyquist frequency
radii = np.arange(freq_nyq).reshape(freq_nyq, 1)

# Generate radial masks using the radial_mask function
radial_masks = np.apply_along_axis(radial_mask, 1, radii, width/2, height/2, np.arange(0, width), np.arange(0, height), 1)

# Calculate spatial frequencies
spatial_freq = radii.astype(np.float32) / freq_nyq
spatial_freq = spatial_freq / max(spatial_freq)
spatial_freq = spatial_freq.squeeze(1)

return radial_masks, spatial_freq


# %% ../nbs/03_losses.ipynb 17
def get_fourier_ring_correlations(image1, image2):


"""
Compute Fourier Ring Correlation (FRC) between two images.
#### Args:
- image1 (torch.Tensor): First input image.
- image2 (torch.Tensor): Second input image.
#### Returns:
tuple: A tuple containing:
- torch.Tensor: Fourier Ring Correlation values.
- torch.Tensor: Array of spatial frequencies.
"""


# Get image height and width
height = image1.shape[len(image1.shape)-1]
width = image1.shape[len(image1.shape)-2]

# Get set of radial masks, spatial frequency, and Nyquist frequency
radial_masks, spatial_frequency = get_radial_masks(height,width)

# Get Nyquist frequency
freq_nyq = len(spatial_frequency)

# Transform tensor to complex
image1 = image1.to(complex64)
image2 = image2.to(complex64)

# Transofrm array dimensions to (freq_nyq, width. height)
image1 = image1.unsqueeze(0).repeat(freq_nyq, 1, 1)
image2 = image2.unsqueeze(0).repeat(freq_nyq, 1, 1)

# Convert spatial frequency and radial masks to torch.tensor
spatial_frequency = torch_from_numpy(spatial_frequency)
radial_masks = torch_from_numpy(radial_masks)

# Transform tensor to complex
radial_masks = radial_masks.to(complex64)

# Compute fourier transform
fft_image1 = fftshift(fft2(image1))
fft_image2 = fftshift(fft2(image2))

# Get elements only in the ring
t1 = fft_image1 * radial_masks
t2 = fft_image2 * radial_masks

# image2 to complex conjugate
t2_conj = t2.conj()

# Numerator
numerador = abs(real((t1 * t2_conj).sum(dim=(1,2))))

# Denominator
denominador_1 = ((abs(t1) * abs(t1)).sum(dim=(1,2)))
denominador_2 = ((abs(t2) * abs(t2)).sum(dim=(1,2)))
denominador = sqrt(denominador_1 * denominador_2)

# Fourier shell correlation
FRC = div(numerador, denominador)

# Remove possible inf and NaN.
FRC = where(isinf(FRC), zeros_like(FRC), FRC) # inf
FRC = where(isnan(FRC), zeros_like(FRC), FRC) # NaN

return FRC , spatial_frequency

# %% ../nbs/03_losses.ipynb 18
# %% ../nbs/03_losses.ipynb 13
def FRCLoss(image1, image2):

"""
Expand All @@ -239,13 +98,10 @@ def FRCLoss(image1, image2):
- torch.Tensor: The FRC loss.
"""

return (1 - FRCM(image1, image2))
return (1 - FRCMetric(image1, image2))


# %% ../nbs/03_losses.ipynb 19
from scipy.optimize import curve_fit


# %% ../nbs/03_losses.ipynb 14
def seventh_fourier_ring_correlation(image1,image2):


Expand Down Expand Up @@ -273,7 +129,7 @@ def exponential_func(x, a, b, c):
return a * np.exp(-b * x) + c

# Make fit
params, params_covariance = curve_fit(exponential_func, x, y, p0=[1, 1, 1])
params, _ = curve_fit(exponential_func, x, y, p0=[1, 1, 1])

# Get Cutoff requency at 1/7
cutoff_frequency = (exponential_func((1/7), *params))
Expand Down
Loading

0 comments on commit c3a3502

Please sign in to comment.