Skip to content

Commit

Permalink
Readability improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Apr 17, 2024
1 parent fb4ce31 commit ee62cce
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions torch_harmonics/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,31 +117,41 @@ def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: in

# find the indices where the rotated position falls into the support of the kernel
if nr % 2 == 1:
# find the support
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
# find indices where conditions are met
iidx = torch.argwhere(cond_r & cond_phi)
vals = (1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr) / norm_factor
# compute the distance to the collocation points
dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
# compute the value of the basis functions
vals = (1 - dist_r / dr) / norm_factor
vals *= torch.where(
(iidx[:, 0] > 0),
(1 - torch.minimum((phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), (2 * math.pi - (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs())) / dphi),
(1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi),
1.0,
)
else:
# in the even case, the inner casis functions overlap into areas with a negative areas
rn = - r
phin = torch.where(phi + math.pi >= 2*math.pi, phi - math.pi, phi + math.pi)

# find the support
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)

# find indices where conditions are met
iidx = torch.argwhere((cond_r & cond_phi) | (cond_rn & cond_phin))
vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr) / norm_factor
vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum((phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), (2 * math.pi - (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs())) / dphi)
valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - (rn[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr) / norm_factor
valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum((phin[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), (2 * math.pi - (phin[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs())) / dphi)

dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
dist_rn = (rn[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phin = (phin[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
# compute the value of the basis functions
vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr) / norm_factor
vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi)
valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr) / norm_factor
valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phin, (2 * math.pi - dist_phin)) / dphi)
vals += valsn

return iidx, vals
Expand Down

0 comments on commit ee62cce

Please sign in to comment.