Skip to content

Commit

Permalink
add test for dataset caching
Browse files Browse the repository at this point in the history
  • Loading branch information
lubbersnick committed Aug 13, 2024
1 parent 864a8cb commit ac8a3c7
Showing 1 changed file with 112 additions and 0 deletions.
112 changes: 112 additions & 0 deletions tests/dataset_writing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import glob
import torch

torch.set_default_dtype(torch.float64)

from hippynn.databases.h5_pyanitools import PyAniFileDB
from hippynn.databases import NPZDatabase
from hippynn.tools import is_equal_state_dict


# compare if databases are equal, split by split
def eqsplit(db1, db2, raise_error=True):
return is_equal_state_dict(db1.splits, db2.splits, raise_where=raise_error)


if __name__ == "__main__":
CLEANUP = True # delete datasets afterwards
# Example dataset
location = "../../datasets/new_qm9_clean.npz"
seed = 1
num_workers = 0
db_info = {}
db1 = NPZDatabase(
file=location,
seed=seed,
num_workers=num_workers,
allow_unfound=True,
**db_info,
inputs=None,
targets=None,
)

# test remove_high_property
db1.remove_high_property("E", species_key="Z", atomwise=False, norm_per_atom=True, std_factor=5)

# throw stuff away
db1.make_random_split("random stuff", 0.99)
del db1.splits["random stuff"] # remove something at random

new_ani_file = "TEST_clean_ani1x.h5"
new_npz_file = "TEST_clean_ani1x.npz"

# Divide up the dataset in a bunch of splits.
db1.make_random_split("first", 0.5)
db1.make_random_split("second", 0.2)
db1.make_random_split("third", 3) # integer
db1.split_the_rest("remaining")
# write an npz version and reload it.
db1.write_npz(file=new_npz_file, record_split_masks=True, overwrite=True)
db3 = NPZDatabase(file=new_npz_file, seed=seed, num_workers=num_workers, allow_unfound=True, inputs=None, targets=None, auto_split=True, **db_info)

# write an h5 version and reload it.
db1.write_h5(split=True, h5path=new_ani_file, species_key="Z", overwrite=True)
db2 = PyAniFileDB(
file=new_ani_file,
species_key="Z",
seed=seed,
num_workers=num_workers,
allow_unfound=True,
**db_info,
inputs=None,
targets=None,
auto_split=True,
)
new_ani_filetwo = "TEST_clean_ani1x_2.h5"
# trim this dataset earlier than others.
db2.trim_by_species("Z")
# write and load new database.
db2.write_h5(split=True, h5path=new_ani_filetwo, species_key="Z", overwrite=True)
db4 = PyAniFileDB(
file=new_ani_filetwo,
species_key="Z",
seed=seed,
num_workers=num_workers,
allow_unfound=True,
**db_info,
inputs=None,
targets=None,
auto_split=True,
)

for i, d in enumerate([db1, db2, db3, db4]):
print("sorting", i)
d.sort_by_index()
print("trimming", i)
d.trim_by_species("Z", keep_splits_same_size=True) # can do either way if disable caching test.

# "sys_number" comes from h5 format databases, but not present in others.
for d in [db2, db4]:
for s in d.splits:
del d.splits[s]["sys_number"]

db1.add_split_masks() # this first didn't ever get split masks! add them now.

print("NPZ Equality:", eqsplit(db1, db3))
print("H5 Equality:", eqsplit(db2, db4))
print("NPZ-H5 Equality1:", eqsplit(db1, db2))
print("NPZ-H5 Equality2:", eqsplit(db2, db3))

# test caching routine.
db2p = db2.make_database_cache(file="TEST_CACHE_FROMH5.npz", overwrite=True, quiet=True)
print("H5 to cache Equality:", eqsplit(db2, db2p))
db3p = db3.make_database_cache(file="TEST_CACHE_FROMNPZ.npz", overwrite=True, quiet=True)
print("NPZ to cache Equality:", eqsplit(db3, db3p))

for ext in ["npz", "h5"]:

for file in glob.glob(f"./TEST*.{ext}"):
print(file)
import os

os.remove(file)

0 comments on commit ac8a3c7

Please sign in to comment.