-
Notifications
You must be signed in to change notification settings - Fork 4
/
utils.py
29 lines (19 loc) · 841 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
import torch
def parameterized_truncated_normal(uniform, mu, sigma, a, b):
normal = torch.distributions.normal.Normal(0, 1)
alpha = (a - mu) / sigma
beta = (b - mu) / sigma
alpha_normal_cdf = normal.cdf(alpha)
p = alpha_normal_cdf + (normal.cdf(beta) - alpha_normal_cdf) * uniform
p = p.numpy()
one = np.array(1, dtype=p.dtype)
epsilon = np.array(np.finfo(p.dtype).eps, dtype=p.dtype)
v = np.clip(2 * p - 1, -one + epsilon, one - epsilon)
x = mu + sigma * np.sqrt(2) * torch.erfinv(torch.from_numpy(v))
x = torch.clamp(x, a, b)
return x
def truncated_normal(uniform):
return parameterized_truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2)
def sample_truncated_normal(shape=()):
return truncted_normal(torch.from_numpy(np.random.uniform(0, 1, shape)))