Skip to content

Commit

Permalink
Checking the weight grad as well
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Dec 18, 2023
1 parent fb75855 commit e364219
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,18 +197,22 @@ def test_disco_convolution(
torch.allclose(conv.psi.to_dense(), psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in))
)

# create a copy of the weight
w_ref = conv.weight.detach().clone()
w_ref.requires_grad_(True)

# create an input signal
x = torch.randn(batch_size, in_channels, *in_shape, requires_grad=True).to(self.device)

# perform the reference computation
x_ref = x.clone().detach()
x_ref.requires_grad_(True)
if transpose:
y_ref = torch.einsum("oif,biqr->bofqr", conv.weight, x_ref)
y_ref = torch.einsum("oif,biqr->bofqr", w_ref, x_ref)
y_ref = torch.einsum("fqrtp,bofqr->botp", psi_dense, y_ref * conv.quad_weights)
else:
y_ref = torch.einsum("ftpqr,bcqr->bcftp", psi_dense, x_ref * conv.quad_weights)
y_ref = torch.einsum("oif,biftp->botp", conv.weight, y_ref)
y_ref = torch.einsum("oif,biftp->botp", w_ref, y_ref)

# use the convolution module
y = conv(x)
Expand All @@ -220,9 +224,10 @@ def test_disco_convolution(
grad_input = torch.randn_like(y)
y_ref.backward(grad_input)
y.backward(grad_input)

# compare
self.assertTrue(torch.allclose(x.grad, x_ref.grad, rtol=tol, atol=tol))
self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol))

if __name__ == "__main__":
unittest.main()

0 comments on commit e364219

Please sign in to comment.