From 936fc025ac24cca462d6cf5d3e6e0e563d0eb713 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Thu, 28 Mar 2024 12:26:08 +0100 Subject: [PATCH 01/48] new mdcath dataset --- torchmdnet/datasets/mdcath.py | 173 ++++++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 torchmdnet/datasets/mdcath.py diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py new file mode 100644 index 00000000..7a2c5441 --- /dev/null +++ b/torchmdnet/datasets/mdcath.py @@ -0,0 +1,173 @@ +# Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org +# Distributed under the MIT License. +# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) + +import os +from os.path import join as opj +import h5py +import torch +from tqdm import tqdm +import math +import numpy as np +from torch_geometric.data import Dataset, download_url, Data + + +class mdCATH(Dataset): + def __init__( + self, + root, + transform=None, + pre_transform=None, + pre_filter=None, + preload_dataset_limit=None, + numAtoms=5000, + numNoHAtoms=None, + numResidues=1000, + temperatures=["348"], + skipFrames=1, + pdb_list=None, + min_gyration_radius=None, + max_gyration_radius=None, + alpha_beta_coil=None, + numFrames=None, + ): + """ mdCATH dataset class for PyTorch Geometric to load protein structures and dynamics from the mdCATH dataset. + + Parameters: + ----------- + root: str + Root directory where the dataset should be stored. Data will be downloaded to 'root/'. + numAtoms: int + Max number of atoms in the protein structure. + numNoHAtoms: int + Max number of non-hydrogen atoms in the protein structure. + numResidues: int + Max number of residues in the protein structure. + temperatures: list + List of temperatures (in Kelvin) to download. Default is ["348"]. Available temperatures are ['320', '348', '379', '413', '450'] + skipFrames: int + Number of frames to skip in the trajectory. Default is 1. + pdb_list: list + List of PDB IDs to download. If None, all available PDB IDs from 'mdcath_source.h5' will be downloaded. + min_gyration_radius: float + Minimum gyration radius (in nm) of the protein structure. Default is None. + max_gyration_radius: float + Maximum gyration radius (in nm) of the protein structure. Default is None. + alpha_beta_coil: tuple + Tuple with the minimum percentage of alpha-helix, beta-sheet and coil residues in the protein structure. Default is None. + numFrames: int + Minimum number of frames in the trajectory in order to be considered. Default is None. + """ + + self.url = "https://zenodo.org/record//files/" + self.preload_dataset_limit = preload_dataset_limit + super(mdCATH, self).__init__(root, transform, pre_transform, pre_filter) + + self.numAtoms = numAtoms + self.numNoHAtoms = numNoHAtoms + self.numResidues = numResidues + self.temperatures = temperatures + self.skipFrames = skipFrames + self.pdb_list = pdb_list + self.min_gyration_radius = min_gyration_radius + self.max_gyration_radius = max_gyration_radius + self.alpha_beta_coil = alpha_beta_coil + self.numFrames = numFrames + self.idx = None + self.process_data_source() + print(f"Total number of domains: {len(self.to_download.keys())}") + print(f"Total number of conformers: {self.num_conformers}") + + @property + def raw_file_names(self): + # Check if the dataset has been processed, and if not, return the original source file + if not hasattr(self, 'to_download'): + return ['mdCATH_source.h5'] + # Otherwise, return the list of HDF5 files that passed the filtering criteria + return [f"cath_dataset_{pdb_id}.h5" for pdb_id in self.to_download.keys()] + + def download(self): + if not hasattr(self, 'to_download') or not self.to_download: + download_url(opj(self.url, 'mdCATH_source.h5'), self.root) + return + for pdb_id in self.to_download.keys(): + download_url(opj(self.url, f"cath_dataset_{pdb_id}.h5"), self.root) + + def process_data_source(self): + print("Processing mdCATH source") + data_info_path = opj(self.root, 'mdCATH_source.h5') + if not os.path.exists(data_info_path): + self.download() + # the to_downlaod is the dictionary that will store the pdb ids and the corresponding temp and replica ids if they pass the filter + self.to_download = {} + self.num_conformers = 0 + with h5py.File(data_info_path, 'r') as f: + domains = f.keys() if self.pdb_list is None else self.pdb_list + for pdb in tqdm(domains, total=len(domains), desc="Processing mdCATH source"): + pdb_group = f[pdb] + if pdb_group.attrs['numProteinAtoms'] > self.numAtoms: + continue + if pdb_group.attrs['numResidues'] > self.numResidues: + continue + if self.numNoHAtoms is not None and pdb_group.attrs['numNoHAtoms'] > self.numNoHAtoms: + continue + for temp in self.temperatures: + if temp not in pdb_group.keys(): + continue + for replica in pdb_group[temp].keys(): + if self.numFrames is not None and pdb_group[temp][replica].attrs['numFrames'] < self.numFrames: + continue + if self.min_gyration_radius is not None and pdb_group[temp][replica].attrs['min_gyration_radius'] < self.min_gyration_radius: + continue + if self.max_gyration_radius is not None and pdb_group[temp][replica].attrs['max_gyration_radius'] > self.max_gyration_radius: + continue + if self.alpha_beta_coil is not None: + alpha = pdb_group[temp][replica].attrs['alpha'] + beta = pdb_group[temp][replica].attrs['beta'] + coil = pdb_group[temp][replica].attrs['coil'] + if not np.isclose([alpha, beta, coil], list(self.alpha_beta_coil)).all(): + continue + if pdb not in self.to_download: + self.to_download[pdb] = [] + self.to_download[pdb].append((temp, replica)) + # append the number of frames of the trajectory to the total number of molecules + self.num_conformers += math.ceil(pdb_group[temp][replica].attrs['numFrames'] / self.skipFrames) + + def len(self): + return self.num_conformers + + def process_specific_group(self, pdb, file, group_info): + with h5py.File(file, 'r') as f: + z = f[pdb]['z'][()] + group = f[pdb][f'sims{group_info[0]}K'][group_info[1]] + # if any assertion fail print the same message + # coords and forces shape (num_frames, num_atoms, 3) + + assert group['coords'].shape[0] == group['forces'].shape[0], f"Number of frames mismatch between coords and forces: {group['coords'].shape[0]} vs {group['forces'].shape[0]}" + assert group['coords'].shape[1] == z.shape[0], f"Number of atoms mismatch between coords and z: {group['coords'].shape[1]} vs {z.shape[0]}" + assert group['forces'].shape[1] == z.shape[0], f"Number of atoms mismatch between forces and z: {group['forces'].shape[1]} vs {z.shape[0]}" + assert group['coords'].attrs['unit'] == 'Angstrom', f"Coords unit is not Angstrom: {group['coords'].attrs['unit']}" + assert group['forces'].attrs['unit'] == 'kcal/mol/Angstrom', f"Forces unit is not kcal/mol/Angstrom: {group['forces'].attrs['unit']}" + + coords = torch.tensor(group['coords'][()])[::self.skipFrames, :, :] + forces = torch.tensor(group['forces'][()])[::self.skipFrames, :, :] + z = torch.tensor(z) + return Data(pos=coords, neg_dy=forces, z=z) + + def _setup_idx(self): + files = [opj(self.root,f"cath_dataset_{pdb_id}.h5") for pdb_id in self.to_download.keys()] + self.idx = [] + for i, (pdb, group_info) in enumerate(self.to_download.items()): + for temp, replica in group_info: + data = self.process_specific_group(pdb, files[i], (temp, replica)) + self.idx.extend([(data.pos[frame], data.neg_dy[frame], data.z) for frame in range(data.pos.shape[0])]) + assert len(self.idx) == self.num_conformers, f"Mismatch between number of conformers and idxs: {self.num_conformers} vs {len(self.idx)}" + + def get(self, idx): + data = Data() + if self.idx is None: + self._setup_idx() + data.pos = self.idx[idx][0] + data.neg_dy = self.idx[idx][1] + data.z = self.idx[idx][2] + return data \ No newline at end of file From 05d298cc7e3a5b1eb5c0017edf80b3858c577b77 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Thu, 28 Mar 2024 12:26:19 +0100 Subject: [PATCH 02/48] add mdcath init --- torchmdnet/datasets/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchmdnet/datasets/__init__.py b/torchmdnet/datasets/__init__.py index b57cd95a..2f15dd13 100644 --- a/torchmdnet/datasets/__init__.py +++ b/torchmdnet/datasets/__init__.py @@ -14,6 +14,7 @@ COMP6v1, COMP6v2, ) +from .mdcath import mdCATH from .custom import Custom from .water import WaterBox from .hdf import HDF5 @@ -39,6 +40,7 @@ "GDB10to13", "GenentechTorsions", "HDF5", + "mdCATH", "MD17", "MD22", "QM9", From 0dd872f466cfc89d82f18055e68e799c05ad4c5f Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Thu, 28 Mar 2024 13:54:52 +0100 Subject: [PATCH 03/48] improve memory usage avoiding to run out of memory --- torchmdnet/datasets/mdcath.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 7a2c5441..a999449a 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -140,19 +140,18 @@ def process_specific_group(self, pdb, file, group_info): with h5py.File(file, 'r') as f: z = f[pdb]['z'][()] group = f[pdb][f'sims{group_info[0]}K'][group_info[1]] + coords = group['coords'][()][::self.skipFrames, :, :] + forces = group['forces'][()][::self.skipFrames, :, :] # if any assertion fail print the same message # coords and forces shape (num_frames, num_atoms, 3) - assert group['coords'].shape[0] == group['forces'].shape[0], f"Number of frames mismatch between coords and forces: {group['coords'].shape[0]} vs {group['forces'].shape[0]}" - assert group['coords'].shape[1] == z.shape[0], f"Number of atoms mismatch between coords and z: {group['coords'].shape[1]} vs {z.shape[0]}" - assert group['forces'].shape[1] == z.shape[0], f"Number of atoms mismatch between forces and z: {group['forces'].shape[1]} vs {z.shape[0]}" + assert coords.shape[0] == forces.shape[0], f"Number of frames mismatch between coords and forces: {group['coords'].shape[0]} vs {group['forces'].shape[0]}" + assert coords.shape[1] == z.shape[0], f"Number of atoms mismatch between coords and z: {group['coords'].shape[1]} vs {z.shape[0]}" + assert forces.shape[1] == z.shape[0], f"Number of atoms mismatch between forces and z: {group['forces'].shape[1]} vs {z.shape[0]}" assert group['coords'].attrs['unit'] == 'Angstrom', f"Coords unit is not Angstrom: {group['coords'].attrs['unit']}" assert group['forces'].attrs['unit'] == 'kcal/mol/Angstrom', f"Forces unit is not kcal/mol/Angstrom: {group['forces'].attrs['unit']}" - coords = torch.tensor(group['coords'][()])[::self.skipFrames, :, :] - forces = torch.tensor(group['forces'][()])[::self.skipFrames, :, :] - z = torch.tensor(z) - return Data(pos=coords, neg_dy=forces, z=z) + return [z, coords, forces] def _setup_idx(self): files = [opj(self.root,f"cath_dataset_{pdb_id}.h5") for pdb_id in self.to_download.keys()] @@ -160,14 +159,17 @@ def _setup_idx(self): for i, (pdb, group_info) in enumerate(self.to_download.items()): for temp, replica in group_info: data = self.process_specific_group(pdb, files[i], (temp, replica)) - self.idx.extend([(data.pos[frame], data.neg_dy[frame], data.z) for frame in range(data.pos.shape[0])]) + conformer_indices = range(data[1].shape[0]) + self.idx.extend([tuple([data[0], data[1][j], data[2][j], [j]]) for j in conformer_indices]) assert len(self.idx) == self.num_conformers, f"Mismatch between number of conformers and idxs: {self.num_conformers} vs {len(self.idx)}" def get(self, idx): data = Data() if self.idx is None: self._setup_idx() - data.pos = self.idx[idx][0] - data.neg_dy = self.idx[idx][1] - data.z = self.idx[idx][2] + *fields_data, i = self.idx[idx] + print(fields_data) + data.z = torch.tensor(fields_data[0][i], dtype=torch.long) + data.pos = torch.tensor(fields_data[1][i], dtype=torch.float32) + data.neg_dy = torch.tensor(fields_data[2][i], dtype=torch.float32) return data \ No newline at end of file From 6a75b021869dec837a93dc6211c7fe9be57bb2a2 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Thu, 28 Mar 2024 13:56:00 +0100 Subject: [PATCH 04/48] remove debug print --- torchmdnet/datasets/mdcath.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index a999449a..b52a023b 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -168,7 +168,6 @@ def get(self, idx): if self.idx is None: self._setup_idx() *fields_data, i = self.idx[idx] - print(fields_data) data.z = torch.tensor(fields_data[0][i], dtype=torch.long) data.pos = torch.tensor(fields_data[1][i], dtype=torch.float32) data.neg_dy = torch.tensor(fields_data[2][i], dtype=torch.float32) From bec8144be1c0384bd32cb6b1bd8f4cbef2aec23c Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Thu, 28 Mar 2024 14:06:23 +0100 Subject: [PATCH 05/48] to black --- torchmdnet/datasets/mdcath.py | 132 ++++++++++++++++++++++------------ 1 file changed, 86 insertions(+), 46 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index b52a023b..e560e6f3 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -31,8 +31,8 @@ def __init__( alpha_beta_coil=None, numFrames=None, ): - """ mdCATH dataset class for PyTorch Geometric to load protein structures and dynamics from the mdCATH dataset. - + """mdCATH dataset class for PyTorch Geometric to load protein structures and dynamics from the mdCATH dataset. + Parameters: ----------- root: str @@ -58,11 +58,11 @@ def __init__( numFrames: int Minimum number of frames in the trajectory in order to be considered. Default is None. """ - + self.url = "https://zenodo.org/record//files/" self.preload_dataset_limit = preload_dataset_limit super(mdCATH, self).__init__(root, transform, pre_transform, pre_filter) - + self.numAtoms = numAtoms self.numNoHAtoms = numNoHAtoms self.numResidues = numResidues @@ -77,92 +77,132 @@ def __init__( self.process_data_source() print(f"Total number of domains: {len(self.to_download.keys())}") print(f"Total number of conformers: {self.num_conformers}") - + @property def raw_file_names(self): # Check if the dataset has been processed, and if not, return the original source file - if not hasattr(self, 'to_download'): - return ['mdCATH_source.h5'] + if not hasattr(self, "to_download"): + return ["mdCATH_source.h5"] # Otherwise, return the list of HDF5 files that passed the filtering criteria return [f"cath_dataset_{pdb_id}.h5" for pdb_id in self.to_download.keys()] def download(self): - if not hasattr(self, 'to_download') or not self.to_download: - download_url(opj(self.url, 'mdCATH_source.h5'), self.root) - return + if not hasattr(self, "to_download") or not self.to_download: + download_url(opj(self.url, "mdCATH_source.h5"), self.root) + return for pdb_id in self.to_download.keys(): download_url(opj(self.url, f"cath_dataset_{pdb_id}.h5"), self.root) - + def process_data_source(self): print("Processing mdCATH source") - data_info_path = opj(self.root, 'mdCATH_source.h5') + data_info_path = opj(self.root, "mdCATH_source.h5") if not os.path.exists(data_info_path): self.download() # the to_downlaod is the dictionary that will store the pdb ids and the corresponding temp and replica ids if they pass the filter self.to_download = {} self.num_conformers = 0 - with h5py.File(data_info_path, 'r') as f: + with h5py.File(data_info_path, "r") as f: domains = f.keys() if self.pdb_list is None else self.pdb_list - for pdb in tqdm(domains, total=len(domains), desc="Processing mdCATH source"): + for pdb in tqdm( + domains, total=len(domains), desc="Processing mdCATH source" + ): pdb_group = f[pdb] - if pdb_group.attrs['numProteinAtoms'] > self.numAtoms: + if pdb_group.attrs["numProteinAtoms"] > self.numAtoms: continue - if pdb_group.attrs['numResidues'] > self.numResidues: + if pdb_group.attrs["numResidues"] > self.numResidues: continue - if self.numNoHAtoms is not None and pdb_group.attrs['numNoHAtoms'] > self.numNoHAtoms: + if ( + self.numNoHAtoms is not None + and pdb_group.attrs["numNoHAtoms"] > self.numNoHAtoms + ): continue for temp in self.temperatures: if temp not in pdb_group.keys(): continue for replica in pdb_group[temp].keys(): - if self.numFrames is not None and pdb_group[temp][replica].attrs['numFrames'] < self.numFrames: + if ( + self.numFrames is not None + and pdb_group[temp][replica].attrs["numFrames"] < self.numFrames + ): continue - if self.min_gyration_radius is not None and pdb_group[temp][replica].attrs['min_gyration_radius'] < self.min_gyration_radius: + if ( + self.min_gyration_radius is not None + and pdb_group[temp][replica].attrs["min_gyration_radius"] + < self.min_gyration_radius + ): continue - if self.max_gyration_radius is not None and pdb_group[temp][replica].attrs['max_gyration_radius'] > self.max_gyration_radius: + if ( + self.max_gyration_radius is not None + and pdb_group[temp][replica].attrs["max_gyration_radius"] + > self.max_gyration_radius + ): continue if self.alpha_beta_coil is not None: - alpha = pdb_group[temp][replica].attrs['alpha'] - beta = pdb_group[temp][replica].attrs['beta'] - coil = pdb_group[temp][replica].attrs['coil'] - if not np.isclose([alpha, beta, coil], list(self.alpha_beta_coil)).all(): + alpha = pdb_group[temp][replica].attrs["alpha"] + beta = pdb_group[temp][replica].attrs["beta"] + coil = pdb_group[temp][replica].attrs["coil"] + if not np.isclose( + [alpha, beta, coil], list(self.alpha_beta_coil) + ).all(): continue if pdb not in self.to_download: self.to_download[pdb] = [] self.to_download[pdb].append((temp, replica)) # append the number of frames of the trajectory to the total number of molecules - self.num_conformers += math.ceil(pdb_group[temp][replica].attrs['numFrames'] / self.skipFrames) - + self.num_conformers += math.ceil( + pdb_group[temp][replica].attrs["numFrames"] / self.skipFrames + ) + def len(self): return self.num_conformers - + def process_specific_group(self, pdb, file, group_info): - with h5py.File(file, 'r') as f: - z = f[pdb]['z'][()] - group = f[pdb][f'sims{group_info[0]}K'][group_info[1]] - coords = group['coords'][()][::self.skipFrames, :, :] - forces = group['forces'][()][::self.skipFrames, :, :] - # if any assertion fail print the same message + with h5py.File(file, "r") as f: + z = f[pdb]["z"][()] + group = f[pdb][f"sims{group_info[0]}K"][group_info[1]] + coords = group["coords"][()][:: self.skipFrames, :, :] + forces = group["forces"][()][:: self.skipFrames, :, :] + # if any assertion fail print the same message # coords and forces shape (num_frames, num_atoms, 3) - - assert coords.shape[0] == forces.shape[0], f"Number of frames mismatch between coords and forces: {group['coords'].shape[0]} vs {group['forces'].shape[0]}" - assert coords.shape[1] == z.shape[0], f"Number of atoms mismatch between coords and z: {group['coords'].shape[1]} vs {z.shape[0]}" - assert forces.shape[1] == z.shape[0], f"Number of atoms mismatch between forces and z: {group['forces'].shape[1]} vs {z.shape[0]}" - assert group['coords'].attrs['unit'] == 'Angstrom', f"Coords unit is not Angstrom: {group['coords'].attrs['unit']}" - assert group['forces'].attrs['unit'] == 'kcal/mol/Angstrom', f"Forces unit is not kcal/mol/Angstrom: {group['forces'].attrs['unit']}" - + + assert ( + coords.shape[0] == forces.shape[0] + ), f"Number of frames mismatch between coords and forces: {group['coords'].shape[0]} vs {group['forces'].shape[0]}" + assert ( + coords.shape[1] == z.shape[0] + ), f"Number of atoms mismatch between coords and z: {group['coords'].shape[1]} vs {z.shape[0]}" + assert ( + forces.shape[1] == z.shape[0] + ), f"Number of atoms mismatch between forces and z: {group['forces'].shape[1]} vs {z.shape[0]}" + assert ( + group["coords"].attrs["unit"] == "Angstrom" + ), f"Coords unit is not Angstrom: {group['coords'].attrs['unit']}" + assert ( + group["forces"].attrs["unit"] == "kcal/mol/Angstrom" + ), f"Forces unit is not kcal/mol/Angstrom: {group['forces'].attrs['unit']}" + return [z, coords, forces] - + def _setup_idx(self): - files = [opj(self.root,f"cath_dataset_{pdb_id}.h5") for pdb_id in self.to_download.keys()] + files = [ + opj(self.root, f"cath_dataset_{pdb_id}.h5") + for pdb_id in self.to_download.keys() + ] self.idx = [] for i, (pdb, group_info) in enumerate(self.to_download.items()): for temp, replica in group_info: data = self.process_specific_group(pdb, files[i], (temp, replica)) conformer_indices = range(data[1].shape[0]) - self.idx.extend([tuple([data[0], data[1][j], data[2][j], [j]]) for j in conformer_indices]) - assert len(self.idx) == self.num_conformers, f"Mismatch between number of conformers and idxs: {self.num_conformers} vs {len(self.idx)}" - + self.idx.extend( + [ + tuple([data[0], data[1][j], data[2][j], [j]]) + for j in conformer_indices + ] + ) + assert ( + len(self.idx) == self.num_conformers + ), f"Mismatch between number of conformers and idxs: {self.num_conformers} vs {len(self.idx)}" + def get(self, idx): data = Data() if self.idx is None: @@ -171,4 +211,4 @@ def get(self, idx): data.z = torch.tensor(fields_data[0][i], dtype=torch.long) data.pos = torch.tensor(fields_data[1][i], dtype=torch.float32) data.neg_dy = torch.tensor(fields_data[2][i], dtype=torch.float32) - return data \ No newline at end of file + return data From 00380dd43638bca373019a2f30684e2a0cfd6095 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Thu, 4 Apr 2024 10:07:23 +0200 Subject: [PATCH 06/48] change raw dir to self.root --- torchmdnet/datasets/mdcath.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index e560e6f3..f4a035f3 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -86,6 +86,11 @@ def raw_file_names(self): # Otherwise, return the list of HDF5 files that passed the filtering criteria return [f"cath_dataset_{pdb_id}.h5" for pdb_id in self.to_download.keys()] + @property + def raw_dir(self): + # Override the raw_dir property to prevent the creation of a 'raw' directory + # The files will be downloaded to the root directory + return self.root def download(self): if not hasattr(self, "to_download") or not self.to_download: download_url(opj(self.url, "mdCATH_source.h5"), self.root) From 153c2b1aeb788eb8bef3826c042534502f2d65b1 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Thu, 4 Apr 2024 10:08:43 +0200 Subject: [PATCH 07/48] add solid_ss to filter while processing --- torchmdnet/datasets/mdcath.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index f4a035f3..29869a5f 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -29,6 +29,7 @@ def __init__( min_gyration_radius=None, max_gyration_radius=None, alpha_beta_coil=None, + solid_ss = None, numFrames=None, ): """mdCATH dataset class for PyTorch Geometric to load protein structures and dynamics from the mdCATH dataset. @@ -142,14 +143,17 @@ def process_data_source(self): > self.max_gyration_radius ): continue - if self.alpha_beta_coil is not None: + if self.alpha_beta_coil is not None or self.solid_ss is not None: alpha = pdb_group[temp][replica].attrs["alpha"] beta = pdb_group[temp][replica].attrs["beta"] coil = pdb_group[temp][replica].attrs["coil"] - if not np.isclose( - [alpha, beta, coil], list(self.alpha_beta_coil) - ).all(): - continue + solid_ss = (alpha + beta) / pdb_group.attrs["numResidues"] * 100 + if self.solid_ss is not None: + if solid_ss < self.solid_ss: + continue + else: + if not np.isclose([alpha, beta, coil], list(self.alpha_beta_coil)).all(): + continue if pdb not in self.to_download: self.to_download[pdb] = [] self.to_download[pdb].append((temp, replica)) From 0b4ab6780af40f4421aa2986c4ab1ad749953287 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Thu, 4 Apr 2024 11:53:32 +0200 Subject: [PATCH 08/48] compute dataset size --- torchmdnet/datasets/mdcath.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 29869a5f..9088cdf8 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -76,8 +76,12 @@ def __init__( self.numFrames = numFrames self.idx = None self.process_data_source() + # Calculate the total size of the dataset in MB + self.total_size_mb = self.calculate_dataset_size() + print(f"Total number of domains: {len(self.to_download.keys())}") print(f"Total number of conformers: {self.num_conformers}") + print(f"Total size of dataset: {self.total_size_mb} MB") @property def raw_file_names(self): @@ -97,8 +101,13 @@ def download(self): download_url(opj(self.url, "mdCATH_source.h5"), self.root) return for pdb_id in self.to_download.keys(): - download_url(opj(self.url, f"cath_dataset_{pdb_id}.h5"), self.root) - + def calculate_dataset_size(self): + total_size_bytes = 0 + for pdb_id in self.to_download.keys(): + file_name = f"cath_noh_dataset_{pdb_id}.h5" if self.noh_mode else f"cath_dataset_{pdb_id}.h5" + total_size_bytes += os.path.getsize(opj(self.root, file_name)) + total_size_mb = round(total_size_bytes / (1024 * 1024), 4) + return total_size_mb def process_data_source(self): print("Processing mdCATH source") data_info_path = opj(self.root, "mdCATH_source.h5") From 848212711295d2ad778896525e26dc6bf9bc16be Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Thu, 4 Apr 2024 11:54:04 +0200 Subject: [PATCH 09/48] add solid_ss to documentation --- torchmdnet/datasets/mdcath.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 9088cdf8..12cbe2ce 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -56,6 +56,8 @@ def __init__( Maximum gyration radius (in nm) of the protein structure. Default is None. alpha_beta_coil: tuple Tuple with the minimum percentage of alpha-helix, beta-sheet and coil residues in the protein structure. Default is None. + solid_ss: float + minimum percentage of solid secondary structure in the protein structure (alpha + beta)/total_residues * 100. Default is None. numFrames: int Minimum number of frames in the trajectory in order to be considered. Default is None. """ From 949a53f63efc9b80e4ba52a7d630f243c899b2e4 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Thu, 4 Apr 2024 11:59:42 +0200 Subject: [PATCH 10/48] fix self.solid_ss --- torchmdnet/datasets/mdcath.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 12cbe2ce..a79fe745 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -76,6 +76,7 @@ def __init__( self.max_gyration_radius = max_gyration_radius self.alpha_beta_coil = alpha_beta_coil self.numFrames = numFrames + self.solid_ss = solid_ss self.idx = None self.process_data_source() # Calculate the total size of the dataset in MB From 788d6fb90514a4f5093afbbe159dad99fb9b3d88 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 5 Apr 2024 17:48:03 +0200 Subject: [PATCH 11/48] fix replica for loop indentation --- torchmdnet/datasets/mdcath.py | 81 +++++++++++++++++++---------------- 1 file changed, 45 insertions(+), 36 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index a79fe745..3ca3c5fd 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -137,42 +137,51 @@ def process_data_source(self): for temp in self.temperatures: if temp not in pdb_group.keys(): continue - for replica in pdb_group[temp].keys(): - if ( - self.numFrames is not None - and pdb_group[temp][replica].attrs["numFrames"] < self.numFrames - ): - continue - if ( - self.min_gyration_radius is not None - and pdb_group[temp][replica].attrs["min_gyration_radius"] - < self.min_gyration_radius - ): - continue - if ( - self.max_gyration_radius is not None - and pdb_group[temp][replica].attrs["max_gyration_radius"] - > self.max_gyration_radius - ): - continue - if self.alpha_beta_coil is not None or self.solid_ss is not None: - alpha = pdb_group[temp][replica].attrs["alpha"] - beta = pdb_group[temp][replica].attrs["beta"] - coil = pdb_group[temp][replica].attrs["coil"] - solid_ss = (alpha + beta) / pdb_group.attrs["numResidues"] * 100 - if self.solid_ss is not None: - if solid_ss < self.solid_ss: - continue - else: - if not np.isclose([alpha, beta, coil], list(self.alpha_beta_coil)).all(): - continue - if pdb not in self.to_download: - self.to_download[pdb] = [] - self.to_download[pdb].append((temp, replica)) - # append the number of frames of the trajectory to the total number of molecules - self.num_conformers += math.ceil( - pdb_group[temp][replica].attrs["numFrames"] / self.skipFrames - ) + for replica in pdb_group[temp].keys(): + if ( + self.numFrames is not None + and pdb_group[temp][replica].attrs["numFrames"] + < self.numFrames + ): + continue + if ( + self.min_gyration_radius is not None + and pdb_group[temp][replica].attrs["min_gyration_radius"] + < self.min_gyration_radius + ): + continue + if ( + self.max_gyration_radius is not None + and pdb_group[temp][replica].attrs["max_gyration_radius"] + > self.max_gyration_radius + ): + continue + if ( + self.alpha_beta_coil is not None + or self.solid_ss is not None + ): + alpha = pdb_group[temp][replica].attrs["alpha"] + beta = pdb_group[temp][replica].attrs["beta"] + coil = pdb_group[temp][replica].attrs["coil"] + solid_ss = ( + (alpha + beta) / pdb_group.attrs["numResidues"] * 100 + ) + if self.solid_ss is not None: + if solid_ss < self.solid_ss: + continue + else: + if not np.isclose( + [alpha, beta, coil], list(self.alpha_beta_coil) + ).all(): + continue + if pdb not in self.to_download: + self.to_download[pdb] = [] + self.to_download[pdb].append((temp, replica)) + # append the number of frames of the trajectory to the total number of molecules + self.num_conformers += math.ceil( + pdb_group[temp][replica].attrs["numFrames"] + / self.skipFrames + ) def len(self): return self.num_conformers From 104dfa992d7ba1c480cbad297f8181b0ab0fd0e9 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 5 Apr 2024 17:53:18 +0200 Subject: [PATCH 12/48] add possibility to get pdb list from file --- torchmdnet/datasets/mdcath.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 3ca3c5fd..652b33e2 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -12,6 +12,21 @@ from torch_geometric.data import Dataset, download_url, Data +def get_pdb_list(pdb_list): + # pdb list could be a list of pdb ids or a file with the pdb ids + if isinstance(pdb_list, list): + return pdb_list + elif isinstance(pdb_list, str): + if os.path.exists(pdb_list): + print(f"Reading PDB list from {pdb_list}") + with open(pdb_list, "r") as f: + return [line.strip() for line in f] + else: + raise FileNotFoundError(f"File {pdb_list} not found") + else: + return None + + class mdCATH(Dataset): def __init__( self, @@ -48,7 +63,7 @@ def __init__( List of temperatures (in Kelvin) to download. Default is ["348"]. Available temperatures are ['320', '348', '379', '413', '450'] skipFrames: int Number of frames to skip in the trajectory. Default is 1. - pdb_list: list + pdb_list: list or str List of PDB IDs to download. If None, all available PDB IDs from 'mdcath_source.h5' will be downloaded. min_gyration_radius: float Minimum gyration radius (in nm) of the protein structure. Default is None. @@ -71,7 +86,7 @@ def __init__( self.numResidues = numResidues self.temperatures = temperatures self.skipFrames = skipFrames - self.pdb_list = pdb_list + self.pdb_list = get_pdb_list(pdb_list) self.min_gyration_radius = min_gyration_radius self.max_gyration_radius = max_gyration_radius self.alpha_beta_coil = alpha_beta_coil From e555277189179ac78e2417c78a7450740a4c040e Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 5 Apr 2024 17:56:57 +0200 Subject: [PATCH 13/48] fix filtering on numResidues and numAtoms, could be None --- torchmdnet/datasets/mdcath.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 652b33e2..a2c9417f 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -140,9 +140,15 @@ def process_data_source(self): domains, total=len(domains), desc="Processing mdCATH source" ): pdb_group = f[pdb] - if pdb_group.attrs["numProteinAtoms"] > self.numAtoms: + if ( + self.numAtoms is not None + and pdb_group.attrs["numProteinAtoms"] > self.numAtoms + ): continue - if pdb_group.attrs["numResidues"] > self.numResidues: + if ( + self.numResidues is not None + and pdb_group.attrs["numResidues"] > self.numResidues + ): continue if ( self.numNoHAtoms is not None From 7652c4c6437118987297dd76712ac4a9e1a2e7b7 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 5 Apr 2024 17:57:33 +0200 Subject: [PATCH 14/48] remove comment --- torchmdnet/datasets/mdcath.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index a2c9417f..b89df5fc 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -213,7 +213,6 @@ def process_specific_group(self, pdb, file, group_info): group = f[pdb][f"sims{group_info[0]}K"][group_info[1]] coords = group["coords"][()][:: self.skipFrames, :, :] forces = group["forces"][()][:: self.skipFrames, :, :] - # if any assertion fail print the same message # coords and forces shape (num_frames, num_atoms, 3) assert ( From c9a19032333ebe9af28b912e34bf98907d2b1d51 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 5 Apr 2024 17:58:31 +0200 Subject: [PATCH 15/48] remove preload dataset feature --- torchmdnet/datasets/mdcath.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index b89df5fc..301f4b43 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -34,7 +34,7 @@ def __init__( transform=None, pre_transform=None, pre_filter=None, - preload_dataset_limit=None, + numAtoms=5000, numNoHAtoms=None, numResidues=1000, @@ -53,6 +53,8 @@ def __init__( ----------- root: str Root directory where the dataset should be stored. Data will be downloaded to 'root/'. + preload_dataset_limit: int + Maximum size of the dataset in MB to load into memory. If the dataset is larger than this limit, a warning will be printed. Default is 1024 MB. numAtoms: int Max number of atoms in the protein structure. numNoHAtoms: int From 2eab7eaf0ceedd0c5fd767a4954f9dc5a9976fb2 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 8 Apr 2024 13:34:57 +0200 Subject: [PATCH 16/48] fix process specific group, return and skipframes --- torchmdnet/datasets/mdcath.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 301f4b43..d06ca4a3 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -211,10 +211,10 @@ def len(self): def process_specific_group(self, pdb, file, group_info): with h5py.File(file, "r") as f: - z = f[pdb]["z"][()] + z = f[pdb]["z"][:] group = f[pdb][f"sims{group_info[0]}K"][group_info[1]] - coords = group["coords"][()][:: self.skipFrames, :, :] - forces = group["forces"][()][:: self.skipFrames, :, :] + coords = group["coords"][::self.skipFrames, :, :] + forces = group["forces"][::self.skipFrames, :, :] # coords and forces shape (num_frames, num_atoms, 3) assert ( @@ -233,7 +233,7 @@ def process_specific_group(self, pdb, file, group_info): group["forces"].attrs["unit"] == "kcal/mol/Angstrom" ), f"Forces unit is not kcal/mol/Angstrom: {group['forces'].attrs['unit']}" - return [z, coords, forces] + return (z, coords, forces) def _setup_idx(self): files = [ From 519ee6a3055635f3f29fdb288797f3f192840785 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 8 Apr 2024 13:35:36 +0200 Subject: [PATCH 17/48] fix setup idx and get function --- torchmdnet/datasets/mdcath.py | 38 +++++++++++++++++------------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index d06ca4a3..997917ba 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -236,31 +236,31 @@ def process_specific_group(self, pdb, file, group_info): return (z, coords, forces) def _setup_idx(self): - files = [ - opj(self.root, f"cath_dataset_{pdb_id}.h5") - for pdb_id in self.to_download.keys() - ] + if self.noh_mode: + files = [opj(self.root, f"cath_noh_dataset_{pdb_id}.h5") for pdb_id in self.to_download.keys()] + else: + files = [opj(self.root, f"cath_dataset_{pdb_id}.h5") for pdb_id in self.to_download.keys()] self.idx = [] for i, (pdb, group_info) in enumerate(self.to_download.items()): for temp, replica in group_info: + # data will return a tuple with the z, coords and forces data = self.process_specific_group(pdb, files[i], (temp, replica)) + # conformer_indices is a list with the indices of the conformers, from the coords (i.e. data[1]) conformer_indices = range(data[1].shape[0]) - self.idx.extend( - [ - tuple([data[0], data[1][j], data[2][j], [j]]) - for j in conformer_indices - ] - ) - assert ( - len(self.idx) == self.num_conformers - ), f"Mismatch between number of conformers and idxs: {self.num_conformers} vs {len(self.idx)}" - - def get(self, idx): + d = [Data(z=data[0], pos=data[1][j], neg_dy=data[2][j]) for j in conformer_indices] + self.idx.extend(d) + + assert (len(self.idx) == self.num_conformers), f"Mismatch between number of conformers and idxs: {self.num_conformers} vs {len(self.idx)}" + + def get(self, element): data = Data() if self.idx is None: + # this process will be performed, num_workers * num_gpus + print("Setting up idx, this may take a while...") self._setup_idx() - *fields_data, i = self.idx[idx] - data.z = torch.tensor(fields_data[0][i], dtype=torch.long) - data.pos = torch.tensor(fields_data[1][i], dtype=torch.float32) - data.neg_dy = torch.tensor(fields_data[2][i], dtype=torch.float32) + # fields_data is a data object with the z, pos and neg_dy + fields_data = self.idx[element] + data.z = torch.tensor(fields_data.z, dtype=torch.long) + data.pos = torch.tensor(fields_data.pos, dtype=torch.float32) + data.neg_dy = torch.tensor(fields_data.neg_dy, dtype=torch.float32) return data From 68ba1401bddf0cd647a9d48565db829a54d2133a Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 30 Apr 2024 15:09:07 +0200 Subject: [PATCH 18/48] allow zero num workers in dataloader --- torchmdnet/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/data.py b/torchmdnet/data.py index ba91e8be..b587f0a1 100644 --- a/torchmdnet/data.py +++ b/torchmdnet/data.py @@ -133,7 +133,7 @@ def _get_dataloader(self, dataset, stage, store_dataloader=True): dataset=dataset, batch_size=batch_size, num_workers=self.hparams["num_workers"], - persistent_workers=True, + persistent_workers=True if self.hparams["num_workers"] > 0 else False, pin_memory=True, shuffle=shuffle, ) From 6fb7c4b2a3571680e4339ffdc1831fd1908a493d Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 30 Apr 2024 15:12:31 +0200 Subject: [PATCH 19/48] remove unused module --- torchmdnet/scripts/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 2e69212b..c3d80462 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -9,7 +9,6 @@ import logging import torch import lightning.pytorch as pl -from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.loggers import WandbLogger, CSVLogger, TensorBoardLogger from lightning.pytorch.callbacks import ( ModelCheckpoint, From 1bd39a0181c9fe0a0aa17616ede11db70e86da02 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 30 Apr 2024 15:15:41 +0200 Subject: [PATCH 20/48] use read_direct from h5py --- torchmdnet/datasets/mdcath.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 997917ba..e1da5ee0 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -209,22 +209,29 @@ def process_data_source(self): def len(self): return self.num_conformers - def process_specific_group(self, pdb, file, group_info): + def process_specific_group(self, pdb, file, temp, repl, conf_idx): + # use the read_direct and np.s_ to get the coords and forces of interest directly + conf_idx = conf_idx*self.skipFrames + slice_idxs = np.s_[conf_idx:conf_idx+1] with h5py.File(file, "r") as f: z = f[pdb]["z"][:] - group = f[pdb][f"sims{group_info[0]}K"][group_info[1]] - coords = group["coords"][::self.skipFrames, :, :] - forces = group["forces"][::self.skipFrames, :, :] - # coords and forces shape (num_frames, num_atoms, 3) + coords = np.zeros((z.shape[0], 3)) + forces = np.zeros((z.shape[0], 3)) + + group = f[f'{pdb}/{temp}/{repl}'] + group['coords'].read_direct(coords, slice_idxs) + group['forces'].read_direct(forces, slice_idxs) + + # coords and forces shape (num_atoms, 3) assert ( coords.shape[0] == forces.shape[0] ), f"Number of frames mismatch between coords and forces: {group['coords'].shape[0]} vs {group['forces'].shape[0]}" assert ( - coords.shape[1] == z.shape[0] + coords.shape[0] == z.shape[0] ), f"Number of atoms mismatch between coords and z: {group['coords'].shape[1]} vs {z.shape[0]}" assert ( - forces.shape[1] == z.shape[0] + forces.shape[0] == z.shape[0] ), f"Number of atoms mismatch between forces and z: {group['forces'].shape[1]} vs {z.shape[0]}" assert ( group["coords"].attrs["unit"] == "Angstrom" From c374ec841284f43a33ed0a3cde4c649364f8dc1b Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 30 Apr 2024 15:16:07 +0200 Subject: [PATCH 21/48] update get function --- torchmdnet/datasets/mdcath.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index e1da5ee0..d4c8650d 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -262,12 +262,12 @@ def _setup_idx(self): def get(self, element): data = Data() if self.idx is None: - # this process will be performed, num_workers * num_gpus - print("Setting up idx, this may take a while...") + # this process will be performed, num_workers * num_gpus (one per thread) self._setup_idx() - # fields_data is a data object with the z, pos and neg_dy - fields_data = self.idx[element] - data.z = torch.tensor(fields_data.z, dtype=torch.long) - data.pos = torch.tensor(fields_data.pos, dtype=torch.float32) - data.neg_dy = torch.tensor(fields_data.neg_dy, dtype=torch.float32) - return data + # fields_data is a tuple with the file, pdb, temp, replica, conf_idx + pdb_id, file_path, temp, replica, conf_idx = self.idx[element] + z, coords, forces = self.process_specific_group(pdb_id, file_path, temp, replica, conf_idx) + data.z = torch.tensor(z, dtype=torch.long) + data.pos = torch.tensor(coords, dtype=torch.float) + data.neg_dy = torch.tensor(forces, dtype=torch.float) + return data \ No newline at end of file From 180535fe46bb33a8fb251c892acdc73946bdea18 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 30 Apr 2024 15:19:22 +0200 Subject: [PATCH 22/48] fix setup idx --- torchmdnet/datasets/mdcath.py | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index d4c8650d..c300f1bb 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -209,6 +209,18 @@ def process_data_source(self): def len(self): return self.num_conformers + def _setup_idx(self): + files = [opj(self.root, f"mdcath_dataset_{pdb_id}.h5") for pdb_id in self.to_download.keys()] + self.idx = [] + for i, (pdb, group_info) in enumerate(self.to_download.items()): + for temp, replica, num_frames in group_info: + # build the catalog here for each conformer + d = [(pdb, files[i], temp, replica, conf_id) for conf_id in range(num_frames)] + self.idx.extend(d) + + assert (len(self.idx) == self.num_conformers), f"Mismatch between number of conformers and idxs: {self.num_conformers} vs {len(self.idx)}" + + def process_specific_group(self, pdb, file, temp, repl, conf_idx): # use the read_direct and np.s_ to get the coords and forces of interest directly conf_idx = conf_idx*self.skipFrames @@ -241,24 +253,7 @@ def process_specific_group(self, pdb, file, temp, repl, conf_idx): ), f"Forces unit is not kcal/mol/Angstrom: {group['forces'].attrs['unit']}" return (z, coords, forces) - - def _setup_idx(self): - if self.noh_mode: - files = [opj(self.root, f"cath_noh_dataset_{pdb_id}.h5") for pdb_id in self.to_download.keys()] - else: - files = [opj(self.root, f"cath_dataset_{pdb_id}.h5") for pdb_id in self.to_download.keys()] - self.idx = [] - for i, (pdb, group_info) in enumerate(self.to_download.items()): - for temp, replica in group_info: - # data will return a tuple with the z, coords and forces - data = self.process_specific_group(pdb, files[i], (temp, replica)) - # conformer_indices is a list with the indices of the conformers, from the coords (i.e. data[1]) - conformer_indices = range(data[1].shape[0]) - d = [Data(z=data[0], pos=data[1][j], neg_dy=data[2][j]) for j in conformer_indices] - self.idx.extend(d) - - assert (len(self.idx) == self.num_conformers), f"Mismatch between number of conformers and idxs: {self.num_conformers} vs {len(self.idx)}" - + def get(self, element): data = Data() if self.idx is None: From 89df1cc06e0553f64099b343a7f7756737ff138a Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 30 Apr 2024 15:21:46 +0200 Subject: [PATCH 23/48] append more info to to_download dict, according to new setup idx --- torchmdnet/datasets/mdcath.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index c300f1bb..0dd14d46 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -197,14 +197,13 @@ def process_data_source(self): [alpha, beta, coil], list(self.alpha_beta_coil) ).all(): continue - if pdb not in self.to_download: - self.to_download[pdb] = [] - self.to_download[pdb].append((temp, replica)) - # append the number of frames of the trajectory to the total number of molecules - self.num_conformers += math.ceil( - pdb_group[temp][replica].attrs["numFrames"] - / self.skipFrames - ) + + if pdb not in self.to_download: + self.to_download[pdb] = [] + num_frames = math.ceil(pdb_group[temp][replica].attrs["numFrames"]/self.skipFrames) + self.to_download[pdb].append((temp, replica, num_frames)) + # append the number of frames of the trajectory to the total number of molecules + self.num_conformers += num_frames def len(self): return self.num_conformers From 71059c80912f681a75e89846858581630ac116b4 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 30 Apr 2024 15:27:49 +0200 Subject: [PATCH 24/48] update file name --- torchmdnet/datasets/mdcath.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 0dd14d46..633fb524 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -109,7 +109,7 @@ def raw_file_names(self): if not hasattr(self, "to_download"): return ["mdCATH_source.h5"] # Otherwise, return the list of HDF5 files that passed the filtering criteria - return [f"cath_dataset_{pdb_id}.h5" for pdb_id in self.to_download.keys()] + return [f"mdcath_dataset_{pdb_id}.h5" for pdb_id in self.to_download.keys()] @property def raw_dir(self): @@ -121,10 +121,15 @@ def download(self): download_url(opj(self.url, "mdCATH_source.h5"), self.root) return for pdb_id in self.to_download.keys(): + file_name = f"mdcath_dataset_{pdb_id}.h5" + file_path = opj(self.raw_dir, file_name) + if not os.path.exists(file_path): + download_url(opj(self.url, file_name), self.root) + def calculate_dataset_size(self): total_size_bytes = 0 for pdb_id in self.to_download.keys(): - file_name = f"cath_noh_dataset_{pdb_id}.h5" if self.noh_mode else f"cath_dataset_{pdb_id}.h5" + file_name = f"mdcath_dataset_{pdb_id}.h5" total_size_bytes += os.path.getsize(opj(self.root, file_name)) total_size_mb = round(total_size_bytes / (1024 * 1024), 4) return total_size_mb From 2ab24c91694bb8bff300933632cf41ea719c5db4 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Tue, 30 Apr 2024 15:32:37 +0200 Subject: [PATCH 25/48] add pytest for mdcath dataset --- tests/test_mdcath.py | 185 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100644 tests/test_mdcath.py diff --git a/tests/test_mdcath.py b/tests/test_mdcath.py new file mode 100644 index 00000000..34e1ff48 --- /dev/null +++ b/tests/test_mdcath.py @@ -0,0 +1,185 @@ +from torchmdnet.datasets.mdcath import mdCATH +from torch_geometric.loader import DataLoader +from tqdm import tqdm +import numpy as np +import h5py +from os.path import join +import psutil +from pytest import mark + + +def test_mdcath(tmpdir): + num_atoms_list = np.linspace(50, 1000, 50) + source_file = h5py.File(join(tmpdir, "source.h5"), mode="w") + for num_atoms in num_atoms_list: + z = np.zeros(int(num_atoms)) + pos = np.zeros((100, int(num_atoms), 3)) + forces = np.zeros((100, int(num_atoms), 3)) + + s_group = source_file.create_group(f"A{num_atoms}") + + s_group.attrs["numChains"] = 1 + s_group.attrs["numNoHAtoms"] = int(num_atoms) / 2 + s_group.attrs["numProteinAtoms"] = int(num_atoms) + s_group.attrs["numResidues"] = int(num_atoms) / 10 + s_temp_group = s_group.create_group("348") + s_replica_group = s_temp_group.create_group("0") + s_replica_group.attrs["numFrames"] = 100 + s_replica_group.attrs["alpha"] = 0.30 + s_replica_group.attrs["beta"] = 0.25 + s_replica_group.attrs["coil"] = 0.45 + s_replica_group.attrs["max_gyration_radius"] = 2 + s_replica_group.attrs["max_num_neighbors_5A"] = 55 + s_replica_group.attrs["max_num_neighbors_9A"] = 200 + s_replica_group.attrs["min_gyration_radius"] = 1 + + # write the dataset + data = h5py.File(join(tmpdir, f"mdcath_dataset_A{num_atoms}.h5"), mode="w") + group = data.create_group(f"A{num_atoms}") + group.create_dataset("z", data=z) + tempgroup = group.create_group("348") + replicagroup = tempgroup.create_group("0") + replicagroup.create_dataset("coords", data=pos) + replicagroup.create_dataset("forces", data=forces) + # add some attributes + replicagroup.attrs["numFrames"] = 100 + replicagroup["coords"].attrs["unit"] = "Angstrom" + replicagroup["forces"].attrs["unit"] = "kcal/mol/Angstrom" + + data.flush() + data.close() + + dataset = mdCATH(root=tmpdir) + dl = DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=0, + pin_memory=True, + persistent_workers=False, + ) + for _, data in enumerate(tqdm(dl)): + pass + + +def test_mdcath_multiprocessing(tmpdir, num_entries=100, numFrames=10): + # generate sample data + z = np.zeros(num_entries) + pos = np.zeros((numFrames, num_entries, 3)) + forces = np.zeros((numFrames, num_entries, 3)) + + source_file = h5py.File(join(tmpdir, "source.h5"), mode="w") + s_group = source_file.create_group("A00") + + s_group.attrs["numChains"] = 1 + s_group.attrs["numNoHAtoms"] = num_entries / 2 + s_group.attrs["numProteinAtoms"] = num_entries + s_group.attrs["numResidues"] = num_entries / 10 + s_temp_group = s_group.create_group("348") + s_replica_group = s_temp_group.create_group("0") + s_replica_group.attrs["numFrames"] = numFrames + s_replica_group.attrs["alpha"] = 0.30 + s_replica_group.attrs["beta"] = 0.25 + s_replica_group.attrs["coil"] = 0.45 + s_replica_group.attrs["max_gyration_radius"] = 2 + s_replica_group.attrs["max_num_neighbors_5A"] = 55 + s_replica_group.attrs["max_num_neighbors_9A"] = 200 + s_replica_group.attrs["min_gyration_radius"] = 1 + + # write the dataset + data = h5py.File(join(tmpdir, "mdcath_dataset_A00.h5"), mode="w") + group = data.create_group("A00") + group.create_dataset("z", data=z) + tempgroup = group.create_group("348") + replicagroup = tempgroup.create_group("0") + replicagroup.create_dataset("coords", data=pos) + replicagroup.create_dataset("forces", data=forces) + # add some attributes + replicagroup.attrs["numFrames"] = numFrames + replicagroup["coords"].attrs["unit"] = "Angstrom" + replicagroup["forces"].attrs["unit"] = "kcal/mol/Angstrom" + + data.flush() + data.close() + + # make sure creating the dataset doesn't open any files on the main process + proc = psutil.Process() + n_open = len(proc.open_files()) + + dset = mdCATH( + root=tmpdir, + source_file=join(tmpdir, "source.h5"), + ) + assert len(proc.open_files()) == n_open, "creating the dataset object opened a file" + + +def replacer(arr, skipframes): + tmp_arr = arr.copy() + # function that take a numpy array of zeros and based on a skipframes value, replaces the zeros with 1s in that position + for i in range(0, len(tmp_arr), skipframes): + tmp_arr[i, :, :] = 1 + return tmp_arr + + +@mark.parametrize("skipframes", [1, 2, 5]) +@mark.parametrize("batch_size", [1, 10]) +def test_mdcath_skipframes(tmpdir, skipframes, batch_size): + + with h5py.File(join(tmpdir, "source.h5"), mode="w") as source_file: + num_frames_list = np.linspace(50, 1000, 50).astype(int) + for num_frame in tqdm(num_frames_list, desc="Creating tmp files"): + z = np.zeros(100) + pos = np.zeros((num_frame, 100, 3)) + forces = np.zeros((num_frame, 100, 3)) + + pos = replacer(pos, skipframes) + forces = replacer(forces, skipframes) + + s_group = source_file.create_group(f"A{num_frame}") + + s_group.attrs["numChains"] = 1 + s_group.attrs["numNoHAtoms"] = 100 / 2 + s_group.attrs["numProteinAtoms"] = 100 + s_group.attrs["numResidues"] = 100 / 10 + s_temp_group = s_group.create_group("348") + s_replica_group = s_temp_group.create_group("0") + s_replica_group.attrs["numFrames"] = num_frame + s_replica_group.attrs["alpha"] = 0.30 + s_replica_group.attrs["beta"] = 0.25 + s_replica_group.attrs["coil"] = 0.45 + s_replica_group.attrs["max_gyration_radius"] = 2 + s_replica_group.attrs["max_num_neighbors_5A"] = 55 + s_replica_group.attrs["max_num_neighbors_9A"] = 200 + s_replica_group.attrs["min_gyration_radius"] = 1 + + # write the dataset + data = h5py.File(join(tmpdir, f"mdcath_dataset_A{num_frame}.h5"), mode="w") + group = data.create_group(f"A{num_frame}") + group.create_dataset("z", data=z) + tempgroup = group.create_group("348") + replicagroup = tempgroup.create_group("0") + replicagroup.create_dataset("coords", data=pos) + replicagroup.create_dataset("forces", data=forces) + # add some attributes + replicagroup.attrs["numFrames"] = num_frame + replicagroup["coords"].attrs["unit"] = "Angstrom" + replicagroup["forces"].attrs["unit"] = "kcal/mol/Angstrom" + + data.flush() + data.close() + + dataset = mdCATH( + root=tmpdir, skipFrames=skipframes, source_file=join(tmpdir, "source.h5") + ) + dl = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=0, + pin_memory=True, + persistent_workers=False, + ) + for _, data in enumerate(tqdm(dl)): + # if the skipframes works correclty, data returned should be only 1s + assert data.pos.all() == 1, "skipframes not working correctly for positions" + assert data.neg_dy.all() == 1, "skipframes not working correctly for forces" From 160fb61317d3d85abbb4bde0fa024f6c584a9cfb Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 17 May 2024 13:04:30 +0200 Subject: [PATCH 26/48] fix memory occupancy due to attrs assertion in get --- torchmdnet/datasets/mdcath.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 633fb524..53294311 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -34,7 +34,6 @@ def __init__( transform=None, pre_transform=None, pre_filter=None, - numAtoms=5000, numNoHAtoms=None, numResidues=1000, @@ -246,15 +245,6 @@ def process_specific_group(self, pdb, file, temp, repl, conf_idx): assert ( coords.shape[0] == z.shape[0] ), f"Number of atoms mismatch between coords and z: {group['coords'].shape[1]} vs {z.shape[0]}" - assert ( - forces.shape[0] == z.shape[0] - ), f"Number of atoms mismatch between forces and z: {group['forces'].shape[1]} vs {z.shape[0]}" - assert ( - group["coords"].attrs["unit"] == "Angstrom" - ), f"Coords unit is not Angstrom: {group['coords'].attrs['unit']}" - assert ( - group["forces"].attrs["unit"] == "kcal/mol/Angstrom" - ), f"Forces unit is not kcal/mol/Angstrom: {group['forces'].attrs['unit']}" return (z, coords, forces) From 86296219b9431ec69f6147d1bfa16bcd85c1c984 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 17 May 2024 13:20:59 +0200 Subject: [PATCH 27/48] remove preload_dataset_limit attr --- torchmdnet/datasets/mdcath.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 53294311..ec3e50cb 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -79,7 +79,6 @@ def __init__( """ self.url = "https://zenodo.org/record//files/" - self.preload_dataset_limit = preload_dataset_limit super(mdCATH, self).__init__(root, transform, pre_transform, pre_filter) self.numAtoms = numAtoms From 30e39a6d27b2c21e3c506a09016b3a20c5672ba4 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 17 May 2024 13:21:26 +0200 Subject: [PATCH 28/48] update source file name --- torchmdnet/datasets/mdcath.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index ec3e50cb..d3ae4ae4 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -105,7 +105,7 @@ def __init__( def raw_file_names(self): # Check if the dataset has been processed, and if not, return the original source file if not hasattr(self, "to_download"): - return ["mdCATH_source.h5"] + return ["mdcath_source.h5"] # Otherwise, return the list of HDF5 files that passed the filtering criteria return [f"mdcath_dataset_{pdb_id}.h5" for pdb_id in self.to_download.keys()] @@ -133,7 +133,7 @@ def calculate_dataset_size(self): return total_size_mb def process_data_source(self): print("Processing mdCATH source") - data_info_path = opj(self.root, "mdCATH_source.h5") + data_info_path = opj(self.root, "mdcath_source.h5") if not os.path.exists(data_info_path): self.download() # the to_downlaod is the dictionary that will store the pdb ids and the corresponding temp and replica ids if they pass the filter From 104348cb2299cc828b5adfb469bd441568f981f0 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 31 May 2024 12:13:38 +0200 Subject: [PATCH 29/48] source file string defined as Instance variable --- torchmdnet/datasets/mdcath.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 633fb524..a7ba2c7c 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -82,7 +82,7 @@ def __init__( self.url = "https://zenodo.org/record//files/" self.preload_dataset_limit = preload_dataset_limit super(mdCATH, self).__init__(root, transform, pre_transform, pre_filter) - + self.source_file = "mdcath_source.h5" self.numAtoms = numAtoms self.numNoHAtoms = numNoHAtoms self.numResidues = numResidues @@ -107,7 +107,7 @@ def __init__( def raw_file_names(self): # Check if the dataset has been processed, and if not, return the original source file if not hasattr(self, "to_download"): - return ["mdCATH_source.h5"] + return [self.source_file] # Otherwise, return the list of HDF5 files that passed the filtering criteria return [f"mdcath_dataset_{pdb_id}.h5" for pdb_id in self.to_download.keys()] @@ -118,7 +118,7 @@ def raw_dir(self): return self.root def download(self): if not hasattr(self, "to_download") or not self.to_download: - download_url(opj(self.url, "mdCATH_source.h5"), self.root) + download_url(opj(self.url, self.source_file), self.root) return for pdb_id in self.to_download.keys(): file_name = f"mdcath_dataset_{pdb_id}.h5" @@ -135,7 +135,7 @@ def calculate_dataset_size(self): return total_size_mb def process_data_source(self): print("Processing mdCATH source") - data_info_path = opj(self.root, "mdCATH_source.h5") + data_info_path = opj(self.root, self.source_file) if not os.path.exists(data_info_path): self.download() # the to_downlaod is the dictionary that will store the pdb ids and the corresponding temp and replica ids if they pass the filter From 481255b4d6c1c063643e6921f916cd0719eae7f6 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 31 May 2024 12:14:13 +0200 Subject: [PATCH 30/48] fix memory leak --- torchmdnet/datasets/mdcath.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index a7ba2c7c..876c478a 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -226,6 +226,7 @@ def _setup_idx(self): def process_specific_group(self, pdb, file, temp, repl, conf_idx): + # do not use attributes from h5group beause is will cause memory leak # use the read_direct and np.s_ to get the coords and forces of interest directly conf_idx = conf_idx*self.skipFrames slice_idxs = np.s_[conf_idx:conf_idx+1] @@ -238,7 +239,7 @@ def process_specific_group(self, pdb, file, temp, repl, conf_idx): group['coords'].read_direct(coords, slice_idxs) group['forces'].read_direct(forces, slice_idxs) - + # coords and forces shape (num_atoms, 3) assert ( coords.shape[0] == forces.shape[0] @@ -246,16 +247,6 @@ def process_specific_group(self, pdb, file, temp, repl, conf_idx): assert ( coords.shape[0] == z.shape[0] ), f"Number of atoms mismatch between coords and z: {group['coords'].shape[1]} vs {z.shape[0]}" - assert ( - forces.shape[0] == z.shape[0] - ), f"Number of atoms mismatch between forces and z: {group['forces'].shape[1]} vs {z.shape[0]}" - assert ( - group["coords"].attrs["unit"] == "Angstrom" - ), f"Coords unit is not Angstrom: {group['coords'].attrs['unit']}" - assert ( - group["forces"].attrs["unit"] == "kcal/mol/Angstrom" - ), f"Forces unit is not kcal/mol/Angstrom: {group['forces'].attrs['unit']}" - return (z, coords, forces) def get(self, element): From d63d0813e75bee2ea14eba39554c73e302c4c590 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 31 May 2024 12:14:36 +0200 Subject: [PATCH 31/48] small change --- torchmdnet/datasets/mdcath.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 876c478a..760f0d02 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -34,7 +34,6 @@ def __init__( transform=None, pre_transform=None, pre_filter=None, - numAtoms=5000, numNoHAtoms=None, numResidues=1000, @@ -44,7 +43,7 @@ def __init__( min_gyration_radius=None, max_gyration_radius=None, alpha_beta_coil=None, - solid_ss = None, + solid_ss=None, numFrames=None, ): """mdCATH dataset class for PyTorch Geometric to load protein structures and dynamics from the mdCATH dataset. From 78ef15ac2608663aabe1fbbdb1cb6d6f10977010 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 31 May 2024 12:14:58 +0200 Subject: [PATCH 32/48] remove unused instance --- torchmdnet/datasets/mdcath.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 760f0d02..f907741e 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -79,7 +79,6 @@ def __init__( """ self.url = "https://zenodo.org/record//files/" - self.preload_dataset_limit = preload_dataset_limit super(mdCATH, self).__init__(root, transform, pre_transform, pre_filter) self.source_file = "mdcath_source.h5" self.numAtoms = numAtoms From e27e8d8e5de69544ef7b4970c18c30e50728e330 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 31 May 2024 12:16:09 +0200 Subject: [PATCH 33/48] get also detailed info from get function --- torchmdnet/datasets/mdcath.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index f907741e..c20403a4 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -258,4 +258,5 @@ def get(self, element): data.z = torch.tensor(z, dtype=torch.long) data.pos = torch.tensor(coords, dtype=torch.float) data.neg_dy = torch.tensor(forces, dtype=torch.float) + data.info = f'{pdb_id}_{temp}_{replica}_{conf_idx}' return data \ No newline at end of file From 714554760b4b748ebf6b0bbc671a574633318920 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 31 May 2024 12:52:25 +0200 Subject: [PATCH 34/48] fix self.source_file initialization --- torchmdnet/datasets/mdcath.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index c20403a4..ad653a5c 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -78,9 +78,9 @@ def __init__( Minimum number of frames in the trajectory in order to be considered. Default is None. """ - self.url = "https://zenodo.org/record//files/" - super(mdCATH, self).__init__(root, transform, pre_transform, pre_filter) + self.url = "https://huggingface.co/datasets/compsciencelab/mdCATH/tree/main/" self.source_file = "mdcath_source.h5" + super(mdCATH, self).__init__(root, transform, pre_transform, pre_filter) self.numAtoms = numAtoms self.numNoHAtoms = numNoHAtoms self.numResidues = numResidues @@ -120,7 +120,7 @@ def download(self): return for pdb_id in self.to_download.keys(): file_name = f"mdcath_dataset_{pdb_id}.h5" - file_path = opj(self.raw_dir, file_name) + file_path = opj(self.raw_dir, 'data', file_name) if not os.path.exists(file_path): download_url(opj(self.url, file_name), self.root) From 4f6fd294e63586437f81e817ef46d4f2968bb053 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 31 May 2024 14:03:48 +0200 Subject: [PATCH 35/48] update --- torchmdnet/datasets/mdcath.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index ef73a2c2..790254dd 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -78,7 +78,8 @@ def __init__( Minimum number of frames in the trajectory in order to be considered. Default is None. """ - self.url = "https://zenodo.org/record//files/" + self.url = "https://huggingface.co/datasets/compsciencelab/mdCATH/tree/main/" + self.source_file = "mdcath_source.h5" super(mdCATH, self).__init__(root, transform, pre_transform, pre_filter) self.numAtoms = numAtoms self.numNoHAtoms = numNoHAtoms @@ -104,7 +105,7 @@ def __init__( def raw_file_names(self): # Check if the dataset has been processed, and if not, return the original source file if not hasattr(self, "to_download"): - return ["mdcath_source.h5"] + return [self.source_file] # Otherwise, return the list of HDF5 files that passed the filtering criteria return [f"mdcath_dataset_{pdb_id}.h5" for pdb_id in self.to_download.keys()] @@ -132,7 +133,7 @@ def calculate_dataset_size(self): return total_size_mb def process_data_source(self): print("Processing mdCATH source") - data_info_path = opj(self.root, "mdcath_source.h5") + data_info_path = opj(self.root, self.source_file) if not os.path.exists(data_info_path): self.download() # the to_downlaod is the dictionary that will store the pdb ids and the corresponding temp and replica ids if they pass the filter From 0d0ee06b7b2e7081e9d4f83e58fbce9338dfd1e5 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 19 Jul 2024 18:01:12 +0200 Subject: [PATCH 36/48] update to allcaps --- torchmdnet/datasets/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmdnet/datasets/__init__.py b/torchmdnet/datasets/__init__.py index 2f15dd13..3a51a79b 100644 --- a/torchmdnet/datasets/__init__.py +++ b/torchmdnet/datasets/__init__.py @@ -14,7 +14,7 @@ COMP6v1, COMP6v2, ) -from .mdcath import mdCATH +from .mdcath import MDCATH from .custom import Custom from .water import WaterBox from .hdf import HDF5 @@ -40,7 +40,7 @@ "GDB10to13", "GenentechTorsions", "HDF5", - "mdCATH", + "MDCATH", "MD17", "MD22", "QM9", From 6134086a27169debf51674757479da392ba4411f Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 19 Jul 2024 18:02:40 +0200 Subject: [PATCH 37/48] reorder and update code --- torchmdnet/datasets/mdcath.py | 200 ++++++++++++++-------------------- 1 file changed, 84 insertions(+), 116 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 790254dd..8b630c22 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -3,31 +3,30 @@ # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) import os -from os.path import join as opj import h5py import torch -from tqdm import tqdm import math +import logging import numpy as np -from torch_geometric.data import Dataset, download_url, Data +from tqdm import tqdm +from os.path import join as opj +from torch_geometric.data import Dataset, Data +import urllib.request +from collections import defaultdict +logger = logging.getLogger('MDCATH') -def get_pdb_list(pdb_list): - # pdb list could be a list of pdb ids or a file with the pdb ids +def load_pdb_list(pdb_list): + """Load PDB list from a file or return list directly.""" if isinstance(pdb_list, list): return pdb_list - elif isinstance(pdb_list, str): - if os.path.exists(pdb_list): - print(f"Reading PDB list from {pdb_list}") - with open(pdb_list, "r") as f: - return [line.strip() for line in f] - else: - raise FileNotFoundError(f"File {pdb_list} not found") - else: - return None - + elif isinstance(pdb_list, str) and os.path.isfile(pdb_list): + logger.info(f"Reading PDB list from {pdb_list}") + with open(pdb_list, "r") as file: + return [line.strip() for line in file] + raise ValueError("Invalid PDB list. Please provide a list or a path to a file.") -class mdCATH(Dataset): +class MDCATH(Dataset): def __init__( self, root, @@ -35,7 +34,6 @@ def __init__( pre_transform=None, pre_filter=None, numAtoms=5000, - numNoHAtoms=None, numResidues=1000, temperatures=["348"], skipFrames=1, @@ -52,8 +50,6 @@ def __init__( ----------- root: str Root directory where the dataset should be stored. Data will be downloaded to 'root/'. - preload_dataset_limit: int - Maximum size of the dataset in MB to load into memory. If the dataset is larger than this limit, a warning will be printed. Default is 1024 MB. numAtoms: int Max number of atoms in the protein structure. numNoHAtoms: int @@ -65,7 +61,8 @@ def __init__( skipFrames: int Number of frames to skip in the trajectory. Default is 1. pdb_list: list or str - List of PDB IDs to download. If None, all available PDB IDs from 'mdcath_source.h5' will be downloaded. + List of PDB IDs to download or path to a file with the PDB IDs. If None, all available PDB IDs from 'mdcath_source.h5' will be downloaded. + The filters will be applied to the PDB IDs in this list in any case. Default is None. min_gyration_radius: float Minimum gyration radius (in nm) of the protein structure. Default is None. max_gyration_radius: float @@ -78,143 +75,114 @@ def __init__( Minimum number of frames in the trajectory in order to be considered. Default is None. """ - self.url = "https://huggingface.co/datasets/compsciencelab/mdCATH/tree/main/" + self.url = "https://huggingface.co/datasets/compsciencelab/mdCATH/resolve/main/" self.source_file = "mdcath_source.h5" - super(mdCATH, self).__init__(root, transform, pre_transform, pre_filter) + self.root = root + os.makedirs(root, exist_ok=True) self.numAtoms = numAtoms - self.numNoHAtoms = numNoHAtoms self.numResidues = numResidues self.temperatures = temperatures self.skipFrames = skipFrames - self.pdb_list = get_pdb_list(pdb_list) + self.pdb_list = load_pdb_list(pdb_list) self.min_gyration_radius = min_gyration_radius self.max_gyration_radius = max_gyration_radius self.alpha_beta_coil = alpha_beta_coil self.numFrames = numFrames self.solid_ss = solid_ss + self._ensure_source_file() + self._filter_and_prepare_data() self.idx = None - self.process_data_source() + super(MDCATH, self).__init__(root, transform, pre_transform, pre_filter) # Calculate the total size of the dataset in MB self.total_size_mb = self.calculate_dataset_size() + - print(f"Total number of domains: {len(self.to_download.keys())}") - print(f"Total number of conformers: {self.num_conformers}") - print(f"Total size of dataset: {self.total_size_mb} MB") + logger.info(f"Total number of domains: {len(self.processed.keys())}") + logger.info(f"Total number of conformers: {self.num_conformers}") + logger.info(f"Total size of dataset: {self.total_size_mb} MB") @property def raw_file_names(self): - # Check if the dataset has been processed, and if not, return the original source file - if not hasattr(self, "to_download"): - return [self.source_file] - # Otherwise, return the list of HDF5 files that passed the filtering criteria - return [f"mdcath_dataset_{pdb_id}.h5" for pdb_id in self.to_download.keys()] + return [f"mdcath_dataset_{pdb_id}.h5" for pdb_id in self.processed.keys()] @property def raw_dir(self): - # Override the raw_dir property to prevent the creation of a 'raw' directory + # Override the raw_dir property to return the root directory # The files will be downloaded to the root directory return self.root + + def _ensure_source_file(self): + """Ensure the source file is downloaded before processing.""" + source_path = os.path.join(self.root, self.source_file) + if not os.path.exists(source_path): + logger.info(f"Downloading source file {self.source_file}") + urllib.request.urlretrieve(opj(self.url, self.source_file), source_path) + def download(self): - if not hasattr(self, "to_download") or not self.to_download: - download_url(opj(self.url, self.source_file), self.root) - return - for pdb_id in self.to_download.keys(): + for pdb_id in self.processed.keys(): file_name = f"mdcath_dataset_{pdb_id}.h5" - file_path = opj(self.raw_dir, 'data', file_name) + file_path = opj(self.raw_dir, file_name) if not os.path.exists(file_path): - download_url(opj(self.url, file_name), self.root) + # Download the file if it does not exist + urllib.request.urlretrieve(opj(self.url, 'data', file_name), file_path) def calculate_dataset_size(self): total_size_bytes = 0 - for pdb_id in self.to_download.keys(): + for pdb_id in self.processed.keys(): file_name = f"mdcath_dataset_{pdb_id}.h5" total_size_bytes += os.path.getsize(opj(self.root, file_name)) total_size_mb = round(total_size_bytes / (1024 * 1024), 4) return total_size_mb - def process_data_source(self): - print("Processing mdCATH source") - data_info_path = opj(self.root, self.source_file) - if not os.path.exists(data_info_path): - self.download() - # the to_downlaod is the dictionary that will store the pdb ids and the corresponding temp and replica ids if they pass the filter - self.to_download = {} + + def _filter_and_prepare_data(self): + source_info_path = os.path.join(self.root, self.source_file) + + self.processed = defaultdict(list) self.num_conformers = 0 - with h5py.File(data_info_path, "r") as f: - domains = f.keys() if self.pdb_list is None else self.pdb_list - for pdb in tqdm( - domains, total=len(domains), desc="Processing mdCATH source" - ): - pdb_group = f[pdb] - if ( - self.numAtoms is not None - and pdb_group.attrs["numProteinAtoms"] > self.numAtoms - ): - continue - if ( - self.numResidues is not None - and pdb_group.attrs["numResidues"] > self.numResidues - ): + + with h5py.File(source_info_path, "r") as file: + domains = file.keys() if self.pdb_list is None else self.pdb_list + + for pdb_id in tqdm(domains, desc="Processing mdcath source"): + pdb_group = file[pdb_id] + if self.numAtoms is not None and pdb_group.attrs["numProteinAtoms"] > self.numAtoms: continue - if ( - self.numNoHAtoms is not None - and pdb_group.attrs["numNoHAtoms"] > self.numNoHAtoms - ): + if self.numResidues is not None and pdb_group.attrs["numResidues"] > self.numResidues: continue - for temp in self.temperatures: - if temp not in pdb_group.keys(): - continue - for replica in pdb_group[temp].keys(): - if ( - self.numFrames is not None - and pdb_group[temp][replica].attrs["numFrames"] - < self.numFrames - ): - continue - if ( - self.min_gyration_radius is not None - and pdb_group[temp][replica].attrs["min_gyration_radius"] - < self.min_gyration_radius - ): - continue - if ( - self.max_gyration_radius is not None - and pdb_group[temp][replica].attrs["max_gyration_radius"] - > self.max_gyration_radius - ): - continue - if ( - self.alpha_beta_coil is not None - or self.solid_ss is not None - ): - alpha = pdb_group[temp][replica].attrs["alpha"] - beta = pdb_group[temp][replica].attrs["beta"] - coil = pdb_group[temp][replica].attrs["coil"] - solid_ss = ( - (alpha + beta) / pdb_group.attrs["numResidues"] * 100 - ) - if self.solid_ss is not None: - if solid_ss < self.solid_ss: - continue - else: - if not np.isclose( - [alpha, beta, coil], list(self.alpha_beta_coil) - ).all(): - continue - - if pdb not in self.to_download: - self.to_download[pdb] = [] - num_frames = math.ceil(pdb_group[temp][replica].attrs["numFrames"]/self.skipFrames) - self.to_download[pdb].append((temp, replica, num_frames)) - # append the number of frames of the trajectory to the total number of molecules - self.num_conformers += num_frames + self._process_temperatures(pdb_id, pdb_group) + + def _process_temperatures(self, pdb_id, pdb_group): + for temp in self.temperatures: + for replica in pdb_group[temp].keys(): + self._evaluate_replica(pdb_id, temp, replica, pdb_group) + + def _evaluate_replica(self, pdb_id, temp, replica, pdb_group): + conditions = [ + self.numFrames is not None and pdb_group[temp][replica].attrs["numFrames"] < self.numFrames, + self.min_gyration_radius is not None and pdb_group[temp][replica].attrs["min_gyration_radius"] < self.min_gyration_radius, + self.max_gyration_radius is not None and pdb_group[temp][replica].attrs["max_gyration_radius"] > self.max_gyration_radius, + self._evaluate_structure(pdb_group, temp, replica) + ] + if any(conditions): + return + + num_frames = math.ceil(pdb_group[temp][replica].attrs["numFrames"] / self.skipFrames) + self.processed[pdb_id].append((temp, replica, num_frames)) + self.num_conformers += num_frames + + def _evaluate_structure(self, pdb_group, temp, replica): + alpha = pdb_group[temp][replica].attrs["alpha"] + beta = pdb_group[temp][replica].attrs["beta"] + solid_ss = (alpha + beta) / pdb_group.attrs["numResidues"] * 100 + return self.solid_ss is not None and solid_ss < self.solid_ss def len(self): return self.num_conformers def _setup_idx(self): - files = [opj(self.root, f"mdcath_dataset_{pdb_id}.h5") for pdb_id in self.to_download.keys()] + files = [opj(self.root, f"mdcath_dataset_{pdb_id}.h5") for pdb_id in self.processed.keys()] self.idx = [] - for i, (pdb, group_info) in enumerate(self.to_download.items()): + for i, (pdb, group_info) in enumerate(self.processed.items()): for temp, replica, num_frames in group_info: # build the catalog here for each conformer d = [(pdb, files[i], temp, replica, conf_id) for conf_id in range(num_frames)] From 262f8adb71220316dcff3d323422f18a1b34914f Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 19 Jul 2024 18:07:32 +0200 Subject: [PATCH 38/48] fix class name in test --- tests/test_mdcath.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_mdcath.py b/tests/test_mdcath.py index 34e1ff48..1e2e2e72 100644 --- a/tests/test_mdcath.py +++ b/tests/test_mdcath.py @@ -1,11 +1,11 @@ -from torchmdnet.datasets.mdcath import mdCATH -from torch_geometric.loader import DataLoader -from tqdm import tqdm -import numpy as np import h5py -from os.path import join import psutil +import numpy as np from pytest import mark +from os.path import join +from torchmdnet.datasets.mdcath import MDCATH +from torch_geometric.loader import DataLoader +from tqdm import tqdm def test_mdcath(tmpdir): @@ -49,7 +49,7 @@ def test_mdcath(tmpdir): data.flush() data.close() - dataset = mdCATH(root=tmpdir) + dataset = MDCATH(root=tmpdir) dl = DataLoader( dataset, batch_size=1, @@ -106,7 +106,7 @@ def test_mdcath_multiprocessing(tmpdir, num_entries=100, numFrames=10): proc = psutil.Process() n_open = len(proc.open_files()) - dset = mdCATH( + dset = MDCATH( root=tmpdir, source_file=join(tmpdir, "source.h5"), ) @@ -168,7 +168,7 @@ def test_mdcath_skipframes(tmpdir, skipframes, batch_size): data.flush() data.close() - dataset = mdCATH( + dataset = MDCATH( root=tmpdir, skipFrames=skipframes, source_file=join(tmpdir, "source.h5") ) dl = DataLoader( From 723164fb404b63131c4885d165b0e669f499e3e4 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 19 Jul 2024 18:08:08 +0200 Subject: [PATCH 39/48] undo on persistent workers --- torchmdnet/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/data.py b/torchmdnet/data.py index b587f0a1..ba91e8be 100644 --- a/torchmdnet/data.py +++ b/torchmdnet/data.py @@ -133,7 +133,7 @@ def _get_dataloader(self, dataset, stage, store_dataloader=True): dataset=dataset, batch_size=batch_size, num_workers=self.hparams["num_workers"], - persistent_workers=True if self.hparams["num_workers"] > 0 else False, + persistent_workers=True, pin_memory=True, shuffle=shuffle, ) From fc561697081c342980da7fb4393566bffe489e6b Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Fri, 19 Jul 2024 18:15:02 +0200 Subject: [PATCH 40/48] persistent_workers to False --- torchmdnet/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/data.py b/torchmdnet/data.py index ba91e8be..986e19f7 100644 --- a/torchmdnet/data.py +++ b/torchmdnet/data.py @@ -133,7 +133,7 @@ def _get_dataloader(self, dataset, stage, store_dataloader=True): dataset=dataset, batch_size=batch_size, num_workers=self.hparams["num_workers"], - persistent_workers=True, + persistent_workers=False, pin_memory=True, shuffle=shuffle, ) From c0d39b774aba1b626b2ee8d1a8143fc6ab0b31ff Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Sat, 20 Jul 2024 10:34:48 +0200 Subject: [PATCH 41/48] fix self.pdb_list in init --- torchmdnet/datasets/mdcath.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 8b630c22..2d00a8be 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -83,7 +83,7 @@ def __init__( self.numResidues = numResidues self.temperatures = temperatures self.skipFrames = skipFrames - self.pdb_list = load_pdb_list(pdb_list) + self.pdb_list = load_pdb_list(pdb_list) if pdb_list is not None else None self.min_gyration_radius = min_gyration_radius self.max_gyration_radius = max_gyration_radius self.alpha_beta_coil = alpha_beta_coil From 1d04f6e5d2be10a1dc1ba9e8e2e03af959ea1e8a Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Sat, 20 Jul 2024 10:37:17 +0200 Subject: [PATCH 42/48] fix source file name in unit test --- tests/test_mdcath.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_mdcath.py b/tests/test_mdcath.py index 1e2e2e72..d6983d05 100644 --- a/tests/test_mdcath.py +++ b/tests/test_mdcath.py @@ -10,7 +10,7 @@ def test_mdcath(tmpdir): num_atoms_list = np.linspace(50, 1000, 50) - source_file = h5py.File(join(tmpdir, "source.h5"), mode="w") + source_file = h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w") for num_atoms in num_atoms_list: z = np.zeros(int(num_atoms)) pos = np.zeros((100, int(num_atoms), 3)) @@ -68,7 +68,7 @@ def test_mdcath_multiprocessing(tmpdir, num_entries=100, numFrames=10): pos = np.zeros((numFrames, num_entries, 3)) forces = np.zeros((numFrames, num_entries, 3)) - source_file = h5py.File(join(tmpdir, "source.h5"), mode="w") + source_file = h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w") s_group = source_file.create_group("A00") s_group.attrs["numChains"] = 1 @@ -108,7 +108,6 @@ def test_mdcath_multiprocessing(tmpdir, num_entries=100, numFrames=10): dset = MDCATH( root=tmpdir, - source_file=join(tmpdir, "source.h5"), ) assert len(proc.open_files()) == n_open, "creating the dataset object opened a file" @@ -125,7 +124,7 @@ def replacer(arr, skipframes): @mark.parametrize("batch_size", [1, 10]) def test_mdcath_skipframes(tmpdir, skipframes, batch_size): - with h5py.File(join(tmpdir, "source.h5"), mode="w") as source_file: + with h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w") as source_file: num_frames_list = np.linspace(50, 1000, 50).astype(int) for num_frame in tqdm(num_frames_list, desc="Creating tmp files"): z = np.zeros(100) @@ -169,7 +168,7 @@ def test_mdcath_skipframes(tmpdir, skipframes, batch_size): data.close() dataset = MDCATH( - root=tmpdir, skipFrames=skipframes, source_file=join(tmpdir, "source.h5") + root=tmpdir, skipFrames=skipframes ) dl = DataLoader( dataset, From 2b4dd43b9a6788fbe61222cfeb1157f0eab6b773 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Sat, 20 Jul 2024 10:39:16 +0200 Subject: [PATCH 43/48] add unit test for pdb_list in mdcath --- tests/test_mdcath.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_mdcath.py b/tests/test_mdcath.py index d6983d05..8601d0de 100644 --- a/tests/test_mdcath.py +++ b/tests/test_mdcath.py @@ -122,8 +122,8 @@ def replacer(arr, skipframes): @mark.parametrize("skipframes", [1, 2, 5]) @mark.parametrize("batch_size", [1, 10]) -def test_mdcath_skipframes(tmpdir, skipframes, batch_size): - +@mark.parametrize("pdb_list", [["A50", "A612", "A1000"], None]) +def test_mdcath_args(tmpdir, skipframes, batch_size, pdb_list): with h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w") as source_file: num_frames_list = np.linspace(50, 1000, 50).astype(int) for num_frame in tqdm(num_frames_list, desc="Creating tmp files"): @@ -168,7 +168,7 @@ def test_mdcath_skipframes(tmpdir, skipframes, batch_size): data.close() dataset = MDCATH( - root=tmpdir, skipFrames=skipframes + root=tmpdir, skipFrames=skipframes, pdb_list=pdb_list ) dl = DataLoader( dataset, From 7148b6124fd2bcfc27c77db608575ebfcc4a8720 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 22 Jul 2024 12:30:13 +0200 Subject: [PATCH 44/48] rename arg skipFrames to skip_frames --- torchmdnet/datasets/mdcath.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 2d00a8be..f9ae052e 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -36,7 +36,7 @@ def __init__( numAtoms=5000, numResidues=1000, temperatures=["348"], - skipFrames=1, + skip_frames=1, pdb_list=None, min_gyration_radius=None, max_gyration_radius=None, @@ -58,7 +58,7 @@ def __init__( Max number of residues in the protein structure. temperatures: list List of temperatures (in Kelvin) to download. Default is ["348"]. Available temperatures are ['320', '348', '379', '413', '450'] - skipFrames: int + skip_frames: int Number of frames to skip in the trajectory. Default is 1. pdb_list: list or str List of PDB IDs to download or path to a file with the PDB IDs. If None, all available PDB IDs from 'mdcath_source.h5' will be downloaded. @@ -82,7 +82,7 @@ def __init__( self.numAtoms = numAtoms self.numResidues = numResidues self.temperatures = temperatures - self.skipFrames = skipFrames + self.skip_frames = skip_frames self.pdb_list = load_pdb_list(pdb_list) if pdb_list is not None else None self.min_gyration_radius = min_gyration_radius self.max_gyration_radius = max_gyration_radius @@ -166,7 +166,7 @@ def _evaluate_replica(self, pdb_id, temp, replica, pdb_group): if any(conditions): return - num_frames = math.ceil(pdb_group[temp][replica].attrs["numFrames"] / self.skipFrames) + num_frames = math.ceil(pdb_group[temp][replica].attrs["numFrames"] / self.skip_frames) self.processed[pdb_id].append((temp, replica, num_frames)) self.num_conformers += num_frames @@ -194,7 +194,7 @@ def _setup_idx(self): def process_specific_group(self, pdb, file, temp, repl, conf_idx): # do not use attributes from h5group beause is will cause memory leak # use the read_direct and np.s_ to get the coords and forces of interest directly - conf_idx = conf_idx*self.skipFrames + conf_idx = conf_idx*self.skip_frames slice_idxs = np.s_[conf_idx:conf_idx+1] with h5py.File(file, "r") as f: z = f[pdb]["z"][:] From 45f7f572d25bf6496cd84bb660f8b8f086d5a54e Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 22 Jul 2024 12:31:30 +0200 Subject: [PATCH 45/48] update skip_frames in mdcathtest --- tests/test_mdcath.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mdcath.py b/tests/test_mdcath.py index 8601d0de..0c744342 100644 --- a/tests/test_mdcath.py +++ b/tests/test_mdcath.py @@ -168,7 +168,7 @@ def test_mdcath_args(tmpdir, skipframes, batch_size, pdb_list): data.close() dataset = MDCATH( - root=tmpdir, skipFrames=skipframes, pdb_list=pdb_list + root=tmpdir, skip_frames=skipframes, pdb_list=pdb_list ) dl = DataLoader( dataset, From 92d5f08faad5f70495e541876e874f781c2e0e1f Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 29 Jul 2024 10:42:30 +0200 Subject: [PATCH 46/48] avoid error due to temperatures list, force str dtype --- torchmdnet/datasets/mdcath.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index f9ae052e..3c52afff 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -19,7 +19,7 @@ def load_pdb_list(pdb_list): """Load PDB list from a file or return list directly.""" if isinstance(pdb_list, list): - return pdb_list + return [str(pdb) for pdb in pdb_list] elif isinstance(pdb_list, str) and os.path.isfile(pdb_list): logger.info(f"Reading PDB list from {pdb_list}") with open(pdb_list, "r") as file: @@ -192,7 +192,7 @@ def _setup_idx(self): def process_specific_group(self, pdb, file, temp, repl, conf_idx): - # do not use attributes from h5group beause is will cause memory leak + # do not use attributes from h5group because is will cause memory leak # use the read_direct and np.s_ to get the coords and forces of interest directly conf_idx = conf_idx*self.skip_frames slice_idxs = np.s_[conf_idx:conf_idx+1] From c12a2bdd69c5a751d243a4b6bc6672d5ce7ed114 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 29 Jul 2024 10:50:23 +0200 Subject: [PATCH 47/48] undo --- torchmdnet/datasets/mdcath.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 3c52afff..4d7055b2 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -19,7 +19,7 @@ def load_pdb_list(pdb_list): """Load PDB list from a file or return list directly.""" if isinstance(pdb_list, list): - return [str(pdb) for pdb in pdb_list] + return pdb_list elif isinstance(pdb_list, str) and os.path.isfile(pdb_list): logger.info(f"Reading PDB list from {pdb_list}") with open(pdb_list, "r") as file: From f5c2d4683479340284ede894be0faf743a065b49 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 29 Jul 2024 10:50:58 +0200 Subject: [PATCH 48/48] force str dtype in temperatures list --- torchmdnet/datasets/mdcath.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py index 4d7055b2..62868fc9 100644 --- a/torchmdnet/datasets/mdcath.py +++ b/torchmdnet/datasets/mdcath.py @@ -81,7 +81,7 @@ def __init__( os.makedirs(root, exist_ok=True) self.numAtoms = numAtoms self.numResidues = numResidues - self.temperatures = temperatures + self.temperatures = [str(temp) for temp in temperatures] self.skip_frames = skip_frames self.pdb_list = load_pdb_list(pdb_list) if pdb_list is not None else None self.min_gyration_radius = min_gyration_radius