Skip to content

Commit

Permalink
test_last
Browse files Browse the repository at this point in the history
  • Loading branch information
Vepricov committed Nov 5, 2024
1 parent 9851c58 commit 4804071
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,25 @@
GaussianRelaxedBernoulli
)

def test_rsample():
# a = torch.tensor([0.2, 0.4, 0.3, 0.1], requires_grad=True)
def test_GaussianRelaxedBernoulli():
loc = torch.tensor([0.], requires_grad=True)
scale = torch.tensor([1.], requires_grad=True)

### rsample test ###
distr = GaussianRelaxedBernoulli(loc = loc, scale=scale)
samples = distr.rsample(sample_shape = torch.Size([3]))
assert samples.shape == torch.Size([3, 1])
assert samples.requires_grad == True
print("GaussianRelaxedBernoulli is OK")

def test_HardConcrete():
alpha = torch.tensor([1.], requires_grad=True)
beta = torch.tensor([2.], requires_grad=True)
gamma = torch.tensor([-3.], requires_grad=True)
xi = torch.tensor([4.], requires_grad=True)

distr_2 = GaussianRelaxedBernoulli(loc = loc, scale=scale)
samples_2 = distr_2.rsample(sample_shape = torch.Size([3]))
assert samples_2.shape == torch.Size([3, 1])
assert samples_2.requires_grad == True
distr_3 = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi)
samples_3 = distr_3.rsample(sample_shape = torch.Size([3]))
assert samples_3.shape == torch.Size([3, 1])
assert samples_3.requires_grad == True
assert samples_3.requires_grad == False
print("rsample is OK")
distr = HardConcrete(alpha=alpha, beta=beta, gamma=gamma, xi=xi)
samples = distr.rsample(sample_shape = torch.Size([3]))
assert samples.shape == torch.Size([3, 1])
assert samples.requires_grad == True
print("HardConcrete is OK")

0 comments on commit 4804071

Please sign in to comment.