-
Notifications
You must be signed in to change notification settings - Fork 61
/
utils.py
89 lines (81 loc) · 2.55 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Shared model-building components."""
from typing import Optional
import numpy as np
import torch
import random
from torch import nn
def get_id_train_val_test(
total_size=1000,
split_seed=123,
train_ratio=None,
val_ratio=0.1,
test_ratio=0.1,
n_train=None,
n_test=None,
n_val=None,
keep_data_order=False,
):
"""Get train, val, test IDs."""
if (
train_ratio is None
and val_ratio is not None
and test_ratio is not None
):
if train_ratio is None:
assert val_ratio + test_ratio < 1
train_ratio = 1 - val_ratio - test_ratio
print("Using rest of the dataset except the test and val sets.")
else:
assert train_ratio + val_ratio + test_ratio <= 1
# indices = list(range(total_size))
if n_train is None:
n_train = int(train_ratio * total_size)
if n_test is None:
n_test = int(test_ratio * total_size)
if n_val is None:
n_val = int(val_ratio * total_size)
ids = list(np.arange(total_size))
if not keep_data_order:
random.seed(split_seed)
random.shuffle(ids)
if n_train + n_val + n_test > total_size:
raise ValueError(
"Check total number of samples.",
n_train + n_val + n_test,
">",
total_size,
)
id_train = ids[:n_train]
id_val = ids[-(n_val + n_test) : -n_test]
id_test = ids[-n_test:]
return id_train, id_val, id_test
class RBFExpansion(nn.Module):
"""Expand interatomic distances with radial basis functions."""
def __init__(
self,
vmin: float = 0,
vmax: float = 8,
bins: int = 40,
lengthscale: Optional[float] = None,
):
"""Register torch parameters for RBF expansion."""
super().__init__()
self.vmin = vmin
self.vmax = vmax
self.bins = bins
self.register_buffer(
"centers", torch.linspace(self.vmin, self.vmax, self.bins)
)
if lengthscale is None:
# SchNet-style
# set lengthscales relative to granularity of RBF expansion
self.lengthscale = np.diff(self.centers).mean()
self.gamma = 1 / self.lengthscale
else:
self.lengthscale = lengthscale
self.gamma = 1 / (lengthscale ** 2)
def forward(self, distance: torch.Tensor) -> torch.Tensor:
"""Apply RBF expansion to interatomic distance tensor."""
return torch.exp(
-self.gamma * (distance.unsqueeze(1) - self.centers) ** 2
)