diff --git a/kornia/feature/laf.py b/kornia/feature/laf.py index af5bc053de..82ac0086d4 100644 --- a/kornia/feature/laf.py +++ b/kornia/feature/laf.py @@ -438,11 +438,12 @@ def extract_patches_from_pyramid( scale_mask = (pyr_idx[i] == cur_pyr_level).squeeze() if (scale_mask.float().sum().item()) == 0: continue - scale_mask = (scale_mask > 0).view(-1) + scale_mask = (scale_mask > 0).view(-1).to(nlaf.dtype).to(nlaf.device) grid = generate_patch_grid_from_normalized_LAF(cur_img[i : i + 1], nlaf[i : i + 1, scale_mask, :, :], PS) patches = F.grid_sample( cur_img[i : i + 1].expand(grid.shape[0], ch, h, w), grid, padding_mode="border", align_corners=False - ) + ).to(nlaf.dtype).to(nlaf.device) + out[i].masked_scatter_(scale_mask.view(-1, 1, 1, 1), patches) we_are_in_business = min(cur_img.size(2), cur_img.size(3)) >= PS if not we_are_in_business: