-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement ZBL potential #134
Conversation
Looks good to me. Tests are failing though:
Also you might have some merge conflicts with the latest PR of Raimondas |
Wait a second, how does this error even occur? Shouldn't the |
Who runs the model knows what units want to use and set it for the priors
…On Fri, Oct 7, 2022, 16:28 Peter Eastman ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In torchmdnet/priors.py
<#134 (comment)>:
> @@ -76,3 +79,52 @@ def get_init_args(self):
def forward(self, x, z, pos, batch):
return x + self.atomref(z)
+
+
+class ZBL(BasePrior):
+ """This class implements the Ziegler-Biersack-Littmark (ZBL) potential for screened nuclear repulsion.
+ Is is described in https://doi.org/10.1007/978-3-642-68779-2_5 (equations 9 and 10 on page 147). It
+ is an empirical potential that does a good job of describing the repulsion between atoms at very short
+ distances.
+
+ To use this prior, the Dataset must provide the following attributes.
+
+ atomic_number: 1D tensor of length max_z. atomic_number[z] is the atomic number of atoms with atom type z.
+ distance_scale: multiply by this factor to convert coordinates stored in the dataset to meters
The neural network is agnostic with regard to units. It receives inputs
and tries to make its outputs match the ones in the dataset. It doesn't
know or care what units they're in.
A physics based potential does depend on units. It needs to know what
units the coordinates are in, and it needs to know what units the energy is
expected to be in.
—
Reply to this email directly, view it on GitHub
<#134 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AB3KUOUMNVJC6NDOJTC7W7LWCCBXVANCNFSM6AAAAAAQ67RI4M>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
The units are determined by the dataset. It isn't a free choice for the user to make. If your dataset contains positions in nm and energies in kJ/mol, that's what the prior needs to work with. Any other units would produce wrong results. We could consider creating an automated unit conversion system like I suggested in #26 (comment), but that's a separate project. |
Yes it should! I think it's a bug in the torchscript jit. I was able to work around it by splitting it into two lines. |
This is ready for review. |
The test is failing if you could take a look |
Fixed. I couldn't reproduce it locally, but I managed to guess what it was unhappy about. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine by me. Does anyone else want to comment?
tests/test_module.py
Outdated
@@ -43,7 +43,7 @@ def test_train(model_name, use_atomref, tmpdir): | |||
prior = None | |||
if use_atomref: | |||
prior = getattr(priors, args["prior_model"])(dataset=datamodule.dataset) | |||
args["prior_args"] = prior.get_init_args() | |||
args["prior_init_args"] = prior.get_init_args() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loading of the pretrained models (https://github.com/torchmd/torchmd-net/tree/main/examples#loading-checkpoints) fails:
from torchmdnet.models.model import load_model
load_model('ANI1-equivariant_transformer/epoch=359-val_loss=0.0004-test_loss=0.0120.ckpt')
AssertionError: Requested prior model Atomref but the arguments are lacking the key "prior_init_args".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It sounds like we want to redo how prior args are specified as described in #26 (comment). That means we can switch back to prior_args
for this value. There will still be compatibility issues, because it will need to become a list with args for multiple prior models, but I can add a check for that case for backward compatibility. I'll go ahead and make the changes in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be a good idea to add a test case for loading model checkpoints from a previous version.
I added the multiple priors support. I tried to make the syntax fairly flexible. All of the following are valid. prior_model: Atomref prior_model:
- ZBL:
cutoff_distance: 4.0
max_num_neighbors: 50
- Atomref prior_model:
- ZBL:
cutoff_distance: 4.0
max_num_neighbors: 50
Atomref: I added tests for creating models both from config files and from checkpoints. And the test for older checkpoints in test_module.py continues to work. I can't be sure I've caught all possible problems though, so if you can check your own files that would be helpful. |
The mechanism has been replaced with a different one.
Is this ok to merge? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is good to go!
This is the first piece of #26. It isn't fully tested yet, but I think it's ready for a first round of comments.
I created a mechanism for passing arguments to the prior.
prior_args
is now the option specified by the user in the config file.prior_init_args
stores the value returned byget_init_args()
, which contains the arguments needed for reconstructing it from a checkpoint.This prior requires the dataset to provide several pieces of information. To keep things general, the HDF5 format allows the file to contain a
_metadata
group which can store arbitrary pieces of information. Most of the other dataset classes should be able to hardcode the necessary values, since they aren't intended to be general.