Skip to content

Commit

Permalink
Remove dtype parameter, use previously existing "precision" instead (#…
Browse files Browse the repository at this point in the history
…208)

* Remove dtype parameter, use previously existing "precision" instead

* Do not store dtype in args when creating the model

* Wrap the dataset in the DataLoader to cast data to the requested precision

* Inherit every member from the wrapped datset when casting to other
float precision

* blacken

* Add tests for double precision training

* Remove unnecessary default

* Add precision to a test

* Fix a test
  • Loading branch information
RaulPPelaez authored Aug 8, 2023
1 parent 4645fa4 commit dca6679
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 26 deletions.
2 changes: 1 addition & 1 deletion examples/ET-QM9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ train_size: 110000
trainable_rbf: false
val_size: 10000
weight_decay: 0.0
dtype: float
precision: 32
21 changes: 11 additions & 10 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@
from torchmdnet import models
from torchmdnet.models.model import create_model
from torchmdnet.models import output_modules
from torchmdnet.models.utils import dtype_mapping

from utils import load_example_args, create_example_batch


@mark.parametrize("model_name", models.__all__)
@mark.parametrize("use_batch", [True, False])
@mark.parametrize("explicit_q_s", [True, False])
@mark.parametrize("dtype", [torch.float32, torch.float64])
def test_forward(model_name, use_batch, explicit_q_s, dtype):
@mark.parametrize("precision", [32, 64])
def test_forward(model_name, use_batch, explicit_q_s, precision):
z, pos, batch = create_example_batch()
pos = pos.to(dtype=dtype)
model = create_model(load_example_args(model_name, prior_model=None, dtype=dtype))
pos = pos.to(dtype=dtype_mapping[precision])
model = create_model(load_example_args(model_name, prior_model=None, precision=precision))
batch = batch if use_batch else None
if explicit_q_s:
model(z, pos, batch=batch, q=None, s=None)
Expand All @@ -28,10 +29,10 @@ def test_forward(model_name, use_batch, explicit_q_s, dtype):

@mark.parametrize("model_name", models.__all__)
@mark.parametrize("output_model", output_modules.__all__)
@mark.parametrize("dtype", [torch.float32, torch.float64])
def test_forward_output_modules(model_name, output_model, dtype):
@mark.parametrize("precision", [32,64])
def test_forward_output_modules(model_name, output_model, precision):
z, pos, batch = create_example_batch()
args = load_example_args(model_name, remove_prior=True, output_model=output_model, dtype=dtype)
args = load_example_args(model_name, remove_prior=True, output_model=output_model, precision=precision)
model = create_model(args)
model(z, pos, batch=batch)

Expand Down Expand Up @@ -146,7 +147,7 @@ def test_forward_output(model_name, output_model, overwrite_reference=False):
@mark.parametrize("model_name", models.__all__)
def test_gradients(model_name):
pl.seed_everything(1234)
dtype = torch.float64
precision = 64
output_model = "Scalar"
# create model and sample batch
derivative = output_model in ["Scalar", "EquivariantScalar"]
Expand All @@ -155,12 +156,12 @@ def test_gradients(model_name):
remove_prior=True,
output_model=output_model,
derivative=derivative,
dtype=dtype,
precision=precision
)
model = create_model(args)
z, pos, batch = create_example_batch(n_atoms=5)
pos.requires_grad_(True)
pos = pos.to(dtype)
pos = pos.to(torch.float64)
torch.autograd.gradcheck(
model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3
)
6 changes: 4 additions & 2 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def test_load_model():

@mark.parametrize("model_name", models.__all__)
@mark.parametrize("use_atomref", [True, False])
def test_train(model_name, use_atomref, tmpdir):
@mark.parametrize("precision", [32, 64])
def test_train(model_name, use_atomref, precision, tmpdir):
args = load_example_args(
model_name,
remove_prior=not use_atomref,
Expand All @@ -37,6 +38,7 @@ def test_train(model_name, use_atomref, tmpdir):
num_layers=2,
num_rbf=16,
batch_size=8,
precision=precision,
)
datamodule = DataModule(args, DummyDataset(has_atomref=use_atomref))

Expand All @@ -47,6 +49,6 @@ def test_train(model_name, use_atomref, tmpdir):

module = LNNP(args, prior_model=prior)

trainer = pl.Trainer(max_steps=10, default_root_dir=tmpdir)
trainer = pl.Trainer(max_steps=10, default_root_dir=tmpdir, precision=args["precision"])
trainer.fit(module, datamodule)
trainer.test(module, datamodule)
5 changes: 3 additions & 2 deletions tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch as pt
from torchmdnet.models.model import create_model
from torchmdnet.optimize import optimize

from torchmdnet.models.utils import dtype_mapping

@mark.parametrize("device", ["cpu", "cuda"])
@mark.parametrize("num_atoms", [10, 100])
Expand Down Expand Up @@ -39,6 +39,7 @@ def test_gn(device, num_atoms):
"prior_model": None,
"output_model": "Scalar",
"reduce_op": "add",
"precision": 32,
}
ref_model = create_model(args).to(device)

Expand All @@ -47,7 +48,7 @@ def test_gn(device, num_atoms):

# Optimize the model
model = optimize(ref_model).to(device)

positions.to(dtype_mapping[args["precision"]])
# Execute the optimize model
energy, gradient = model(elements, positions)

Expand Down
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs
config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml")
with open(config_file, "r") as f:
args = yaml.load(f, Loader=yaml.FullLoader)
if "dtype" not in args:
args["dtype"] = "float"
if "precision" not in args:
args["precision"] = 32
args["model"] = model_name
args["seed"] = 1234
if remove_prior:
Expand Down
34 changes: 33 additions & 1 deletion torchmdnet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,37 @@
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities import rank_zero_warn
from torchmdnet import datasets
from torch_geometric.data import Dataset
from torchmdnet.utils import make_splits, MissingEnergyException
from torch_scatter import scatter
from torchmdnet.models.utils import dtype_mapping


class FloatCastDatasetWrapper(Dataset):
def __init__(self, dataset, dtype=torch.float64):
super(FloatCastDatasetWrapper, self).__init__(
dataset.root, dataset.transform, dataset.pre_transform, dataset.pre_filter
)
self.dataset = dataset
self.dtype = dtype

def len(self):
return len(self.dataset)

def get(self, idx):
data = self.dataset.get(idx)
for key, value in data:
if torch.is_tensor(value) and torch.is_floating_point(value):
setattr(data, key, value.to(self.dtype))
return data

def __getattr__(self, name):
# Check if the attribute exists in the underlying dataset
if hasattr(self.dataset, name):
return getattr(self.dataset, name)
raise AttributeError(
f"'{type(self).__name__}' and its underlying dataset have no attribute '{name}'"
)


class DataModule(LightningDataModule):
Expand All @@ -34,6 +63,9 @@ def setup(self, stage):
self.dataset = getattr(datasets, self.hparams["dataset"])(
self.hparams["dataset_root"], **dataset_arg
)
self.dataset = FloatCastDatasetWrapper(
self.dataset, dtype_mapping[self.hparams["precision"]]
)

self.idx_train, self.idx_val, self.idx_test = make_splits(
len(self.dataset),
Expand Down Expand Up @@ -62,7 +94,7 @@ def val_dataloader(self):
loaders = [self._get_dataloader(self.val_dataset, "val")]
if (
len(self.test_dataset) > 0
and (self.trainer.current_epoch+1) % self.hparams["test_interval"] == 0
and (self.trainer.current_epoch + 1) % self.hparams["test_interval"] == 0
):
loaders.append(self._get_dataloader(self.test_dataset, "test"))
return loaders
Expand Down
9 changes: 4 additions & 5 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
-------
nn.Module: An instance of the TorchMD_Net model.
"""
args["dtype"] = "float32" if "dtype" not in args else args["dtype"]
args["dtype"] = dtype_mapping[args["dtype"]] if isinstance(args["dtype"], str) else args["dtype"]
dtype = dtype_mapping[args["precision"]]
shared_args = dict(
hidden_channels=args["embedding_dimension"],
num_layers=args["num_layers"],
Expand All @@ -38,7 +37,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
cutoff_upper=args["cutoff_upper"],
max_z=args["max_z"],
max_num_neighbors=args["max_num_neighbors"],
dtype=args["dtype"]
dtype=dtype
)

# representation network
Expand Down Expand Up @@ -102,7 +101,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
args["embedding_dimension"],
activation=args["activation"],
reduce_op=args["reduce_op"],
dtype=args["dtype"],
dtype=dtype,
)

# combine representation and output network
Expand All @@ -113,7 +112,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
mean=mean,
std=std,
derivative=args["derivative"],
dtype=args["dtype"],
dtype=dtype,
)
return model

Expand Down
2 changes: 1 addition & 1 deletion torchmdnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,4 +526,4 @@ def forward(self, x, v):
"sigmoid": nn.Sigmoid,
}

dtype_mapping = {"float": torch.float, "double": torch.float64, "float32": torch.float32, "float64": torch.float64}
dtype_mapping = {16: torch.float16, 32: torch.float, 64: torch.float64}
3 changes: 1 addition & 2 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_args():
parser.add_argument('--ema-alpha-neg-dy', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of dy')
parser.add_argument('--ngpus', type=int, default=-1, help='Number of GPUs, -1 use all available. Use CUDA_VISIBLE_DEVICES=1, to decide gpus')
parser.add_argument('--num-nodes', type=int, default=1, help='Number of nodes')
parser.add_argument('--precision', type=int, default=32, choices=[16, 32], help='Floating point precision')
parser.add_argument('--precision', type=int, default=32, choices=[16, 32, 64], help='Floating point precision')
parser.add_argument('--log-dir', '-l', default='/tmp/logs', help='log file')
parser.add_argument('--splits', default=None, help='Npz with splits idx_train, idx_val, idx_test')
parser.add_argument('--train-size', type=number, default=None, help='Percentage/number of samples in training set (None to use all remaining samples)')
Expand Down Expand Up @@ -67,7 +67,6 @@ def get_args():
parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use')

# architectural args
parser.add_argument('--dtype', type=str, default="float32", choices=list(dtype_mapping.keys()), help='Floating point precision. Can be float32 or float64')
parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge')
parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state')
parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension')
Expand Down

0 comments on commit dca6679

Please sign in to comment.