From 48601fb1ea22bd5277d1f12372152b9ebef888a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Ser=C3=B3dio?= Date: Fri, 26 Jul 2024 14:36:18 -0300 Subject: [PATCH] PatchInferencer fixes --- minerva/utils/patches.py | 145 ++++++++++++++++++++++----------------- 1 file changed, 83 insertions(+), 62 deletions(-) diff --git a/minerva/utils/patches.py b/minerva/utils/patches.py index a8abf92..5a670d5 100644 --- a/minerva/utils/patches.py +++ b/minerva/utils/patches.py @@ -1,23 +1,23 @@ from typing import List, Tuple, Optional, Dict, Any import torch import numpy as np -from torchmetrics import Accuracy import lightning as L class BasePatchInferencer: """Inference in patches for models - This class provides utilitary methods for performing inference in patches + This class provides utility methods for performing inference in patches """ def __init__( self, model: L.LightningModule, input_shape: Tuple, - weight_function: Optional[function], - offsets: Optional[List[Tuple]], - padding: Optional[Dict[str, Any]], + output_shape: Optional[Tuple] = None, + weight_function: Optional[callable] = None, + offsets: Optional[List[Tuple]] = None, + padding: Optional[Dict[str, Any]] = None, ): """Initialize the patch inference auxiliary class @@ -27,7 +27,9 @@ def __init__( Model used in inference. input_shape : Tuple Expected input shape of the model - weight_function: function, optional + output_shape : Tuple, optional + Expected output shape of the model. Defaults to input_shape + weight_function: callable, optional Function that receives a tensor shape and returns the weights for each position of a tensor with the given shape Useful when regions of the inference present diminishing performance when getting closer to borders, for instance. offsets : Tuple, optional @@ -40,6 +42,7 @@ def __init__( """ self.model = model self.input_shape = input_shape + self.output_shape = output_shape if output_shape is not None else input_shape self.weight_function = weight_function if offsets is not None: @@ -47,7 +50,9 @@ def __init__( assert len(input_shape) == len( offset ), f"Offset tuple does not match expected size ({len(input_shape)})" - self.offsets = offsets + self.offsets = offsets + else: + self.offsets = [] if padding is not None: assert len(input_shape) == len( @@ -57,6 +62,9 @@ def __init__( else: self.padding = {"pad": tuple([0] * len(input_shape))} + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self.forward(x) + def _reconstruct_patches( self, patches: torch.Tensor, @@ -67,23 +75,23 @@ def _reconstruct_patches( """ Rearranges patches to reconstruct area of interest from patches and weights """ - reconstruct_shape = np.array(self._input_size) * np.array(index) + reconstruct_shape = np.array(self.output_shape) * np.array(index) if weights: - weight = torch.zeros(reconstruct_shape) + weight = torch.zeros(tuple(reconstruct_shape)) base_weight = ( - self._weight_function(self._input_size) - if self._weight_function - else torch.ones(self._input_size) + self.weight_function(self.input_shape) + if self.weight_function + else torch.ones(self.input_shape) ) else: weight = None if inner_dim is not None: reconstruct_shape = np.append(reconstruct_shape, inner_dim) - reconstruct = torch.zeros(reconstruct_shape) + reconstruct = torch.zeros(tuple(reconstruct_shape)) for patch_index, patch in zip(np.ndindex(index), patches): sl = [ slice(idx * patch_len, (idx + 1) * patch_len, None) - for idx, patch_len in zip(patch_index, self._input_size) + for idx, patch_len in zip(patch_index, self.input_shape) ] if weights: weight[tuple(sl)] = base_weight @@ -98,17 +106,16 @@ def _adjust_patches( ref_shape: Tuple[int], offset: Tuple[int], pad_value: int = 0, - ) -> torch.Tensor: + ) -> List[torch.Tensor]: """ Pads reconstructed_patches with 'pad_value' to have same shape as the reference shape from the base patch set """ - has_inner_dim = len(offset) < len(ref_shape) + has_inner_dim = len(offset) < len(arrays[0].shape) pad_width = [] sl = [] ref_shape = list(ref_shape) arr_shape = list(arrays[0].shape) if has_inner_dim: - ref_shape = ref_shape[:-1] arr_shape = arr_shape[:-1] for idx, lenght, ref in zip(offset, arr_shape, ref_shape): if idx > 0: @@ -119,16 +126,16 @@ def _adjust_patches( pad_width = [0, max(ref - lenght - idx, 0)] + pad_width adjusted = [ ( - torch.pad( + torch.nn.functional.pad( arr[tuple([*sl, slice(None, None, None)])], - pad_width=[0, 0, *pad_width], + pad=tuple([0, 0, *pad_width]), mode="constant", value=pad_value, ) if has_inner_dim - else np.pad( + else torch.nn.functional.pad( arr[tuple(sl)], - pad_width=pad_width, + pad=tuple(pad_width), mode="constant", value=pad_value, ) @@ -160,11 +167,25 @@ def _extract_patches( patches = [] for patch_index in np.ndindex(indexes): sl = [ - slice(idx, idx + 1, patch_len) + slice(idx * patch_len, (idx + 1) * patch_len, None) for idx, patch_len in zip(patch_index, patch_shape) ] patches.append(data[tuple(sl)]) - return torch.Tensor(patches), indexes + return torch.stack(patches), indexes + + def _compute_output_shape(self, tensor: torch.Tensor) -> Tuple[int]: + """ + Computes PatchInferencer output shape based on input tensor shape, and model's input and output shapes. + """ + if self.input_shape == self.output_shape: + return tensor.shape + shape = [] + for i, o, t in zip(self.input_shape, self.output_shape, tensor.shape): + if i != o: + shape.append(int(t * o // i)) + else: + shape.append(t) + return tuple(shape) def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -178,14 +199,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: assert len(x.shape) == len( self.input_shape ), "Input and self.input_shape sizes must match" - offsets = list(self._offsets) + + self.ref_shape = self._compute_output_shape(x) + offsets = list(self.offsets) base = self.padding["pad"] offsets.insert(0, tuple([0] * len(base))) slices = [ tuple( [ - slice(i + base) # TODO: if ((i + base >= 0) and (i < in_dim)) + slice(i + base, None) # TODO: if ((i + base >= 0) and (i < in_dim)) for i, base, in_dim in zip(offset, base, x.shape) ] ) @@ -204,7 +227,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: results = [] indexes = [] for sl in slices: - patch_set, patch_idx = self._extract_patches(x_padded[sl], self.input_size) + patch_set, patch_idx = self._extract_patches(x_padded[sl], self.input_shape) results.append(self.model(patch_set)) indexes.append(patch_idx) output_slice = tuple( @@ -230,12 +253,10 @@ def _combine_patches( reconstruct, weight = self._reconstruct_patches( patches, shape, weights=True ) - if len(reconstructed) > 0: - adjusted = self._adjust_patches( - [reconstruct, weight], reconstructed[0].shape, offset - ) - reconstruct = adjusted[0] - weight = adjusted[1] + reconstruct, weight = self._adjust_patches( + [reconstruct, weight], self.ref_shape, offset + ) + reconstructed.append(reconstruct) weights.append(weight) reconstructed = torch.stack(reconstructed, dim=0) @@ -246,16 +267,18 @@ def _combine_patches( class VotingPatchInferencer(BasePatchInferencer): """ PatchInferencer with Voting combination function. + Note: Models used with VotingPatchInferencer must return class probabilities in inner dimension """ def __init__( self, model: L.LightningModule, input_shape: Tuple, - weight_function: Optional[function], - offsets: Optional[List[Tuple]], - padding: Optional[Dict[str, Any]], - num_classes: int, + output_shape: Optional[Tuple] = None, + weight_function: Optional[callable] = None, + offsets: Optional[List[Tuple]] = None, + padding: Optional[Dict[str, Any]] = None, + num_classes: int = None, voting: str = "soft", ): """Initialize the patch inference auxiliary class @@ -266,7 +289,9 @@ def __init__( Model used in inference. input_shape : Tuple Expected input shape of the model - weight_function: function, optional + output_shape : Tuple, optional + Expected output shape of the model. Defaults to input_shape + weight_function: callable, optional Function that receives a tensor shape and returns the weights for each position of a tensor with the given shape Useful when regions of the inference present diminishing performance when getting closer to borders, for instance. offsets : Tuple, optional @@ -281,9 +306,9 @@ def __init__( voting: str voting method to use, can be either 'soft'or 'hard'. Defaults to 'soft'. """ - super().__init__(model, input_shape, weight_function, offsets, padding) - self.model = model - self.input_shape = input_shape + super().__init__( + model, input_shape, output_shape, weight_function, offsets, padding + ) assert voting in ["soft", "hard"], "voting should be either 'soft' or 'hard'" self.num_classes = num_classes self.voting = voting @@ -306,23 +331,23 @@ def _hard_voting( """ Hard voting combination function """ - reconstructed = [] - for patches, offset, shape in zip(results, offsets, indexes): - reconstruct, _ = self._reconstruct_patches( - patches, shape, weights=False, inner_dim=self.num_classes - ) - reconstruct = torch.argmax(reconstruct, dim=-1).astype(torch.float32) - if len(reconstructed) > 0: - adjusted = self._adjust_patches( - [reconstruct], reconstructed[0].shape, offset, pad_value=torch.nan - ) - reconstruct = adjusted[0] - reconstructed.append(reconstruct) - reconstructed = torch.stack(reconstructed, dim=0) - ret = torch.mode(reconstructed, dim=0, keepdims=False)[ - 0 - ] # TODO check behaviour on GPU, according to issues may have nonsense results - return ret + # torch.mode does not work like scipy.stats.mode + raise NotImplementedError("Hard voting not yet supported") + # reconstructed = [] + # for patches, offset, shape in zip(results, offsets, indexes): + # reconstruct, _ = self._reconstruct_patches( + # patches, shape, weights=False, inner_dim=self.num_classes + # ) + # reconstruct = torch.argmax(reconstruct, dim=-1).float() + # reconstruct = self._adjust_patches( + # [reconstruct], self.ref_shape, offset, pad_value=torch.nan + # )[0] + # reconstructed.append(reconstruct) + # reconstructed = torch.stack(reconstructed, dim=0) + # ret = torch.mode(reconstructed, dim=0, keepdims=False)[ + # 0 + # ] # TODO check behaviour on GPU, according to issues may have nonsense results + # return ret def _soft_voting( self, @@ -338,11 +363,7 @@ def _soft_voting( reconstruct, _ = self._reconstruct_patches( patches, shape, weights=False, inner_dim=self.num_classes ) - if len(reconstructed) > 0: - adjusted = self._adjust_patches( - [reconstruct], reconstructed[0].shape, offset - ) - reconstruct = adjusted[0] + reconstruct = self._adjust_patches([reconstruct], self.ref_shape, offset)[0] reconstructed.append(reconstruct) reconstructed = torch.stack(reconstructed, dim=0) return torch.argmax(torch.sum(reconstructed, dim=0), dim=-1)