Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MDCATH #337

Merged
merged 58 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
936fc02
new mdcath dataset
AntonioMirarchi Mar 28, 2024
05d298c
add mdcath init
AntonioMirarchi Mar 28, 2024
5548e6c
Merge branch 'main' of https://github.com/AntonioMirarchi/torchmd-net…
AntonioMirarchi Mar 28, 2024
0dd872f
improve memory usage avoiding to run out of memory
AntonioMirarchi Mar 28, 2024
6a75b02
remove debug print
AntonioMirarchi Mar 28, 2024
bec8144
to black
AntonioMirarchi Mar 28, 2024
00380dd
change raw dir to self.root
AntonioMirarchi Apr 4, 2024
153c2b1
add solid_ss to filter while processing
AntonioMirarchi Apr 4, 2024
3626751
Merge branch 'torchmd:main' into mdcath_dataloader
AntonioMirarchi Apr 4, 2024
7536c30
Merge branch 'mdcath_dataloader' of https://github.com/AntonioMirarch…
AntonioMirarchi Apr 4, 2024
0b4ab67
compute dataset size
AntonioMirarchi Apr 4, 2024
8482127
add solid_ss to documentation
AntonioMirarchi Apr 4, 2024
949a53f
fix self.solid_ss
AntonioMirarchi Apr 4, 2024
788d6fb
fix replica for loop indentation
AntonioMirarchi Apr 5, 2024
104dfa9
add possibility to get pdb list from file
AntonioMirarchi Apr 5, 2024
e555277
fix filtering on numResidues and numAtoms, could be None
AntonioMirarchi Apr 5, 2024
7652c4c
remove comment
AntonioMirarchi Apr 5, 2024
c9a1903
remove preload dataset feature
AntonioMirarchi Apr 5, 2024
2eab7ea
fix process specific group, return and skipframes
AntonioMirarchi Apr 8, 2024
519ee6a
fix setup idx and get function
AntonioMirarchi Apr 8, 2024
a54d934
Merge branch 'main' of https://github.com/AntonioMirarchi/torchmd-net…
AntonioMirarchi Apr 30, 2024
68ba140
allow zero num workers in dataloader
AntonioMirarchi Apr 30, 2024
6fb7c4b
remove unused module
AntonioMirarchi Apr 30, 2024
1bd39a0
use read_direct from h5py
AntonioMirarchi Apr 30, 2024
c374ec8
update get function
AntonioMirarchi Apr 30, 2024
180535f
fix setup idx
AntonioMirarchi Apr 30, 2024
89df1cc
append more info to to_download dict, according to new setup idx
AntonioMirarchi Apr 30, 2024
71059c8
update file name
AntonioMirarchi Apr 30, 2024
2ab24c9
add pytest for mdcath dataset
AntonioMirarchi Apr 30, 2024
160fb61
fix memory occupancy due to attrs assertion in get
AntonioMirarchi May 17, 2024
8629621
remove preload_dataset_limit attr
AntonioMirarchi May 17, 2024
30e39a6
update source file name
AntonioMirarchi May 17, 2024
3ab4406
Merge branch 'main' into mdcath_dataloader
AntonioMirarchi May 31, 2024
104348c
source file string defined as Instance variable
AntonioMirarchi May 31, 2024
481255b
fix memory leak
AntonioMirarchi May 31, 2024
d63d081
small change
AntonioMirarchi May 31, 2024
78ef15a
remove unused instance
AntonioMirarchi May 31, 2024
e27e8d8
get also detailed info from get function
AntonioMirarchi May 31, 2024
7145547
fix self.source_file initialization
AntonioMirarchi May 31, 2024
6f68d62
Merge branch 'mdcath_dataloader' of https://github.com/AntonioMirarch…
AntonioMirarchi May 31, 2024
4f6fd29
update
AntonioMirarchi May 31, 2024
9687f6f
Merge branch 'main' of https://github.com/AntonioMirarchi/torchmd-net…
AntonioMirarchi Jul 10, 2024
7b576c1
Merge branch 'mdcath_dataloader' of https://github.com/AntonioMirarch…
AntonioMirarchi Jul 10, 2024
0d0ee06
update to allcaps
AntonioMirarchi Jul 19, 2024
6134086
reorder and update code
AntonioMirarchi Jul 19, 2024
41b681a
Merge branch 'main' of https://github.com/AntonioMirarchi/torchmd-net…
AntonioMirarchi Jul 19, 2024
6e090e4
Merge branch 'mdcath_dataloader' of https://github.com/AntonioMirarch…
AntonioMirarchi Jul 19, 2024
262f8ad
fix class name in test
AntonioMirarchi Jul 19, 2024
723164f
undo on persistent workers
AntonioMirarchi Jul 19, 2024
fc56169
persistent_workers to False
AntonioMirarchi Jul 19, 2024
c0d39b7
fix self.pdb_list in init
AntonioMirarchi Jul 20, 2024
1d04f6e
fix source file name in unit test
AntonioMirarchi Jul 20, 2024
2b4dd43
add unit test for pdb_list in mdcath
AntonioMirarchi Jul 20, 2024
7148b61
rename arg skipFrames to skip_frames
AntonioMirarchi Jul 22, 2024
45f7f57
update skip_frames in mdcathtest
AntonioMirarchi Jul 22, 2024
92d5f08
avoid error due to temperatures list, force str dtype
AntonioMirarchi Jul 29, 2024
c12a2bd
undo
AntonioMirarchi Jul 29, 2024
f5c2d46
force str dtype in temperatures list
AntonioMirarchi Jul 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions tests/test_mdcath.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import h5py
import psutil
import numpy as np
from pytest import mark
from os.path import join
from torchmdnet.datasets.mdcath import MDCATH
from torch_geometric.loader import DataLoader
from tqdm import tqdm


def test_mdcath(tmpdir):
num_atoms_list = np.linspace(50, 1000, 50)
source_file = h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w")
for num_atoms in num_atoms_list:
z = np.zeros(int(num_atoms))
pos = np.zeros((100, int(num_atoms), 3))
forces = np.zeros((100, int(num_atoms), 3))

s_group = source_file.create_group(f"A{num_atoms}")

s_group.attrs["numChains"] = 1
s_group.attrs["numNoHAtoms"] = int(num_atoms) / 2
s_group.attrs["numProteinAtoms"] = int(num_atoms)
s_group.attrs["numResidues"] = int(num_atoms) / 10
s_temp_group = s_group.create_group("348")
s_replica_group = s_temp_group.create_group("0")
s_replica_group.attrs["numFrames"] = 100
s_replica_group.attrs["alpha"] = 0.30
s_replica_group.attrs["beta"] = 0.25
s_replica_group.attrs["coil"] = 0.45
s_replica_group.attrs["max_gyration_radius"] = 2
s_replica_group.attrs["max_num_neighbors_5A"] = 55
s_replica_group.attrs["max_num_neighbors_9A"] = 200
s_replica_group.attrs["min_gyration_radius"] = 1

# write the dataset
data = h5py.File(join(tmpdir, f"mdcath_dataset_A{num_atoms}.h5"), mode="w")
group = data.create_group(f"A{num_atoms}")
group.create_dataset("z", data=z)
tempgroup = group.create_group("348")
replicagroup = tempgroup.create_group("0")
replicagroup.create_dataset("coords", data=pos)
replicagroup.create_dataset("forces", data=forces)
# add some attributes
replicagroup.attrs["numFrames"] = 100
replicagroup["coords"].attrs["unit"] = "Angstrom"
replicagroup["forces"].attrs["unit"] = "kcal/mol/Angstrom"

data.flush()
data.close()

dataset = MDCATH(root=tmpdir)
dl = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
pin_memory=True,
persistent_workers=False,
)
for _, data in enumerate(tqdm(dl)):
pass


def test_mdcath_multiprocessing(tmpdir, num_entries=100, numFrames=10):
# generate sample data
z = np.zeros(num_entries)
pos = np.zeros((numFrames, num_entries, 3))
forces = np.zeros((numFrames, num_entries, 3))

source_file = h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w")
s_group = source_file.create_group("A00")

s_group.attrs["numChains"] = 1
s_group.attrs["numNoHAtoms"] = num_entries / 2
s_group.attrs["numProteinAtoms"] = num_entries
s_group.attrs["numResidues"] = num_entries / 10
s_temp_group = s_group.create_group("348")
s_replica_group = s_temp_group.create_group("0")
s_replica_group.attrs["numFrames"] = numFrames
s_replica_group.attrs["alpha"] = 0.30
s_replica_group.attrs["beta"] = 0.25
s_replica_group.attrs["coil"] = 0.45
s_replica_group.attrs["max_gyration_radius"] = 2
s_replica_group.attrs["max_num_neighbors_5A"] = 55
s_replica_group.attrs["max_num_neighbors_9A"] = 200
s_replica_group.attrs["min_gyration_radius"] = 1

# write the dataset
data = h5py.File(join(tmpdir, "mdcath_dataset_A00.h5"), mode="w")
group = data.create_group("A00")
group.create_dataset("z", data=z)
tempgroup = group.create_group("348")
replicagroup = tempgroup.create_group("0")
replicagroup.create_dataset("coords", data=pos)
replicagroup.create_dataset("forces", data=forces)
# add some attributes
replicagroup.attrs["numFrames"] = numFrames
replicagroup["coords"].attrs["unit"] = "Angstrom"
replicagroup["forces"].attrs["unit"] = "kcal/mol/Angstrom"

data.flush()
data.close()

# make sure creating the dataset doesn't open any files on the main process
proc = psutil.Process()
n_open = len(proc.open_files())

dset = MDCATH(
root=tmpdir,
)
assert len(proc.open_files()) == n_open, "creating the dataset object opened a file"


def replacer(arr, skipframes):
tmp_arr = arr.copy()
# function that take a numpy array of zeros and based on a skipframes value, replaces the zeros with 1s in that position
for i in range(0, len(tmp_arr), skipframes):
tmp_arr[i, :, :] = 1
return tmp_arr


@mark.parametrize("skipframes", [1, 2, 5])
@mark.parametrize("batch_size", [1, 10])
@mark.parametrize("pdb_list", [["A50", "A612", "A1000"], None])
def test_mdcath_args(tmpdir, skipframes, batch_size, pdb_list):
with h5py.File(join(tmpdir, "mdcath_source.h5"), mode="w") as source_file:
num_frames_list = np.linspace(50, 1000, 50).astype(int)
for num_frame in tqdm(num_frames_list, desc="Creating tmp files"):
z = np.zeros(100)
pos = np.zeros((num_frame, 100, 3))
forces = np.zeros((num_frame, 100, 3))

pos = replacer(pos, skipframes)
forces = replacer(forces, skipframes)

s_group = source_file.create_group(f"A{num_frame}")

s_group.attrs["numChains"] = 1
s_group.attrs["numNoHAtoms"] = 100 / 2
s_group.attrs["numProteinAtoms"] = 100
s_group.attrs["numResidues"] = 100 / 10
s_temp_group = s_group.create_group("348")
s_replica_group = s_temp_group.create_group("0")
s_replica_group.attrs["numFrames"] = num_frame
s_replica_group.attrs["alpha"] = 0.30
s_replica_group.attrs["beta"] = 0.25
s_replica_group.attrs["coil"] = 0.45
s_replica_group.attrs["max_gyration_radius"] = 2
s_replica_group.attrs["max_num_neighbors_5A"] = 55
s_replica_group.attrs["max_num_neighbors_9A"] = 200
s_replica_group.attrs["min_gyration_radius"] = 1

# write the dataset
data = h5py.File(join(tmpdir, f"mdcath_dataset_A{num_frame}.h5"), mode="w")
group = data.create_group(f"A{num_frame}")
group.create_dataset("z", data=z)
tempgroup = group.create_group("348")
replicagroup = tempgroup.create_group("0")
replicagroup.create_dataset("coords", data=pos)
replicagroup.create_dataset("forces", data=forces)
# add some attributes
replicagroup.attrs["numFrames"] = num_frame
replicagroup["coords"].attrs["unit"] = "Angstrom"
replicagroup["forces"].attrs["unit"] = "kcal/mol/Angstrom"

data.flush()
data.close()

dataset = MDCATH(
root=tmpdir, skip_frames=skipframes, pdb_list=pdb_list
)
dl = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
persistent_workers=False,
)
for _, data in enumerate(tqdm(dl)):
# if the skipframes works correclty, data returned should be only 1s
assert data.pos.all() == 1, "skipframes not working correctly for positions"
assert data.neg_dy.all() == 1, "skipframes not working correctly for forces"
2 changes: 2 additions & 0 deletions torchmdnet/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
COMP6v1,
COMP6v2,
)
from .mdcath import MDCATH
from .custom import Custom
from .water import WaterBox
from .hdf import HDF5
Expand All @@ -40,6 +41,7 @@
"GDB10to13",
"GenentechTorsions",
"HDF5",
"MDCATH",
"MD17",
"MD22",
"QM9",
Expand Down
Loading
Loading