diff --git a/tests/test_mdcath.py b/tests/test_mdcath.py new file mode 100644 index 00000000..0c744342 --- /dev/null +++ b/tests/test_mdcath.py @@ -0,0 +1,184 @@ +import h5py +import psutil +import numpy as np +from pytest import mark +from os.path import join +from torchmdnet.datasets.mdcath import MDCATH +from torch_geometric.loader import DataLoader +from tqdm import tqdm + + +def test_mdcath(tmpdir): + num_atoms_list = np.linspace(50, 1000, 50) + source_file = h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w") + for num_atoms in num_atoms_list: + z = np.zeros(int(num_atoms)) + pos = np.zeros((100, int(num_atoms), 3)) + forces = np.zeros((100, int(num_atoms), 3)) + + s_group = source_file.create_group(f"A{num_atoms}") + + s_group.attrs["numChains"] = 1 + s_group.attrs["numNoHAtoms"] = int(num_atoms) / 2 + s_group.attrs["numProteinAtoms"] = int(num_atoms) + s_group.attrs["numResidues"] = int(num_atoms) / 10 + s_temp_group = s_group.create_group("348") + s_replica_group = s_temp_group.create_group("0") + s_replica_group.attrs["numFrames"] = 100 + s_replica_group.attrs["alpha"] = 0.30 + s_replica_group.attrs["beta"] = 0.25 + s_replica_group.attrs["coil"] = 0.45 + s_replica_group.attrs["max_gyration_radius"] = 2 + s_replica_group.attrs["max_num_neighbors_5A"] = 55 + s_replica_group.attrs["max_num_neighbors_9A"] = 200 + s_replica_group.attrs["min_gyration_radius"] = 1 + + # write the dataset + data = h5py.File(join(tmpdir, f"mdcath_dataset_A{num_atoms}.h5"), mode="w") + group = data.create_group(f"A{num_atoms}") + group.create_dataset("z", data=z) + tempgroup = group.create_group("348") + replicagroup = tempgroup.create_group("0") + replicagroup.create_dataset("coords", data=pos) + replicagroup.create_dataset("forces", data=forces) + # add some attributes + replicagroup.attrs["numFrames"] = 100 + replicagroup["coords"].attrs["unit"] = "Angstrom" + replicagroup["forces"].attrs["unit"] = "kcal/mol/Angstrom" + + data.flush() + data.close() + + dataset = MDCATH(root=tmpdir) + dl = DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=0, + pin_memory=True, + persistent_workers=False, + ) + for _, data in enumerate(tqdm(dl)): + pass + + +def test_mdcath_multiprocessing(tmpdir, num_entries=100, numFrames=10): + # generate sample data + z = np.zeros(num_entries) + pos = np.zeros((numFrames, num_entries, 3)) + forces = np.zeros((numFrames, num_entries, 3)) + + source_file = h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w") + s_group = source_file.create_group("A00") + + s_group.attrs["numChains"] = 1 + s_group.attrs["numNoHAtoms"] = num_entries / 2 + s_group.attrs["numProteinAtoms"] = num_entries + s_group.attrs["numResidues"] = num_entries / 10 + s_temp_group = s_group.create_group("348") + s_replica_group = s_temp_group.create_group("0") + s_replica_group.attrs["numFrames"] = numFrames + s_replica_group.attrs["alpha"] = 0.30 + s_replica_group.attrs["beta"] = 0.25 + s_replica_group.attrs["coil"] = 0.45 + s_replica_group.attrs["max_gyration_radius"] = 2 + s_replica_group.attrs["max_num_neighbors_5A"] = 55 + s_replica_group.attrs["max_num_neighbors_9A"] = 200 + s_replica_group.attrs["min_gyration_radius"] = 1 + + # write the dataset + data = h5py.File(join(tmpdir, "mdcath_dataset_A00.h5"), mode="w") + group = data.create_group("A00") + group.create_dataset("z", data=z) + tempgroup = group.create_group("348") + replicagroup = tempgroup.create_group("0") + replicagroup.create_dataset("coords", data=pos) + replicagroup.create_dataset("forces", data=forces) + # add some attributes + replicagroup.attrs["numFrames"] = numFrames + replicagroup["coords"].attrs["unit"] = "Angstrom" + replicagroup["forces"].attrs["unit"] = "kcal/mol/Angstrom" + + data.flush() + data.close() + + # make sure creating the dataset doesn't open any files on the main process + proc = psutil.Process() + n_open = len(proc.open_files()) + + dset = MDCATH( + root=tmpdir, + ) + assert len(proc.open_files()) == n_open, "creating the dataset object opened a file" + + +def replacer(arr, skipframes): + tmp_arr = arr.copy() + # function that take a numpy array of zeros and based on a skipframes value, replaces the zeros with 1s in that position + for i in range(0, len(tmp_arr), skipframes): + tmp_arr[i, :, :] = 1 + return tmp_arr + + +@mark.parametrize("skipframes", [1, 2, 5]) +@mark.parametrize("batch_size", [1, 10]) +@mark.parametrize("pdb_list", [["A50", "A612", "A1000"], None]) +def test_mdcath_args(tmpdir, skipframes, batch_size, pdb_list): + with h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w") as source_file: + num_frames_list = np.linspace(50, 1000, 50).astype(int) + for num_frame in tqdm(num_frames_list, desc="Creating tmp files"): + z = np.zeros(100) + pos = np.zeros((num_frame, 100, 3)) + forces = np.zeros((num_frame, 100, 3)) + + pos = replacer(pos, skipframes) + forces = replacer(forces, skipframes) + + s_group = source_file.create_group(f"A{num_frame}") + + s_group.attrs["numChains"] = 1 + s_group.attrs["numNoHAtoms"] = 100 / 2 + s_group.attrs["numProteinAtoms"] = 100 + s_group.attrs["numResidues"] = 100 / 10 + s_temp_group = s_group.create_group("348") + s_replica_group = s_temp_group.create_group("0") + s_replica_group.attrs["numFrames"] = num_frame + s_replica_group.attrs["alpha"] = 0.30 + s_replica_group.attrs["beta"] = 0.25 + s_replica_group.attrs["coil"] = 0.45 + s_replica_group.attrs["max_gyration_radius"] = 2 + s_replica_group.attrs["max_num_neighbors_5A"] = 55 + s_replica_group.attrs["max_num_neighbors_9A"] = 200 + s_replica_group.attrs["min_gyration_radius"] = 1 + + # write the dataset + data = h5py.File(join(tmpdir, f"mdcath_dataset_A{num_frame}.h5"), mode="w") + group = data.create_group(f"A{num_frame}") + group.create_dataset("z", data=z) + tempgroup = group.create_group("348") + replicagroup = tempgroup.create_group("0") + replicagroup.create_dataset("coords", data=pos) + replicagroup.create_dataset("forces", data=forces) + # add some attributes + replicagroup.attrs["numFrames"] = num_frame + replicagroup["coords"].attrs["unit"] = "Angstrom" + replicagroup["forces"].attrs["unit"] = "kcal/mol/Angstrom" + + data.flush() + data.close() + + dataset = MDCATH( + root=tmpdir, skip_frames=skipframes, pdb_list=pdb_list + ) + dl = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=0, + pin_memory=True, + persistent_workers=False, + ) + for _, data in enumerate(tqdm(dl)): + # if the skipframes works correclty, data returned should be only 1s + assert data.pos.all() == 1, "skipframes not working correctly for positions" + assert data.neg_dy.all() == 1, "skipframes not working correctly for forces" diff --git a/torchmdnet/datasets/__init__.py b/torchmdnet/datasets/__init__.py index f395e6e6..92a70bd2 100644 --- a/torchmdnet/datasets/__init__.py +++ b/torchmdnet/datasets/__init__.py @@ -14,6 +14,7 @@ COMP6v1, COMP6v2, ) +from .mdcath import MDCATH from .custom import Custom from .water import WaterBox from .hdf import HDF5 @@ -40,6 +41,7 @@ "GDB10to13", "GenentechTorsions", "HDF5", + "MDCATH", "MD17", "MD22", "QM9", diff --git a/torchmdnet/datasets/mdcath.py b/torchmdnet/datasets/mdcath.py new file mode 100644 index 00000000..62868fc9 --- /dev/null +++ b/torchmdnet/datasets/mdcath.py @@ -0,0 +1,231 @@ +# Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org +# Distributed under the MIT License. +# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) + +import os +import h5py +import torch +import math +import logging +import numpy as np +from tqdm import tqdm +from os.path import join as opj +from torch_geometric.data import Dataset, Data +import urllib.request +from collections import defaultdict + +logger = logging.getLogger('MDCATH') + +def load_pdb_list(pdb_list): + """Load PDB list from a file or return list directly.""" + if isinstance(pdb_list, list): + return pdb_list + elif isinstance(pdb_list, str) and os.path.isfile(pdb_list): + logger.info(f"Reading PDB list from {pdb_list}") + with open(pdb_list, "r") as file: + return [line.strip() for line in file] + raise ValueError("Invalid PDB list. Please provide a list or a path to a file.") + +class MDCATH(Dataset): + def __init__( + self, + root, + transform=None, + pre_transform=None, + pre_filter=None, + numAtoms=5000, + numResidues=1000, + temperatures=["348"], + skip_frames=1, + pdb_list=None, + min_gyration_radius=None, + max_gyration_radius=None, + alpha_beta_coil=None, + solid_ss=None, + numFrames=None, + ): + """mdCATH dataset class for PyTorch Geometric to load protein structures and dynamics from the mdCATH dataset. + + Parameters: + ----------- + root: str + Root directory where the dataset should be stored. Data will be downloaded to 'root/'. + numAtoms: int + Max number of atoms in the protein structure. + numNoHAtoms: int + Max number of non-hydrogen atoms in the protein structure. + numResidues: int + Max number of residues in the protein structure. + temperatures: list + List of temperatures (in Kelvin) to download. Default is ["348"]. Available temperatures are ['320', '348', '379', '413', '450'] + skip_frames: int + Number of frames to skip in the trajectory. Default is 1. + pdb_list: list or str + List of PDB IDs to download or path to a file with the PDB IDs. If None, all available PDB IDs from 'mdcath_source.h5' will be downloaded. + The filters will be applied to the PDB IDs in this list in any case. Default is None. + min_gyration_radius: float + Minimum gyration radius (in nm) of the protein structure. Default is None. + max_gyration_radius: float + Maximum gyration radius (in nm) of the protein structure. Default is None. + alpha_beta_coil: tuple + Tuple with the minimum percentage of alpha-helix, beta-sheet and coil residues in the protein structure. Default is None. + solid_ss: float + minimum percentage of solid secondary structure in the protein structure (alpha + beta)/total_residues * 100. Default is None. + numFrames: int + Minimum number of frames in the trajectory in order to be considered. Default is None. + """ + + self.url = "https://huggingface.co/datasets/compsciencelab/mdCATH/resolve/main/" + self.source_file = "mdcath_source.h5" + self.root = root + os.makedirs(root, exist_ok=True) + self.numAtoms = numAtoms + self.numResidues = numResidues + self.temperatures = [str(temp) for temp in temperatures] + self.skip_frames = skip_frames + self.pdb_list = load_pdb_list(pdb_list) if pdb_list is not None else None + self.min_gyration_radius = min_gyration_radius + self.max_gyration_radius = max_gyration_radius + self.alpha_beta_coil = alpha_beta_coil + self.numFrames = numFrames + self.solid_ss = solid_ss + self._ensure_source_file() + self._filter_and_prepare_data() + self.idx = None + super(MDCATH, self).__init__(root, transform, pre_transform, pre_filter) + # Calculate the total size of the dataset in MB + self.total_size_mb = self.calculate_dataset_size() + + + logger.info(f"Total number of domains: {len(self.processed.keys())}") + logger.info(f"Total number of conformers: {self.num_conformers}") + logger.info(f"Total size of dataset: {self.total_size_mb} MB") + + @property + def raw_file_names(self): + return [f"mdcath_dataset_{pdb_id}.h5" for pdb_id in self.processed.keys()] + + @property + def raw_dir(self): + # Override the raw_dir property to return the root directory + # The files will be downloaded to the root directory + return self.root + + def _ensure_source_file(self): + """Ensure the source file is downloaded before processing.""" + source_path = os.path.join(self.root, self.source_file) + if not os.path.exists(source_path): + logger.info(f"Downloading source file {self.source_file}") + urllib.request.urlretrieve(opj(self.url, self.source_file), source_path) + + def download(self): + for pdb_id in self.processed.keys(): + file_name = f"mdcath_dataset_{pdb_id}.h5" + file_path = opj(self.raw_dir, file_name) + if not os.path.exists(file_path): + # Download the file if it does not exist + urllib.request.urlretrieve(opj(self.url, 'data', file_name), file_path) + + def calculate_dataset_size(self): + total_size_bytes = 0 + for pdb_id in self.processed.keys(): + file_name = f"mdcath_dataset_{pdb_id}.h5" + total_size_bytes += os.path.getsize(opj(self.root, file_name)) + total_size_mb = round(total_size_bytes / (1024 * 1024), 4) + return total_size_mb + + def _filter_and_prepare_data(self): + source_info_path = os.path.join(self.root, self.source_file) + + self.processed = defaultdict(list) + self.num_conformers = 0 + + with h5py.File(source_info_path, "r") as file: + domains = file.keys() if self.pdb_list is None else self.pdb_list + + for pdb_id in tqdm(domains, desc="Processing mdcath source"): + pdb_group = file[pdb_id] + if self.numAtoms is not None and pdb_group.attrs["numProteinAtoms"] > self.numAtoms: + continue + if self.numResidues is not None and pdb_group.attrs["numResidues"] > self.numResidues: + continue + self._process_temperatures(pdb_id, pdb_group) + + def _process_temperatures(self, pdb_id, pdb_group): + for temp in self.temperatures: + for replica in pdb_group[temp].keys(): + self._evaluate_replica(pdb_id, temp, replica, pdb_group) + + def _evaluate_replica(self, pdb_id, temp, replica, pdb_group): + conditions = [ + self.numFrames is not None and pdb_group[temp][replica].attrs["numFrames"] < self.numFrames, + self.min_gyration_radius is not None and pdb_group[temp][replica].attrs["min_gyration_radius"] < self.min_gyration_radius, + self.max_gyration_radius is not None and pdb_group[temp][replica].attrs["max_gyration_radius"] > self.max_gyration_radius, + self._evaluate_structure(pdb_group, temp, replica) + ] + if any(conditions): + return + + num_frames = math.ceil(pdb_group[temp][replica].attrs["numFrames"] / self.skip_frames) + self.processed[pdb_id].append((temp, replica, num_frames)) + self.num_conformers += num_frames + + def _evaluate_structure(self, pdb_group, temp, replica): + alpha = pdb_group[temp][replica].attrs["alpha"] + beta = pdb_group[temp][replica].attrs["beta"] + solid_ss = (alpha + beta) / pdb_group.attrs["numResidues"] * 100 + return self.solid_ss is not None and solid_ss < self.solid_ss + + def len(self): + return self.num_conformers + + def _setup_idx(self): + files = [opj(self.root, f"mdcath_dataset_{pdb_id}.h5") for pdb_id in self.processed.keys()] + self.idx = [] + for i, (pdb, group_info) in enumerate(self.processed.items()): + for temp, replica, num_frames in group_info: + # build the catalog here for each conformer + d = [(pdb, files[i], temp, replica, conf_id) for conf_id in range(num_frames)] + self.idx.extend(d) + + assert (len(self.idx) == self.num_conformers), f"Mismatch between number of conformers and idxs: {self.num_conformers} vs {len(self.idx)}" + + + def process_specific_group(self, pdb, file, temp, repl, conf_idx): + # do not use attributes from h5group because is will cause memory leak + # use the read_direct and np.s_ to get the coords and forces of interest directly + conf_idx = conf_idx*self.skip_frames + slice_idxs = np.s_[conf_idx:conf_idx+1] + with h5py.File(file, "r") as f: + z = f[pdb]["z"][:] + coords = np.zeros((z.shape[0], 3)) + forces = np.zeros((z.shape[0], 3)) + + group = f[f'{pdb}/{temp}/{repl}'] + + group['coords'].read_direct(coords, slice_idxs) + group['forces'].read_direct(forces, slice_idxs) + + # coords and forces shape (num_atoms, 3) + assert ( + coords.shape[0] == forces.shape[0] + ), f"Number of frames mismatch between coords and forces: {group['coords'].shape[0]} vs {group['forces'].shape[0]}" + assert ( + coords.shape[0] == z.shape[0] + ), f"Number of atoms mismatch between coords and z: {group['coords'].shape[1]} vs {z.shape[0]}" + + return (z, coords, forces) + + def get(self, element): + data = Data() + if self.idx is None: + # this process will be performed, num_workers * num_gpus (one per thread) + self._setup_idx() + # fields_data is a tuple with the file, pdb, temp, replica, conf_idx + pdb_id, file_path, temp, replica, conf_idx = self.idx[element] + z, coords, forces = self.process_specific_group(pdb_id, file_path, temp, replica, conf_idx) + data.z = torch.tensor(z, dtype=torch.long) + data.pos = torch.tensor(coords, dtype=torch.float) + data.neg_dy = torch.tensor(forces, dtype=torch.float) + data.info = f'{pdb_id}_{temp}_{replica}_{conf_idx}' + return data \ No newline at end of file diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 7f2d8e07..0951b92d 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -9,7 +9,6 @@ import logging import torch import lightning.pytorch as pl -from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.loggers import WandbLogger, CSVLogger, TensorBoardLogger from lightning.pytorch.callbacks import ( ModelCheckpoint,