Skip to content

Commit

Permalink
Merge pull request #333 from RaulPPelaez/sp2
Browse files Browse the repository at this point in the history
Add SPICE 2.0.1, Fix bug in MemmappedDataset
  • Loading branch information
stefdoerr authored Jun 28, 2024
2 parents e908988 + 66a84d2 commit c800af1
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
8 changes: 4 additions & 4 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def test_ace(tmpdir):
f2.attrs["layout_version"] = "2.0"
f2.attrs["name"] = "sample_molecule_data_v2"
master_mol_group = f2.create_group("master_molecule_group")
for m in range(3): # Three molecules
for m in range(4):
mol = master_mol_group.create_group(f"mol_{m+1}")
mol["atomic_numbers"] = [1, 6, 8] # H, C, O
mol["formal_charges"] = [0, 0, 0] # Neutral charges
Expand All @@ -291,9 +291,9 @@ def test_ace(tmpdir):
mol["forces"].attrs["units"] = "eV/Å"
mol["partial_charges"] = np.random.random((2, 3))
mol["partial_charges"].attrs["units"] = "e"
mol["dipole_moment"] = np.random.random((2, 3))
mol["dipole_moment"].attrs["units"] = "e*Å"
mol["dipole_moments"] = np.random.random((2, 3))
mol["dipole_moments"].attrs["units"] = "e*Å"
dataset_v2 = Ace(root=tmpdir, paths=tmpfilename_v2)
assert len(dataset_v2) == 6
assert len(dataset_v2) == 8
f2.flush()
f2.close()
3 changes: 2 additions & 1 deletion torchmdnet/datasets/memdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __init__(
pre_filter=None,
properties=("y", "neg_dy", "q", "pq", "dp"),
):
self.name = self.__class__.__name__
if not hasattr(self, "name"):
self.name = self.__class__.__name__
self.properties = properties
super().__init__(root, transform, pre_transform, pre_filter)

Expand Down
23 changes: 21 additions & 2 deletions torchmdnet/datasets/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torchmdnet.datasets.memdataset import MemmappedDataset
from torch_geometric.data import Data, download_url
from tqdm import tqdm
import logging


class SPICE(MemmappedDataset):
Expand Down Expand Up @@ -55,24 +56,33 @@ class SPICE(MemmappedDataset):
"1.1.1": {
"url": "https://zenodo.org/record/7258940/files",
"file": "SPICE-1.1.1.hdf5",
"hash": "5411e7014c6d18ff07d108c9ad820b53",
},
"1.1.2": {
"url": "https://zenodo.org/record/7338495/files",
"file": "SPICE-1.1.2.hdf5",
"hash": "a2b5ae2d1f72581040e1cceb20a79a33",
},
"1.1.3": {
"url": "https://zenodo.org/record/7606550/files",
"file": "SPICE-1.1.3.hdf5",
"hash": "be93706b3bb2b2e327b690b185905856",
},
"1.1.4": {
"url": "https://zenodo.org/records/8222043/files",
"file": "SPICE-1.1.4.hdf5",
"hash": "f27d4c81da0e37d6547276bf6b4ae6a1",
},
"2.0.1": {
"url": "https://zenodo.org/records/10975225/files",
"file": "SPICE-2.0.1.hdf5",
"hash": "bfba2224b6540e1390a579569b475510",
},
}

@property
def raw_dir(self):
return os.path.join(super().raw_dir, self.version)
return os.path.join(super().raw_dir, "spice", self.version)

@property
def raw_file_names(self):
Expand Down Expand Up @@ -137,7 +147,12 @@ def sample_iter(self, mol_ids=False):
* self.HARTREE_TO_EV
/ self.BORH_TO_ANGSTROM
)

if all_pos.ndim < 3:
logging.warning(f"Bogus conformation {mol_id}")
logging.warning(
f"Found {all_pos.shape} positions, {all_y.shape} energies and {all_neg_dy.shape} gradients"
)
continue
assert all_pos.shape[0] == all_y.shape[0]
assert all_pos.shape[1] == z.shape[0]
assert all_pos.shape[2] == 3
Expand Down Expand Up @@ -168,3 +183,7 @@ def sample_iter(self, mol_ids=False):

def download(self):
download_url(self.raw_url, self.raw_dir)
if "hash" in self.VERSIONS[self.version]:
with open(self.raw_paths[0], "rb") as f:
file_hash = hashlib.md5(f.read()).hexdigest()
assert file_hash == self.VERSIONS[self.version]["hash"]

0 comments on commit c800af1

Please sign in to comment.