Skip to content

Commit

Permalink
Bbonev/disco refactor (#28)
Browse files Browse the repository at this point in the history
Changed the code to only implicitly use sparse tensors in the modules, in order to enable compatibility with DDP
  • Loading branch information
bonevbs authored Dec 22, 2023
1 parent 942aa4e commit c971d45
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
4 changes: 3 additions & 1 deletion tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,10 @@ def test_disco_convolution(
in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff
).to(self.device)

psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense()

self.assertTrue(
torch.allclose(conv.psi.to_dense(), psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in))
torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in))
)

# create a copy of the weight
Expand Down
32 changes: 20 additions & 12 deletions torch_harmonics/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,12 @@ def __init__(
idx, vals = _precompute_convolution_tensor(
in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff
)
psi = torch.sparse_coo_tensor(
idx, vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)
).coalesce()
self.register_buffer("psi", psi, persistent=False)
# psi = torch.sparse_coo_tensor(
# idx, vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)
# ).coalesce()
self.register_buffer("psi_idx", idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
# self.register_buffer("psi", psi, persistent=False)

# groups
self.groups = groups
Expand All @@ -248,10 +250,12 @@ def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tens
# pre-multiply x with the quadrature weights
x = self.quad_weights * x

psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()

if x.is_cuda and use_triton_kernel:
x = _disco_s2_contraction_triton(x, self.psi, self.nlon_out)
x = _disco_s2_contraction_triton(x, psi, self.nlon_out)
else:
x = _disco_s2_contraction_torch(x, self.psi, self.nlon_out)
x = _disco_s2_contraction_torch(x, psi, self.nlon_out)

# extract shape
B, C, K, H, W = x.shape
Expand Down Expand Up @@ -317,10 +321,12 @@ def __init__(
idx, vals = _precompute_convolution_tensor(
out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff
)
psi = torch.sparse_coo_tensor(
idx, vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)
).coalesce()
self.register_buffer("psi", psi, persistent=False)
# psi = torch.sparse_coo_tensor(
# idx, vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)
# ).coalesce()
self.register_buffer("psi_idx", idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
# self.register_buffer("psi", psi, persistent=False)

# groups
self.groups = groups
Expand Down Expand Up @@ -351,10 +357,12 @@ def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tens
# pre-multiply x with the quadrature weights
x = self.quad_weights * x

psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()

if x.is_cuda and use_triton_kernel:
out = _disco_s2_transpose_contraction_triton(x, self.psi, self.nlon_out)
out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out)
else:
out = _disco_s2_transpose_contraction_torch(x, self.psi, self.nlon_out)
out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)

if self.bias is not None:
out = out + self.bias.reshape(1, -1, 1, 1)
Expand Down

0 comments on commit c971d45

Please sign in to comment.