From fc0cbf61aeb14da010aff4d46cb8494bc4dc9ffc Mon Sep 17 00:00:00 2001 From: Matthew Giammar <43814525+mgiammar@users.noreply.github.com> Date: Mon, 23 Dec 2024 10:15:16 -0500 Subject: [PATCH] refactor: change arguments for `cumulative_dose_filter_3d` --- src/torch_fourier_filter/dose_weight.py | 61 ++++++++++++------------- tests/dose_weight/test_dose_weight.py | 39 ++++++++-------- 2 files changed, 49 insertions(+), 51 deletions(-) diff --git a/src/torch_fourier_filter/dose_weight.py b/src/torch_fourier_filter/dose_weight.py index 3759f61..b0befc4 100644 --- a/src/torch_fourier_filter/dose_weight.py +++ b/src/torch_fourier_filter/dose_weight.py @@ -24,30 +24,29 @@ def critical_exposure(fft_freq: torch.Tensor) -> torch.Tensor: return Ne -def critical_exposure_Bfac(fft_freq: torch.Tensor, Bfac: float) -> torch.Tensor: +def critical_exposure_bfactor(fft_freq: torch.Tensor, bfactor: float) -> torch.Tensor: """ Calculate the critical exposure using a user defined B-factor. Args: fft_freq: The frequency grid of the Fourier transform. - Bfac: The B-factor to use. + bfactor: The B-factor to use. Returns ------- The critical exposure for the given frequency grid """ eps = 1e-10 - Ne = 2 / (Bfac * fft_freq.clamp(min=eps) ** 2) + Ne = 2 / (bfactor * fft_freq.clamp(min=eps) ** 2) return Ne def cumulative_dose_filter_3d( volume_shape: tuple[int, int, int], - num_frames: int, - start_exposure: float = 0.0, pixel_size: float = 1, - flux: float = 1, - Bfac: float = -1, + start_exposure: float = 0.0, + end_exposure: float = 30.0, + crit_exposure_bfactor: int | float = -1, rfft: bool = True, fftshift: bool = False, ) -> torch.Tensor: @@ -59,17 +58,17 @@ def cumulative_dose_filter_3d( Parameters ---------- volume_shape : tuple[int, int, int] - The volume shape for dose weighting. - num_frames : int - The number of frames for dose weighting. - start_exposure : float - The start exposure for dose weighting. + The shape of the filter to calculate (real space). Rfft is + automatically calculated from this. pixel_size : float - The pixel size of the volume. - flux : float - The fluence per frame. - Bfac : float - The B factor for dose weighting, -1=use Grant and Grigorieff values. + The pixel size of the volume, in Angstrom. + start_exposure : float + The start exposure for dose weighting, in e-/A^2. Default is 0.0. + end_exposure : float + The end exposure for dose weighting, in e-/A^2. Default is 30.0. + crit_exposure_bfactor : int | float + The B factor for dose weighting based on critical exposure. If '-1', + then use Grant and Grigorieff (2015) values. rfft : bool If the FFT is a real FFT. fftshift : bool @@ -80,24 +79,24 @@ def cumulative_dose_filter_3d( torch.Tensor The dose weighting filter. """ - end_exposure = start_exposure + num_frames * flux # Get the frequency grid for 1 frame - fft_freq_px = ( - fftfreq_grid( - image_shape=volume_shape, - rfft=rfft, - fftshift=fftshift, - norm=True, - ) - / pixel_size + fft_freq_px = fftfreq_grid( + image_shape=volume_shape, + rfft=rfft, + fftshift=fftshift, + norm=True, ) + fft_freq_px /= pixel_size # Convert to Angstrom^-1 # Get the critical exposure for each frequency - Ne = ( - critical_exposure_Bfac(fft_freq=fft_freq_px, Bfac=Bfac) - if Bfac >= 0 - else critical_exposure(fft_freq=fft_freq_px) - ) + if crit_exposure_bfactor == -1: + Ne = critical_exposure(fft_freq=fft_freq_px) + elif crit_exposure_bfactor >= 0: + Ne = critical_exposure_bfactor( + fft_freq=fft_freq_px, bfactor=crit_exposure_bfactor + ) + else: + raise ValueError("B-factor must be positive or -1.") # Add small epsilon to prevent division by zero eps = 1e-10 diff --git a/tests/dose_weight/test_dose_weight.py b/tests/dose_weight/test_dose_weight.py index 50d5d92..2d0b4fc 100644 --- a/tests/dose_weight/test_dose_weight.py +++ b/tests/dose_weight/test_dose_weight.py @@ -2,7 +2,7 @@ from torch_fourier_filter.dose_weight import ( critical_exposure, - critical_exposure_Bfac, + critical_exposure_bfactor, cumulative_dose_filter_3d, ) @@ -16,57 +16,56 @@ def test_critical_exposure(): ), "critical_exposure output mismatch" -def test_critical_exposure_Bfac(): +def test_critical_exposure_bfactor(): fft_freq = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) - Bfac = 1.0 + bfac = 1.0 expected_output = torch.tensor([200.0, 50.0, 22.222, 12.5, 8.0]) - output = critical_exposure_Bfac(fft_freq, Bfac) + output = critical_exposure_bfactor(fft_freq, bfac) assert torch.allclose( output, expected_output, atol=1e-3 - ), "critical_exposure_Bfac output mismatch" + ), "critical_exposure_bfactor output mismatch" def test_cumulative_dose_filter_3d(): # Test parameters volume_shape = (32, 32, 32) - num_frames = 10 start_exposure = 0.0 + end_exposure = 10.0 pixel_size = 1.0 - flux = 1.0 - Bfac = -1.0 + crit_exposure_bfactor = -1 rfft = True fftshift = False # Call the function dose_filter = cumulative_dose_filter_3d( volume_shape=volume_shape, - num_frames=num_frames, - start_exposure=start_exposure, pixel_size=pixel_size, - flux=flux, - Bfac=Bfac, + start_exposure=start_exposure, + end_exposure=end_exposure, + crit_exposure_bfactor=crit_exposure_bfactor, rfft=rfft, fftshift=fftshift, ) # Check if the values are within a reasonable range + # TODO: Test these against known, static values rather than just range assert torch.all(dose_filter >= 0) and torch.all( dose_filter <= 1 ), "Dose filter values out of range" - # Test with different Bfac values - Bfac_values = [0.5, 1.0, 2.0] - for Bfac in Bfac_values: + # Test with different bfac values + # TODO: Test these against known, static values rather than just range + crit_bfac_values = [0.5, 1.0, 2.0] + for bfac in crit_bfac_values: dose_filter = cumulative_dose_filter_3d( volume_shape=volume_shape, - num_frames=num_frames, - start_exposure=start_exposure, pixel_size=pixel_size, - flux=flux, - Bfac=Bfac, + start_exposure=start_exposure, + end_exposure=end_exposure, + crit_exposure_bfactor=bfac, rfft=rfft, fftshift=fftshift, ) assert torch.all(dose_filter >= 0) and torch.all( dose_filter <= 1 - ), f"Dose filter values out of range for Bfac={Bfac}" + ), f"Dose filter values out of range for bfac={bfac}"