Skip to content

Commit

Permalink
refactor: change arguments for cumulative_dose_filter_3d
Browse files Browse the repository at this point in the history
  • Loading branch information
mgiammar committed Dec 23, 2024
1 parent 601b7e5 commit fc0cbf6
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 51 deletions.
61 changes: 30 additions & 31 deletions src/torch_fourier_filter/dose_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,29 @@ def critical_exposure(fft_freq: torch.Tensor) -> torch.Tensor:
return Ne


def critical_exposure_Bfac(fft_freq: torch.Tensor, Bfac: float) -> torch.Tensor:
def critical_exposure_bfactor(fft_freq: torch.Tensor, bfactor: float) -> torch.Tensor:
"""
Calculate the critical exposure using a user defined B-factor.
Args:
fft_freq: The frequency grid of the Fourier transform.
Bfac: The B-factor to use.
bfactor: The B-factor to use.
Returns
-------
The critical exposure for the given frequency grid
"""
eps = 1e-10
Ne = 2 / (Bfac * fft_freq.clamp(min=eps) ** 2)
Ne = 2 / (bfactor * fft_freq.clamp(min=eps) ** 2)
return Ne


def cumulative_dose_filter_3d(
volume_shape: tuple[int, int, int],
num_frames: int,
start_exposure: float = 0.0,
pixel_size: float = 1,
flux: float = 1,
Bfac: float = -1,
start_exposure: float = 0.0,
end_exposure: float = 30.0,
crit_exposure_bfactor: int | float = -1,
rfft: bool = True,
fftshift: bool = False,
) -> torch.Tensor:
Expand All @@ -59,17 +58,17 @@ def cumulative_dose_filter_3d(
Parameters
----------
volume_shape : tuple[int, int, int]
The volume shape for dose weighting.
num_frames : int
The number of frames for dose weighting.
start_exposure : float
The start exposure for dose weighting.
The shape of the filter to calculate (real space). Rfft is
automatically calculated from this.
pixel_size : float
The pixel size of the volume.
flux : float
The fluence per frame.
Bfac : float
The B factor for dose weighting, -1=use Grant and Grigorieff values.
The pixel size of the volume, in Angstrom.
start_exposure : float
The start exposure for dose weighting, in e-/A^2. Default is 0.0.
end_exposure : float
The end exposure for dose weighting, in e-/A^2. Default is 30.0.
crit_exposure_bfactor : int | float
The B factor for dose weighting based on critical exposure. If '-1',
then use Grant and Grigorieff (2015) values.
rfft : bool
If the FFT is a real FFT.
fftshift : bool
Expand All @@ -80,24 +79,24 @@ def cumulative_dose_filter_3d(
torch.Tensor
The dose weighting filter.
"""
end_exposure = start_exposure + num_frames * flux
# Get the frequency grid for 1 frame
fft_freq_px = (
fftfreq_grid(
image_shape=volume_shape,
rfft=rfft,
fftshift=fftshift,
norm=True,
)
/ pixel_size
fft_freq_px = fftfreq_grid(
image_shape=volume_shape,
rfft=rfft,
fftshift=fftshift,
norm=True,
)
fft_freq_px /= pixel_size # Convert to Angstrom^-1

# Get the critical exposure for each frequency
Ne = (
critical_exposure_Bfac(fft_freq=fft_freq_px, Bfac=Bfac)
if Bfac >= 0
else critical_exposure(fft_freq=fft_freq_px)
)
if crit_exposure_bfactor == -1:
Ne = critical_exposure(fft_freq=fft_freq_px)
elif crit_exposure_bfactor >= 0:
Ne = critical_exposure_bfactor(
fft_freq=fft_freq_px, bfactor=crit_exposure_bfactor
)
else:
raise ValueError("B-factor must be positive or -1.")

# Add small epsilon to prevent division by zero
eps = 1e-10
Expand Down
39 changes: 19 additions & 20 deletions tests/dose_weight/test_dose_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch_fourier_filter.dose_weight import (
critical_exposure,
critical_exposure_Bfac,
critical_exposure_bfactor,
cumulative_dose_filter_3d,
)

Expand All @@ -16,57 +16,56 @@ def test_critical_exposure():
), "critical_exposure output mismatch"


def test_critical_exposure_Bfac():
def test_critical_exposure_bfactor():
fft_freq = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5])
Bfac = 1.0
bfac = 1.0
expected_output = torch.tensor([200.0, 50.0, 22.222, 12.5, 8.0])
output = critical_exposure_Bfac(fft_freq, Bfac)
output = critical_exposure_bfactor(fft_freq, bfac)
assert torch.allclose(
output, expected_output, atol=1e-3
), "critical_exposure_Bfac output mismatch"
), "critical_exposure_bfactor output mismatch"


def test_cumulative_dose_filter_3d():
# Test parameters
volume_shape = (32, 32, 32)
num_frames = 10
start_exposure = 0.0
end_exposure = 10.0
pixel_size = 1.0
flux = 1.0
Bfac = -1.0
crit_exposure_bfactor = -1
rfft = True
fftshift = False

# Call the function
dose_filter = cumulative_dose_filter_3d(
volume_shape=volume_shape,
num_frames=num_frames,
start_exposure=start_exposure,
pixel_size=pixel_size,
flux=flux,
Bfac=Bfac,
start_exposure=start_exposure,
end_exposure=end_exposure,
crit_exposure_bfactor=crit_exposure_bfactor,
rfft=rfft,
fftshift=fftshift,
)

# Check if the values are within a reasonable range
# TODO: Test these against known, static values rather than just range
assert torch.all(dose_filter >= 0) and torch.all(
dose_filter <= 1
), "Dose filter values out of range"

# Test with different Bfac values
Bfac_values = [0.5, 1.0, 2.0]
for Bfac in Bfac_values:
# Test with different bfac values
# TODO: Test these against known, static values rather than just range
crit_bfac_values = [0.5, 1.0, 2.0]
for bfac in crit_bfac_values:
dose_filter = cumulative_dose_filter_3d(
volume_shape=volume_shape,
num_frames=num_frames,
start_exposure=start_exposure,
pixel_size=pixel_size,
flux=flux,
Bfac=Bfac,
start_exposure=start_exposure,
end_exposure=end_exposure,
crit_exposure_bfactor=bfac,
rfft=rfft,
fftshift=fftshift,
)
assert torch.all(dose_filter >= 0) and torch.all(
dose_filter <= 1
), f"Dose filter values out of range for Bfac={Bfac}"
), f"Dose filter values out of range for bfac={bfac}"

0 comments on commit fc0cbf6

Please sign in to comment.