From 8036aa540fbc64cfc941495e46c9a415743d6998 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 28 Jun 2024 09:18:34 +0200 Subject: [PATCH 1/7] Add SPICE 2.0.1 --- torchmdnet/datasets/spice.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchmdnet/datasets/spice.py b/torchmdnet/datasets/spice.py index 123b3ed5..b3361828 100644 --- a/torchmdnet/datasets/spice.py +++ b/torchmdnet/datasets/spice.py @@ -68,6 +68,10 @@ class SPICE(MemmappedDataset): "url": "https://zenodo.org/records/8222043/files", "file": "SPICE-1.1.4.hdf5", }, + "2.0.1": { + "url": "https://zenodo.org/records/10975225/files", + "file": "SPICE-2.0.1.hdf5", + }, } @property From b68788d9a5e570813ea204ca7e0bf1f28e8acbf6 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 28 Jun 2024 10:26:44 +0200 Subject: [PATCH 2/7] Place raw spice data in a subfolder --- torchmdnet/datasets/spice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/datasets/spice.py b/torchmdnet/datasets/spice.py index b3361828..63d351f4 100644 --- a/torchmdnet/datasets/spice.py +++ b/torchmdnet/datasets/spice.py @@ -76,7 +76,7 @@ class SPICE(MemmappedDataset): @property def raw_dir(self): - return os.path.join(super().raw_dir, self.version) + return os.path.join(super().raw_dir, "spice", self.version) @property def raw_file_names(self): From 610e33c8794fe7d8f84ed19dcb66e73893a8b976 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 28 Jun 2024 10:27:01 +0200 Subject: [PATCH 3/7] Add check for bogus conformations --- torchmdnet/datasets/spice.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchmdnet/datasets/spice.py b/torchmdnet/datasets/spice.py index 63d351f4..620e6ced 100644 --- a/torchmdnet/datasets/spice.py +++ b/torchmdnet/datasets/spice.py @@ -10,6 +10,7 @@ from torchmdnet.datasets.memdataset import MemmappedDataset from torch_geometric.data import Data, download_url from tqdm import tqdm +import logging class SPICE(MemmappedDataset): @@ -141,7 +142,12 @@ def sample_iter(self, mol_ids=False): * self.HARTREE_TO_EV / self.BORH_TO_ANGSTROM ) - + if all_pos.ndim < 3: + logging.warning(f"Bogus conformation {mol_id}") + logging.warning( + f"Found {all_pos.shape} positions, {all_y.shape} energies and {all_neg_dy.shape} gradients" + ) + continue assert all_pos.shape[0] == all_y.shape[0] assert all_pos.shape[1] == z.shape[0] assert all_pos.shape[2] == 3 From cb7c50e083d0390576ff2eb6deb809d68fc69259 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 28 Jun 2024 10:58:08 +0200 Subject: [PATCH 4/7] Fix processed name in MemmappedDataset --- torchmdnet/datasets/memdataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchmdnet/datasets/memdataset.py b/torchmdnet/datasets/memdataset.py index a97b54be..a56fdbe5 100644 --- a/torchmdnet/datasets/memdataset.py +++ b/torchmdnet/datasets/memdataset.py @@ -57,7 +57,8 @@ def __init__( pre_filter=None, properties=("y", "neg_dy", "q", "pq", "dp"), ): - self.name = self.__class__.__name__ + if not hasattr(self, "name"): + self.name = self.__class__.__name__ self.properties = properties super().__init__(root, transform, pre_transform, pre_filter) From 9b5da8b4d9c6dec3f445b28c5cfdbecf5a599e0d Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 28 Jun 2024 10:58:37 +0200 Subject: [PATCH 5/7] Add hash checking of downloaded files in SPICE --- torchmdnet/datasets/spice.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchmdnet/datasets/spice.py b/torchmdnet/datasets/spice.py index 620e6ced..45a28f89 100644 --- a/torchmdnet/datasets/spice.py +++ b/torchmdnet/datasets/spice.py @@ -56,22 +56,27 @@ class SPICE(MemmappedDataset): "1.1.1": { "url": "https://zenodo.org/record/7258940/files", "file": "SPICE-1.1.1.hdf5", + "hash": "5411e7014c6d18ff07d108c9ad820b53", }, "1.1.2": { "url": "https://zenodo.org/record/7338495/files", "file": "SPICE-1.1.2.hdf5", + "hash": "a2b5ae2d1f72581040e1cceb20a79a33", }, "1.1.3": { "url": "https://zenodo.org/record/7606550/files", "file": "SPICE-1.1.3.hdf5", + "hash": "be93706b3bb2b2e327b690b185905856", }, "1.1.4": { "url": "https://zenodo.org/records/8222043/files", "file": "SPICE-1.1.4.hdf5", + "hash": "f27d4c81da0e37d6547276bf6b4ae6a1", }, "2.0.1": { "url": "https://zenodo.org/records/10975225/files", "file": "SPICE-2.0.1.hdf5", + "hash": "bfba2224b6540e1390a579569b475510", }, } @@ -178,3 +183,7 @@ def sample_iter(self, mol_ids=False): def download(self): download_url(self.raw_url, self.raw_dir) + if "hash" in self.VERSIONS[self.version]: + with open(self.raw_paths[0], "rb") as f: + file_hash = hashlib.md5(f.read()).hexdigest() + assert file_hash == self.VERSIONS[self.version]["hash"] From 9a5ddf119bd0f8e8775fb603b7579c03cd7c90e7 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 28 Jun 2024 11:15:59 +0200 Subject: [PATCH 6/7] Fix typo in ACE dataset test --- tests/test_datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 8d8695b7..e084ce93 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -291,8 +291,8 @@ def test_ace(tmpdir): mol["forces"].attrs["units"] = "eV/Å" mol["partial_charges"] = np.random.random((2, 3)) mol["partial_charges"].attrs["units"] = "e" - mol["dipole_moment"] = np.random.random((2, 3)) - mol["dipole_moment"].attrs["units"] = "e*Å" + mol["dipole_moments"] = np.random.random((2, 3)) + mol["dipole_moments"].attrs["units"] = "e*Å" dataset_v2 = Ace(root=tmpdir, paths=tmpfilename_v2) assert len(dataset_v2) == 6 f2.flush() From 66a84d20b3251722a60d89046f75e2e6dded1d22 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 28 Jun 2024 11:23:33 +0200 Subject: [PATCH 7/7] Update to Ace dataset test --- tests/test_datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e084ce93..22b00249 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -279,7 +279,7 @@ def test_ace(tmpdir): f2.attrs["layout_version"] = "2.0" f2.attrs["name"] = "sample_molecule_data_v2" master_mol_group = f2.create_group("master_molecule_group") - for m in range(3): # Three molecules + for m in range(4): mol = master_mol_group.create_group(f"mol_{m+1}") mol["atomic_numbers"] = [1, 6, 8] # H, C, O mol["formal_charges"] = [0, 0, 0] # Neutral charges @@ -294,6 +294,6 @@ def test_ace(tmpdir): mol["dipole_moments"] = np.random.random((2, 3)) mol["dipole_moments"].attrs["units"] = "e*Å" dataset_v2 = Ace(root=tmpdir, paths=tmpfilename_v2) - assert len(dataset_v2) == 6 + assert len(dataset_v2) == 8 f2.flush() f2.close()