From 3a27156660786d6fbee3da0fc711b15a662c4bf1 Mon Sep 17 00:00:00 2001 From: jdickerson95 Date: Mon, 18 Nov 2024 11:55:01 -0800 Subject: [PATCH] feat: First (cpu only) version up and running --- .pre-commit-config.yaml | 2 +- .typos.toml | 8 + src/ttsim3d/__init__.py | 2 +- src/ttsim3d/device_handler.py | 137 +++++ src/ttsim3d/elastic_scattering_factors.json | 240 ++++++++ src/ttsim3d/grid_coords.py | 49 ++ src/ttsim3d/pdb_handler.py | 49 ++ src/ttsim3d/run_ttsim3d.py | 31 ++ src/ttsim3d/scattering_potential.py | 180 ++++++ src/ttsim3d/simulate3d.py | 576 ++++++++++++++++++++ src/ttsim3d/test_code.txt | 103 ++++ 11 files changed, 1375 insertions(+), 2 deletions(-) create mode 100644 .typos.toml create mode 100644 src/ttsim3d/device_handler.py create mode 100644 src/ttsim3d/elastic_scattering_factors.json create mode 100644 src/ttsim3d/grid_coords.py create mode 100644 src/ttsim3d/pdb_handler.py create mode 100644 src/ttsim3d/run_ttsim3d.py create mode 100644 src/ttsim3d/scattering_potential.py create mode 100644 src/ttsim3d/simulate3d.py create mode 100644 src/ttsim3d/test_code.txt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 86d5cc1..4cfb088 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: rev: v1.24.6 hooks: - id: typos - args: [--force-exclude] # omitting --write-changes + args: ["--force-exclude"] # omitting --write-changes - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.7.2 diff --git a/.typos.toml b/.typos.toml new file mode 100644 index 0000000..88f40dc --- /dev/null +++ b/.typos.toml @@ -0,0 +1,8 @@ +[files] +extend-exclude = ["src/ttsim3d/elastic_scattering_factors.json"] + +[default.dictionary] +terms = [ + "BA", + "ND", +] \ No newline at end of file diff --git a/src/ttsim3d/__init__.py b/src/ttsim3d/__init__.py index 562dd50..df7af16 100644 --- a/src/ttsim3d/__init__.py +++ b/src/ttsim3d/__init__.py @@ -1,4 +1,4 @@ -"""Simulate a 3D electrostatic potential map from a PDB in pyTorch""" +"""Simulate a 3D electrostatic potential map from a PDB in pyTorch.""" from importlib.metadata import PackageNotFoundError, version diff --git a/src/ttsim3d/device_handler.py b/src/ttsim3d/device_handler.py new file mode 100644 index 0000000..8661b92 --- /dev/null +++ b/src/ttsim3d/device_handler.py @@ -0,0 +1,137 @@ +"""Handles cpu/gpu device selection.""" + +import multiprocessing as mp +from typing import Optional + +import torch + + +def get_cpu_cores() -> int: + """ + Get the number of CPU cores available. + + Returns + ------- + int: Number of CPU cores available. + """ + return mp.cpu_count() + + +def select_gpus( + gpu_ids: Optional[list[int]] = None, + num_gpus: int = 1, +) -> list[torch.device]: + """ + Select multiple GPU devices based on IDs or available memory. + + Args: + gpu_ids: List of specific GPU IDs to use. + If None, selects GPUs with most available memory. + num_gpus: Number of GPUs to use if gpu_ids is None. + + Returns + ------- + list[torch.device]: Selected GPU devices or [CPU] if no GPU available + """ + # Check if you can actually use a cuda gpu + if not torch.cuda.is_available(): + print("No GPU available, using CPU") + return [torch.device("cpu")] + + n_gpus = torch.cuda.device_count() + if n_gpus == 0: + print("No GPU available, using CPU") + return [torch.device("cpu")] + + # If specific GPUs requested, validate and return them + if gpu_ids is not None: + valid_devices = [] + for gpu_id in gpu_ids: + if gpu_id >= n_gpus: + print(f"Requested GPU {gpu_id} not available. Max GPU ID is {n_gpus-1}") + continue + valid_devices.append(torch.device(f"cuda:{gpu_id}")) + + if not valid_devices: + print("No valid GPUs specified. Using CPU") + return [torch.device("cpu")] + return valid_devices + + # Find GPUs with most available memory + gpu_memory_available = [] + print("\nAvailable GPUs:") + for i in range(n_gpus): + torch.cuda.set_device(i) + total_memory = torch.cuda.get_device_properties(i).total_memory + allocated_memory = torch.cuda.memory_allocated(i) + available = total_memory - allocated_memory + + print(f"GPU {i}: {torch.cuda.get_device_properties(i).name}") + print(f" Total memory: {total_memory/1024**3:.1f} GB") + print(f" Available memory: {available/1024**3:.1f} GB") + + gpu_memory_available.append((i, available)) + + # Sort by available memory and select the top num_gpus + gpu_memory_available.sort(key=lambda x: x[1], reverse=True) + selected_gpus = [ + torch.device(f"cuda:{idx}") for idx, _ in gpu_memory_available[:num_gpus] + ] + + print("\nSelected GPUs:", [str(device) for device in selected_gpus]) + return selected_gpus + + +def calculate_batch_size_gpu( + total_atoms: int, + neighborhood_size: int, + device: torch.device, + safety_factor: float = 0.8, # Use only 80% of available memory by default + min_batch_size: int = 100, +) -> int: + """ + Calculate optimal batch size based on available GPU memory and data size. + + Args: + total_atoms: Total number of atoms to process + neighborhood_size: Size of neighborhood around each atom + device: PyTorch device (GPU) + safety_factor: Fraction of available memory to use (0.0 to 1.0) + + Returns + ------- + Optimal batch size + """ + # Get available GPU memory in bytes + gpu_memory = torch.cuda.get_device_properties(device).total_memory + + # Calculate memory requirements per atom + voxels_per_atom = (2 * neighborhood_size + 1) ** 3 + bytes_per_float = 4 # 32-bit float + + # Memory needed for: + # 1. Voxel positions (float32): batch_size * voxels_per_atom * 3 coordinates + # 2. Valid mask (bool): batch_size * voxels_per_atom + # 3. Relative coordinates (float32): batch_size * voxels_per_atom * 3 + # 4. Potentials (float32): batch_size * voxels_per_atom + # Plus some overhead for temporary variables + memory_per_atom = ( + voxels_per_atom * (3 * bytes_per_float) # Voxel positions + + voxels_per_atom * 1 # Valid mask (bool) + + voxels_per_atom * (3 * bytes_per_float) # Relative coordinates + + voxels_per_atom * bytes_per_float # Potentials + + 1024 # Additional overhead + ) + + # Calculate batch size + optimal_batch_size = int((gpu_memory * safety_factor) / memory_per_atom) + + # Ensure batch size is at least 1 but not larger than total atoms + optimal_batch_size = max(min_batch_size, min(optimal_batch_size, total_atoms)) + + print(f"Total GPU memory: {gpu_memory / 1024**3:.2f} GB") + # print(f"Available GPU memory: {gpu_memory_available / 1024**3:.2f} GB") + print(f"Estimated memory per atom: {memory_per_atom / 1024**2:.2f} MB") + print(f"Optimal batch size: {optimal_batch_size}") + + return optimal_batch_size diff --git a/src/ttsim3d/elastic_scattering_factors.json b/src/ttsim3d/elastic_scattering_factors.json new file mode 100644 index 0000000..6b4e3f4 --- /dev/null +++ b/src/ttsim3d/elastic_scattering_factors.json @@ -0,0 +1,240 @@ +{ + "parameters_a": { + "H": [0.0349, 0.1201, 0.1970, 0.0573, 0.1195], + "HE": [0.0317, 0.0838, 0.1526, 0.1334, 0.0164], + "LI": [0.0750, 0.2249, 0.5548, 1.4954, 0.9354], + "BE": [0.0780, 0.2210, 0.6740, 1.3867, 0.6925], + "B": [0.0909, 0.2551, 0.7738, 1.2136, 0.4606], + + "C": [0.0893, 0.2563, 0.7570, 1.0487, 0.3575], + "N": [0.1022, 0.3219, 0.7982, 0.8197, 0.1715], + "O": [0.0974, 0.2921, 0.6910, 0.6990, 0.2039], + "F": [0.1083, 0.3175, 0.6487, 0.5846, 0.1421], + "NE": [0.1269, 0.3535, 0.5582, 0.4674, 0.1460], + + "NA": [0.2142, 0.6853, 0.7692, 1.6589, 1.4482], + "MG": [0.2314, 0.6866, 0.9677, 2.1882, 1.1339], + "AL": [0.2390, 0.6573, 1.2011, 2.5586, 1.2312], + "SI": [0.2519, 0.6372, 1.3795, 2.5082, 1.0500], + "P": [0.2548, 0.6106, 1.4541, 2.3204, 0.8477], + + "S": [0.2497, 0.5628, 1.3899, 2.1865, 0.7715], + "CL": [0.2443, 0.5397, 1.3919, 2.0197, 0.6621], + "AR": [0.2385, 0.5017, 1.3428, 1.8899, 0.6079], + "K": [0.4115, 1.4031, 2.2784, 2.6742, 2.2162], + "CA": [0.4054, 1.3880, 2.1602, 3.7532, 2.2063], + + "SC": [], + "TI": [], + "V": [], + "CR": [], + "MN": [], + + "FE": [], + "CO": [], + "NI": [], + "CU": [], + "ZN": [], + + "GA": [], + "GE": [], + "AS": [], + "SE": [], + "BR": [], + + "KR": [], + "RB": [], + "SR": [], + "Y": [], + "ZR": [], + + "NB": [], + "MO": [], + "TC": [], + "RU": [], + "RH": [], + + "PD": [], + "AG": [], + "CD": [], + "IN": [], + "SN": [], + + "SB": [], + "TE": [], + "I": [], + "XE": [], + "CS": [], + + "BA": [], + "LA": [], + "CE": [], + "PR": [], + "ND": [], + + "PM": [], + "SM": [], + "EU": [], + "GD": [], + "TB": [], + + "DY": [], + "HO": [], + "ER": [], + "TM": [], + "YB": [], + + "LU": [], + "HF": [], + "TA": [], + "W": [], + "RE": [], + + "OS": [], + "IR": [], + "PT": [], + "AU": [], + "HG": [], + + "TL": [], + "PB": [], + "BI": [], + "PO": [], + "AT": [], + + "RN": [], + "FR": [], + "RA": [], + "AC": [], + "TH": [], + + "PA": [], + "U": [], + "NP": [], + "PU": [], + "AM": [], + + "CM": [], + "BK": [], + "CF": [] + }, + "parameters_b": { + "H": [0.5347, 3.5867, 12.3471, 18.9525, 38.6269], + "HE": [0.2507, 1.4751, 4.4938, 12.6646, 31.1653], + "LI": [0.3864, 2.9383, 15.3829, 53.5545, 138.7337], + "BE": [0.3131, 2.2381, 10.1517, 30.9061, 78.3273], + "B": [0.2995, 2.1155, 8.3816, 24.1292, 63.1314], + + "C": [0.2465, 1.7100, 6.4094, 18.6113, 50.2523], + "N": [0.2451, 1.7481, 6.1925, 17.3894, 48.1431], + "O": [0.2067, 1.3815, 4.6943, 12.7105, 32.4726], + "F": [0.2057, 1.3439, 4.2788, 11.3932, 28.7881], + "NE": [0.2200, 1.3779, 4.0203, 9.4934, 23.1278], + + "NA": [0.3334, 2.3446, 10.0830, 48.3037, 138.2700], + "MG": [0.3278, 2.2720, 10.9241, 39.2898, 101.9748], + "AL": [0.3138, 2.1063, 10.4163, 34.4552, 98.5344], + "SI": [0.3075, 2.0174, 9.6746, 29.3744, 80.4732], + "P": [0.2908, 1.8740, 8.5176, 24.3434, 63.2996], + + "S": [0.2681, 1.6711, 7.0267, 19.5377, 50.3888], + "CL": [0.2468, 1.5242, 6.1537, 16.6687, 42.3086], + "AR": [0.2289, 1.3694, 5.2561, 14.0928, 35.5361], + "K": [0.3703, 3.3874, 13.1029, 68.9592, 194.4329], + "CA": [0.3499, 3.0991, 11.9608, 53.9353, 142.3892], + + "SC": [], + "TI": [], + "V": [], + "CR": [], + "MN": [], + + "FE": [], + "CO": [], + "NI": [], + "CU": [], + "ZN": [], + + "GA": [], + "GE": [], + "AS": [], + "SE": [], + "BR": [], + + "KR": [], + "RB": [], + "SR": [], + "Y": [], + "ZR": [], + + "NB": [], + "MO": [], + "TC": [], + "RU": [], + "RH": [], + + "PD": [], + "AG": [], + "CD": [], + "IN": [], + "SN": [], + + "SB": [], + "TE": [], + "I": [], + "XE": [], + "CS": [], + + "BA": [], + "LA": [], + "CE": [], + "PR": [], + "ND": [], + + "PM": [], + "SM": [], + "EU": [], + "GD": [], + "TB": [], + + "DY": [], + "HO": [], + "ER": [], + "TM": [], + "YB": [], + + "LU": [], + "HF": [], + "TA": [], + "W": [], + "RE": [], + + "OS": [], + "IR": [], + "PT": [], + "AU": [], + "HG": [], + + "TL": [], + "PB": [], + "BI": [], + "PO": [], + "AT": [], + + "RN": [], + "FR": [], + "RA": [], + "AC": [], + "TH": [], + + "PA": [], + "U": [], + "NP": [], + "PU": [], + "AM": [], + + "CM": [], + "BK": [], + "CF": [] + } +} diff --git a/src/ttsim3d/grid_coords.py b/src/ttsim3d/grid_coords.py new file mode 100644 index 0000000..aab601f --- /dev/null +++ b/src/ttsim3d/grid_coords.py @@ -0,0 +1,49 @@ +"""Deals with grid coordinates.""" + +import torch + + +def get_upsampling( + wanted_pixel_size: float, wanted_output_size: int, max_size: int = 1536 +) -> int: + """ + Calculate the upsampling factor for the simulation volume. + + Args: + wanted_pixel_size: The pixel size in Angstroms. + wanted_output_size: The output size of the 3D volume. + max_size: The maximum size of the 3D volume. + + Returns + ------- + int: The upsampling factor. + """ + if wanted_pixel_size > 1.5 and wanted_output_size * 4 < max_size: + print("Oversampling your 3d by a factor of 4 for calculation.") + return 4 + + if 0.75 < wanted_pixel_size <= 1.5 and wanted_output_size * 2 < max_size: + print("Oversampling your 3d by a factor of 2 for calculation.") + return 2 + + return 1 + + +def get_size_neighborhood_cistem( + mean_b_factor: float, upsampled_pixel_size: float +) -> int: + """ + Calculate the size of the neighborhood of voxels. + + Args: + mean_b_factor: The mean B factor of the atoms. + upsampled_pixel_size: The pixel size in Angstroms. + + Returns + ------- + int: The size of the neighborhood. + """ + return int( + 1 + + torch.round((0.4 * (0.6 * mean_b_factor) ** 0.5 + 0.2) / upsampled_pixel_size) + ) diff --git a/src/ttsim3d/pdb_handler.py b/src/ttsim3d/pdb_handler.py new file mode 100644 index 0000000..3f92efc --- /dev/null +++ b/src/ttsim3d/pdb_handler.py @@ -0,0 +1,49 @@ +"""Handle PDB related operations.""" + +import mmdf +import torch + + +def load_model( + file_path: str, +) -> tuple[torch.Tensor, list[str], torch.Tensor]: + """ + Load model from pdb file_path and return atom coordinates in Angstroms. + + Args: + file_paths: A list of file paths. + + Returns + ------- + atom coordinates in Angstroms. + """ + df = mmdf.read(file_path) + atom_zyx = torch.tensor(df[["z", "y", "x"]].to_numpy()).float() # (n_atoms, 3) + atom_zyx -= torch.mean(atom_zyx, dim=0, keepdim=True) # center + atom_id = df["element"].str.upper().tolist() + atom_b_factor = torch.tensor(df["b_isotropic"].to_numpy()).float() + return atom_zyx, atom_id, atom_b_factor + + +def remove_hydrogens( + atoms_zyx: torch.Tensor, + atoms_id: list, + atoms_b_factor_scaled: torch.Tensor, +) -> tuple[torch.Tensor, list[str], torch.Tensor]: + """ + Remove hydrogen atoms from the atom list. + + Args: + atoms_zyx: Atom coordinates in Angstroms. + atoms_id: Atom IDs. + atoms_b_factor_scaled: Atom B factors. + + Returns + ------- + Atom coordinates in Angstroms. + """ + non_h_mask = [aid != "H" for aid in atoms_id] + atoms_zyx_filtered = atoms_zyx[non_h_mask] + atoms_id_filtered = [aid for i, aid in enumerate(atoms_id) if non_h_mask[i]] + atoms_b_factor_scaled_filtered = atoms_b_factor_scaled[non_h_mask] + return atoms_zyx_filtered, atoms_id_filtered, atoms_b_factor_scaled_filtered diff --git a/src/ttsim3d/run_ttsim3d.py b/src/ttsim3d/run_ttsim3d.py new file mode 100644 index 0000000..f766c2e --- /dev/null +++ b/src/ttsim3d/run_ttsim3d.py @@ -0,0 +1,31 @@ +"""Simple run script.""" + +from ttsim3d.simulate3d import simulate3d + + +def main() -> None: + """A test function to run the simulate3d function from the ttsim3d package.""" + simulate3d( + pdb_filename="/Users/josh/git/2dtm_tests/simulator/parsed_6Q8Y_whole_LSU_match3.pdb", + output_filename="/Users/josh/git/2dtm_tests/simulator/simulated_6Q8Y_whole_LSU_match3.mrc", + sim_volume_shape=(400, 400, 400), + sim_pixel_spacing=0.95, + num_frames=50, + fluence_per_frame=1, + beam_energy_kev=300, + dose_weighting=True, + dose_B=-1, + apply_dqe=True, + mtf_filename="/Users/josh/git/2dtm_tests/simulator/mtf_k2_300kV.star", + b_scaling=0.5, + added_B=0.0, + upsampling=-1, + n_cpu_cores=-1, + gpu_ids=[-999], + num_gpus=0, + modify_signal=1, # This is how to apply the dose weighting. + ) + + +if __name__ == "__main__": + main() diff --git a/src/ttsim3d/scattering_potential.py b/src/ttsim3d/scattering_potential.py new file mode 100644 index 0000000..b2b528f --- /dev/null +++ b/src/ttsim3d/scattering_potential.py @@ -0,0 +1,180 @@ +"""Calculates the scatttering potential.""" + +import json +from pathlib import Path + +import torch +from scipy import constants as C + + +def calculate_relativistic_electron_wavelength(energy: float) -> float: + """Calculate the relativistic electron wavelength in SI units. + + For derivation see: + 1. Kirkland, E. J. Advanced Computing in Electron Microscopy. + (Springer International Publishing, 2020). doi:10.1007/978-3-030-33260-0. + + 2. https://en.wikipedia.org/wiki/Electron_diffraction#Relativistic_theory + + Parameters + ---------- + energy: float + acceleration potential in volts. + + Returns + ------- + wavelength: float + relativistic wavelength of the electron in meters. + """ + h = C.Planck + c = C.speed_of_light + m0 = C.electron_mass + e = C.elementary_charge + V = energy + eV = e * V + + numerator = h * c + denominator = (eV * (2 * m0 * c**2 + eV)) ** 0.5 + return float(numerator / denominator) + + +def get_scattering_parameters() -> tuple[dict, dict]: + """ + Load scattering parameters from JSON file. + + Args: + None + + Returns + ------- + scattering_params_a: dict + Scattering parameters for atom type A. + scattering_params_b: dict + Scattering parameters for atom type B. + """ + scattering_param_path = Path(__file__).parent / "elastic_scattering_factors.json" + + with open(scattering_param_path) as f: + data = json.load(f) + + scattering_params_a = {k: v for k, v in data["parameters_a"].items() if v != []} + scattering_params_b = {k: v for k, v in data["parameters_b"].items() if v != []} + return scattering_params_a, scattering_params_b + + +def get_total_b_param( + scattering_params_b: dict, + atoms_id_filtered: list[str], + atoms_b_factor_scaled_filtered: torch.Tensor, +) -> torch.Tensor: + """ + Calculate the total B parameter for each atom in the neighborhood. + + Args: + scattering_params_b: dict + Scattering parameters for atom type B. + atoms_id_filtered: list[str] + Atom IDs. + atoms_b_factor_scaled_filtered: torch.Tensor + Atom B factors. + + Returns + ------- + bPlusB: torch.Tensor + Total B parameter for each atom in the neighborhood. + """ + b_params = torch.stack( + [torch.tensor(scattering_params_b[atom_id]) for atom_id in atoms_id_filtered] + ) + bPlusB = ( + 2 + * torch.pi + / torch.sqrt(atoms_b_factor_scaled_filtered.unsqueeze(1) + b_params) + ) + return bPlusB + + +def get_scattering_potential_of_voxel( + zyx_coords1: torch.Tensor, # Shape: (N, 3) + zyx_coords2: torch.Tensor, # Shape: (N, 3) + bPlusB: torch.Tensor, + atom_id: str, + lead_term: float, + scattering_params_a: dict, # Add parameter dictionary + device: torch.device = None, +) -> torch.Tensor: + """ + Calculate scattering potential for all voxels in the neighborhood of of the atom. + + Args: + zyx_coords1: torch.Tensor + Coordinates of the first voxel in the neighborhood. + zyx_coords2: torch.Tensor + Coordinates of the second voxel in the neighborhood. + bPlusB: torch.Tensor + Total B parameter for each atom in the neighborhood. + atom_id: str + Atom ID. + lead_term: float + Lead term for the scattering potential. + scattering_params_a: dict + Scattering parameters for atom type A. + device: torch.device + Device to run the computation on. + + Returns + ------- + potential: torch.Tensor + Scattering potential for all voxels in the neighborhood. + """ + # If device not specified, use the device of input tensors + if device is None: + device = zyx_coords1.device + + # Get scattering parameters for this atom type and move to correct device + # Convert parameters to tensor and move to device + if isinstance(scattering_params_a[atom_id], torch.Tensor): + a_params = scattering_params_a[atom_id].clone().detach().to(device) + else: + a_params = torch.as_tensor(scattering_params_a[atom_id], device=device) + + # Compare signs element-wise for batched coordinates + t1 = (zyx_coords1[:, 2] * zyx_coords2[:, 2]) >= 0 # Shape: (N,) + t2 = (zyx_coords1[:, 1] * zyx_coords2[:, 1]) >= 0 # Shape: (N,) + t3 = (zyx_coords1[:, 0] * zyx_coords2[:, 0]) >= 0 # Shape: (N,) + + temp_potential = torch.zeros(len(zyx_coords1), device=device) + + for i, bb in enumerate(bPlusB): + a = a_params[i] + # Handle x dimension + x_term = torch.where( + t1, + torch.special.erf(bb * zyx_coords2[:, 2]) + - torch.special.erf(bb * zyx_coords1[:, 2]), + torch.abs(torch.special.erf(bb * zyx_coords2[:, 2])) + + torch.abs(torch.special.erf(bb * zyx_coords1[:, 2])), + ) + + # Handle y dimension + y_term = torch.where( + t2, + torch.special.erf(bb * zyx_coords2[:, 1]) + - torch.special.erf(bb * zyx_coords1[:, 1]), + torch.abs(torch.special.erf(bb * zyx_coords2[:, 1])) + + torch.abs(torch.special.erf(bb * zyx_coords1[:, 1])), + ) + + # Handle z dimension + z_term = torch.where( + t3, + torch.special.erf(bb * zyx_coords2[:, 0]) + - torch.special.erf(bb * zyx_coords1[:, 0]), + torch.abs(torch.special.erf(bb * zyx_coords2[:, 0])) + + torch.abs(torch.special.erf(bb * zyx_coords1[:, 0])), + ) + + t0 = z_term * y_term * x_term + temp_potential += a * torch.abs(t0) + + return lead_term * temp_potential diff --git a/src/ttsim3d/simulate3d.py b/src/ttsim3d/simulate3d.py new file mode 100644 index 0000000..f6553e3 --- /dev/null +++ b/src/ttsim3d/simulate3d.py @@ -0,0 +1,576 @@ +"""The main simulation function.""" + +import multiprocessing as mp +import time +from typing import Optional + +import mrcfile +import numpy as np +import torch +from torch_fourier_filter.dose_weight import cumulative_dose_filter_3d +from torch_fourier_filter.mtf import make_mtf_grid, read_mtf + +from ttsim3d.device_handler import get_cpu_cores, select_gpus +from ttsim3d.grid_coords import get_size_neighborhood_cistem, get_upsampling +from ttsim3d.pdb_handler import load_model, remove_hydrogens +from ttsim3d.scattering_potential import ( + calculate_relativistic_electron_wavelength, + get_scattering_parameters, + get_scattering_potential_of_voxel, + get_total_b_param, +) + +BOND_SCALING_FACTOR = 1.043 +PIXEL_OFFSET = 0.5 + + +# This will definitely be moved to a different program +def fourier_rescale_3d_force_size( + volume_fft: torch.Tensor, + volume_shape: tuple[int, int, int], + target_size: int, + rfft: bool = True, + fftshift: bool = False, +) -> torch.Tensor: + """ + Crop a 3D Fourier-transformed volume to a specific target size. + + Parameters + ---------- + volume_fft: torch.Tensor + The Fourier-transformed volume. + volume_shape: tuple[int, int, int] + The original shape of the volume. + target_size: int + The target size of the cropped volume. + rfft: bool + Whether the input is a real-to-complex Fourier Transform. + fftshift: bool + Whether the zero frequency is shifted to the center. + + Returns + ------- + - cropped_fft_shifted_back (torch.Tensor): The cropped fft + """ + # Ensure the target size is even + assert target_size > 0, "Target size must be positive." + + # Get the original size of the volume + assert ( + volume_shape[0] == volume_shape[1] == volume_shape[2] + ), "Volume must be cubic." + + # Step 1: Perform real-to-complex Fourier Transform (rfftn) + # and shift the zero frequency to the center + if not fftshift: + volume_fft = torch.fft.fftshift( + volume_fft, dim=(-3, -2, -1) + ) # Shift along first two dimensions only + + # Calculate the dimensions of the rfftn output + rfft_size_z, rfft_size_y, rfft_size_x = volume_fft.shape + + # Calculate cropping indices for each dimension + center_z = rfft_size_z // 2 + center_y = rfft_size_y // 2 + center_x = rfft_size_x // 2 + + # Define the cropping ranges + crop_start_z = int(center_z - target_size // 2) + crop_end_z = int(crop_start_z + target_size) + crop_start_y = int(center_y - target_size // 2) + crop_end_y = int(crop_start_y + target_size) + crop_start_x = int(center_x - target_size // 2) + crop_end_x = int( + target_size // 2 + 1 + ) # Crop from the high-frequency end only along the last dimension + + # Step 2: Crop the Fourier-transformed volume + cropped_fft = torch.zeros_like(volume_fft) + if rfft: + cropped_fft = volume_fft[ + crop_start_z:crop_end_z, crop_start_y:crop_end_y, -crop_end_x: + ] + else: + crop_end_x = int(crop_start_x + target_size) + cropped_fft = volume_fft[ + crop_start_z:crop_end_z, crop_start_y:crop_end_y, crop_start_x:crop_end_x + ] + + # Step 3: Inverse shift and apply the inverse rFFT to return to real space + cropped_fft_shifted_back = torch.fft.ifftshift(cropped_fft, dim=(-3, -2)) + + return cropped_fft_shifted_back + + +def process_atom_batch( + batch_args: tuple, +) -> torch.Tensor: + """ + Process a batch of atoms to calculate the scattering potential in parallel. + + Args: + batch_args: Tuple containing the arguments for the batch. + + Returns + ------- + torch.Tensor: The local volume grid for the batch. + """ + try: + # Unpack the tuple correctly + ( + atom_indices_batch, + atom_dds_batch, + bPlusB_batch, + atoms_id_filtered_batch, + voxel_offsets_flat, + upsampled_shape, + upsampled_pixel_size, + lead_term, + scattering_params_a, + ) = batch_args + + # Move tensors to CPU and ensure they're contiguous + atom_indices_batch = atom_indices_batch.cpu().contiguous() + atom_dds_batch = atom_dds_batch.cpu().contiguous() + voxel_offsets_flat = voxel_offsets_flat.cpu().contiguous() + + # Initialize local volume grid for this batch + local_volume = torch.zeros(upsampled_shape, device="cpu") + + # Add debug print to verify data + print(f"Processing batch of size {len(atom_indices_batch)}") + + # offset_test = upsampled_pixel_size/2 + # Process each atom in the batch + for i in range(len(atom_indices_batch)): + atom_pos = atom_indices_batch[i] + atom_dds = atom_dds_batch[i] + atom_id = atoms_id_filtered_batch[i] + + # Calculate voxel positions relative to atom center + voxel_positions = ( + atom_pos.view(1, 3) + voxel_offsets_flat + ) # indX/Y/Z equivalent + + # print(voxel_positions.shape) + # Check bounds for each dimension separately + valid_z = (voxel_positions[:, 0] >= 0) & ( + voxel_positions[:, 0] < upsampled_shape[0] + ) + valid_y = (voxel_positions[:, 1] >= 0) & ( + voxel_positions[:, 1] < upsampled_shape[1] + ) + valid_x = (voxel_positions[:, 2] >= 0) & ( + voxel_positions[:, 2] < upsampled_shape[2] + ) + valid_mask = valid_z & valid_y & valid_x + + if valid_mask.any(): + # Calculate coordinates relative to atom center + relative_coords = ( + voxel_positions[valid_mask] - atom_pos - atom_dds - PIXEL_OFFSET + ) * upsampled_pixel_size + coords1 = relative_coords + coords2 = relative_coords + upsampled_pixel_size + + # Calculate potentials for valid positions + potentials = get_scattering_potential_of_voxel( + zyx_coords1=coords1, + zyx_coords2=coords2, + bPlusB=bPlusB_batch[i], + atom_id=atom_id, + lead_term=lead_term, + scattering_params_a=scattering_params_a, # Pass the parameters + ) + + # Get valid voxel positions + valid_positions = voxel_positions[valid_mask].long() + + # Update local volume + local_volume[ + valid_positions[:, 0], valid_positions[:, 1], valid_positions[:, 2] + ] += potentials + except Exception as e: + print(f"Error in process_atom_batch: {e!s}") + raise e + + return local_volume + + +def process_atoms_parallel( + atom_indices: torch.Tensor, + atom_dds: torch.Tensor, + bPlusB: torch.Tensor, + scattering_params_a: dict, + atoms_id_filtered: list[str], + voxel_offsets_flat: torch.Tensor, + upsampled_shape: tuple[int, int, int], + upsampled_pixel_size: float, + lead_term: float, + n_cores: int = 1, +) -> torch.Tensor: + """ + Scattering potential of atoms in parallel using cpu multiprocessing. + + Args: + atom_indices: The indices of the atoms. + atom_dds: The offset from the edge of the voxel. + bPlusB: The sum of the B factors from scattering and pdb file. + scattering_params_a: The 'a' scattering parameters. + atoms_id_filtered: The list of atom IDs (no H). + voxel_offsets_flat: The flattened voxel offsets for the neighborhood. + upsampled_shape: The shape of the upsampled volume. + upsampled_pixel_size: The pixel size of the upsampled volume. + lead_term: The lead term for the calculation. + n_cores: The number of CPU cores to use. + + Returns + ------- + torch.Tensor: The final volume grid. + """ + # Ensure all inputs are on CPU and contiguous + atom_indices = atom_indices.cpu().contiguous() + atom_dds = atom_dds.cpu().contiguous() + voxel_offsets_flat = voxel_offsets_flat.cpu().contiguous() + + # Convert pandas Series to list if necessary + if hasattr(atoms_id_filtered, "tolist"): + atoms_id_filtered = atoms_id_filtered.tolist() + + num_atoms = len(atom_indices) + batch_size = max(1, num_atoms // (n_cores)) # Divide work into smaller batches + + print(f"Processing {num_atoms} atoms in batches of {batch_size}") + + # Prepare batches + batches = [] + for start_idx in range(0, num_atoms, batch_size): + end_idx = min(start_idx + batch_size, num_atoms) + batch_args = ( + atom_indices[start_idx:end_idx], + atom_dds[start_idx:end_idx], + bPlusB[start_idx:end_idx], + atoms_id_filtered[start_idx:end_idx], + voxel_offsets_flat, + upsampled_shape, + upsampled_pixel_size, + lead_term, + scattering_params_a, + ) + batches.append(batch_args) + + # Process batches in parallel + with mp.Pool(n_cores) as pool: + results = [] + for i, result in enumerate(pool.imap_unordered(process_atom_batch, batches)): + results.append(result) + if (i + 1) % 10 == 0: + print(f"Processed {(i + 1) * batch_size} atoms of {num_atoms}") + + # Combine results + final_volume = torch.zeros(upsampled_shape, device="cpu") + for result in results: + final_volume += result + + return final_volume + + +def process_device_atoms( + args: tuple, +) -> torch.Tensor: + """ + Process atoms for a single device (gpu or cpu) in parallel. + + Args: + args: Tuple containing the arguments for the device. + + Returns + ------- + torch.Tensor: The final volume grid for the device. + """ + ( + device_atom_indices, + device_atom_dds, + device_bPlusB, + device_atoms_id, + scattering_params_a, + voxel_offsets_flat, + upsampled_shape, + upsampled_pixel_size, + lead_term, + device, + n_cpu_cores, + ) = args + + print(f"\nProcessing atoms on {device}") + + if device.type == "cuda": + print("Not done this yet!") + else: + volume_grid = process_atoms_parallel( + atom_indices=device_atom_indices, + atom_dds=device_atom_dds, + bPlusB=device_bPlusB, + scattering_params_a=scattering_params_a, + atoms_id_filtered=device_atoms_id, + voxel_offsets_flat=voxel_offsets_flat, + upsampled_shape=upsampled_shape, + upsampled_pixel_size=upsampled_pixel_size, + lead_term=lead_term, + n_cores=n_cpu_cores, + ) + + return volume_grid + + +def simulate3d( + pdb_filename: str, + output_filename: str, + sim_volume_shape: tuple[int, int, int], + sim_pixel_spacing: float, + num_frames: int, + fluence_per_frame: float, + beam_energy_kev: float = 300, + dose_weighting: bool = True, + dose_B: float = -1, # -1 is use Grant Grigorieff dose weighting + apply_dqe: bool = True, + mtf_filename: str = "", + b_scaling: float = 1.0, + added_B: float = 0.0, + upsampling: int = -1, # -1 is calculate automatically + n_cpu_cores: int = 1, # -1 id get automatically + gpu_ids: Optional[list[int]] = None, # [-999] cpu, [-1] auto, [0, 1] etc=gpuid + num_gpus: int = 1, + modify_signal: int = 1, +) -> None: + """ + Run the 3D simulation. + + Args: + pdb_filename: The filename of the PDB file. + output_filename: The filename of the output MRC file. + sim_volume_shape: The shape of the simulation volume. + sim_pixel_spacing: The pixel spacing of the simulation volume. + num_frames: The number of frames for the simulation. + fluence_per_frame: The fluence per frame. + beam_energy_kev: The beam energy in keV. + dose_weighting: Whether to apply dose weighting. + dose_B: The B factor for dose weighting. + apply_dqe: Whether to apply DQE. + mtf_filename: The filename of the MTF file. + b_scaling: The B scaling factor. + added_B: The added B factor. + upsampling: The upsampling factor. + n_cpu_cores: The number of CPU cores. + gpu_ids: The list of GPU IDs. + num_gpus: The number of GPUs. + modify_signal: The signal modification factor. + + Returns + ------- + None + """ + # This is the main program + start_time = time.time() + # Get the wavelength from the beam energy + wavelength_A = ( + calculate_relativistic_electron_wavelength(beam_energy_kev * 1000) * 1e10 + ) + # Get lead term, call it something better and move it out elsewhere + lead_term = BOND_SCALING_FACTOR * wavelength_A / 8.0 / (sim_pixel_spacing**2) + # get the scattering parameters + scattering_params_a, scattering_params_b = get_scattering_parameters() + + # It is called by ttsim3d.py with all the inputs + + # Select devices + if gpu_ids == [-999]: # Special case for CPU-only + devices = [torch.device("cpu")] + else: + devices = select_gpus(gpu_ids, num_gpus) + if devices[0].type == "cpu": + if n_cpu_cores == -1: + n_cpu_cores = get_cpu_cores() + print(f"Using devices: {[str(device) for device in devices]}") + + # Then load pdb (a separate file) and get non-H atom + # list with zyx coords and isotropic b factors + atoms_zyx, atoms_id, atoms_b_factor = load_model(pdb_filename) + atoms_zyx_filtered, atoms_id_filtered, atoms_b_factor_filtered = remove_hydrogens( + atoms_zyx, atoms_id, atoms_b_factor + ) + # Scale the B-factors (now doing it after filtered unlike before) + # the 0.25 is strange but keeping like cisTEM for now + atoms_b_factor_scaled = 0.25 * (atoms_b_factor_filtered * b_scaling + added_B) + mean_b_factor = torch.mean(atoms_b_factor_scaled) + # Get the B parameter for each atom plus scattering parameter B + total_b_param = get_total_b_param( + scattering_params_b, atoms_id_filtered, atoms_b_factor_scaled + ) + + # Set up the simulation volume - push this out into a separate file grid_coords.py + # Start with upsampling to improve accuracy + upsampling = ( + get_upsampling(sim_pixel_spacing, sim_volume_shape[0], max_size=1536) + if upsampling == -1 + else upsampling + ) + upsampled_pixel_size = sim_pixel_spacing / upsampling + upsampled_shape = tuple(np.array(sim_volume_shape) * upsampling) + # Get the centre if the upsampled volume + origin_idx = ( + upsampled_shape[0] / 2, + upsampled_shape[1] / 2, + upsampled_shape[2] / 2, + ) + # Get the size of the voxel neighbourhood to calculate the potential of each atom + size_neighborhood = get_size_neighborhood_cistem( + mean_b_factor, upsampled_pixel_size + ) + neighborhood_range = torch.arange(-size_neighborhood, size_neighborhood + 1) + # Create coordinate grids for the neighborhood + sz, sy, sx = torch.meshgrid( + neighborhood_range, neighborhood_range, neighborhood_range, indexing="ij" + ) + voxel_offsets = torch.stack([sz, sy, sx]) # (3, n, n, n) + # Flatten while preserving the relative positions + voxel_offsets_flat = voxel_offsets.reshape(3, -1).T # (n^3, 3) + # Calculate the pixel coordinates of each atom + this_coords = ( + (atoms_zyx_filtered / upsampled_pixel_size) + + torch.tensor(origin_idx).unsqueeze(0) + + PIXEL_OFFSET + ) + atom_indices = torch.floor(this_coords) # these are the voxel indices + atom_dds = ( + this_coords - atom_indices - PIXEL_OFFSET + ) # this is offset from the edge of the voxel + + # Now divide into chunks for parallel processing + + # atoms_per_device = len(atoms_id_filtered) // num_devices + device_outputs = [] + # device_args = [] + if devices[0].type == "cpu": + # If CPU only, use the original parallel processing directly + volume_grid = process_atoms_parallel( + atom_indices=atom_indices, + atom_dds=atom_dds, + bPlusB=total_b_param, + scattering_params_a=scattering_params_a, + atoms_id_filtered=atoms_id_filtered, + voxel_offsets_flat=voxel_offsets_flat, + upsampled_shape=upsampled_shape, + upsampled_pixel_size=upsampled_pixel_size, + lead_term=lead_term, + n_cores=n_cpu_cores, + ) + device_outputs = [volume_grid] + else: + print("Not done this yet!") + + # Combine results from all devices + main_device = devices[0] + final_volume = torch.zeros(upsampled_shape, device=main_device) + for volume in device_outputs: + final_volume += volume.to(main_device) + + # Convert to Fourier space for filtering + final_volume = torch.fft.fftshift(final_volume, dim=(-3, -2, -1)) + final_volume_FFT = torch.fft.rfftn(final_volume, dim=(-3, -2, -1)) + # Dose weight + if dose_weighting: + dose_filter = cumulative_dose_filter_3d( + volume_shape=final_volume.shape, + num_frames=num_frames, + start_exposure=0, + pixel_size=upsampled_pixel_size, + flux=fluence_per_frame, + Bfac=dose_B, + rfft=True, + fftshift=False, + ) + if modify_signal == 1: + # Add small epsilon to prevent division by zero + denominator = 1 + dose_filter + epsilon = 1e-10 + denominator = torch.clamp(denominator, min=epsilon) + modification = 1 - (1 - dose_filter) / denominator + + # Check for invalid values + if torch.any(torch.isnan(modification)): + print("Warning: NaN values in modification factor") + modification = torch.nan_to_num(modification, nan=1.0) + + final_volume_FFT *= modification + elif modify_signal == 2: + final_volume_FFT *= dose_filter**0.5 + else: + final_volume_FFT *= dose_filter + # apply dqe + # I should really apply the mtf after Fourier cropping + # cisTEM does it before + """ + if apply_dqe: + mtf_frequencies, mtf_amplitudes = read_mtf( + file_path=mtf_filename + ) + mtf = make_mtf_grid( + image_shape=final_volume.shape, + mtf_frequencies=mtf_frequencies, #1D tensor + mtf_amplitudes=mtf_amplitudes, #1D tensor + rfft=True, + fftshift=False, + ) + final_volume_FFT *= mtf + """ + # fourier crop back to desired output size + + if upsampling > 1: + final_volume_FFT = fourier_rescale_3d_force_size( + volume_fft=final_volume_FFT, + volume_shape=final_volume.shape, + target_size=sim_volume_shape[0], + rfft=True, + fftshift=False, + ) + + if apply_dqe: + """ + mtf = get_dqe_parameterized( + image_shape=final_volume.shape, + pixel_size=upsampled_pixel_size, + rfft=True, + fftshift=False, + ) + """ + mtf_frequencies, mtf_amplitudes = read_mtf(file_path=mtf_filename) + mtf = make_mtf_grid( + image_shape=sim_volume_shape, + mtf_frequencies=mtf_frequencies, # 1D tensor + mtf_amplitudes=mtf_amplitudes, # 1D tensor + rfft=True, + fftshift=False, + ) + final_volume_FFT *= mtf + + # inverse FFT + cropped_volume = torch.fft.irfftn( + final_volume_FFT, + s=(sim_volume_shape[0], sim_volume_shape[0], sim_volume_shape[0]), + dim=(-3, -2, -1), + ) + cropped_volume = torch.fft.ifftshift(cropped_volume, dim=(-3, -2, -1)) + + # Write now for testing + with mrcfile.new(output_filename, overwrite=True) as mrc: + mrc.set_data(cropped_volume.cpu().numpy()) + mrc.voxel_size = (sim_pixel_spacing, sim_pixel_spacing, sim_pixel_spacing) + # Populate more of the metadata... + + end_time = time.time() + elapsed_time = end_time - start_time + minutes = int(elapsed_time // 60) + seconds = int(elapsed_time % 60) + print(f"Total simulation time: {minutes} minutes {seconds} seconds") diff --git a/src/ttsim3d/test_code.txt b/src/ttsim3d/test_code.txt new file mode 100644 index 0000000..d78578a --- /dev/null +++ b/src/ttsim3d/test_code.txt @@ -0,0 +1,103 @@ +from torch.utils.data import Dataset + +class AtomDataset(Dataset): + def __init__(self, atom_indices, atom_dds, bPlusB, atoms_id_filtered, voxel_offsets_flat, + upsampled_shape, upsampled_pixel_size, lead_term, scattering_params_a): + self.atom_indices = atom_indices + self.atom_dds = atom_dds + self.bPlusB = bPlusB + self.atoms_id_filtered = atoms_id_filtered + self.voxel_offsets_flat = voxel_offsets_flat + self.upsampled_shape = upsampled_shape + self.upsampled_pixel_size = upsampled_pixel_size + self.lead_term = lead_term + self.scattering_params_a = scattering_params_a + + def __len__(self): + return len(self.atom_indices) + + def __getitem__(self, idx): + return (self.atom_indices[idx], self.atom_dds[idx], self.bPlusB[idx], self.atoms_id_filtered[idx], + self.voxel_offsets_flat, self.upsampled_shape, self.upsampled_pixel_size, self.lead_term, + self.scattering_params_a) + + +def process_atom_batch(batch_args): + try: + (atom_indices_batch, atom_dds_batch, bPlusB_batch, atoms_id_filtered_batch, + voxel_offsets_flat, upsampled_shape, upsampled_pixel_size, lead_term, + scattering_params_a) = batch_args + + # Move tensors to the device (GPU or CPU) + device = atom_indices_batch.device + local_volume = torch.zeros(upsampled_shape, device=device) + + voxel_positions = (atom_indices_batch.unsqueeze(1) + voxel_offsets_flat) # (batch_size, n^3, 3) + valid_mask = (voxel_positions >= 0) & (voxel_positions < torch.tensor(upsampled_shape, device=device)) + + # Apply valid_mask for each dimension and calculate potentials + valid_positions = voxel_positions[valid_mask] + relative_coords = ((valid_positions - atom_indices_batch.unsqueeze(1) - atom_dds_batch.unsqueeze(1) + - PIXEL_OFFSET) * upsampled_pixel_size) + + coords1 = relative_coords + coords2 = relative_coords + upsampled_pixel_size + + potentials = get_scattering_potential_of_voxel( + zyx_coords1=coords1, + zyx_coords2=coords2, + bPlusB=bPlusB_batch, + atom_id=atoms_id_filtered_batch, + lead_term=lead_term, + scattering_params_a=scattering_params_a + ) + + local_volume[valid_positions[:, 0], valid_positions[:, 1], valid_positions[:, 2]] += potentials + except Exception as e: + print(f"Error in process_atom_batch: {str(e)}") + raise e + + return local_volume + + +from torch.utils.data import DataLoader + +def process_atoms_dataloader(atom_indices, atom_dds, bPlusB, atoms_id_filtered, voxel_offsets_flat, + upsampled_shape, upsampled_pixel_size, lead_term, scattering_params_a, + batch_size, device): + dataset = AtomDataset(atom_indices, atom_dds, bPlusB, atoms_id_filtered, + voxel_offsets_flat, upsampled_shape, upsampled_pixel_size, lead_term, scattering_params_a) + + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + + final_volume = torch.zeros(upsampled_shape, device=device) + + for batch in dataloader: + # Move batch data to the GPU + batch = [item.to(device) if isinstance(item, torch.Tensor) else item for item in batch] + local_volume = process_atom_batch(batch) + final_volume += local_volume + + return final_volume + + +def simulate3d(...): + # Existing setup code here... + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Move atom data to the selected device + atom_indices = atom_indices.to(device) + atom_dds = atom_dds.to(device) + bPlusB = total_b_param.to(device) + + # Process with DataLoader + batch_size = 256 # Adjust based on GPU memory + final_volume = process_atoms_dataloader( + atom_indices, atom_dds, bPlusB, atoms_id_filtered, voxel_offsets_flat, + upsampled_shape, upsampled_pixel_size, lead_term, scattering_params_a, + batch_size, device + ) + + # Post-processing...