Skip to content

Commit

Permalink
Merge pull request #785 from ACEsuit/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
ilyes319 authored Jan 15, 2025
2 parents 6dce504 + d5e8a38 commit fca3022
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
2 changes: 1 addition & 1 deletion mace/modules/radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

@staticmethod
def calculate_envelope(
x: torch.Tensor, r_max: torch.Tensor, p: int
x: torch.Tensor, r_max: torch.Tensor, p: torch.Tensor
) -> torch.Tensor:
r_over_r_max = x / r_max
envelope = (
Expand Down
28 changes: 28 additions & 0 deletions tests/test_foundations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn.functional
from ase.build import molecule
from e3nn import o3
from e3nn.util import jit
from scipy.spatial.transform import Rotation as R

from mace import data, modules, tools
Expand Down Expand Up @@ -176,6 +177,33 @@ def test_multi_reference():
)


@pytest.mark.parametrize(
"calc",
[
mace_mp(device="cpu", default_dtype="float64"),
mace_mp(model="small", device="cpu", default_dtype="float64"),
mace_mp(model="medium", device="cpu", default_dtype="float64"),
mace_mp(model="large", device="cpu", default_dtype="float64"),
mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64"),
mace_off(model="small", device="cpu", default_dtype="float64"),
mace_off(model="medium", device="cpu", default_dtype="float64"),
mace_off(model="large", device="cpu", default_dtype="float64"),
mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64"),
],
)
def test_compile_foundation(calc):
model = calc.models[0]
atoms = molecule("CH4")
atoms.positions += np.random.randn(*atoms.positions.shape) * 0.1
batch = calc._atoms_to_batch(atoms)
output_1 = model(batch.to_dict())
model_compiled = jit.compile(model)
output = model_compiled(batch.to_dict())
for key in output_1.keys():
if isinstance(output_1[key], torch.Tensor):
assert torch.allclose(output_1[key], output[key], atol=1e-5)


@pytest.mark.parametrize(
"model",
[
Expand Down

0 comments on commit fca3022

Please sign in to comment.