Skip to content

Commit

Permalink
Fix variable naming and padding logic in PatchInferencerEngine for im…
Browse files Browse the repository at this point in the history
…proved clarity and functionality
  • Loading branch information
GabrielBG0 committed Nov 12, 2024
1 parent f5ae031 commit 78fdd78
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions minerva/engines/patch_inferencer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,13 @@ def _adjust_patches(
sl = []
ref_shape = list(ref_shape)
arr_shape = list(arrays[0].shape)
for idx, lenght, ref in zip([0, *offset], arr_shape, ref_shape):
for idx, length, ref in zip([0, *offset], arr_shape, ref_shape):
if idx > 0:
sl.append(slice(0, min(lenght, ref), None))
pad_width = [idx, max(ref - lenght - idx, 0)] + pad_width
sl.append(slice(0, min(length, ref), None))
pad_width = [idx, max(ref - length - idx, 0)] + pad_width
else:
sl.append(slice(np.abs(idx), min(lenght, ref - idx), None))
pad_width = [0, max(ref - lenght - idx, 0)] + pad_width
sl.append(slice(np.abs(idx), min(length, ref - idx), None))
pad_width = [0, max(ref - length - idx, 0)] + pad_width
adjusted = [
(
torch.nn.functional.pad(
Expand Down Expand Up @@ -251,6 +251,7 @@ def _combine_patches(
)
reconstructed.append(reconstruct)
weights.append(weight)
print(reconstruct.shape)
reconstructed = torch.stack(reconstructed, dim=0)
weights = torch.stack(weights, dim=0)
return torch.sum(reconstructed * weights, dim=0) / torch.sum(weights, dim=0)
Expand Down Expand Up @@ -319,7 +320,7 @@ def __call__(
slices = [
tuple(
[
slice(i + base, None) # TODO: if ((i + base >= 0) and (i < in_dim))
slice(i, None) # TODO: if ((i + base >= 0) and (i < in_dim))
for i, base, in_dim in zip([0, *offset], base, x.shape)
]
)
Expand All @@ -328,7 +329,7 @@ def __call__(

torch_pad = []
for pad_value in reversed(base):
torch_pad = torch_pad + [pad_value, pad_value]
torch_pad = torch_pad + [0, pad_value]
x_padded = torch.nn.functional.pad(
x,
pad=tuple(torch_pad),
Expand All @@ -349,7 +350,7 @@ def __call__(
else:
results.append(inference)
indexes.append(patch_idx)
output_slice = tuple([slice(0, lenght) for lenght in self.ref_shape])
output_slice = tuple([slice(0, length) for length in self.ref_shape])
if self.return_tuple:
comb_list = []
for i in range(self.return_tuple):
Expand Down

0 comments on commit 78fdd78

Please sign in to comment.