Skip to content

Commit

Permalink
PatchInferencer fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
SerodioJ committed Jul 26, 2024
1 parent 6fad88a commit 48601fb
Showing 1 changed file with 83 additions and 62 deletions.
145 changes: 83 additions & 62 deletions minerva/utils/patches.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
from typing import List, Tuple, Optional, Dict, Any
import torch
import numpy as np
from torchmetrics import Accuracy
import lightning as L


class BasePatchInferencer:
"""Inference in patches for models
This class provides utilitary methods for performing inference in patches
This class provides utility methods for performing inference in patches
"""

def __init__(
self,
model: L.LightningModule,
input_shape: Tuple,
weight_function: Optional[function],
offsets: Optional[List[Tuple]],
padding: Optional[Dict[str, Any]],
output_shape: Optional[Tuple] = None,
weight_function: Optional[callable] = None,
offsets: Optional[List[Tuple]] = None,
padding: Optional[Dict[str, Any]] = None,
):
"""Initialize the patch inference auxiliary class
Expand All @@ -27,7 +27,9 @@ def __init__(
Model used in inference.
input_shape : Tuple
Expected input shape of the model
weight_function: function, optional
output_shape : Tuple, optional
Expected output shape of the model. Defaults to input_shape
weight_function: callable, optional
Function that receives a tensor shape and returns the weights for each position of a tensor with the given shape
Useful when regions of the inference present diminishing performance when getting closer to borders, for instance.
offsets : Tuple, optional
Expand All @@ -40,14 +42,17 @@ def __init__(
"""
self.model = model
self.input_shape = input_shape
self.output_shape = output_shape if output_shape is not None else input_shape
self.weight_function = weight_function

if offsets is not None:
for offset in offsets:
assert len(input_shape) == len(
offset
), f"Offset tuple does not match expected size ({len(input_shape)})"
self.offsets = offsets
self.offsets = offsets
else:
self.offsets = []

if padding is not None:
assert len(input_shape) == len(
Expand All @@ -57,6 +62,9 @@ def __init__(
else:
self.padding = {"pad": tuple([0] * len(input_shape))}

def __call__(self, x: torch.Tensor) -> torch.Tensor:
return self.forward(x)

def _reconstruct_patches(
self,
patches: torch.Tensor,
Expand All @@ -67,23 +75,23 @@ def _reconstruct_patches(
"""
Rearranges patches to reconstruct area of interest from patches and weights
"""
reconstruct_shape = np.array(self._input_size) * np.array(index)
reconstruct_shape = np.array(self.output_shape) * np.array(index)
if weights:
weight = torch.zeros(reconstruct_shape)
weight = torch.zeros(tuple(reconstruct_shape))
base_weight = (
self._weight_function(self._input_size)
if self._weight_function
else torch.ones(self._input_size)
self.weight_function(self.input_shape)
if self.weight_function
else torch.ones(self.input_shape)
)
else:
weight = None
if inner_dim is not None:
reconstruct_shape = np.append(reconstruct_shape, inner_dim)
reconstruct = torch.zeros(reconstruct_shape)
reconstruct = torch.zeros(tuple(reconstruct_shape))
for patch_index, patch in zip(np.ndindex(index), patches):
sl = [
slice(idx * patch_len, (idx + 1) * patch_len, None)
for idx, patch_len in zip(patch_index, self._input_size)
for idx, patch_len in zip(patch_index, self.input_shape)
]
if weights:
weight[tuple(sl)] = base_weight
Expand All @@ -98,17 +106,16 @@ def _adjust_patches(
ref_shape: Tuple[int],
offset: Tuple[int],
pad_value: int = 0,
) -> torch.Tensor:
) -> List[torch.Tensor]:
"""
Pads reconstructed_patches with 'pad_value' to have same shape as the reference shape from the base patch set
"""
has_inner_dim = len(offset) < len(ref_shape)
has_inner_dim = len(offset) < len(arrays[0].shape)
pad_width = []
sl = []
ref_shape = list(ref_shape)
arr_shape = list(arrays[0].shape)
if has_inner_dim:
ref_shape = ref_shape[:-1]
arr_shape = arr_shape[:-1]
for idx, lenght, ref in zip(offset, arr_shape, ref_shape):
if idx > 0:
Expand All @@ -119,16 +126,16 @@ def _adjust_patches(
pad_width = [0, max(ref - lenght - idx, 0)] + pad_width
adjusted = [
(
torch.pad(
torch.nn.functional.pad(
arr[tuple([*sl, slice(None, None, None)])],
pad_width=[0, 0, *pad_width],
pad=tuple([0, 0, *pad_width]),
mode="constant",
value=pad_value,
)
if has_inner_dim
else np.pad(
else torch.nn.functional.pad(
arr[tuple(sl)],
pad_width=pad_width,
pad=tuple(pad_width),
mode="constant",
value=pad_value,
)
Expand Down Expand Up @@ -160,11 +167,25 @@ def _extract_patches(
patches = []
for patch_index in np.ndindex(indexes):
sl = [
slice(idx, idx + 1, patch_len)
slice(idx * patch_len, (idx + 1) * patch_len, None)
for idx, patch_len in zip(patch_index, patch_shape)
]
patches.append(data[tuple(sl)])
return torch.Tensor(patches), indexes
return torch.stack(patches), indexes

def _compute_output_shape(self, tensor: torch.Tensor) -> Tuple[int]:
"""
Computes PatchInferencer output shape based on input tensor shape, and model's input and output shapes.
"""
if self.input_shape == self.output_shape:
return tensor.shape
shape = []
for i, o, t in zip(self.input_shape, self.output_shape, tensor.shape):
if i != o:
shape.append(int(t * o // i))
else:
shape.append(t)
return tuple(shape)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -178,14 +199,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
assert len(x.shape) == len(
self.input_shape
), "Input and self.input_shape sizes must match"
offsets = list(self._offsets)

self.ref_shape = self._compute_output_shape(x)
offsets = list(self.offsets)
base = self.padding["pad"]
offsets.insert(0, tuple([0] * len(base)))

slices = [
tuple(
[
slice(i + base) # TODO: if ((i + base >= 0) and (i < in_dim))
slice(i + base, None) # TODO: if ((i + base >= 0) and (i < in_dim))
for i, base, in_dim in zip(offset, base, x.shape)
]
)
Expand All @@ -204,7 +227,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
results = []
indexes = []
for sl in slices:
patch_set, patch_idx = self._extract_patches(x_padded[sl], self.input_size)
patch_set, patch_idx = self._extract_patches(x_padded[sl], self.input_shape)
results.append(self.model(patch_set))
indexes.append(patch_idx)
output_slice = tuple(
Expand All @@ -230,12 +253,10 @@ def _combine_patches(
reconstruct, weight = self._reconstruct_patches(
patches, shape, weights=True
)
if len(reconstructed) > 0:
adjusted = self._adjust_patches(
[reconstruct, weight], reconstructed[0].shape, offset
)
reconstruct = adjusted[0]
weight = adjusted[1]
reconstruct, weight = self._adjust_patches(
[reconstruct, weight], self.ref_shape, offset
)

reconstructed.append(reconstruct)
weights.append(weight)
reconstructed = torch.stack(reconstructed, dim=0)
Expand All @@ -246,16 +267,18 @@ def _combine_patches(
class VotingPatchInferencer(BasePatchInferencer):
"""
PatchInferencer with Voting combination function.
Note: Models used with VotingPatchInferencer must return class probabilities in inner dimension
"""

def __init__(
self,
model: L.LightningModule,
input_shape: Tuple,
weight_function: Optional[function],
offsets: Optional[List[Tuple]],
padding: Optional[Dict[str, Any]],
num_classes: int,
output_shape: Optional[Tuple] = None,
weight_function: Optional[callable] = None,
offsets: Optional[List[Tuple]] = None,
padding: Optional[Dict[str, Any]] = None,
num_classes: int = None,
voting: str = "soft",
):
"""Initialize the patch inference auxiliary class
Expand All @@ -266,7 +289,9 @@ def __init__(
Model used in inference.
input_shape : Tuple
Expected input shape of the model
weight_function: function, optional
output_shape : Tuple, optional
Expected output shape of the model. Defaults to input_shape
weight_function: callable, optional
Function that receives a tensor shape and returns the weights for each position of a tensor with the given shape
Useful when regions of the inference present diminishing performance when getting closer to borders, for instance.
offsets : Tuple, optional
Expand All @@ -281,9 +306,9 @@ def __init__(
voting: str
voting method to use, can be either 'soft'or 'hard'. Defaults to 'soft'.
"""
super().__init__(model, input_shape, weight_function, offsets, padding)
self.model = model
self.input_shape = input_shape
super().__init__(
model, input_shape, output_shape, weight_function, offsets, padding
)
assert voting in ["soft", "hard"], "voting should be either 'soft' or 'hard'"
self.num_classes = num_classes
self.voting = voting
Expand All @@ -306,23 +331,23 @@ def _hard_voting(
"""
Hard voting combination function
"""
reconstructed = []
for patches, offset, shape in zip(results, offsets, indexes):
reconstruct, _ = self._reconstruct_patches(
patches, shape, weights=False, inner_dim=self.num_classes
)
reconstruct = torch.argmax(reconstruct, dim=-1).astype(torch.float32)
if len(reconstructed) > 0:
adjusted = self._adjust_patches(
[reconstruct], reconstructed[0].shape, offset, pad_value=torch.nan
)
reconstruct = adjusted[0]
reconstructed.append(reconstruct)
reconstructed = torch.stack(reconstructed, dim=0)
ret = torch.mode(reconstructed, dim=0, keepdims=False)[
0
] # TODO check behaviour on GPU, according to issues may have nonsense results
return ret
# torch.mode does not work like scipy.stats.mode
raise NotImplementedError("Hard voting not yet supported")
# reconstructed = []
# for patches, offset, shape in zip(results, offsets, indexes):
# reconstruct, _ = self._reconstruct_patches(
# patches, shape, weights=False, inner_dim=self.num_classes
# )
# reconstruct = torch.argmax(reconstruct, dim=-1).float()
# reconstruct = self._adjust_patches(
# [reconstruct], self.ref_shape, offset, pad_value=torch.nan
# )[0]
# reconstructed.append(reconstruct)
# reconstructed = torch.stack(reconstructed, dim=0)
# ret = torch.mode(reconstructed, dim=0, keepdims=False)[
# 0
# ] # TODO check behaviour on GPU, according to issues may have nonsense results
# return ret

def _soft_voting(
self,
Expand All @@ -338,11 +363,7 @@ def _soft_voting(
reconstruct, _ = self._reconstruct_patches(
patches, shape, weights=False, inner_dim=self.num_classes
)
if len(reconstructed) > 0:
adjusted = self._adjust_patches(
[reconstruct], reconstructed[0].shape, offset
)
reconstruct = adjusted[0]
reconstruct = self._adjust_patches([reconstruct], self.ref_shape, offset)[0]
reconstructed.append(reconstruct)
reconstructed = torch.stack(reconstructed, dim=0)
return torch.argmax(torch.sum(reconstructed, dim=0), dim=-1)

0 comments on commit 48601fb

Please sign in to comment.