From 78fdd7863bc4e4c885b60d0047847943708efe81 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Tue, 12 Nov 2024 16:18:08 -0300 Subject: [PATCH] Fix variable naming and padding logic in PatchInferencerEngine for improved clarity and functionality --- minerva/engines/patch_inferencer_engine.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/minerva/engines/patch_inferencer_engine.py b/minerva/engines/patch_inferencer_engine.py index c730794..49a99ea 100644 --- a/minerva/engines/patch_inferencer_engine.py +++ b/minerva/engines/patch_inferencer_engine.py @@ -213,13 +213,13 @@ def _adjust_patches( sl = [] ref_shape = list(ref_shape) arr_shape = list(arrays[0].shape) - for idx, lenght, ref in zip([0, *offset], arr_shape, ref_shape): + for idx, length, ref in zip([0, *offset], arr_shape, ref_shape): if idx > 0: - sl.append(slice(0, min(lenght, ref), None)) - pad_width = [idx, max(ref - lenght - idx, 0)] + pad_width + sl.append(slice(0, min(length, ref), None)) + pad_width = [idx, max(ref - length - idx, 0)] + pad_width else: - sl.append(slice(np.abs(idx), min(lenght, ref - idx), None)) - pad_width = [0, max(ref - lenght - idx, 0)] + pad_width + sl.append(slice(np.abs(idx), min(length, ref - idx), None)) + pad_width = [0, max(ref - length - idx, 0)] + pad_width adjusted = [ ( torch.nn.functional.pad( @@ -251,6 +251,7 @@ def _combine_patches( ) reconstructed.append(reconstruct) weights.append(weight) + print(reconstruct.shape) reconstructed = torch.stack(reconstructed, dim=0) weights = torch.stack(weights, dim=0) return torch.sum(reconstructed * weights, dim=0) / torch.sum(weights, dim=0) @@ -319,7 +320,7 @@ def __call__( slices = [ tuple( [ - slice(i + base, None) # TODO: if ((i + base >= 0) and (i < in_dim)) + slice(i, None) # TODO: if ((i + base >= 0) and (i < in_dim)) for i, base, in_dim in zip([0, *offset], base, x.shape) ] ) @@ -328,7 +329,7 @@ def __call__( torch_pad = [] for pad_value in reversed(base): - torch_pad = torch_pad + [pad_value, pad_value] + torch_pad = torch_pad + [0, pad_value] x_padded = torch.nn.functional.pad( x, pad=tuple(torch_pad), @@ -349,7 +350,7 @@ def __call__( else: results.append(inference) indexes.append(patch_idx) - output_slice = tuple([slice(0, lenght) for lenght in self.ref_shape]) + output_slice = tuple([slice(0, length) for length in self.ref_shape]) if self.return_tuple: comb_list = [] for i in range(self.return_tuple):