From 4b277fbdea1980286bfa000227f5c050b52ab1a6 Mon Sep 17 00:00:00 2001 From: Jian S Date: Wed, 3 Apr 2024 02:28:15 +0300 Subject: [PATCH 01/16] init --- kornia/augmentation/base.py | 30 +- kornia/augmentation/callbacks/__init__.py | 1 + kornia/augmentation/callbacks/_logger.py | 5 + kornia/augmentation/callbacks/base.py | 269 ++++++++++++++++++ kornia/augmentation/callbacks/wandb_logger.py | 80 ++++++ kornia/augmentation/container/augment.py | 11 + kornia/augmentation/container/base.py | 32 ++- kornia/augmentation/container/image.py | 5 +- kornia/augmentation/container/patch.py | 8 + kornia/augmentation/container/video.py | 12 +- kornia/augmentation/utils/label_maps.py | 82 ++++++ 11 files changed, 528 insertions(+), 7 deletions(-) create mode 100644 kornia/augmentation/callbacks/__init__.py create mode 100644 kornia/augmentation/callbacks/_logger.py create mode 100644 kornia/augmentation/callbacks/base.py create mode 100644 kornia/augmentation/callbacks/wandb_logger.py create mode 100644 kornia/augmentation/utils/label_maps.py diff --git a/kornia/augmentation/base.py b/kornia/augmentation/base.py index b953c2cf66..4e69a6b764 100644 --- a/kornia/augmentation/base.py +++ b/kornia/augmentation/base.py @@ -1,9 +1,10 @@ from enum import Enum -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch.distributions import Bernoulli, Distribution, RelaxedBernoulli +from kornia.augmentation.callbacks import AugmentationCallbackBase from kornia.augmentation.random_generator import RandomGeneratorBase from kornia.augmentation.utils import ( _adapted_rsampling, @@ -230,8 +231,11 @@ def forward(self, input: Tensor, params: Optional[Dict[str, Tensor]] = None, **k params, flags = self._process_kwargs_to_params_and_flags(params, self.flags, **kwargs) + [cb.on_forward_start(input, params, flags) for cb in self.callbacks] output = self.apply_func(in_tensor, params, flags) - return self.transform_output_tensor(output, input_shape) if self.keepdim else output + output = self.transform_output_tensor(output, input_shape) if self.keepdim else output + [cb.on_forward_end(output, params, flags) for cb in self.callbacks] + return output class _AugmentationBase(_BasicAugmentationBase): @@ -247,8 +251,20 @@ class _AugmentationBase(_BasicAugmentationBase): same_on_batch: apply the same transformation across the batch. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch form ``False``. + callbacks: add a list of callbacks. """ + def __init__( + self, + p: float, + p_batch: float, + same_on_batch: bool = False, + keepdim: bool = False, + callbacks: List[AugmentationCallbackBase] = [], + ) -> None: + super().__init__(p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim) + self.callbacks = callbacks + def apply_transform( self, input: Tensor, @@ -288,6 +304,7 @@ def transform_inputs( in_tensor = self.transform_tensor(input) self.validate_tensor(in_tensor) + [cb.on_transform_inputs_start(in_tensor, params, flags, transform) for cb in self.callbacks] if to_apply.all(): output = self.apply_transform(in_tensor, params, flags, transform=transform) elif not to_apply.any(): @@ -310,6 +327,7 @@ def transform_inputs( if is_autocast_enabled(): output = output.type(input.dtype) + [cb.on_transform_inputs_end(output, params, flags, transform) for cb in self.callbacks] return output def transform_masks( @@ -332,6 +350,7 @@ def transform_masks( in_tensor = self.transform_tensor(input, shape=shape, match_channel=False) self.validate_tensor(in_tensor) + [cb.on_transform_masks_start(in_tensor, params, flags, transform) for cb in self.callbacks] if to_apply.all(): output = self.apply_transform_mask(in_tensor, params, flags, transform=transform) elif not to_apply.any(): @@ -346,6 +365,7 @@ def transform_masks( ) output = output.index_put((to_apply,), applied) output = _transform_output_shape(output, ori_shape, reference_shape=shape) if self.keepdim else output + [cb.on_transform_masks_end(output, params, flags, transform) for cb in self.callbacks] return output def transform_boxes( @@ -365,6 +385,7 @@ def transform_boxes( batch_prob = params["batch_prob"] to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions. + [cb.on_transform_boxes_start(input, params, flags, transform) for cb in self.callbacks] output: Boxes if to_apply.bool().all(): output = self.apply_transform_box(input, params, flags, transform=transform) @@ -383,6 +404,7 @@ def transform_boxes( applied = applied.type(input.dtype) output = output.index_put((to_apply,), applied) + [cb.on_transform_boxes_end(output, params, flags, transform) for cb in self.callbacks] return output def transform_keypoints( @@ -402,6 +424,7 @@ def transform_keypoints( batch_prob = params["batch_prob"] to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions. + [cb.on_transform_keypoints_start(input, params, flags, transform) for cb in self.callbacks] if to_apply.all(): output = self.apply_transform_keypoint(input, params, flags, transform=transform) elif not to_apply.any(): @@ -418,6 +441,7 @@ def transform_keypoints( output = output.type(input.dtype) applied = applied.type(input.dtype) output = output.index_put((to_apply,), applied) + [cb.on_transform_keypoints_end(output, params, flags, transform) for cb in self.callbacks] return output def transform_classes( @@ -434,6 +458,7 @@ def transform_classes( batch_prob = params["batch_prob"] to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions. + [cb.on_transform_classes_start(input, params, flags, transform) for cb in self.callbacks] if to_apply.all(): output = self.apply_transform_class(input, params, flags, transform=transform) elif not to_apply.any(): @@ -447,6 +472,7 @@ def transform_classes( transform=transform if transform is None else transform[to_apply], ) output = output.index_put((to_apply,), applied) + [cb.on_transform_classes_end(output, params, flags, transform) for cb in self.callbacks] return output def apply_non_transform_mask( diff --git a/kornia/augmentation/callbacks/__init__.py b/kornia/augmentation/callbacks/__init__.py new file mode 100644 index 0000000000..3adbb81d26 --- /dev/null +++ b/kornia/augmentation/callbacks/__init__.py @@ -0,0 +1 @@ +from .base import AugmentationCallbackBase, SequentialCallbackBase diff --git a/kornia/augmentation/callbacks/_logger.py b/kornia/augmentation/callbacks/_logger.py new file mode 100644 index 0000000000..bd22030ce3 --- /dev/null +++ b/kornia/augmentation/callbacks/_logger.py @@ -0,0 +1,5 @@ +from .base import AugmentationCallbackBase + + +class Logger(AugmentationCallbackBase): + """Generic logging module.""" diff --git a/kornia/augmentation/callbacks/base.py b/kornia/augmentation/callbacks/base.py new file mode 100644 index 0000000000..47e2249f01 --- /dev/null +++ b/kornia/augmentation/callbacks/base.py @@ -0,0 +1,269 @@ +from typing import Any, Dict, List, Optional, Union + +from kornia.augmentation.container.ops import DataType +from kornia.augmentation.container.params import ParamItem +from kornia.constants import DataKey +from kornia.core import Module, Tensor +from kornia.geometry.boxes import Boxes +from kornia.geometry.keypoints import Keypoints + +__all__ = [ + "AugmentationCallbackBase", + "SequentialCallbackBase", +] + + +class AugmentationCallbackBase(Module): + """A Meta Callback base class.""" + + def on_transform_inputs_start( + self, + input: Tensor, + params: Dict[str, Tensor], + flags: Dict[str, Any], + transform: Optional[Tensor] = None, + ): + """Called when `transform_inputs` begins.""" + ... + + def on_transform_inputs_end( + self, + input: Tensor, + params: Dict[str, Tensor], + flags: Dict[str, Any], + transform: Optional[Tensor] = None, + ): + """Called when `transform_inputs` ends.""" + ... + + def on_transform_masks_start( + self, + input: Tensor, + params: Dict[str, Tensor], + flags: Dict[str, Any], + transform: Optional[Tensor] = None, + ): + """Called when `transform_masks` begins.""" + ... + + def on_transform_masks_end( + self, + input: Tensor, + params: Dict[str, Tensor], + flags: Dict[str, Any], + transform: Optional[Tensor] = None, + ): + """Called when `transform_masks` ends.""" + ... + + def on_transform_classes_start( + self, + input: Tensor, + params: Dict[str, Tensor], + flags: Dict[str, Any], + transform: Optional[Tensor] = None, + ): + """Called when `transform_classes` begins.""" + ... + + def on_transform_classes_end( + self, + input: Tensor, + params: Dict[str, Tensor], + flags: Dict[str, Any], + transform: Optional[Tensor] = None, + ): + """Called when `transform_classes` ends.""" + ... + + def on_transform_boxes_start( + self, + input: Boxes, + params: Dict[str, Tensor], + flags: Dict[str, Any], + transform: Optional[Tensor] = None, + ): + """Called when `transform_boxes` begins.""" + ... + + def on_transform_boxes_end( + self, + input: Boxes, + params: Dict[str, Tensor], + flags: Dict[str, Any], + transform: Optional[Tensor] = None, + ): + """Called when `transform_boxes` ends.""" + ... + + def on_transform_keypoints_start( + self, + input: Keypoints, + params: Dict[str, Tensor], + flags: Dict[str, Any], + transform: Optional[Tensor] = None, + ): + """Called when `transform_keypoints` begins.""" + ... + + def on_transform_keypoints_end( + self, + input: Keypoints, + params: Dict[str, Tensor], + flags: Dict[str, Any], + transform: Optional[Tensor] = None, + ): + """Called when `transform_keypoints` ends.""" + ... + + def on_inverse_inputs_start( + self, + input: Tensor, + params: Dict[str, Tensor], + flags: Dict[str, Any], + inverse: Optional[Tensor] = None, + ): + """Called when `inverse_input` begins.""" + ... + + def on_inverse_inputs_end( + self, + input: Tensor, + params: Dict[str, Tensor], + flags: Dict[str, Any], + inverse: Optional[Tensor] = None, + ): + """Called when `inverse_inputs` ends.""" + ... + + def on_inverse_masks_start( + self, + input: Tensor, + params: Dict[str, Tensor], + flags: Dict[str, Any], + inverse: Optional[Tensor] = None, + ): + """Called when `inverse_masks` begins.""" + ... + + def on_inverse_masks_end( + self, + input: Tensor, + params: Dict[str, Tensor], + flags: Dict[str, Any], + inverse: Optional[Tensor] = None, + ): + """Called when `inverse_masks` ends.""" + ... + + def on_inverse_classes_start( + self, + input: Tensor, + params: Dict[str, Tensor], + flags: Dict[str, Any], + transform: Optional[Tensor] = None, + ): + """Called when `inverse_classes` begins.""" + ... + + def on_inverse_classes_end( + self, + input: Tensor, + params: Dict[str, Tensor], + flags: Dict[str, Any], + transform: Optional[Tensor] = None, + ): + """Called when `inverse_classes` ends.""" + ... + + def on_inverse_boxes_start( + self, + input: Boxes, + params: Dict[str, Tensor], + flags: Dict[str, Any], + inverse: Optional[Tensor] = None, + ): + """Called when `inverse_boxes` begins.""" + ... + + def on_inverse_boxes_end( + self, + input: Boxes, + params: Dict[str, Tensor], + flags: Dict[str, Any], + inverse: Optional[Tensor] = None, + ): + """Called when `inverse_boxes` ends.""" + ... + + def on_inverse_keypoints_start( + self, + input: Keypoints, + params: Dict[str, Tensor], + flags: Dict[str, Any], + inverse: Optional[Tensor] = None, + ): + """Called when `inverse_keypoints` begins.""" + ... + + def on_inverse_keypoints_end( + self, + input: Keypoints, + params: Dict[str, Tensor], + flags: Dict[str, Any], + inverse: Optional[Tensor] = None, + ): + """Called when `inverse_keypoints` ends.""" + ... + + def on_forward_start(self, input: Tensor, params: Optional[Dict[str, Tensor]] = None, **kwargs: Any): + """Called when `forward` starts.""" + ... + + def on_forward_end(self, input: Tensor, params: Optional[Dict[str, Tensor]] = None, **kwargs: Any): + """Called when `forward` ends.""" + ... + + def on_inverse_start(self, input: Tensor, params: Optional[Dict[str, Tensor]] = None, **kwargs: Any): + """Called when `inverse` starts.""" + ... + + def on_inverse_end(self, input: Tensor, params: Optional[Dict[str, Tensor]] = None, **kwargs: Any): + """Called when `inverse` ends.""" + ... + + def on_sequential_forward_start( + self, + *args: Union[DataType, Dict[str, DataType]], + params: Optional[List[ParamItem]] = None, + data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, + ): + """Called when `forward` begins for `AugmentationSequential`.""" + ... + + def on_sequential_forward_end( + self, + *args: Union[DataType, Dict[str, DataType]], + params: Optional[List[ParamItem]] = None, + data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, + ): + """Called when `forward` ends for `AugmentationSequential`.""" + ... + + def on_sequential_inverse_start( + self, + *args: Union[DataType, Dict[str, DataType]], + params: Optional[List[ParamItem]] = None, + data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, + ): + """Called when `inverse` begins for `AugmentationSequential`.""" + ... + + def on_sequential_inverse_end( + self, + *args: Union[DataType, Dict[str, DataType]], + params: Optional[List[ParamItem]] = None, + data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, + ): + """Called when `inverse` ends for `AugmentationSequential`.""" + ... diff --git a/kornia/augmentation/callbacks/wandb_logger.py b/kornia/augmentation/callbacks/wandb_logger.py new file mode 100644 index 0000000000..30d256ddde --- /dev/null +++ b/kornia/augmentation/callbacks/wandb_logger.py @@ -0,0 +1,80 @@ +import importlib +from typing import Dict, List, Optional, Union + +from kornia.augmentation.container.ops import DataType +from kornia.augmentation.container.params import ParamItem +from kornia.constants import DataKey +from kornia.core import Module, Tensor + +from .base import AugmentationCallbackBase + + +class WandbLogger(AugmentationCallbackBase): + """Logging images onto W&B for `AugmentationSequential`. + + Args: + batches_to_save: the number of batches to be logged. -1 is to save all batches. + num_to_log: number of images to log in a batch. + log_indices: only selected input types are logged. If `log_indices=[0, 2]` and + `data_keys=["input", "bbox", "mask"]`, only the images and masks + will be logged. + data_keys: the input type sequential. Accepts "input", "image", "mask", + "bbox", "bbox_xyxy", "bbox_xywh", "keypoints". + preprocessing: add preprocessing for images if needed. If not None, the length + must match `data_keys`. + """ + + def __init__( + self, + run: Optional["wandb.Run"] = None, + batches_to_log: int = -1, + num_to_log: int = 4, + log_indices: Optional[List[int]] = None, + data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, + preprocessing: Optional[List[Optional[Module]]] = None, + ): + super().__init__() + self.batches_to_log = batches_to_log + self.log_indices = log_indices + self.data_keys = data_keys + self.preprocessing = preprocessing + self.num_to_log = num_to_log + if run is None: + self.wandb = importlib.import_module("wandb") + else: + self.wandb = run + + def _make_mask_data(self, mask: Tensor): + raise NotImplementedError + + def _make_bbox_data(self, mask: Tensor): + raise NotImplementedError + + def on_sequential_forward_end( + self, + *args: Union[DataType, Dict[str, DataType]], + params: Optional[List[ParamItem]] = None, + data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, + ): + """Called when `forward` ends for `AugmentationSequential`.""" + image_data = None + mask_data = [] + box_data = [] + for i, (arg, data_key) in enumerate(zip(args, data_keys)): + if i not in self.log_indices: + continue + + preproc = self.preprocessing[self.log_indices[i]] + out_arg = arg[: self.num_to_log] + if preproc is not None: + out_arg = preproc(out_arg) + if data_key in [DataKey.INPUT]: + image_data = out_arg + if data_key in [DataKey.MASK]: + mask_data = self._make_mask_data(out_arg) + if data_key in [DataKey.BBOX, DataKey.BBOX_XYWH, DataKey.BBOX_XYXY]: + box_data = self._make_bbox_data(out_arg) + + for i, (img, mask, box) in enumerate(zip(image_data, mask_data, box_data)): + wandb_img = self.wandb.Image(img, masks=mask, boxes=box) + self.wandb.log({"kornia_augmentation": wandb_img}) diff --git a/kornia/augmentation/container/augment.py b/kornia/augmentation/container/augment.py index 21716f277e..f1cc9ebca0 100644 --- a/kornia/augmentation/container/augment.py +++ b/kornia/augmentation/container/augment.py @@ -4,6 +4,7 @@ from kornia.augmentation._2d.base import RigidAffineAugmentationBase2D from kornia.augmentation._3d.base import AugmentationBase3D, RigidAffineAugmentationBase3D from kornia.augmentation.base import _AugmentationBase +from kornia.augmentation.callbacks import AugmentationCallbackBase from kornia.constants import DataKey, Resample from kornia.core import Module, Tensor from kornia.geometry.boxes import Boxes, VideoBoxes @@ -207,6 +208,7 @@ def __init__( extra_args: Dict[DataKey, Dict[str, Any]] = { DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None} }, + callbacks: List[AugmentationCallbackBase] = [], ) -> None: self._transform_matrix: Optional[Tensor] self._transform_matrices: List[Optional[Tensor]] = [] @@ -217,6 +219,7 @@ def __init__( keepdim=keepdim, random_apply=random_apply, random_apply_weights=random_apply_weights, + callbacks=callbacks, ) self._parse_transformation_matrix_mode(transformation_matrix_mode) @@ -302,6 +305,8 @@ def inverse( # type: ignore[override] ) params = self._params + [cb.on_sequential_inverse_start(in_args, params=params, data_keys=data_keys) for cb in self.callbacks] + outputs: List[DataType] = in_args for param in params[::-1]: module = self.get_submodule(param.name) @@ -317,6 +322,8 @@ def inverse( # type: ignore[override] if isinstance(original_keys, tuple): return {k: v for v, k in zip(outputs, original_keys)} + [cb.on_sequential_inverse_end(*outputs, params=params, data_keys=data_keys) for cb in self.callbacks] + if len(outputs) == 1 and isinstance(outputs, list): return outputs[0] @@ -415,6 +422,8 @@ def forward( # type: ignore[override] else: raise ValueError("`params` must be provided whilst INPUT is not in data_keys.") + [cb.on_sequential_forward_start(in_args, params=params, data_keys=data_keys) for cb in self.callbacks] + outputs: Union[Tensor, List[DataType]] = in_args for param in params: module = self.get_submodule(param.name) @@ -435,6 +444,8 @@ def forward( # type: ignore[override] if isinstance(original_keys, tuple): return {k: v for v, k in zip(outputs, original_keys)} + [cb.on_sequential_forward_end(*outputs, params=params, data_keys=data_keys) for cb in self.callbacks] + if len(outputs) == 1 and isinstance(outputs, list): return outputs[0] diff --git a/kornia/augmentation/container/base.py b/kornia/augmentation/container/base.py index 3c15ac6125..20a5d473cb 100644 --- a/kornia/augmentation/container/base.py +++ b/kornia/augmentation/container/base.py @@ -7,6 +7,7 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase +from kornia.augmentation.callbacks import AugmentationCallbackBase from kornia.core import Module, Tensor from kornia.geometry.boxes import Boxes from kornia.geometry.keypoints import Keypoints @@ -109,11 +110,18 @@ class SequentialBase(BasicSequentialBase): to the batch form (False). If None, it will not overwrite the function-wise settings. """ - def __init__(self, *args: Module, same_on_batch: Optional[bool] = None, keepdim: Optional[bool] = None) -> None: + def __init__( + self, + *args: Module, + same_on_batch: Optional[bool] = None, + keepdim: Optional[bool] = None, + callbacks: List[AugmentationCallbackBase] = [], + ) -> None: # To name the modules properly super().__init__(*args) self._same_on_batch = same_on_batch self._keepdim = keepdim + self.callbacks = callbacks self.update_attribute(same_on_batch, keepdim) def update_attribute( @@ -193,51 +201,67 @@ def get_forward_sequence(self, params: Optional[List[ParamItem]] = None) -> Iter raise NotImplementedError def transform_inputs(self, input: Tensor, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Tensor: + [cb.on_transform_inputs_start(input, params) for cb in self.callbacks] for param in params: module = self.get_submodule(param.name) input = InputSequentialOps.transform(input, module=module, param=param, extra_args=extra_args) + [cb.on_transform_inputs_end(input, params) for cb in self.callbacks] return input def inverse_inputs(self, input: Tensor, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Tensor: + [cb.on_inverse_inputs_start(input, params) for cb in self.callbacks] for (name, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]): input = InputSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args) + [cb.on_inverse_inputs_end(input, params) for cb in self.callbacks] return input def transform_masks(self, input: Tensor, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Tensor: + [cb.on_transform_masks_start(input, params) for cb in self.callbacks] for param in params: module = self.get_submodule(param.name) input = MaskSequentialOps.transform(input, module=module, param=param, extra_args=extra_args) + [cb.on_transform_masks_end(input, params) for cb in self.callbacks] return input def inverse_masks(self, input: Tensor, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Tensor: + [cb.on_inverse_masks_start(input, params) for cb in self.callbacks] for (name, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]): input = MaskSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args) + [cb.on_inverse_masks_end(input, params) for cb in self.callbacks] return input def transform_boxes(self, input: Boxes, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Boxes: + [cb.on_transform_boxes_start(input, params) for cb in self.callbacks] for param in params: module = self.get_submodule(param.name) input = BoxSequentialOps.transform(input, module=module, param=param, extra_args=extra_args) + [cb.on_transform_boxes_end(input, params) for cb in self.callbacks] return input def inverse_boxes(self, input: Boxes, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Boxes: + [cb.on_inverse_boxes_start(input, params) for cb in self.callbacks] for (name, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]): input = BoxSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args) + [cb.on_inverse_boxes_end(input, params) for cb in self.callbacks] return input def transform_keypoints( self, input: Keypoints, params: List[ParamItem], extra_args: Dict[str, Any] = {} ) -> Keypoints: + [cb.on_transform_keypoints_start(input, params) for cb in self.callbacks] for param in params: module = self.get_submodule(param.name) input = KeypointSequentialOps.transform(input, module=module, param=param, extra_args=extra_args) + [cb.on_transform_keypoints_end(input, params) for cb in self.callbacks] return input def inverse_keypoints( self, input: Keypoints, params: List[ParamItem], extra_args: Dict[str, Any] = {} ) -> Keypoints: + [cb.on_inverse_keypoints_start(input, params) for cb in self.callbacks] for (name, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]): input = KeypointSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args) + [cb.on_inverse_keypoints_end(input, params) for cb in self.callbacks] return input def inverse( @@ -255,9 +279,9 @@ def inverse( "or passing valid params into this function." ) params = self._params - + [cb.on_inverse_start(input, params=params) for cb in self.callbacks] input = self.inverse_inputs(input, params, extra_args=extra_args) - + [cb.on_inverse_end(input, params=params) for cb in self.callbacks] return input def forward( @@ -270,7 +294,9 @@ def forward( _, out_shape = self.autofill_dim(inp, dim_range=(2, 4)) params = self.forward_parameters(out_shape) + [cb.on_forward_start(input, params=params) for cb in self.callbacks] input = self.transform_inputs(input, params=params, extra_args=extra_args) + [cb.on_forward_end(input, params=params) for cb in self.callbacks] self._params = params return input diff --git a/kornia/augmentation/container/image.py b/kornia/augmentation/container/image.py index 3f465b4d18..58b44afa98 100644 --- a/kornia/augmentation/container/image.py +++ b/kornia/augmentation/container/image.py @@ -4,6 +4,7 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase +from kornia.augmentation.callbacks import AugmentationCallbackBase from kornia.augmentation.utils import override_parameters from kornia.core import Module, Tensor, as_tensor from kornia.utils import eye_like @@ -32,6 +33,7 @@ class ImageSequential(ImageSequentialBase): If False, the whole list of args will be processed as a sequence in original order. random_apply_weights: a list of selection weights for each operation. The length shall be as same as the number of operations. By default, operations are sampled uniformly. + callbacks: add a list of callbacks. .. note:: Transformation matrix returned only considers the transformation applied in ``kornia.augmentation`` module. @@ -85,8 +87,9 @@ def __init__( random_apply: Union[int, bool, Tuple[int, int]] = False, random_apply_weights: Optional[List[float]] = None, if_unsupported_ops: str = "raise", + callbacks: List[AugmentationCallbackBase] = [], ) -> None: - super().__init__(*args, same_on_batch=same_on_batch, keepdim=keepdim) + super().__init__(*args, same_on_batch=same_on_batch, keepdim=keepdim, callbacks=callbacks) self.random_apply = self._read_random_apply(random_apply, len(args)) if random_apply_weights is not None and len(random_apply_weights) != len(self): diff --git a/kornia/augmentation/container/patch.py b/kornia/augmentation/container/patch.py index e7d28aa471..3a86b22959 100644 --- a/kornia/augmentation/container/patch.py +++ b/kornia/augmentation/container/patch.py @@ -5,6 +5,7 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase +from kornia.augmentation.callbacks import SequentialCallbackBase from kornia.contrib.extract_patches import extract_tensor_patches from kornia.core import Module, Tensor, concatenate from kornia.core import pad as fpad @@ -51,6 +52,7 @@ class PatchSequential(ImageSequential): If ``False`` and not ``patchwise_apply``, the whole list of args will be processed in original order. If ``False`` and ``patchwise_apply``, the whole list of args will be processed in original order location-wisely. + callbacks: add a list of callbacks. .. note:: Transformation matrix returned only considers the transformation applied in ``kornia.augmentation`` module. @@ -119,6 +121,7 @@ def __init__( patchwise_apply: bool = True, random_apply: Union[int, bool, Tuple[int, int]] = False, random_apply_weights: Optional[List[float]] = None, + callbacks: List[SequentialCallbackBase] = [], ) -> None: _random_apply: Optional[Union[int, Tuple[int, int]]] @@ -143,6 +146,7 @@ def __init__( keepdim=keepdim, random_apply=_random_apply, random_apply_weights=random_apply_weights, + callbacks=callbacks, ) if padding not in ("same", "valid"): raise ValueError(f"`padding` must be either `same` or `valid`. Got {padding}.") @@ -385,6 +389,8 @@ def inverse( # type: ignore[override] provided parameters. """ if self.is_intensity_only(): + [cb.on_inverse_start(input, params=params) for cb in self.callbacks] + [cb.on_inverse_end(input, params=params) for cb in self.callbacks] return input raise NotImplementedError("PatchSequential inverse cannot be used with geometric transformations.") @@ -398,7 +404,9 @@ def forward(self, input: Tensor, params: Optional[List[PatchParamItem]] = None) if params is None: params = self.forward_parameters(input.shape) + [cb.on_forward_start(input, params=params) for cb in self.callbacks] output = self.transform_inputs(input, params=params) + [cb.on_forward_end(input, params=params) for cb in self.callbacks] self._params = params diff --git a/kornia/augmentation/container/video.py b/kornia/augmentation/container/video.py index 372f26b83e..94a9b6a1f6 100644 --- a/kornia/augmentation/container/video.py +++ b/kornia/augmentation/container/video.py @@ -4,6 +4,7 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase +from kornia.augmentation.callbacks import SequentialCallbackBase from kornia.augmentation.container.base import SequentialBase from kornia.augmentation.container.image import ImageSequential, _get_new_batch_shape from kornia.core import Module, Tensor @@ -33,6 +34,7 @@ class VideoSequential(ImageSequential): If (a,), x number of transformations (a <= x <= len(args)) will be selected. If (a, b), x number of transformations (a <= x <= b) will be selected. If None, the whole list of args will be processed as a sequence. + callbacks: add a list of callbacks. Note: Transformation matrix returned only considers the transformation applied in ``kornia.augmentation`` module. @@ -105,6 +107,7 @@ def __init__( same_on_frame: bool = True, random_apply: Union[int, bool, Tuple[int, int]] = False, random_apply_weights: Optional[List[float]] = None, + callbacks: List[SequentialCallbackBase] = [], ) -> None: super().__init__( *args, @@ -112,6 +115,7 @@ def __init__( keepdim=None, random_apply=random_apply, random_apply_weights=random_apply_weights, + callbacks=callbacks, ) self.same_on_frame = same_on_frame self.data_format = data_format.upper() @@ -322,7 +326,11 @@ def inverse( else: raise RuntimeError("No valid params to inverse the transformation.") - return self.inverse_inputs(input, params, extra_args=extra_args) + [cb.on_inverse_start(input, params=params) for cb in self.callbacks] + output = self.inverse_inputs(input, params, extra_args=extra_args) + [cb.on_inverse_end(input, params=params) for cb in self.callbacks] + + return output def forward( self, input: Tensor, params: Optional[List[ParamItem]] = None, extra_args: Dict[str, Any] = {} @@ -335,6 +343,8 @@ def forward( self._params = self.forward_parameters(input.shape) params = self._params + [cb.on_forward_start(input, params=params) for cb in self.callbacks] output = self.transform_inputs(input, params, extra_args=extra_args) + [cb.on_forward_end(input, params=params) for cb in self.callbacks] return output diff --git a/kornia/augmentation/utils/label_maps.py b/kornia/augmentation/utils/label_maps.py new file mode 100644 index 0000000000..de17d76e5e --- /dev/null +++ b/kornia/augmentation/utils/label_maps.py @@ -0,0 +1,82 @@ +coco_category_map = { + 1: "person", + 2: "bicycle", + 3: "car", + 4: "motorcycle", + 5: "airplane", + 6: "bus", + 7: "train", + 8: "truck", + 9: "boat", + 10: "traffic light", + 11: "fire hydrant", + 13: "stop sign", + 14: "parking meter", + 15: "bench", + 16: "bird", + 17: "cat", + 18: "dog", + 19: "horse", + 20: "sheep", + 21: "cow", + 22: "elephant", + 23: "bear", + 24: "zebra", + 25: "giraffe", + 27: "backpack", + 28: "umbrella", + 31: "handbag", + 32: "tie", + 33: "suitcase", + 34: "frisbee", + 35: "skis", + 36: "snowboard", + 37: "sports ball", + 38: "kite", + 39: "baseball bat", + 40: "baseball glove", + 41: "skateboard", + 42: "surfboard", + 43: "tennis racket", + 44: "bottle", + 46: "wine glass", + 47: "cup", + 48: "fork", + 49: "knife", + 50: "spoon", + 51: "bowl", + 52: "banana", + 53: "apple", + 54: "sandwich", + 55: "orange", + 56: "broccoli", + 57: "carrot", + 58: "hot dog", + 59: "pizza", + 60: "donut", + 61: "cake", + 62: "chair", + 63: "couch", + 64: "potted plant", + 65: "bed", + 67: "dining table", + 70: "toilet", + 72: "tv", + 73: "laptop", + 74: "mouse", + 75: "remote", + 76: "keyboard", + 77: "cell phone", + 78: "microwave", + 79: "oven", + 80: "toaster", + 81: "sink", + 82: "refrigerator", + 84: "book", + 85: "clock", + 86: "vase", + 87: "scissors", + 88: "teddy bear", + 89: "hair drier", + 90: "toothbrush", +} From 6b328ab638e6f2a037da822630c3e9c6a0f94bc1 Mon Sep 17 00:00:00 2001 From: Jian S Date: Wed, 3 Apr 2024 03:21:37 +0300 Subject: [PATCH 02/16] update --- kornia/augmentation/auto/base.py | 3 +- kornia/augmentation/auto/operations/policy.py | 10 +- kornia/augmentation/callbacks/base.py | 82 +++++---------- kornia/augmentation/container/augment.py | 12 +-- kornia/augmentation/container/base.py | 99 +++++-------------- kornia/augmentation/container/mixins.py | 88 +++++++++++++++++ 6 files changed, 151 insertions(+), 143 deletions(-) create mode 100644 kornia/augmentation/container/mixins.py diff --git a/kornia/augmentation/auto/base.py b/kornia/augmentation/auto/base.py index 175d5880b9..c8893d110a 100644 --- a/kornia/augmentation/auto/base.py +++ b/kornia/augmentation/auto/base.py @@ -4,7 +4,8 @@ from kornia.augmentation.auto.operations.base import OperationBase from kornia.augmentation.auto.operations.policy import PolicySequential -from kornia.augmentation.container.base import ImageSequentialBase, TransformMatrixMinIn +from kornia.augmentation.container.base import ImageSequentialBase +from kornia.augmentation.container.mixins import TransformMatrixMinIn from kornia.augmentation.container.ops import InputSequentialOps from kornia.augmentation.container.params import ParamItem from kornia.core import Module, Tensor diff --git a/kornia/augmentation/auto/operations/policy.py b/kornia/augmentation/auto/operations/policy.py index 475cf38a3b..e53c3f1abb 100644 --- a/kornia/augmentation/auto/operations/policy.py +++ b/kornia/augmentation/auto/operations/policy.py @@ -4,7 +4,9 @@ import kornia.augmentation as K from kornia.augmentation.auto.operations import OperationBase -from kornia.augmentation.container.base import ImageSequentialBase, TransformMatrixMinIn +from kornia.augmentation.callbacks import AugmentationCallbackBase +from kornia.augmentation.container.base import ImageSequentialBase +from kornia.augmentation.container.mixins import TransformMatrixMinIn from kornia.augmentation.container.ops import InputSequentialOps from kornia.augmentation.container.params import ParamItem from kornia.augmentation.utils import _transform_input, override_parameters @@ -12,16 +14,16 @@ from kornia.utils import eye_like -class PolicySequential(TransformMatrixMinIn, ImageSequentialBase): +class PolicySequential(ImageSequentialBase, TransformMatrixMinIn): """Policy tuple for applying multiple operations. Args: operations: a list of operations to perform. """ - def __init__(self, *operations: OperationBase) -> None: + def __init__(self, *operations: OperationBase, callbacks: List[AugmentationCallbackBase] = [],) -> None: self.validate_operations(*operations) - super().__init__(*operations) + super().__init__(*operations, callbacks=callbacks) self._valid_ops_for_transform_computation: Tuple[Any, ...] = (OperationBase,) def _update_transform_matrix_for_valid_op(self, module: Module) -> None: diff --git a/kornia/augmentation/callbacks/base.py b/kornia/augmentation/callbacks/base.py index 47e2249f01..81ecf0a7c3 100644 --- a/kornia/augmentation/callbacks/base.py +++ b/kornia/augmentation/callbacks/base.py @@ -1,5 +1,7 @@ from typing import Any, Dict, List, Optional, Union +from kornia.augmentation.base import _BasicAugmentationBase +from kornia.augmentation.container.augment import AugmentationSequential from kornia.augmentation.container.ops import DataType from kornia.augmentation.container.params import ParamItem from kornia.constants import DataKey @@ -20,8 +22,7 @@ def on_transform_inputs_start( self, input: Tensor, params: Dict[str, Tensor], - flags: Dict[str, Any], - transform: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `transform_inputs` begins.""" ... @@ -30,8 +31,7 @@ def on_transform_inputs_end( self, input: Tensor, params: Dict[str, Tensor], - flags: Dict[str, Any], - transform: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `transform_inputs` ends.""" ... @@ -40,8 +40,7 @@ def on_transform_masks_start( self, input: Tensor, params: Dict[str, Tensor], - flags: Dict[str, Any], - transform: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `transform_masks` begins.""" ... @@ -50,8 +49,7 @@ def on_transform_masks_end( self, input: Tensor, params: Dict[str, Tensor], - flags: Dict[str, Any], - transform: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `transform_masks` ends.""" ... @@ -60,8 +58,7 @@ def on_transform_classes_start( self, input: Tensor, params: Dict[str, Tensor], - flags: Dict[str, Any], - transform: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `transform_classes` begins.""" ... @@ -70,8 +67,7 @@ def on_transform_classes_end( self, input: Tensor, params: Dict[str, Tensor], - flags: Dict[str, Any], - transform: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `transform_classes` ends.""" ... @@ -80,8 +76,7 @@ def on_transform_boxes_start( self, input: Boxes, params: Dict[str, Tensor], - flags: Dict[str, Any], - transform: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `transform_boxes` begins.""" ... @@ -90,8 +85,7 @@ def on_transform_boxes_end( self, input: Boxes, params: Dict[str, Tensor], - flags: Dict[str, Any], - transform: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `transform_boxes` ends.""" ... @@ -100,8 +94,7 @@ def on_transform_keypoints_start( self, input: Keypoints, params: Dict[str, Tensor], - flags: Dict[str, Any], - transform: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `transform_keypoints` begins.""" ... @@ -110,8 +103,7 @@ def on_transform_keypoints_end( self, input: Keypoints, params: Dict[str, Tensor], - flags: Dict[str, Any], - transform: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `transform_keypoints` ends.""" ... @@ -120,8 +112,7 @@ def on_inverse_inputs_start( self, input: Tensor, params: Dict[str, Tensor], - flags: Dict[str, Any], - inverse: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `inverse_input` begins.""" ... @@ -130,8 +121,7 @@ def on_inverse_inputs_end( self, input: Tensor, params: Dict[str, Tensor], - flags: Dict[str, Any], - inverse: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `inverse_inputs` ends.""" ... @@ -140,8 +130,7 @@ def on_inverse_masks_start( self, input: Tensor, params: Dict[str, Tensor], - flags: Dict[str, Any], - inverse: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `inverse_masks` begins.""" ... @@ -150,8 +139,7 @@ def on_inverse_masks_end( self, input: Tensor, params: Dict[str, Tensor], - flags: Dict[str, Any], - inverse: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `inverse_masks` ends.""" ... @@ -160,8 +148,7 @@ def on_inverse_classes_start( self, input: Tensor, params: Dict[str, Tensor], - flags: Dict[str, Any], - transform: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `inverse_classes` begins.""" ... @@ -170,8 +157,7 @@ def on_inverse_classes_end( self, input: Tensor, params: Dict[str, Tensor], - flags: Dict[str, Any], - transform: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `inverse_classes` ends.""" ... @@ -180,8 +166,7 @@ def on_inverse_boxes_start( self, input: Boxes, params: Dict[str, Tensor], - flags: Dict[str, Any], - inverse: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `inverse_boxes` begins.""" ... @@ -190,8 +175,7 @@ def on_inverse_boxes_end( self, input: Boxes, params: Dict[str, Tensor], - flags: Dict[str, Any], - inverse: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `inverse_boxes` ends.""" ... @@ -200,8 +184,7 @@ def on_inverse_keypoints_start( self, input: Keypoints, params: Dict[str, Tensor], - flags: Dict[str, Any], - inverse: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `inverse_keypoints` begins.""" ... @@ -210,33 +193,17 @@ def on_inverse_keypoints_end( self, input: Keypoints, params: Dict[str, Tensor], - flags: Dict[str, Any], - inverse: Optional[Tensor] = None, + module: _BasicAugmentationBase, ): """Called when `inverse_keypoints` ends.""" ... - def on_forward_start(self, input: Tensor, params: Optional[Dict[str, Tensor]] = None, **kwargs: Any): - """Called when `forward` starts.""" - ... - - def on_forward_end(self, input: Tensor, params: Optional[Dict[str, Tensor]] = None, **kwargs: Any): - """Called when `forward` ends.""" - ... - - def on_inverse_start(self, input: Tensor, params: Optional[Dict[str, Tensor]] = None, **kwargs: Any): - """Called when `inverse` starts.""" - ... - - def on_inverse_end(self, input: Tensor, params: Optional[Dict[str, Tensor]] = None, **kwargs: Any): - """Called when `inverse` ends.""" - ... - def on_sequential_forward_start( self, *args: Union[DataType, Dict[str, DataType]], params: Optional[List[ParamItem]] = None, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, + module: AugmentationSequential, ): """Called when `forward` begins for `AugmentationSequential`.""" ... @@ -246,6 +213,7 @@ def on_sequential_forward_end( *args: Union[DataType, Dict[str, DataType]], params: Optional[List[ParamItem]] = None, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, + module: AugmentationSequential, ): """Called when `forward` ends for `AugmentationSequential`.""" ... @@ -255,6 +223,7 @@ def on_sequential_inverse_start( *args: Union[DataType, Dict[str, DataType]], params: Optional[List[ParamItem]] = None, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, + module: AugmentationSequential, ): """Called when `inverse` begins for `AugmentationSequential`.""" ... @@ -264,6 +233,7 @@ def on_sequential_inverse_end( *args: Union[DataType, Dict[str, DataType]], params: Optional[List[ParamItem]] = None, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, + module: AugmentationSequential, ): """Called when `inverse` ends for `AugmentationSequential`.""" ... diff --git a/kornia/augmentation/container/augment.py b/kornia/augmentation/container/augment.py index f1cc9ebca0..84e981596e 100644 --- a/kornia/augmentation/container/augment.py +++ b/kornia/augmentation/container/augment.py @@ -11,7 +11,7 @@ from kornia.geometry.keypoints import Keypoints, VideoKeypoints from kornia.utils import eye_like, is_autocast_enabled -from .base import TransformMatrixMinIn +from .mixins import TransformMatrixMinIn from .image import ImageSequential from .ops import AugmentationSequentialOps, DataType from .params import ParamItem @@ -25,7 +25,7 @@ _IMG_MSK_OPTIONS = {DataKey.INPUT, DataKey.MASK} -class AugmentationSequential(TransformMatrixMinIn, ImageSequential): +class AugmentationSequential(ImageSequential, TransformMatrixMinIn): r"""AugmentationSequential for handling multiple input types like inputs, masks, keypoints at once. .. image:: _static/img/AugmentationSequential.png @@ -305,7 +305,7 @@ def inverse( # type: ignore[override] ) params = self._params - [cb.on_sequential_inverse_start(in_args, params=params, data_keys=data_keys) for cb in self.callbacks] + self.run_callbacks("on_sequential_inverse_start", input=in_args, params=params) outputs: List[DataType] = in_args for param in params[::-1]: @@ -322,7 +322,7 @@ def inverse( # type: ignore[override] if isinstance(original_keys, tuple): return {k: v for v, k in zip(outputs, original_keys)} - [cb.on_sequential_inverse_end(*outputs, params=params, data_keys=data_keys) for cb in self.callbacks] + self.run_callbacks("on_sequential_inverse_end", input=outputs, params=params) if len(outputs) == 1 and isinstance(outputs, list): return outputs[0] @@ -422,7 +422,7 @@ def forward( # type: ignore[override] else: raise ValueError("`params` must be provided whilst INPUT is not in data_keys.") - [cb.on_sequential_forward_start(in_args, params=params, data_keys=data_keys) for cb in self.callbacks] + self.run_callbacks("on_sequential_forward_start", input=in_args, params=params) outputs: Union[Tensor, List[DataType]] = in_args for param in params: @@ -444,7 +444,7 @@ def forward( # type: ignore[override] if isinstance(original_keys, tuple): return {k: v for v, k in zip(outputs, original_keys)} - [cb.on_sequential_forward_end(*outputs, params=params, data_keys=data_keys) for cb in self.callbacks] + self.run_callbacks("on_sequential_forward_end", input=outputs, params=params) if len(outputs) == 1 and isinstance(outputs, list): return outputs[0] diff --git a/kornia/augmentation/container/base.py b/kornia/augmentation/container/base.py index 20a5d473cb..90b1b55132 100644 --- a/kornia/augmentation/container/base.py +++ b/kornia/augmentation/container/base.py @@ -8,6 +8,7 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase from kornia.augmentation.callbacks import AugmentationCallbackBase +from kornia.augmentation.container.mixins import CallbacksMixIn from kornia.core import Module, Tensor from kornia.geometry.boxes import Boxes from kornia.geometry.keypoints import Keypoints @@ -97,7 +98,7 @@ def get_params_by_module(self, named_modules: Iterator[Tuple[str, Module]]) -> I yield ParamItem(name, None) -class SequentialBase(BasicSequentialBase): +class SequentialBase(BasicSequentialBase, CallbacksMixIn): r"""SequentialBase for creating kornia modulized processing pipeline. Args: @@ -123,6 +124,7 @@ def __init__( self._keepdim = keepdim self.callbacks = callbacks self.update_attribute(same_on_batch, keepdim) + self.register_callbacks(callbacks) def update_attribute( self, @@ -201,67 +203,67 @@ def get_forward_sequence(self, params: Optional[List[ParamItem]] = None) -> Iter raise NotImplementedError def transform_inputs(self, input: Tensor, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Tensor: - [cb.on_transform_inputs_start(input, params) for cb in self.callbacks] + self.run_callbacks("on_transform_inputs_start", input=input, params=param) for param in params: module = self.get_submodule(param.name) input = InputSequentialOps.transform(input, module=module, param=param, extra_args=extra_args) - [cb.on_transform_inputs_end(input, params) for cb in self.callbacks] + self.run_callbacks("on_transform_inputs_end", input=input, params=params) return input def inverse_inputs(self, input: Tensor, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Tensor: - [cb.on_inverse_inputs_start(input, params) for cb in self.callbacks] + self.run_callbacks("on_inverse_inputs_start", input=input, params=param) for (name, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]): input = InputSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args) - [cb.on_inverse_inputs_end(input, params) for cb in self.callbacks] + self.run_callbacks("on_inverse_inputs_end", input=input, params=params) return input def transform_masks(self, input: Tensor, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Tensor: - [cb.on_transform_masks_start(input, params) for cb in self.callbacks] + self.run_callbacks("on_transform_masks_start", input=input, params=params) for param in params: module = self.get_submodule(param.name) input = MaskSequentialOps.transform(input, module=module, param=param, extra_args=extra_args) - [cb.on_transform_masks_end(input, params) for cb in self.callbacks] + self.run_callbacks("on_transform_masks_end", input=input, params=params) return input def inverse_masks(self, input: Tensor, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Tensor: - [cb.on_inverse_masks_start(input, params) for cb in self.callbacks] + self.run_callbacks("on_inverse_masks_start", input=input, params=params) for (name, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]): input = MaskSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args) - [cb.on_inverse_masks_end(input, params) for cb in self.callbacks] + self.run_callbacks("on_inverse_masks_end", input=input, params=params) return input def transform_boxes(self, input: Boxes, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Boxes: - [cb.on_transform_boxes_start(input, params) for cb in self.callbacks] + self.run_callbacks("on_transform_boxes_start", input=input, params=params) for param in params: module = self.get_submodule(param.name) input = BoxSequentialOps.transform(input, module=module, param=param, extra_args=extra_args) - [cb.on_transform_boxes_end(input, params) for cb in self.callbacks] + self.run_callbacks("on_transform_boxes_end", input=input, params=params) return input def inverse_boxes(self, input: Boxes, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Boxes: - [cb.on_inverse_boxes_start(input, params) for cb in self.callbacks] + self.run_callbacks("on_inverse_boxes_start", input=input, params=params) for (name, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]): input = BoxSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args) - [cb.on_inverse_boxes_end(input, params) for cb in self.callbacks] + self.run_callbacks("on_inverse_boxes_end", input=input, params=params) return input def transform_keypoints( self, input: Keypoints, params: List[ParamItem], extra_args: Dict[str, Any] = {} ) -> Keypoints: - [cb.on_transform_keypoints_start(input, params) for cb in self.callbacks] + self.run_callbacks("on_transform_keypoints_start", input=input, params=params) for param in params: module = self.get_submodule(param.name) input = KeypointSequentialOps.transform(input, module=module, param=param, extra_args=extra_args) - [cb.on_transform_keypoints_end(input, params) for cb in self.callbacks] + self.run_callbacks("on_transform_keypoints_end", input=input, params=params) return input def inverse_keypoints( self, input: Keypoints, params: List[ParamItem], extra_args: Dict[str, Any] = {} ) -> Keypoints: - [cb.on_inverse_keypoints_start(input, params) for cb in self.callbacks] + self.run_callbacks("on_inverse_keypoints_start", input=input, params=params) for (name, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]): input = KeypointSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args) - [cb.on_inverse_keypoints_end(input, params) for cb in self.callbacks] + self.run_callbacks("on_inverse_keypoints_end", input=input, params=params) return input def inverse( @@ -279,9 +281,9 @@ def inverse( "or passing valid params into this function." ) params = self._params - [cb.on_inverse_start(input, params=params) for cb in self.callbacks] + self.run_callbacks("on_inverse_start", input=input, params=params) input = self.inverse_inputs(input, params, extra_args=extra_args) - [cb.on_inverse_end(input, params=params) for cb in self.callbacks] + self.run_callbacks("on_inverse_end", input=input, params=params) return input def forward( @@ -294,64 +296,9 @@ def forward( _, out_shape = self.autofill_dim(inp, dim_range=(2, 4)) params = self.forward_parameters(out_shape) - [cb.on_forward_start(input, params=params) for cb in self.callbacks] + self.run_callbacks("on_forward_start", input=input, params=params) input = self.transform_inputs(input, params=params, extra_args=extra_args) - [cb.on_forward_end(input, params=params) for cb in self.callbacks] + self.run_callbacks("on_forward_end", input=input, params=params) self._params = params return input - - -class TransformMatrixMinIn: - """Enables computation matrix computation.""" - - _valid_ops_for_transform_computation: Tuple[Any, ...] = () - _transformation_matrix_arg: str = "silent" - - def __init__(self, *args, **kwargs) -> None: # type:ignore - super().__init__(*args, **kwargs) - self._transform_matrix: Optional[Tensor] = None - self._transform_matrices: List[Optional[Tensor]] = [] - - def _parse_transformation_matrix_mode(self, transformation_matrix_mode: str) -> None: - _valid_transformation_matrix_args = {"silence", "silent", "rigid", "skip"} - if transformation_matrix_mode not in _valid_transformation_matrix_args: - raise ValueError( - f"`transformation_matrix` has to be one of {_valid_transformation_matrix_args}. " - f"Got {transformation_matrix_mode}." - ) - self._transformation_matrix_arg = transformation_matrix_mode - - @property - def transform_matrix(self) -> Optional[Tensor]: - # In AugmentationSequential, the parent class is accessed first. - # So that it was None in the beginning. We hereby use lazy computation here. - if self._transform_matrix is None and len(self._transform_matrices) != 0: - self._transform_matrix = self._transform_matrices[0] - for mat in self._transform_matrices[1:]: - self._update_transform_matrix(mat) - return self._transform_matrix - - def _update_transform_matrix_for_valid_op(self, module: Module) -> None: - raise NotImplementedError(module) - - def _update_transform_matrix_by_module(self, module: Module) -> None: - if self._transformation_matrix_arg == "skip": - return - if isinstance(module, self._valid_ops_for_transform_computation): - self._update_transform_matrix_for_valid_op(module) - elif self._transformation_matrix_arg == "rigid": - raise RuntimeError( - f"Non-rigid module `{module}` is not supported under `rigid` computation mode. " - "Please either update the module or change the `transformation_matrix` argument." - ) - - def _update_transform_matrix(self, transform_matrix: Optional[Tensor]) -> None: - if self._transform_matrix is None: - self._transform_matrix = transform_matrix - else: - self._transform_matrix = transform_matrix @ self._transform_matrix - - def _reset_transform_matrix_state(self) -> None: - self._transform_matrix = None - self._transform_matrices = [] diff --git a/kornia/augmentation/container/mixins.py b/kornia/augmentation/container/mixins.py new file mode 100644 index 0000000000..5b6e0ffee9 --- /dev/null +++ b/kornia/augmentation/container/mixins.py @@ -0,0 +1,88 @@ +from typing import Any, List, Optional, Tuple + +from kornia.augmentation.callbacks import AugmentationCallbackBase +from kornia.core import Module, Tensor + +__all__ = ["CallbacksMixIn", "TransformMatrixMinIn",] + + +class CallbacksMixIn: + + def __init__(self, *args, **kwargs) -> None: # type:ignore + super().__init__(*args, **kwargs) + self._callbacks: List[AugmentationCallbackBase] = [] + self._hooks = [] + + @property + def callbacks(self,): + return self._callbacks + + def register_callbacks(self, callbacks: AugmentationCallbackBase) -> None: + [self._callbacks.append(cb) for cb in callbacks] + + def run_callbacks(self, hook: str, *args, **kwargs) -> None: + for cb in self.callbacks: + if not hasattr(cb, hook): + continue + + hook_callable = getattr(cb, hook) + + if not callable(hook_callable): + continue + + hook_callable(*args, **kwargs) + + +class TransformMatrixMinIn: + """Enables computation matrix computation.""" + + _valid_ops_for_transform_computation: Tuple[Any, ...] = () + _transformation_matrix_arg: str = "silent" + + def __init__(self, *args, **kwargs) -> None: # type:ignore + super().__init__(*args, **kwargs) + self._transform_matrix: Optional[Tensor] = None + self._transform_matrices: List[Optional[Tensor]] = [] + + def _parse_transformation_matrix_mode(self, transformation_matrix_mode: str) -> None: + _valid_transformation_matrix_args = {"silence", "silent", "rigid", "skip"} + if transformation_matrix_mode not in _valid_transformation_matrix_args: + raise ValueError( + f"`transformation_matrix` has to be one of {_valid_transformation_matrix_args}. " + f"Got {transformation_matrix_mode}." + ) + self._transformation_matrix_arg = transformation_matrix_mode + + @property + def transform_matrix(self) -> Optional[Tensor]: + # In AugmentationSequential, the parent class is accessed first. + # So that it was None in the beginning. We hereby use lazy computation here. + if self._transform_matrix is None and len(self._transform_matrices) != 0: + self._transform_matrix = self._transform_matrices[0] + for mat in self._transform_matrices[1:]: + self._update_transform_matrix(mat) + return self._transform_matrix + + def _update_transform_matrix_for_valid_op(self, module: Module) -> None: + raise NotImplementedError(module) + + def _update_transform_matrix_by_module(self, module: Module) -> None: + if self._transformation_matrix_arg == "skip": + return + if isinstance(module, self._valid_ops_for_transform_computation): + self._update_transform_matrix_for_valid_op(module) + elif self._transformation_matrix_arg == "rigid": + raise RuntimeError( + f"Non-rigid module `{module}` is not supported under `rigid` computation mode. " + "Please either update the module or change the `transformation_matrix` argument." + ) + + def _update_transform_matrix(self, transform_matrix: Optional[Tensor]) -> None: + if self._transform_matrix is None: + self._transform_matrix = transform_matrix + else: + self._transform_matrix = transform_matrix @ self._transform_matrix + + def _reset_transform_matrix_state(self) -> None: + self._transform_matrix = None + self._transform_matrices = [] From de807a25458eb821407d8343580ebffd8ab1e0b5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 3 Apr 2024 00:21:53 +0000 Subject: [PATCH 03/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- kornia/augmentation/auto/operations/policy.py | 6 +++++- kornia/augmentation/callbacks/base.py | 2 +- kornia/augmentation/container/augment.py | 2 +- kornia/augmentation/container/mixins.py | 10 +++++++--- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/kornia/augmentation/auto/operations/policy.py b/kornia/augmentation/auto/operations/policy.py index e53c3f1abb..e86e254686 100644 --- a/kornia/augmentation/auto/operations/policy.py +++ b/kornia/augmentation/auto/operations/policy.py @@ -21,7 +21,11 @@ class PolicySequential(ImageSequentialBase, TransformMatrixMinIn): operations: a list of operations to perform. """ - def __init__(self, *operations: OperationBase, callbacks: List[AugmentationCallbackBase] = [],) -> None: + def __init__( + self, + *operations: OperationBase, + callbacks: List[AugmentationCallbackBase] = [], + ) -> None: self.validate_operations(*operations) super().__init__(*operations, callbacks=callbacks) self._valid_ops_for_transform_computation: Tuple[Any, ...] = (OperationBase,) diff --git a/kornia/augmentation/callbacks/base.py b/kornia/augmentation/callbacks/base.py index 81ecf0a7c3..34ac6a3e84 100644 --- a/kornia/augmentation/callbacks/base.py +++ b/kornia/augmentation/callbacks/base.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union from kornia.augmentation.base import _BasicAugmentationBase from kornia.augmentation.container.augment import AugmentationSequential diff --git a/kornia/augmentation/container/augment.py b/kornia/augmentation/container/augment.py index 84e981596e..19ee4fed5c 100644 --- a/kornia/augmentation/container/augment.py +++ b/kornia/augmentation/container/augment.py @@ -11,8 +11,8 @@ from kornia.geometry.keypoints import Keypoints, VideoKeypoints from kornia.utils import eye_like, is_autocast_enabled -from .mixins import TransformMatrixMinIn from .image import ImageSequential +from .mixins import TransformMatrixMinIn from .ops import AugmentationSequentialOps, DataType from .params import ParamItem from .patch import PatchSequential diff --git a/kornia/augmentation/container/mixins.py b/kornia/augmentation/container/mixins.py index 5b6e0ffee9..76e6f1b515 100644 --- a/kornia/augmentation/container/mixins.py +++ b/kornia/augmentation/container/mixins.py @@ -3,18 +3,22 @@ from kornia.augmentation.callbacks import AugmentationCallbackBase from kornia.core import Module, Tensor -__all__ = ["CallbacksMixIn", "TransformMatrixMinIn",] +__all__ = [ + "CallbacksMixIn", + "TransformMatrixMinIn", +] class CallbacksMixIn: - def __init__(self, *args, **kwargs) -> None: # type:ignore super().__init__(*args, **kwargs) self._callbacks: List[AugmentationCallbackBase] = [] self._hooks = [] @property - def callbacks(self,): + def callbacks( + self, + ): return self._callbacks def register_callbacks(self, callbacks: AugmentationCallbackBase) -> None: From 7e06b3595d184415bc28840a64bb3f0ae7b0389b Mon Sep 17 00:00:00 2001 From: shijianjian Date: Thu, 11 Apr 2024 11:04:39 +0300 Subject: [PATCH 04/16] update --- kornia/augmentation/base.py | 16 ---- kornia/augmentation/callbacks/__init__.py | 2 +- kornia/augmentation/callbacks/base.py | 81 ++++++++++----------- kornia/augmentation/container/data_types.py | 10 +++ kornia/augmentation/container/ops.py | 5 +- kornia/augmentation/container/patch.py | 12 +-- kornia/augmentation/container/video.py | 12 +-- 7 files changed, 64 insertions(+), 74 deletions(-) create mode 100644 kornia/augmentation/container/data_types.py diff --git a/kornia/augmentation/base.py b/kornia/augmentation/base.py index 4e69a6b764..b9b0a6c3e6 100644 --- a/kornia/augmentation/base.py +++ b/kornia/augmentation/base.py @@ -4,7 +4,6 @@ import torch from torch.distributions import Bernoulli, Distribution, RelaxedBernoulli -from kornia.augmentation.callbacks import AugmentationCallbackBase from kornia.augmentation.random_generator import RandomGeneratorBase from kornia.augmentation.utils import ( _adapted_rsampling, @@ -231,10 +230,8 @@ def forward(self, input: Tensor, params: Optional[Dict[str, Tensor]] = None, **k params, flags = self._process_kwargs_to_params_and_flags(params, self.flags, **kwargs) - [cb.on_forward_start(input, params, flags) for cb in self.callbacks] output = self.apply_func(in_tensor, params, flags) output = self.transform_output_tensor(output, input_shape) if self.keepdim else output - [cb.on_forward_end(output, params, flags) for cb in self.callbacks] return output @@ -251,7 +248,6 @@ class _AugmentationBase(_BasicAugmentationBase): same_on_batch: apply the same transformation across the batch. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch form ``False``. - callbacks: add a list of callbacks. """ def __init__( @@ -260,10 +256,8 @@ def __init__( p_batch: float, same_on_batch: bool = False, keepdim: bool = False, - callbacks: List[AugmentationCallbackBase] = [], ) -> None: super().__init__(p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim) - self.callbacks = callbacks def apply_transform( self, @@ -304,7 +298,6 @@ def transform_inputs( in_tensor = self.transform_tensor(input) self.validate_tensor(in_tensor) - [cb.on_transform_inputs_start(in_tensor, params, flags, transform) for cb in self.callbacks] if to_apply.all(): output = self.apply_transform(in_tensor, params, flags, transform=transform) elif not to_apply.any(): @@ -327,7 +320,6 @@ def transform_inputs( if is_autocast_enabled(): output = output.type(input.dtype) - [cb.on_transform_inputs_end(output, params, flags, transform) for cb in self.callbacks] return output def transform_masks( @@ -350,7 +342,6 @@ def transform_masks( in_tensor = self.transform_tensor(input, shape=shape, match_channel=False) self.validate_tensor(in_tensor) - [cb.on_transform_masks_start(in_tensor, params, flags, transform) for cb in self.callbacks] if to_apply.all(): output = self.apply_transform_mask(in_tensor, params, flags, transform=transform) elif not to_apply.any(): @@ -365,7 +356,6 @@ def transform_masks( ) output = output.index_put((to_apply,), applied) output = _transform_output_shape(output, ori_shape, reference_shape=shape) if self.keepdim else output - [cb.on_transform_masks_end(output, params, flags, transform) for cb in self.callbacks] return output def transform_boxes( @@ -385,7 +375,6 @@ def transform_boxes( batch_prob = params["batch_prob"] to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions. - [cb.on_transform_boxes_start(input, params, flags, transform) for cb in self.callbacks] output: Boxes if to_apply.bool().all(): output = self.apply_transform_box(input, params, flags, transform=transform) @@ -404,7 +393,6 @@ def transform_boxes( applied = applied.type(input.dtype) output = output.index_put((to_apply,), applied) - [cb.on_transform_boxes_end(output, params, flags, transform) for cb in self.callbacks] return output def transform_keypoints( @@ -424,7 +412,6 @@ def transform_keypoints( batch_prob = params["batch_prob"] to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions. - [cb.on_transform_keypoints_start(input, params, flags, transform) for cb in self.callbacks] if to_apply.all(): output = self.apply_transform_keypoint(input, params, flags, transform=transform) elif not to_apply.any(): @@ -441,7 +428,6 @@ def transform_keypoints( output = output.type(input.dtype) applied = applied.type(input.dtype) output = output.index_put((to_apply,), applied) - [cb.on_transform_keypoints_end(output, params, flags, transform) for cb in self.callbacks] return output def transform_classes( @@ -458,7 +444,6 @@ def transform_classes( batch_prob = params["batch_prob"] to_apply = batch_prob > 0.5 # NOTE: in case of Relaxed Distributions. - [cb.on_transform_classes_start(input, params, flags, transform) for cb in self.callbacks] if to_apply.all(): output = self.apply_transform_class(input, params, flags, transform=transform) elif not to_apply.any(): @@ -472,7 +457,6 @@ def transform_classes( transform=transform if transform is None else transform[to_apply], ) output = output.index_put((to_apply,), applied) - [cb.on_transform_classes_end(output, params, flags, transform) for cb in self.callbacks] return output def apply_non_transform_mask( diff --git a/kornia/augmentation/callbacks/__init__.py b/kornia/augmentation/callbacks/__init__.py index 3adbb81d26..3087e99b96 100644 --- a/kornia/augmentation/callbacks/__init__.py +++ b/kornia/augmentation/callbacks/__init__.py @@ -1 +1 @@ -from .base import AugmentationCallbackBase, SequentialCallbackBase +from .base import AugmentationCallbackBase diff --git a/kornia/augmentation/callbacks/base.py b/kornia/augmentation/callbacks/base.py index 34ac6a3e84..556d4f62da 100644 --- a/kornia/augmentation/callbacks/base.py +++ b/kornia/augmentation/callbacks/base.py @@ -1,9 +1,9 @@ from typing import Dict, List, Optional, Union -from kornia.augmentation.base import _BasicAugmentationBase -from kornia.augmentation.container.augment import AugmentationSequential -from kornia.augmentation.container.ops import DataType -from kornia.augmentation.container.params import ParamItem +# NOTE: fix circular import +import kornia.augmentation as K +# .data_types import DataType +# from kornia.augmentation.container.params import ParamItem from kornia.constants import DataKey from kornia.core import Module, Tensor from kornia.geometry.boxes import Boxes @@ -11,7 +11,6 @@ __all__ = [ "AugmentationCallbackBase", - "SequentialCallbackBase", ] @@ -22,7 +21,7 @@ def on_transform_inputs_start( self, input: Tensor, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `transform_inputs` begins.""" ... @@ -31,7 +30,7 @@ def on_transform_inputs_end( self, input: Tensor, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `transform_inputs` ends.""" ... @@ -40,7 +39,7 @@ def on_transform_masks_start( self, input: Tensor, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `transform_masks` begins.""" ... @@ -49,7 +48,7 @@ def on_transform_masks_end( self, input: Tensor, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `transform_masks` ends.""" ... @@ -58,7 +57,7 @@ def on_transform_classes_start( self, input: Tensor, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `transform_classes` begins.""" ... @@ -67,7 +66,7 @@ def on_transform_classes_end( self, input: Tensor, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `transform_classes` ends.""" ... @@ -76,7 +75,7 @@ def on_transform_boxes_start( self, input: Boxes, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `transform_boxes` begins.""" ... @@ -85,7 +84,7 @@ def on_transform_boxes_end( self, input: Boxes, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `transform_boxes` ends.""" ... @@ -94,7 +93,7 @@ def on_transform_keypoints_start( self, input: Keypoints, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `transform_keypoints` begins.""" ... @@ -103,7 +102,7 @@ def on_transform_keypoints_end( self, input: Keypoints, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `transform_keypoints` ends.""" ... @@ -112,7 +111,7 @@ def on_inverse_inputs_start( self, input: Tensor, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `inverse_input` begins.""" ... @@ -121,7 +120,7 @@ def on_inverse_inputs_end( self, input: Tensor, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `inverse_inputs` ends.""" ... @@ -130,7 +129,7 @@ def on_inverse_masks_start( self, input: Tensor, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `inverse_masks` begins.""" ... @@ -139,7 +138,7 @@ def on_inverse_masks_end( self, input: Tensor, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `inverse_masks` ends.""" ... @@ -148,7 +147,7 @@ def on_inverse_classes_start( self, input: Tensor, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `inverse_classes` begins.""" ... @@ -157,7 +156,7 @@ def on_inverse_classes_end( self, input: Tensor, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `inverse_classes` ends.""" ... @@ -166,7 +165,7 @@ def on_inverse_boxes_start( self, input: Boxes, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `inverse_boxes` begins.""" ... @@ -175,7 +174,7 @@ def on_inverse_boxes_end( self, input: Boxes, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `inverse_boxes` ends.""" ... @@ -184,7 +183,7 @@ def on_inverse_keypoints_start( self, input: Keypoints, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `inverse_keypoints` begins.""" ... @@ -193,47 +192,47 @@ def on_inverse_keypoints_end( self, input: Keypoints, params: Dict[str, Tensor], - module: _BasicAugmentationBase, + module: object, ): """Called when `inverse_keypoints` ends.""" ... def on_sequential_forward_start( self, - *args: Union[DataType, Dict[str, DataType]], - params: Optional[List[ParamItem]] = None, - data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, - module: AugmentationSequential, + *args: Union["K.container.data_types.DataType", Dict[str, "K.container.data_types.DataType"]], + params: Optional[List["K.container.params.DataType"]] = None, + data_keys: Optional[Union[List[str], List[int], List["K.container.data_types.DataType"]]] = None, + module: object, ): """Called when `forward` begins for `AugmentationSequential`.""" ... def on_sequential_forward_end( self, - *args: Union[DataType, Dict[str, DataType]], - params: Optional[List[ParamItem]] = None, - data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, - module: AugmentationSequential, + *args: Union["K.container.data_types.DataType", Dict[str, "K.container.data_types.DataType"]], + params: Optional[List["K.container.params.DataType"]] = None, + data_keys: Optional[Union[List[str], List[int], List["K.container.data_types.DataType"]]] = None, + module: object, ): """Called when `forward` ends for `AugmentationSequential`.""" ... def on_sequential_inverse_start( self, - *args: Union[DataType, Dict[str, DataType]], - params: Optional[List[ParamItem]] = None, - data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, - module: AugmentationSequential, + *args: Union["K.container.data_types.DataType", Dict[str, "K.container.data_types.DataType"]], + params: Optional[List["K.container.params.DataType"]] = None, + data_keys: Optional[Union[List[str], List[int], List["K.container.data_types.DataType"]]] = None, + module: object, ): """Called when `inverse` begins for `AugmentationSequential`.""" ... def on_sequential_inverse_end( self, - *args: Union[DataType, Dict[str, DataType]], - params: Optional[List[ParamItem]] = None, - data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, - module: AugmentationSequential, + *args: Union["K.container.data_types.DataType", Dict[str, "K.container.data_types.DataType"]], + params: Optional[List["K.container.params.DataType"]] = None, + data_keys: Optional[Union[List[str], List[int], List["K.container.data_types.DataType"]]] = None, + module: object, ): """Called when `inverse` ends for `AugmentationSequential`.""" ... diff --git a/kornia/augmentation/container/data_types.py b/kornia/augmentation/container/data_types.py new file mode 100644 index 0000000000..97939cad2d --- /dev/null +++ b/kornia/augmentation/container/data_types.py @@ -0,0 +1,10 @@ +from typing import List, Union + +from kornia.core import Tensor +from kornia.geometry.boxes import Boxes +from kornia.geometry.keypoints import Keypoints + +DataType = Union[Tensor, List[Tensor], Boxes, Keypoints] + +# NOTE: shouldn't this SequenceDataType alias be equals to List[DataType]? +SequenceDataType = Union[List[Tensor], List[List[Tensor]], List[Boxes], List[Keypoints]] diff --git a/kornia/augmentation/container/ops.py b/kornia/augmentation/container/ops.py index 0b6e8bc7b8..4492c0f00b 100644 --- a/kornia/augmentation/container/ops.py +++ b/kornia/augmentation/container/ops.py @@ -11,12 +11,9 @@ from kornia.geometry.boxes import Boxes from kornia.geometry.keypoints import Keypoints +from .data_types import DataType, SequenceDataType from .params import ParamItem -DataType = Union[Tensor, List[Tensor], Boxes, Keypoints] - -# NOTE: shouldn't this SequenceDataType alias be equals to List[DataType]? -SequenceDataType = Union[List[Tensor], List[List[Tensor]], List[Boxes], List[Keypoints]] T = TypeVar("T") diff --git a/kornia/augmentation/container/patch.py b/kornia/augmentation/container/patch.py index 3a86b22959..7dd579df1e 100644 --- a/kornia/augmentation/container/patch.py +++ b/kornia/augmentation/container/patch.py @@ -5,7 +5,7 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase -from kornia.augmentation.callbacks import SequentialCallbackBase +from kornia.augmentation.callbacks import AugmentationCallbackBase from kornia.contrib.extract_patches import extract_tensor_patches from kornia.core import Module, Tensor, concatenate from kornia.core import pad as fpad @@ -121,7 +121,7 @@ def __init__( patchwise_apply: bool = True, random_apply: Union[int, bool, Tuple[int, int]] = False, random_apply_weights: Optional[List[float]] = None, - callbacks: List[SequentialCallbackBase] = [], + callbacks: List[AugmentationCallbackBase] = [], ) -> None: _random_apply: Optional[Union[int, Tuple[int, int]]] @@ -389,8 +389,8 @@ def inverse( # type: ignore[override] provided parameters. """ if self.is_intensity_only(): - [cb.on_inverse_start(input, params=params) for cb in self.callbacks] - [cb.on_inverse_end(input, params=params) for cb in self.callbacks] + self.run_callbacks("on_sequential_inverse_start", input=outputs, params=params) + self.run_callbacks("on_sequential_inverse_end", input=outputs, params=params) return input raise NotImplementedError("PatchSequential inverse cannot be used with geometric transformations.") @@ -404,9 +404,9 @@ def forward(self, input: Tensor, params: Optional[List[PatchParamItem]] = None) if params is None: params = self.forward_parameters(input.shape) - [cb.on_forward_start(input, params=params) for cb in self.callbacks] + self.run_callbacks("on_sequential_forward_start", input=outputs, params=params) output = self.transform_inputs(input, params=params) - [cb.on_forward_end(input, params=params) for cb in self.callbacks] + self.run_callbacks("on_sequential_forward_end", input=outputs, params=params) self._params = params diff --git a/kornia/augmentation/container/video.py b/kornia/augmentation/container/video.py index 94a9b6a1f6..6c5d6b400b 100644 --- a/kornia/augmentation/container/video.py +++ b/kornia/augmentation/container/video.py @@ -4,7 +4,7 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase -from kornia.augmentation.callbacks import SequentialCallbackBase +from kornia.augmentation.callbacks import AugmentationCallbackBase from kornia.augmentation.container.base import SequentialBase from kornia.augmentation.container.image import ImageSequential, _get_new_batch_shape from kornia.core import Module, Tensor @@ -107,7 +107,7 @@ def __init__( same_on_frame: bool = True, random_apply: Union[int, bool, Tuple[int, int]] = False, random_apply_weights: Optional[List[float]] = None, - callbacks: List[SequentialCallbackBase] = [], + callbacks: List[AugmentationCallbackBase] = [], ) -> None: super().__init__( *args, @@ -326,9 +326,9 @@ def inverse( else: raise RuntimeError("No valid params to inverse the transformation.") - [cb.on_inverse_start(input, params=params) for cb in self.callbacks] + self.run_callbacks("on_sequential_inverse_start", input=outputs, params=params) output = self.inverse_inputs(input, params, extra_args=extra_args) - [cb.on_inverse_end(input, params=params) for cb in self.callbacks] + self.run_callbacks("on_sequential_inverse_end", input=outputs, params=params) return output @@ -343,8 +343,8 @@ def forward( self._params = self.forward_parameters(input.shape) params = self._params - [cb.on_forward_start(input, params=params) for cb in self.callbacks] + self.run_callbacks("on_sequential_forward_start", input=outputs, params=params) output = self.transform_inputs(input, params, extra_args=extra_args) - [cb.on_forward_end(input, params=params) for cb in self.callbacks] + self.run_callbacks("on_sequential_forward_start", input=outputs, params=params) return output From 4675c135159f5482bbf4977fa714f0a6001539a1 Mon Sep 17 00:00:00 2001 From: shijianjian Date: Thu, 11 Apr 2024 11:39:47 +0300 Subject: [PATCH 05/16] update --- kornia/augmentation/base.py | 4 ++-- kornia/augmentation/container/augment.py | 5 ++++- kornia/augmentation/container/base.py | 6 +----- kornia/augmentation/container/image.py | 7 +++++-- kornia/augmentation/container/mixins.py | 10 ++-------- kornia/augmentation/container/patch.py | 13 +++++++------ kornia/augmentation/container/video.py | 13 +++++++------ 7 files changed, 28 insertions(+), 30 deletions(-) diff --git a/kornia/augmentation/base.py b/kornia/augmentation/base.py index b9b0a6c3e6..6b37efe009 100644 --- a/kornia/augmentation/base.py +++ b/kornia/augmentation/base.py @@ -252,8 +252,8 @@ class _AugmentationBase(_BasicAugmentationBase): def __init__( self, - p: float, - p_batch: float, + p: float = 0.5, + p_batch: float = 1.0, same_on_batch: bool = False, keepdim: bool = False, ) -> None: diff --git a/kornia/augmentation/container/augment.py b/kornia/augmentation/container/augment.py index 19ee4fed5c..74022e7428 100644 --- a/kornia/augmentation/container/augment.py +++ b/kornia/augmentation/container/augment.py @@ -5,6 +5,7 @@ from kornia.augmentation._3d.base import AugmentationBase3D, RigidAffineAugmentationBase3D from kornia.augmentation.base import _AugmentationBase from kornia.augmentation.callbacks import AugmentationCallbackBase +from kornia.augmentation.container.mixins import CallbacksMixIn from kornia.constants import DataKey, Resample from kornia.core import Module, Tensor from kornia.geometry.boxes import Boxes, VideoBoxes @@ -25,7 +26,7 @@ _IMG_MSK_OPTIONS = {DataKey.INPUT, DataKey.MASK} -class AugmentationSequential(ImageSequential, TransformMatrixMinIn): +class AugmentationSequential(ImageSequential, TransformMatrixMinIn, CallbacksMixIn): r"""AugmentationSequential for handling multiple input types like inputs, masks, keypoints at once. .. image:: _static/img/AugmentationSequential.png @@ -258,6 +259,8 @@ def __init__( self._transform_matrix = None self.extra_args = extra_args + self.register_callbacks(callbacks) + def clear_state(self) -> None: self._reset_transform_matrix_state() return super().clear_state() diff --git a/kornia/augmentation/container/base.py b/kornia/augmentation/container/base.py index 90b1b55132..dc4e15e0c0 100644 --- a/kornia/augmentation/container/base.py +++ b/kornia/augmentation/container/base.py @@ -8,7 +8,6 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase from kornia.augmentation.callbacks import AugmentationCallbackBase -from kornia.augmentation.container.mixins import CallbacksMixIn from kornia.core import Module, Tensor from kornia.geometry.boxes import Boxes from kornia.geometry.keypoints import Keypoints @@ -98,7 +97,7 @@ def get_params_by_module(self, named_modules: Iterator[Tuple[str, Module]]) -> I yield ParamItem(name, None) -class SequentialBase(BasicSequentialBase, CallbacksMixIn): +class SequentialBase(BasicSequentialBase): r"""SequentialBase for creating kornia modulized processing pipeline. Args: @@ -116,15 +115,12 @@ def __init__( *args: Module, same_on_batch: Optional[bool] = None, keepdim: Optional[bool] = None, - callbacks: List[AugmentationCallbackBase] = [], ) -> None: # To name the modules properly super().__init__(*args) self._same_on_batch = same_on_batch self._keepdim = keepdim - self.callbacks = callbacks self.update_attribute(same_on_batch, keepdim) - self.register_callbacks(callbacks) def update_attribute( self, diff --git a/kornia/augmentation/container/image.py b/kornia/augmentation/container/image.py index 58b44afa98..6b124b8e07 100644 --- a/kornia/augmentation/container/image.py +++ b/kornia/augmentation/container/image.py @@ -5,6 +5,7 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase from kornia.augmentation.callbacks import AugmentationCallbackBase +from kornia.augmentation.container.mixins import CallbacksMixIn from kornia.augmentation.utils import override_parameters from kornia.core import Module, Tensor, as_tensor from kornia.utils import eye_like @@ -15,7 +16,7 @@ __all__ = ["ImageSequential"] -class ImageSequential(ImageSequentialBase): +class ImageSequential(ImageSequentialBase, CallbacksMixIn): r"""Sequential for creating kornia image processing pipeline. Args: @@ -89,7 +90,7 @@ def __init__( if_unsupported_ops: str = "raise", callbacks: List[AugmentationCallbackBase] = [], ) -> None: - super().__init__(*args, same_on_batch=same_on_batch, keepdim=keepdim, callbacks=callbacks) + super().__init__(*args, same_on_batch=same_on_batch, keepdim=keepdim) self.random_apply = self._read_random_apply(random_apply, len(args)) if random_apply_weights is not None and len(random_apply_weights) != len(self): @@ -100,6 +101,8 @@ def __init__( self.random_apply_weights = as_tensor(random_apply_weights or torch.ones((len(self),))) self.if_unsupported_ops = if_unsupported_ops + self.register_callbacks(callbacks) + def _read_random_apply( self, random_apply: Union[int, bool, Tuple[int, int]], max_length: int ) -> Union[Tuple[int, int], bool]: diff --git a/kornia/augmentation/container/mixins.py b/kornia/augmentation/container/mixins.py index 76e6f1b515..e05d2b8346 100644 --- a/kornia/augmentation/container/mixins.py +++ b/kornia/augmentation/container/mixins.py @@ -10,22 +10,16 @@ class CallbacksMixIn: + """Enables callbacks life cycle.""" def __init__(self, *args, **kwargs) -> None: # type:ignore super().__init__(*args, **kwargs) self._callbacks: List[AugmentationCallbackBase] = [] - self._hooks = [] - - @property - def callbacks( - self, - ): - return self._callbacks def register_callbacks(self, callbacks: AugmentationCallbackBase) -> None: [self._callbacks.append(cb) for cb in callbacks] def run_callbacks(self, hook: str, *args, **kwargs) -> None: - for cb in self.callbacks: + for cb in self._callbacks: if not hasattr(cb, hook): continue diff --git a/kornia/augmentation/container/patch.py b/kornia/augmentation/container/patch.py index 7dd579df1e..3934355178 100644 --- a/kornia/augmentation/container/patch.py +++ b/kornia/augmentation/container/patch.py @@ -6,6 +6,7 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase from kornia.augmentation.callbacks import AugmentationCallbackBase +from kornia.augmentation.container.mixins import CallbacksMixIn from kornia.contrib.extract_patches import extract_tensor_patches from kornia.core import Module, Tensor, concatenate from kornia.core import pad as fpad @@ -20,7 +21,7 @@ __all__ = ["PatchSequential"] -class PatchSequential(ImageSequential): +class PatchSequential(ImageSequential, CallbacksMixIn): r"""Container for performing patch-level image data augmentation. .. image:: _static/img/PatchSequential.png @@ -146,7 +147,6 @@ def __init__( keepdim=keepdim, random_apply=_random_apply, random_apply_weights=random_apply_weights, - callbacks=callbacks, ) if padding not in ("same", "valid"): raise ValueError(f"`padding` must be either `same` or `valid`. Got {padding}.") @@ -154,6 +154,7 @@ def __init__( self.padding = padding self.patchwise_apply = patchwise_apply self._params: Optional[List[PatchParamItem]] # type: ignore[assignment] + self.register_callbacks(callbacks) def compute_padding( self, input: Tensor, padding: str, grid_size: Optional[Tuple[int, int]] = None @@ -389,8 +390,8 @@ def inverse( # type: ignore[override] provided parameters. """ if self.is_intensity_only(): - self.run_callbacks("on_sequential_inverse_start", input=outputs, params=params) - self.run_callbacks("on_sequential_inverse_end", input=outputs, params=params) + self.run_callbacks("on_sequential_inverse_start", input=input, params=params) + self.run_callbacks("on_sequential_inverse_end", input=input, params=params) return input raise NotImplementedError("PatchSequential inverse cannot be used with geometric transformations.") @@ -404,9 +405,9 @@ def forward(self, input: Tensor, params: Optional[List[PatchParamItem]] = None) if params is None: params = self.forward_parameters(input.shape) - self.run_callbacks("on_sequential_forward_start", input=outputs, params=params) + self.run_callbacks("on_sequential_forward_start", input=input, params=params) output = self.transform_inputs(input, params=params) - self.run_callbacks("on_sequential_forward_end", input=outputs, params=params) + self.run_callbacks("on_sequential_forward_end", input=output, params=params) self._params = params diff --git a/kornia/augmentation/container/video.py b/kornia/augmentation/container/video.py index 6c5d6b400b..a52e6c4db4 100644 --- a/kornia/augmentation/container/video.py +++ b/kornia/augmentation/container/video.py @@ -5,6 +5,7 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase from kornia.augmentation.callbacks import AugmentationCallbackBase +from kornia.augmentation.container.mixins import CallbacksMixIn from kornia.augmentation.container.base import SequentialBase from kornia.augmentation.container.image import ImageSequential, _get_new_batch_shape from kornia.core import Module, Tensor @@ -16,7 +17,7 @@ __all__ = ["VideoSequential"] -class VideoSequential(ImageSequential): +class VideoSequential(ImageSequential, CallbacksMixIn): r"""VideoSequential for processing 5-dim video data like (B, T, C, H, W) and (B, C, T, H, W). `VideoSequential` is used to replace `nn.Sequential` for processing video data augmentations. @@ -115,7 +116,6 @@ def __init__( keepdim=None, random_apply=random_apply, random_apply_weights=random_apply_weights, - callbacks=callbacks, ) self.same_on_frame = same_on_frame self.data_format = data_format.upper() @@ -126,6 +126,7 @@ def __init__( self._temporal_channel = 2 elif self.data_format == "BTCHW": self._temporal_channel = 1 + self.register_callbacks(callbacks) def __infer_channel_exclusive_batch_shape__(self, batch_shape: torch.Size, chennel_index: int) -> torch.Size: # Fix mypy complains: error: Incompatible return value type (got "Tuple[int, ...]", expected "Size") @@ -326,9 +327,9 @@ def inverse( else: raise RuntimeError("No valid params to inverse the transformation.") - self.run_callbacks("on_sequential_inverse_start", input=outputs, params=params) + self.run_callbacks("on_sequential_inverse_start", input=input, params=params) output = self.inverse_inputs(input, params, extra_args=extra_args) - self.run_callbacks("on_sequential_inverse_end", input=outputs, params=params) + self.run_callbacks("on_sequential_inverse_end", input=output, params=params) return output @@ -343,8 +344,8 @@ def forward( self._params = self.forward_parameters(input.shape) params = self._params - self.run_callbacks("on_sequential_forward_start", input=outputs, params=params) + self.run_callbacks("on_sequential_forward_start", input=input, params=params) output = self.transform_inputs(input, params, extra_args=extra_args) - self.run_callbacks("on_sequential_forward_start", input=outputs, params=params) + self.run_callbacks("on_sequential_forward_start", input=output, params=params) return output From 84532a2d6176662aa100a4c283b24019848de4c6 Mon Sep 17 00:00:00 2001 From: shijianjian Date: Thu, 11 Apr 2024 11:41:02 +0300 Subject: [PATCH 06/16] update --- kornia/augmentation/container/mixins.py | 1 + 1 file changed, 1 insertion(+) diff --git a/kornia/augmentation/container/mixins.py b/kornia/augmentation/container/mixins.py index e05d2b8346..740dd1f982 100644 --- a/kornia/augmentation/container/mixins.py +++ b/kornia/augmentation/container/mixins.py @@ -17,6 +17,7 @@ def __init__(self, *args, **kwargs) -> None: # type:ignore def register_callbacks(self, callbacks: AugmentationCallbackBase) -> None: [self._callbacks.append(cb) for cb in callbacks] + assert False, self._callbacks def run_callbacks(self, hook: str, *args, **kwargs) -> None: for cb in self._callbacks: From f0d83033272cab1940046bb623578e46ccddafbb Mon Sep 17 00:00:00 2001 From: shijianjian Date: Thu, 11 Apr 2024 11:48:46 +0300 Subject: [PATCH 07/16] update --- kornia/augmentation/container/base.py | 41 +++++++++++++------------ kornia/augmentation/container/mixins.py | 9 +++--- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/kornia/augmentation/container/base.py b/kornia/augmentation/container/base.py index dc4e15e0c0..83dd7fa040 100644 --- a/kornia/augmentation/container/base.py +++ b/kornia/augmentation/container/base.py @@ -199,67 +199,68 @@ def get_forward_sequence(self, params: Optional[List[ParamItem]] = None) -> Iter raise NotImplementedError def transform_inputs(self, input: Tensor, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Tensor: - self.run_callbacks("on_transform_inputs_start", input=input, params=param) for param in params: module = self.get_submodule(param.name) + # NOTE: temp disabled + # self.run_callbacks("on_transform_inputs_start", input=input, params=param) input = InputSequentialOps.transform(input, module=module, param=param, extra_args=extra_args) - self.run_callbacks("on_transform_inputs_end", input=input, params=params) + # self.run_callbacks("on_transform_inputs_end", input=input, params=params) return input def inverse_inputs(self, input: Tensor, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Tensor: - self.run_callbacks("on_inverse_inputs_start", input=input, params=param) for (name, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]): + # self.run_callbacks("on_inverse_inputs_start", input=input, params=param) input = InputSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args) - self.run_callbacks("on_inverse_inputs_end", input=input, params=params) + # self.run_callbacks("on_inverse_inputs_end", input=input, params=params) return input def transform_masks(self, input: Tensor, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Tensor: - self.run_callbacks("on_transform_masks_start", input=input, params=params) for param in params: module = self.get_submodule(param.name) + # self.run_callbacks("on_transform_masks_start", input=input, params=params) input = MaskSequentialOps.transform(input, module=module, param=param, extra_args=extra_args) - self.run_callbacks("on_transform_masks_end", input=input, params=params) + # self.run_callbacks("on_transform_masks_end", input=input, params=params) return input def inverse_masks(self, input: Tensor, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Tensor: - self.run_callbacks("on_inverse_masks_start", input=input, params=params) for (name, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]): + # self.run_callbacks("on_inverse_masks_start", input=input, params=params) input = MaskSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args) - self.run_callbacks("on_inverse_masks_end", input=input, params=params) + # self.run_callbacks("on_inverse_masks_end", input=input, params=params) return input def transform_boxes(self, input: Boxes, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Boxes: - self.run_callbacks("on_transform_boxes_start", input=input, params=params) for param in params: module = self.get_submodule(param.name) + # self.run_callbacks("on_transform_boxes_start", input=input, params=params) input = BoxSequentialOps.transform(input, module=module, param=param, extra_args=extra_args) - self.run_callbacks("on_transform_boxes_end", input=input, params=params) + # self.run_callbacks("on_transform_boxes_end", input=input, params=params) return input def inverse_boxes(self, input: Boxes, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Boxes: - self.run_callbacks("on_inverse_boxes_start", input=input, params=params) for (name, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]): + # self.run_callbacks("on_inverse_boxes_start", input=input, params=params) input = BoxSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args) - self.run_callbacks("on_inverse_boxes_end", input=input, params=params) + # self.run_callbacks("on_inverse_boxes_end", input=input, params=params) return input def transform_keypoints( self, input: Keypoints, params: List[ParamItem], extra_args: Dict[str, Any] = {} ) -> Keypoints: - self.run_callbacks("on_transform_keypoints_start", input=input, params=params) for param in params: module = self.get_submodule(param.name) + # self.run_callbacks("on_transform_keypoints_start", input=input, params=params) input = KeypointSequentialOps.transform(input, module=module, param=param, extra_args=extra_args) - self.run_callbacks("on_transform_keypoints_end", input=input, params=params) + # self.run_callbacks("on_transform_keypoints_end", input=input, params=params) return input def inverse_keypoints( self, input: Keypoints, params: List[ParamItem], extra_args: Dict[str, Any] = {} ) -> Keypoints: - self.run_callbacks("on_inverse_keypoints_start", input=input, params=params) for (name, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]): + # self.run_callbacks("on_inverse_keypoints_start", input=input, params=params) input = KeypointSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args) - self.run_callbacks("on_inverse_keypoints_end", input=input, params=params) + # self.run_callbacks("on_inverse_keypoints_end", input=input, params=params) return input def inverse( @@ -277,9 +278,9 @@ def inverse( "or passing valid params into this function." ) params = self._params - self.run_callbacks("on_inverse_start", input=input, params=params) + self.run_callbacks("on_sequential_inverse_start", input=input, params=params) input = self.inverse_inputs(input, params, extra_args=extra_args) - self.run_callbacks("on_inverse_end", input=input, params=params) + self.run_callbacks("on_sequential_inverse_end", input=input, params=params) return input def forward( @@ -292,9 +293,9 @@ def forward( _, out_shape = self.autofill_dim(inp, dim_range=(2, 4)) params = self.forward_parameters(out_shape) - self.run_callbacks("on_forward_start", input=input, params=params) + self.run_callbacks("on_sequential_forward_start", input=input, params=params) input = self.transform_inputs(input, params=params, extra_args=extra_args) - self.run_callbacks("on_forward_end", input=input, params=params) + self.run_callbacks("on_sequential_forward_end", input=input, params=params) self._params = params return input diff --git a/kornia/augmentation/container/mixins.py b/kornia/augmentation/container/mixins.py index 740dd1f982..2770a3d425 100644 --- a/kornia/augmentation/container/mixins.py +++ b/kornia/augmentation/container/mixins.py @@ -11,13 +11,14 @@ class CallbacksMixIn: """Enables callbacks life cycle.""" - def __init__(self, *args, **kwargs) -> None: # type:ignore - super().__init__(*args, **kwargs) - self._callbacks: List[AugmentationCallbackBase] = [] + _callbacks: List[AugmentationCallbackBase] = [] + + @property + def callbacks(self,): + return self._callbacks def register_callbacks(self, callbacks: AugmentationCallbackBase) -> None: [self._callbacks.append(cb) for cb in callbacks] - assert False, self._callbacks def run_callbacks(self, hook: str, *args, **kwargs) -> None: for cb in self._callbacks: From 0042a1297c6896761754ff7679f0ec1ff9fa15df Mon Sep 17 00:00:00 2001 From: shijianjian Date: Thu, 11 Apr 2024 12:16:07 +0300 Subject: [PATCH 08/16] Added local logger --- kornia/augmentation/callbacks/base.py | 68 +++++++++++++++++++ kornia/augmentation/callbacks/local_logger.py | 52 ++++++++++++++ kornia/augmentation/callbacks/wandb_logger.py | 65 +++++++----------- 3 files changed, 145 insertions(+), 40 deletions(-) create mode 100644 kornia/augmentation/callbacks/local_logger.py diff --git a/kornia/augmentation/callbacks/base.py b/kornia/augmentation/callbacks/base.py index 556d4f62da..2ea3a4bd93 100644 --- a/kornia/augmentation/callbacks/base.py +++ b/kornia/augmentation/callbacks/base.py @@ -236,3 +236,71 @@ def on_sequential_inverse_end( ): """Called when `inverse` ends for `AugmentationSequential`.""" ... + + +class AugmentationCallback(AugmentationCallbackBase): + """Logging images for `AugmentationSequential`. + + Args: + batches_to_save: the number of batches to be logged. -1 is to save all batches. + num_to_log: number of images to log in a batch. + log_indices: only selected input types are logged. If `log_indices=[0, 2]` and + `data_keys=["input", "bbox", "mask"]`, only the images and masks + will be logged. + data_keys: the input type sequential. Accepts "input", "image", "mask", + "bbox", "bbox_xyxy", "bbox_xywh", "keypoints". + postprocessing: add postprocessing for images if needed. If not None, the length + must match `data_keys`. + """ + + def __init__( + self, + batches_to_save: int = 10, + num_to_log: int = 4, + log_indices: Optional[List[int]] = None, + data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, + postprocessing: Optional[List[Optional[Module]]] = None, + ): + super().__init__() + self.batches_to_log = batches_to_log + self.log_indices = log_indices + self.data_keys = data_keys + self.postprocessing = postprocessing + self.num_to_log = num_to_log + + def _make_mask_data(self, mask: Tensor): + raise NotImplementedError + + def _make_bbox_data(self, bbox: Tensor): + raise NotImplementedError + + def _log_data(self, data: SequenceDataType): + raise NotImplementedError + + def on_sequential_forward_end( + self, + *args: Union[DataType, Dict[str, DataType]], + params: Optional[List[ParamItem]] = None, + data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, + ): + """Called when `forward` ends for `AugmentationSequential`.""" + image_data = None + output_data = [] + for i, (arg, data_key) in enumerate(zip(args, data_keys)): + if i not in self.log_indices: + continue + + postproc = self.postprocessing[self.log_indices[i]] + data = arg[: self.num_to_log] + if postproc is not None: + data = postproc(data) + if data_key in [DataKey.INPUT]: + data = data + if data_key in [DataKey.MASK]: + data = self._make_mask_data(data) + if data_key in [DataKey.BBOX, DataKey.BBOX_XYWH, DataKey.BBOX_XYXY]: + data = self._make_bbox_data(data) + + output_data.append(data) + + self._log_data(output_data) diff --git a/kornia/augmentation/callbacks/local_logger.py b/kornia/augmentation/callbacks/local_logger.py new file mode 100644 index 0000000000..93c3cfbf2b --- /dev/null +++ b/kornia/augmentation/callbacks/local_logger.py @@ -0,0 +1,52 @@ +import importlib +from typing import Dict, List, Optional, Union + +from kornia.augmentation.container.ops import DataType, SequenceDataType +from kornia.augmentation.container.params import ParamItem +from kornia.constants import DataKey +from kornia.core import Module, Tensor + +from .base import AugmentationCallback + + +class WandbLogger(AugmentationCallback): + """Logging images onto W&B for `AugmentationSequential`. + + Args: + batches_to_save: the number of batches to be logged. -1 is to save all batches. + num_to_log: number of images to log in a batch. + log_indices: only selected input types are logged. If `log_indices=[0, 2]` and + `data_keys=["input", "bbox", "mask"]`, only the images and masks + will be logged. + data_keys: the input type sequential. Accepts "input", "image", "mask", + "bbox", "bbox_xyxy", "bbox_xywh", "keypoints". + preprocessing: add preprocessing for images if needed. If not None, the length + must match `data_keys`. + """ + + def __init__( + self, + log_dir: str = "./kornia_logs", + batches_to_save: int = 10, + num_to_log: int = 4, + log_indices: Optional[List[int]] = None, + data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, + postprocessing: Optional[List[Optional[Module]]] = None, + ): + super().__init__( + batches_to_log=batches_to_log, + num_to_log=num_to_log, + log_indices=log_indices, + data_keys=data_keys, + postprocessing=postprocessing, + ) + self.log_dir = log_dir + + def _make_mask_data(self, mask: Tensor): + ... + + def _make_bbox_data(self, bbox: Tensor): + ... + + def _log_data(self, data: SequenceDataType): + ... \ No newline at end of file diff --git a/kornia/augmentation/callbacks/wandb_logger.py b/kornia/augmentation/callbacks/wandb_logger.py index 30d256ddde..80d34fe3b6 100644 --- a/kornia/augmentation/callbacks/wandb_logger.py +++ b/kornia/augmentation/callbacks/wandb_logger.py @@ -1,15 +1,15 @@ import importlib from typing import Dict, List, Optional, Union -from kornia.augmentation.container.ops import DataType +from kornia.augmentation.container.ops import DataType, SequenceDataType from kornia.augmentation.container.params import ParamItem from kornia.constants import DataKey from kornia.core import Module, Tensor -from .base import AugmentationCallbackBase +from .base import AugmentationCallback -class WandbLogger(AugmentationCallbackBase): +class WandbLogger(AugmentationCallback): """Logging images onto W&B for `AugmentationSequential`. Args: @@ -27,54 +27,39 @@ class WandbLogger(AugmentationCallbackBase): def __init__( self, run: Optional["wandb.Run"] = None, - batches_to_log: int = -1, + batches_to_save: int = 10, num_to_log: int = 4, log_indices: Optional[List[int]] = None, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, - preprocessing: Optional[List[Optional[Module]]] = None, + postprocessing: Optional[List[Optional[Module]]] = None, ): - super().__init__() - self.batches_to_log = batches_to_log - self.log_indices = log_indices - self.data_keys = data_keys - self.preprocessing = preprocessing - self.num_to_log = num_to_log + super().__init__( + batches_to_log=batches_to_log, + num_to_log=num_to_log, + log_indices=log_indices, + data_keys=data_keys, + postprocessing=postprocessing, + ) if run is None: self.wandb = importlib.import_module("wandb") else: self.wandb = run + + self.has_duplication(data_keys) + + def has_duplication(self, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None): + # WANDB only supports visualization without duplication + ... def _make_mask_data(self, mask: Tensor): raise NotImplementedError - def _make_bbox_data(self, mask: Tensor): + def _make_bbox_data(self, bbox: Tensor): raise NotImplementedError - def on_sequential_forward_end( - self, - *args: Union[DataType, Dict[str, DataType]], - params: Optional[List[ParamItem]] = None, - data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, - ): - """Called when `forward` ends for `AugmentationSequential`.""" - image_data = None - mask_data = [] - box_data = [] - for i, (arg, data_key) in enumerate(zip(args, data_keys)): - if i not in self.log_indices: - continue - - preproc = self.preprocessing[self.log_indices[i]] - out_arg = arg[: self.num_to_log] - if preproc is not None: - out_arg = preproc(out_arg) - if data_key in [DataKey.INPUT]: - image_data = out_arg - if data_key in [DataKey.MASK]: - mask_data = self._make_mask_data(out_arg) - if data_key in [DataKey.BBOX, DataKey.BBOX_XYWH, DataKey.BBOX_XYXY]: - box_data = self._make_bbox_data(out_arg) - - for i, (img, mask, box) in enumerate(zip(image_data, mask_data, box_data)): - wandb_img = self.wandb.Image(img, masks=mask, boxes=box) - self.wandb.log({"kornia_augmentation": wandb_img}) + def _log_data(self, data: SequenceDataType): + ... + # assert self.data_keys no duplication, ... + # for i, (value, key) in enumerate(zip(data, self.data_keys)): + # wandb_img = self.wandb.Image(img, masks=mask, boxes=box) + # self.wandb.log({"kornia_augmentation": wandb_img}) From 5d93a60e961ded8bf503091d0ff298c13c9c6f6d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Apr 2024 09:17:56 +0000 Subject: [PATCH 09/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- kornia/augmentation/base.py | 2 +- kornia/augmentation/callbacks/base.py | 1 + kornia/augmentation/callbacks/local_logger.py | 15 +++++---------- kornia/augmentation/callbacks/wandb_logger.py | 7 +++---- kornia/augmentation/container/base.py | 1 - kornia/augmentation/container/mixins.py | 5 ++++- kornia/augmentation/container/ops.py | 1 - kornia/augmentation/container/video.py | 2 +- 8 files changed, 15 insertions(+), 19 deletions(-) diff --git a/kornia/augmentation/base.py b/kornia/augmentation/base.py index 6b37efe009..df6f370526 100644 --- a/kornia/augmentation/base.py +++ b/kornia/augmentation/base.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch from torch.distributions import Bernoulli, Distribution, RelaxedBernoulli diff --git a/kornia/augmentation/callbacks/base.py b/kornia/augmentation/callbacks/base.py index 2ea3a4bd93..a48aed5693 100644 --- a/kornia/augmentation/callbacks/base.py +++ b/kornia/augmentation/callbacks/base.py @@ -2,6 +2,7 @@ # NOTE: fix circular import import kornia.augmentation as K + # .data_types import DataType # from kornia.augmentation.container.params import ParamItem from kornia.constants import DataKey diff --git a/kornia/augmentation/callbacks/local_logger.py b/kornia/augmentation/callbacks/local_logger.py index 93c3cfbf2b..f0a7bf87c2 100644 --- a/kornia/augmentation/callbacks/local_logger.py +++ b/kornia/augmentation/callbacks/local_logger.py @@ -1,8 +1,6 @@ -import importlib -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union -from kornia.augmentation.container.ops import DataType, SequenceDataType -from kornia.augmentation.container.params import ParamItem +from kornia.augmentation.container.ops import SequenceDataType from kornia.constants import DataKey from kornia.core import Module, Tensor @@ -42,11 +40,8 @@ def __init__( ) self.log_dir = log_dir - def _make_mask_data(self, mask: Tensor): - ... + def _make_mask_data(self, mask: Tensor): ... - def _make_bbox_data(self, bbox: Tensor): - ... + def _make_bbox_data(self, bbox: Tensor): ... - def _log_data(self, data: SequenceDataType): - ... \ No newline at end of file + def _log_data(self, data: SequenceDataType): ... diff --git a/kornia/augmentation/callbacks/wandb_logger.py b/kornia/augmentation/callbacks/wandb_logger.py index 80d34fe3b6..d1e9d0e03e 100644 --- a/kornia/augmentation/callbacks/wandb_logger.py +++ b/kornia/augmentation/callbacks/wandb_logger.py @@ -1,8 +1,7 @@ import importlib -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union -from kornia.augmentation.container.ops import DataType, SequenceDataType -from kornia.augmentation.container.params import ParamItem +from kornia.augmentation.container.ops import SequenceDataType from kornia.constants import DataKey from kornia.core import Module, Tensor @@ -44,7 +43,7 @@ def __init__( self.wandb = importlib.import_module("wandb") else: self.wandb = run - + self.has_duplication(data_keys) def has_duplication(self, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None): diff --git a/kornia/augmentation/container/base.py b/kornia/augmentation/container/base.py index 83dd7fa040..4bed8b14d3 100644 --- a/kornia/augmentation/container/base.py +++ b/kornia/augmentation/container/base.py @@ -7,7 +7,6 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase -from kornia.augmentation.callbacks import AugmentationCallbackBase from kornia.core import Module, Tensor from kornia.geometry.boxes import Boxes from kornia.geometry.keypoints import Keypoints diff --git a/kornia/augmentation/container/mixins.py b/kornia/augmentation/container/mixins.py index 2770a3d425..f32872670e 100644 --- a/kornia/augmentation/container/mixins.py +++ b/kornia/augmentation/container/mixins.py @@ -11,10 +11,13 @@ class CallbacksMixIn: """Enables callbacks life cycle.""" + _callbacks: List[AugmentationCallbackBase] = [] @property - def callbacks(self,): + def callbacks( + self, + ): return self._callbacks def register_callbacks(self, callbacks: AugmentationCallbackBase) -> None: diff --git a/kornia/augmentation/container/ops.py b/kornia/augmentation/container/ops.py index 4492c0f00b..bb2ef0ae9a 100644 --- a/kornia/augmentation/container/ops.py +++ b/kornia/augmentation/container/ops.py @@ -14,7 +14,6 @@ from .data_types import DataType, SequenceDataType from .params import ParamItem - T = TypeVar("T") diff --git a/kornia/augmentation/container/video.py b/kornia/augmentation/container/video.py index a52e6c4db4..a9caf9d1bc 100644 --- a/kornia/augmentation/container/video.py +++ b/kornia/augmentation/container/video.py @@ -5,9 +5,9 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase from kornia.augmentation.callbacks import AugmentationCallbackBase -from kornia.augmentation.container.mixins import CallbacksMixIn from kornia.augmentation.container.base import SequentialBase from kornia.augmentation.container.image import ImageSequential, _get_new_batch_shape +from kornia.augmentation.container.mixins import CallbacksMixIn from kornia.core import Module, Tensor from kornia.geometry.boxes import Boxes from kornia.geometry.keypoints import Keypoints From 776ee4c93904d670344a2e0ec560acba5e7e8d44 Mon Sep 17 00:00:00 2001 From: Jian S Date: Mon, 15 Apr 2024 19:00:23 +0300 Subject: [PATCH 10/16] update --- kornia/augmentation/callbacks/base.py | 10 ++++------ kornia/augmentation/callbacks/local_logger.py | 6 +++--- kornia/augmentation/callbacks/wandb_logger.py | 6 +++--- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/kornia/augmentation/callbacks/base.py b/kornia/augmentation/callbacks/base.py index a48aed5693..dc71c1f666 100644 --- a/kornia/augmentation/callbacks/base.py +++ b/kornia/augmentation/callbacks/base.py @@ -3,8 +3,6 @@ # NOTE: fix circular import import kornia.augmentation as K -# .data_types import DataType -# from kornia.augmentation.container.params import ParamItem from kornia.constants import DataKey from kornia.core import Module, Tensor from kornia.geometry.boxes import Boxes @@ -263,7 +261,7 @@ def __init__( postprocessing: Optional[List[Optional[Module]]] = None, ): super().__init__() - self.batches_to_log = batches_to_log + self.batches_to_save = batches_to_save self.log_indices = log_indices self.data_keys = data_keys self.postprocessing = postprocessing @@ -275,13 +273,13 @@ def _make_mask_data(self, mask: Tensor): def _make_bbox_data(self, bbox: Tensor): raise NotImplementedError - def _log_data(self, data: SequenceDataType): + def _log_data(self, data: "K.container.data_types.SequenceDataType"): raise NotImplementedError def on_sequential_forward_end( self, - *args: Union[DataType, Dict[str, DataType]], - params: Optional[List[ParamItem]] = None, + *args: Union["K.container.data_types.DataType", Dict[str, "K.container.data_types.DataType"]], + params: Optional[List["K.container.params.ParamItem"]] = None, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, ): """Called when `forward` ends for `AugmentationSequential`.""" diff --git a/kornia/augmentation/callbacks/local_logger.py b/kornia/augmentation/callbacks/local_logger.py index f0a7bf87c2..c3eab1e4bb 100644 --- a/kornia/augmentation/callbacks/local_logger.py +++ b/kornia/augmentation/callbacks/local_logger.py @@ -7,8 +7,8 @@ from .base import AugmentationCallback -class WandbLogger(AugmentationCallback): - """Logging images onto W&B for `AugmentationSequential`. +class LocalLogger(AugmentationCallback): + """Logging images to the desired folder for `AugmentationSequential`. Args: batches_to_save: the number of batches to be logged. -1 is to save all batches. @@ -32,7 +32,7 @@ def __init__( postprocessing: Optional[List[Optional[Module]]] = None, ): super().__init__( - batches_to_log=batches_to_log, + batches_to_save=batches_to_save, num_to_log=num_to_log, log_indices=log_indices, data_keys=data_keys, diff --git a/kornia/augmentation/callbacks/wandb_logger.py b/kornia/augmentation/callbacks/wandb_logger.py index d1e9d0e03e..53cc9fd7fe 100644 --- a/kornia/augmentation/callbacks/wandb_logger.py +++ b/kornia/augmentation/callbacks/wandb_logger.py @@ -33,7 +33,7 @@ def __init__( postprocessing: Optional[List[Optional[Module]]] = None, ): super().__init__( - batches_to_log=batches_to_log, + batches_to_save=batches_to_save, num_to_log=num_to_log, log_indices=log_indices, data_keys=data_keys, @@ -44,9 +44,9 @@ def __init__( else: self.wandb = run - self.has_duplication(data_keys) + self.contains_duplicated_keys(data_keys) - def has_duplication(self, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None): + def contains_duplicated_keys(self, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None): # WANDB only supports visualization without duplication ... From 858e449872f2bae5e79abce7932ea5cc8a8df2f7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Apr 2024 16:03:08 +0000 Subject: [PATCH 11/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- kornia/augmentation/callbacks/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/kornia/augmentation/callbacks/base.py b/kornia/augmentation/callbacks/base.py index dc71c1f666..e58f28e4d9 100644 --- a/kornia/augmentation/callbacks/base.py +++ b/kornia/augmentation/callbacks/base.py @@ -2,7 +2,6 @@ # NOTE: fix circular import import kornia.augmentation as K - from kornia.constants import DataKey from kornia.core import Module, Tensor from kornia.geometry.boxes import Boxes From 3802c6df0fc6ff25824048b6d481fc7e6d12c78c Mon Sep 17 00:00:00 2001 From: Jian S Date: Mon, 15 Apr 2024 19:34:47 +0300 Subject: [PATCH 12/16] update --- kornia/augmentation/auto/operations/policy.py | 5 +++-- kornia/augmentation/container/augment.py | 3 +-- kornia/augmentation/container/base.py | 3 ++- kornia/augmentation/container/image.py | 3 +-- kornia/augmentation/container/patch.py | 3 +-- kornia/augmentation/container/video.py | 3 +-- 6 files changed, 9 insertions(+), 11 deletions(-) diff --git a/kornia/augmentation/auto/operations/policy.py b/kornia/augmentation/auto/operations/policy.py index e86e254686..0796ec8b06 100644 --- a/kornia/augmentation/auto/operations/policy.py +++ b/kornia/augmentation/auto/operations/policy.py @@ -14,7 +14,7 @@ from kornia.utils import eye_like -class PolicySequential(ImageSequentialBase, TransformMatrixMinIn): +class PolicySequential(TransformMatrixMinIn, ImageSequentialBase): """Policy tuple for applying multiple operations. Args: @@ -27,8 +27,9 @@ def __init__( callbacks: List[AugmentationCallbackBase] = [], ) -> None: self.validate_operations(*operations) - super().__init__(*operations, callbacks=callbacks) + super().__init__(*operations) self._valid_ops_for_transform_computation: Tuple[Any, ...] = (OperationBase,) + self.register_callbacks(callbacks) def _update_transform_matrix_for_valid_op(self, module: Module) -> None: self._transform_matrices.append(module.transform_matrix) diff --git a/kornia/augmentation/container/augment.py b/kornia/augmentation/container/augment.py index 74022e7428..f99d8bb95b 100644 --- a/kornia/augmentation/container/augment.py +++ b/kornia/augmentation/container/augment.py @@ -5,7 +5,6 @@ from kornia.augmentation._3d.base import AugmentationBase3D, RigidAffineAugmentationBase3D from kornia.augmentation.base import _AugmentationBase from kornia.augmentation.callbacks import AugmentationCallbackBase -from kornia.augmentation.container.mixins import CallbacksMixIn from kornia.constants import DataKey, Resample from kornia.core import Module, Tensor from kornia.geometry.boxes import Boxes, VideoBoxes @@ -26,7 +25,7 @@ _IMG_MSK_OPTIONS = {DataKey.INPUT, DataKey.MASK} -class AugmentationSequential(ImageSequential, TransformMatrixMinIn, CallbacksMixIn): +class AugmentationSequential(TransformMatrixMinIn, ImageSequential): r"""AugmentationSequential for handling multiple input types like inputs, masks, keypoints at once. .. image:: _static/img/AugmentationSequential.png diff --git a/kornia/augmentation/container/base.py b/kornia/augmentation/container/base.py index 4bed8b14d3..bf67e51125 100644 --- a/kornia/augmentation/container/base.py +++ b/kornia/augmentation/container/base.py @@ -7,6 +7,7 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase +from kornia.augmentation.container.mixins import CallbacksMixIn from kornia.core import Module, Tensor from kornia.geometry.boxes import Boxes from kornia.geometry.keypoints import Keypoints @@ -168,7 +169,7 @@ def autofill_dim(self, input: Tensor, dim_range: Tuple[int, int] = (2, 4)) -> Tu return ori_shape, input.shape -class ImageSequentialBase(SequentialBase): +class ImageSequentialBase(CallbacksMixIn, SequentialBase): def identity_matrix(self, input: Tensor) -> Tensor: """Return identity matrix.""" raise NotImplementedError diff --git a/kornia/augmentation/container/image.py b/kornia/augmentation/container/image.py index 6b124b8e07..d05a0c023c 100644 --- a/kornia/augmentation/container/image.py +++ b/kornia/augmentation/container/image.py @@ -5,7 +5,6 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase from kornia.augmentation.callbacks import AugmentationCallbackBase -from kornia.augmentation.container.mixins import CallbacksMixIn from kornia.augmentation.utils import override_parameters from kornia.core import Module, Tensor, as_tensor from kornia.utils import eye_like @@ -16,7 +15,7 @@ __all__ = ["ImageSequential"] -class ImageSequential(ImageSequentialBase, CallbacksMixIn): +class ImageSequential(ImageSequentialBase): r"""Sequential for creating kornia image processing pipeline. Args: diff --git a/kornia/augmentation/container/patch.py b/kornia/augmentation/container/patch.py index 3934355178..640a138f95 100644 --- a/kornia/augmentation/container/patch.py +++ b/kornia/augmentation/container/patch.py @@ -6,7 +6,6 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase from kornia.augmentation.callbacks import AugmentationCallbackBase -from kornia.augmentation.container.mixins import CallbacksMixIn from kornia.contrib.extract_patches import extract_tensor_patches from kornia.core import Module, Tensor, concatenate from kornia.core import pad as fpad @@ -21,7 +20,7 @@ __all__ = ["PatchSequential"] -class PatchSequential(ImageSequential, CallbacksMixIn): +class PatchSequential(ImageSequential): r"""Container for performing patch-level image data augmentation. .. image:: _static/img/PatchSequential.png diff --git a/kornia/augmentation/container/video.py b/kornia/augmentation/container/video.py index a9caf9d1bc..ae45953ce3 100644 --- a/kornia/augmentation/container/video.py +++ b/kornia/augmentation/container/video.py @@ -7,7 +7,6 @@ from kornia.augmentation.callbacks import AugmentationCallbackBase from kornia.augmentation.container.base import SequentialBase from kornia.augmentation.container.image import ImageSequential, _get_new_batch_shape -from kornia.augmentation.container.mixins import CallbacksMixIn from kornia.core import Module, Tensor from kornia.geometry.boxes import Boxes from kornia.geometry.keypoints import Keypoints @@ -17,7 +16,7 @@ __all__ = ["VideoSequential"] -class VideoSequential(ImageSequential, CallbacksMixIn): +class VideoSequential(ImageSequential): r"""VideoSequential for processing 5-dim video data like (B, T, C, H, W) and (B, C, T, H, W). `VideoSequential` is used to replace `nn.Sequential` for processing video data augmentations. From 522f6f116b7963f1b4b88c47811171240a23f167 Mon Sep 17 00:00:00 2001 From: Jian S Date: Mon, 15 Apr 2024 21:20:38 +0300 Subject: [PATCH 13/16] Fixed typing --- kornia/augmentation/callbacks/base.py | 120 +++++++++--------- kornia/augmentation/callbacks/local_logger.py | 14 +- kornia/augmentation/callbacks/wandb_logger.py | 15 ++- kornia/augmentation/container/augment.py | 12 +- kornia/augmentation/container/mixins.py | 9 +- kornia/augmentation/container/patch.py | 13 +- kornia/utils/draw.py | 2 +- 7 files changed, 104 insertions(+), 81 deletions(-) diff --git a/kornia/augmentation/callbacks/base.py b/kornia/augmentation/callbacks/base.py index e58f28e4d9..d764e8cae6 100644 --- a/kornia/augmentation/callbacks/base.py +++ b/kornia/augmentation/callbacks/base.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Union +from typing import cast, Dict, List, Optional, Union # NOTE: fix circular import import kornia.augmentation as K @@ -20,7 +20,7 @@ def on_transform_inputs_start( input: Tensor, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `transform_inputs` begins.""" ... @@ -29,7 +29,7 @@ def on_transform_inputs_end( input: Tensor, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `transform_inputs` ends.""" ... @@ -38,7 +38,7 @@ def on_transform_masks_start( input: Tensor, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `transform_masks` begins.""" ... @@ -47,7 +47,7 @@ def on_transform_masks_end( input: Tensor, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `transform_masks` ends.""" ... @@ -56,7 +56,7 @@ def on_transform_classes_start( input: Tensor, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `transform_classes` begins.""" ... @@ -65,7 +65,7 @@ def on_transform_classes_end( input: Tensor, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `transform_classes` ends.""" ... @@ -74,7 +74,7 @@ def on_transform_boxes_start( input: Boxes, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `transform_boxes` begins.""" ... @@ -83,7 +83,7 @@ def on_transform_boxes_end( input: Boxes, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `transform_boxes` ends.""" ... @@ -92,7 +92,7 @@ def on_transform_keypoints_start( input: Keypoints, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `transform_keypoints` begins.""" ... @@ -101,7 +101,7 @@ def on_transform_keypoints_end( input: Keypoints, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `transform_keypoints` ends.""" ... @@ -110,7 +110,7 @@ def on_inverse_inputs_start( input: Tensor, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `inverse_input` begins.""" ... @@ -119,7 +119,7 @@ def on_inverse_inputs_end( input: Tensor, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `inverse_inputs` ends.""" ... @@ -128,7 +128,7 @@ def on_inverse_masks_start( input: Tensor, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `inverse_masks` begins.""" ... @@ -137,7 +137,7 @@ def on_inverse_masks_end( input: Tensor, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `inverse_masks` ends.""" ... @@ -146,7 +146,7 @@ def on_inverse_classes_start( input: Tensor, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `inverse_classes` begins.""" ... @@ -155,7 +155,7 @@ def on_inverse_classes_end( input: Tensor, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `inverse_classes` ends.""" ... @@ -164,7 +164,7 @@ def on_inverse_boxes_start( input: Boxes, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `inverse_boxes` begins.""" ... @@ -173,7 +173,7 @@ def on_inverse_boxes_end( input: Boxes, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `inverse_boxes` ends.""" ... @@ -182,7 +182,7 @@ def on_inverse_keypoints_start( input: Keypoints, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `inverse_keypoints` begins.""" ... @@ -191,47 +191,47 @@ def on_inverse_keypoints_end( input: Keypoints, params: Dict[str, Tensor], module: object, - ): + ) -> None: """Called when `inverse_keypoints` ends.""" ... def on_sequential_forward_start( self, - *args: Union["K.container.data_types.DataType", Dict[str, "K.container.data_types.DataType"]], - params: Optional[List["K.container.params.DataType"]] = None, - data_keys: Optional[Union[List[str], List[int], List["K.container.data_types.DataType"]]] = None, - module: object, - ): + *args: "K.container.data_types.DataType", + module: "K.AugmentationSequential", + params: List["K.container.params.ParamItem"], + data_keys: Union[List[str], List[int], List[DataKey]], + ) -> None: """Called when `forward` begins for `AugmentationSequential`.""" ... def on_sequential_forward_end( self, - *args: Union["K.container.data_types.DataType", Dict[str, "K.container.data_types.DataType"]], - params: Optional[List["K.container.params.DataType"]] = None, - data_keys: Optional[Union[List[str], List[int], List["K.container.data_types.DataType"]]] = None, - module: object, - ): + *args: "K.container.data_types.DataType", + module: "K.AugmentationSequential", + params: List["K.container.params.ParamItem"], + data_keys: Union[List[str], List[int], List[DataKey]], + ) -> None: """Called when `forward` ends for `AugmentationSequential`.""" ... def on_sequential_inverse_start( self, - *args: Union["K.container.data_types.DataType", Dict[str, "K.container.data_types.DataType"]], - params: Optional[List["K.container.params.DataType"]] = None, - data_keys: Optional[Union[List[str], List[int], List["K.container.data_types.DataType"]]] = None, - module: object, - ): + *args: "K.container.data_types.DataType", + module: "K.AugmentationSequential", + params: List["K.container.params.ParamItem"], + data_keys: Union[List[str], List[int], List[DataKey]], + ) -> None: """Called when `inverse` begins for `AugmentationSequential`.""" ... def on_sequential_inverse_end( self, - *args: Union["K.container.data_types.DataType", Dict[str, "K.container.data_types.DataType"]], - params: Optional[List["K.container.params.DataType"]] = None, - data_keys: Optional[Union[List[str], List[int], List["K.container.data_types.DataType"]]] = None, - module: object, - ): + *args: "K.container.data_types.DataType", + module: "K.AugmentationSequential", + params: List["K.container.params.ParamItem"], + data_keys: Union[List[str], List[int], List[DataKey]], + ) -> None: """Called when `inverse` ends for `AugmentationSequential`.""" ... @@ -266,39 +266,45 @@ def __init__( self.postprocessing = postprocessing self.num_to_log = num_to_log - def _make_mask_data(self, mask: Tensor): + def _make_mask_data(self, mask: Tensor) -> Tensor: raise NotImplementedError - def _make_bbox_data(self, bbox: Tensor): + def _make_bbox_data(self, bbox: Boxes) -> Boxes: raise NotImplementedError - def _log_data(self, data: "K.container.data_types.SequenceDataType"): + def _log_data(self, data: List["K.container.data_types.DataType"]) -> None: raise NotImplementedError def on_sequential_forward_end( self, - *args: Union["K.container.data_types.DataType", Dict[str, "K.container.data_types.DataType"]], - params: Optional[List["K.container.params.ParamItem"]] = None, - data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, - ): + *args: "K.container.data_types.DataType", + module: "K.AugmentationSequential", + params: List["K.container.params.ParamItem"], + data_keys: Union[List[str], List[int], List[DataKey]], + ) -> None: """Called when `forward` ends for `AugmentationSequential`.""" - image_data = None - output_data = [] + output_data: List["K.container.data_types.DataType"] = [] + + # Log all the indices + if self.log_indices is None: + self.log_indices = list(range(len(data_keys))) + for i, (arg, data_key) in enumerate(zip(args, data_keys)): if i not in self.log_indices: continue - postproc = self.postprocessing[self.log_indices[i]] - data = arg[: self.num_to_log] + postproc = None + if self.postprocessing is not None: + postproc = self.postprocessing[self.log_indices[i]] + data = arg[:self.num_to_log] if postproc is not None: data = postproc(data) + if data_key in [DataKey.INPUT]: - data = data + output_data.append(data) if data_key in [DataKey.MASK]: - data = self._make_mask_data(data) + output_data.append(self._make_mask_data(cast(Tensor, data))) if data_key in [DataKey.BBOX, DataKey.BBOX_XYWH, DataKey.BBOX_XYXY]: - data = self._make_bbox_data(data) - - output_data.append(data) + output_data.append(self._make_bbox_data(cast(Boxes, data))) self._log_data(output_data) diff --git a/kornia/augmentation/callbacks/local_logger.py b/kornia/augmentation/callbacks/local_logger.py index c3eab1e4bb..19476bad67 100644 --- a/kornia/augmentation/callbacks/local_logger.py +++ b/kornia/augmentation/callbacks/local_logger.py @@ -1,8 +1,9 @@ from typing import List, Optional, Union -from kornia.augmentation.container.ops import SequenceDataType +from kornia.augmentation.container.ops import DataType from kornia.constants import DataKey from kornia.core import Module, Tensor +from kornia.geometry.boxes import Boxes from .base import AugmentationCallback @@ -30,7 +31,7 @@ def __init__( log_indices: Optional[List[int]] = None, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, postprocessing: Optional[List[Optional[Module]]] = None, - ): + ) -> None: super().__init__( batches_to_save=batches_to_save, num_to_log=num_to_log, @@ -40,8 +41,11 @@ def __init__( ) self.log_dir = log_dir - def _make_mask_data(self, mask: Tensor): ... + def _make_mask_data(self, mask: Tensor) -> Tensor: + raise NotImplementedError - def _make_bbox_data(self, bbox: Tensor): ... + def _make_bbox_data(self, bbox: Boxes) -> Boxes: + raise NotImplementedError - def _log_data(self, data: SequenceDataType): ... + def _log_data(self, data: List[DataType]) -> None: + raise NotImplementedError diff --git a/kornia/augmentation/callbacks/wandb_logger.py b/kornia/augmentation/callbacks/wandb_logger.py index 53cc9fd7fe..6ee90de8e0 100644 --- a/kornia/augmentation/callbacks/wandb_logger.py +++ b/kornia/augmentation/callbacks/wandb_logger.py @@ -1,9 +1,10 @@ import importlib from typing import List, Optional, Union -from kornia.augmentation.container.ops import SequenceDataType +from kornia.augmentation.container.ops import DataType from kornia.constants import DataKey from kornia.core import Module, Tensor +from kornia.geometry.boxes import Boxes from .base import AugmentationCallback @@ -25,7 +26,7 @@ class WandbLogger(AugmentationCallback): def __init__( self, - run: Optional["wandb.Run"] = None, + run: Optional["wandb.Run"] = None, # type: ignore batches_to_save: int = 10, num_to_log: int = 4, log_indices: Optional[List[int]] = None, @@ -46,17 +47,19 @@ def __init__( self.contains_duplicated_keys(data_keys) - def contains_duplicated_keys(self, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None): + def contains_duplicated_keys( + self, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None + ) -> None: # WANDB only supports visualization without duplication ... - def _make_mask_data(self, mask: Tensor): + def _make_mask_data(self, mask: Tensor) -> Tensor: raise NotImplementedError - def _make_bbox_data(self, bbox: Tensor): + def _make_bbox_data(self, bbox: Boxes) -> Boxes: raise NotImplementedError - def _log_data(self, data: SequenceDataType): + def _log_data(self, data: List[DataType]) -> None: ... # assert self.data_keys no duplication, ... # for i, (value, key) in enumerate(zip(data, self.data_keys)): diff --git a/kornia/augmentation/container/augment.py b/kornia/augmentation/container/augment.py index f99d8bb95b..ccaf51ea51 100644 --- a/kornia/augmentation/container/augment.py +++ b/kornia/augmentation/container/augment.py @@ -307,7 +307,8 @@ def inverse( # type: ignore[override] ) params = self._params - self.run_callbacks("on_sequential_inverse_start", input=in_args, params=params) + self.run_callbacks( + "on_sequential_inverse_start", input=in_args, module=self, params=params, data_keys=data_keys) outputs: List[DataType] = in_args for param in params[::-1]: @@ -324,7 +325,8 @@ def inverse( # type: ignore[override] if isinstance(original_keys, tuple): return {k: v for v, k in zip(outputs, original_keys)} - self.run_callbacks("on_sequential_inverse_end", input=outputs, params=params) + self.run_callbacks( + "on_sequential_inverse_end", input=outputs, module=self, params=params, data_keys=data_keys) if len(outputs) == 1 and isinstance(outputs, list): return outputs[0] @@ -424,7 +426,8 @@ def forward( # type: ignore[override] else: raise ValueError("`params` must be provided whilst INPUT is not in data_keys.") - self.run_callbacks("on_sequential_forward_start", input=in_args, params=params) + self.run_callbacks( + "on_sequential_forward_start", input=in_args, module=self, params=params, data_keys=data_keys) outputs: Union[Tensor, List[DataType]] = in_args for param in params: @@ -446,7 +449,8 @@ def forward( # type: ignore[override] if isinstance(original_keys, tuple): return {k: v for v, k in zip(outputs, original_keys)} - self.run_callbacks("on_sequential_forward_end", input=outputs, params=params) + self.run_callbacks( + "on_sequential_forward_end", input=outputs, module=self, params=params, data_keys=data_keys) if len(outputs) == 1 and isinstance(outputs, list): return outputs[0] diff --git a/kornia/augmentation/container/mixins.py b/kornia/augmentation/container/mixins.py index f32872670e..87948a4427 100644 --- a/kornia/augmentation/container/mixins.py +++ b/kornia/augmentation/container/mixins.py @@ -17,13 +17,14 @@ class CallbacksMixIn: @property def callbacks( self, - ): + ) -> List[AugmentationCallbackBase]: return self._callbacks - def register_callbacks(self, callbacks: AugmentationCallbackBase) -> None: - [self._callbacks.append(cb) for cb in callbacks] + def register_callbacks(self, callbacks: List[AugmentationCallbackBase]) -> None: + for cb in callbacks: + self._callbacks.append(cb) - def run_callbacks(self, hook: str, *args, **kwargs) -> None: + def run_callbacks(self, hook: str, *args, **kwargs) -> None: # type: ignore for cb in self._callbacks: if not hasattr(cb, hook): continue diff --git a/kornia/augmentation/container/patch.py b/kornia/augmentation/container/patch.py index 640a138f95..e47cb77788 100644 --- a/kornia/augmentation/container/patch.py +++ b/kornia/augmentation/container/patch.py @@ -6,6 +6,7 @@ import kornia.augmentation as K from kornia.augmentation.base import _AugmentationBase from kornia.augmentation.callbacks import AugmentationCallbackBase +from kornia.constants import DataKey from kornia.contrib.extract_patches import extract_tensor_patches from kornia.core import Module, Tensor, concatenate from kornia.core import pad as fpad @@ -389,8 +390,10 @@ def inverse( # type: ignore[override] provided parameters. """ if self.is_intensity_only(): - self.run_callbacks("on_sequential_inverse_start", input=input, params=params) - self.run_callbacks("on_sequential_inverse_end", input=input, params=params) + self.run_callbacks( + "on_sequential_inverse_start", input=[input], module=self, params=params, data_keys=[DataKey.INPUT]) + self.run_callbacks( + "on_sequential_inverse_end", input=[input], module=self, params=params, data_keys=[DataKey.INPUT]) return input raise NotImplementedError("PatchSequential inverse cannot be used with geometric transformations.") @@ -404,9 +407,11 @@ def forward(self, input: Tensor, params: Optional[List[PatchParamItem]] = None) if params is None: params = self.forward_parameters(input.shape) - self.run_callbacks("on_sequential_forward_start", input=input, params=params) + self.run_callbacks( + "on_sequential_forward_start", input=[input], module=self, params=params, data_keys=[DataKey.INPUT]) output = self.transform_inputs(input, params=params) - self.run_callbacks("on_sequential_forward_end", input=output, params=params) + self.run_callbacks( + "on_sequential_forward_end", input=[output], module=self, params=params, data_keys=[DataKey.INPUT]) self._params = params diff --git a/kornia/utils/draw.py b/kornia/utils/draw.py index abaff6e0a1..a8ad2bfd45 100644 --- a/kornia/utils/draw.py +++ b/kornia/utils/draw.py @@ -1,7 +1,7 @@ from typing import List, Optional, Tuple, Union import torch -from torch import Tensor +from kornia.core import Tensor from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE From 88470fadf89555815701e1992b1b52e8a62a19b1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Apr 2024 18:20:55 +0000 Subject: [PATCH 14/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- kornia/augmentation/callbacks/base.py | 6 +++--- kornia/augmentation/callbacks/wandb_logger.py | 4 +--- kornia/augmentation/container/augment.py | 12 ++++++------ kornia/augmentation/container/mixins.py | 2 +- kornia/augmentation/container/patch.py | 12 ++++++++---- kornia/utils/draw.py | 2 +- 6 files changed, 20 insertions(+), 18 deletions(-) diff --git a/kornia/augmentation/callbacks/base.py b/kornia/augmentation/callbacks/base.py index d764e8cae6..b930570033 100644 --- a/kornia/augmentation/callbacks/base.py +++ b/kornia/augmentation/callbacks/base.py @@ -1,4 +1,4 @@ -from typing import cast, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, cast # NOTE: fix circular import import kornia.augmentation as K @@ -284,7 +284,7 @@ def on_sequential_forward_end( ) -> None: """Called when `forward` ends for `AugmentationSequential`.""" output_data: List["K.container.data_types.DataType"] = [] - + # Log all the indices if self.log_indices is None: self.log_indices = list(range(len(data_keys))) @@ -296,7 +296,7 @@ def on_sequential_forward_end( postproc = None if self.postprocessing is not None: postproc = self.postprocessing[self.log_indices[i]] - data = arg[:self.num_to_log] + data = arg[: self.num_to_log] if postproc is not None: data = postproc(data) diff --git a/kornia/augmentation/callbacks/wandb_logger.py b/kornia/augmentation/callbacks/wandb_logger.py index 6ee90de8e0..eccae44bcf 100644 --- a/kornia/augmentation/callbacks/wandb_logger.py +++ b/kornia/augmentation/callbacks/wandb_logger.py @@ -47,9 +47,7 @@ def __init__( self.contains_duplicated_keys(data_keys) - def contains_duplicated_keys( - self, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None - ) -> None: + def contains_duplicated_keys(self, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None) -> None: # WANDB only supports visualization without duplication ... diff --git a/kornia/augmentation/container/augment.py b/kornia/augmentation/container/augment.py index ccaf51ea51..3057f621d3 100644 --- a/kornia/augmentation/container/augment.py +++ b/kornia/augmentation/container/augment.py @@ -308,7 +308,8 @@ def inverse( # type: ignore[override] params = self._params self.run_callbacks( - "on_sequential_inverse_start", input=in_args, module=self, params=params, data_keys=data_keys) + "on_sequential_inverse_start", input=in_args, module=self, params=params, data_keys=data_keys + ) outputs: List[DataType] = in_args for param in params[::-1]: @@ -325,8 +326,7 @@ def inverse( # type: ignore[override] if isinstance(original_keys, tuple): return {k: v for v, k in zip(outputs, original_keys)} - self.run_callbacks( - "on_sequential_inverse_end", input=outputs, module=self, params=params, data_keys=data_keys) + self.run_callbacks("on_sequential_inverse_end", input=outputs, module=self, params=params, data_keys=data_keys) if len(outputs) == 1 and isinstance(outputs, list): return outputs[0] @@ -427,7 +427,8 @@ def forward( # type: ignore[override] raise ValueError("`params` must be provided whilst INPUT is not in data_keys.") self.run_callbacks( - "on_sequential_forward_start", input=in_args, module=self, params=params, data_keys=data_keys) + "on_sequential_forward_start", input=in_args, module=self, params=params, data_keys=data_keys + ) outputs: Union[Tensor, List[DataType]] = in_args for param in params: @@ -449,8 +450,7 @@ def forward( # type: ignore[override] if isinstance(original_keys, tuple): return {k: v for v, k in zip(outputs, original_keys)} - self.run_callbacks( - "on_sequential_forward_end", input=outputs, module=self, params=params, data_keys=data_keys) + self.run_callbacks("on_sequential_forward_end", input=outputs, module=self, params=params, data_keys=data_keys) if len(outputs) == 1 and isinstance(outputs, list): return outputs[0] diff --git a/kornia/augmentation/container/mixins.py b/kornia/augmentation/container/mixins.py index 87948a4427..b44d9a15a6 100644 --- a/kornia/augmentation/container/mixins.py +++ b/kornia/augmentation/container/mixins.py @@ -22,7 +22,7 @@ def callbacks( def register_callbacks(self, callbacks: List[AugmentationCallbackBase]) -> None: for cb in callbacks: - self._callbacks.append(cb) + self._callbacks.append(cb) def run_callbacks(self, hook: str, *args, **kwargs) -> None: # type: ignore for cb in self._callbacks: diff --git a/kornia/augmentation/container/patch.py b/kornia/augmentation/container/patch.py index e47cb77788..0ae7fc6685 100644 --- a/kornia/augmentation/container/patch.py +++ b/kornia/augmentation/container/patch.py @@ -391,9 +391,11 @@ def inverse( # type: ignore[override] """ if self.is_intensity_only(): self.run_callbacks( - "on_sequential_inverse_start", input=[input], module=self, params=params, data_keys=[DataKey.INPUT]) + "on_sequential_inverse_start", input=[input], module=self, params=params, data_keys=[DataKey.INPUT] + ) self.run_callbacks( - "on_sequential_inverse_end", input=[input], module=self, params=params, data_keys=[DataKey.INPUT]) + "on_sequential_inverse_end", input=[input], module=self, params=params, data_keys=[DataKey.INPUT] + ) return input raise NotImplementedError("PatchSequential inverse cannot be used with geometric transformations.") @@ -408,10 +410,12 @@ def forward(self, input: Tensor, params: Optional[List[PatchParamItem]] = None) params = self.forward_parameters(input.shape) self.run_callbacks( - "on_sequential_forward_start", input=[input], module=self, params=params, data_keys=[DataKey.INPUT]) + "on_sequential_forward_start", input=[input], module=self, params=params, data_keys=[DataKey.INPUT] + ) output = self.transform_inputs(input, params=params) self.run_callbacks( - "on_sequential_forward_end", input=[output], module=self, params=params, data_keys=[DataKey.INPUT]) + "on_sequential_forward_end", input=[output], module=self, params=params, data_keys=[DataKey.INPUT] + ) self._params = params diff --git a/kornia/utils/draw.py b/kornia/utils/draw.py index a8ad2bfd45..7e7db6e521 100644 --- a/kornia/utils/draw.py +++ b/kornia/utils/draw.py @@ -1,8 +1,8 @@ from typing import List, Optional, Tuple, Union import torch -from kornia.core import Tensor +from kornia.core import Tensor from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE # TODO: implement width of the line From c1689d0a7fc5afaf99ee46260d10bb2a692b6357 Mon Sep 17 00:00:00 2001 From: Jian S Date: Mon, 15 Apr 2024 21:56:15 +0300 Subject: [PATCH 15/16] update --- kornia/augmentation/callbacks/base.py | 22 ++++++++++++------- kornia/augmentation/callbacks/local_logger.py | 10 ++------- kornia/augmentation/callbacks/wandb_logger.py | 8 +------ kornia/utils/draw.py | 8 +++---- 4 files changed, 21 insertions(+), 27 deletions(-) diff --git a/kornia/augmentation/callbacks/base.py b/kornia/augmentation/callbacks/base.py index b930570033..cfaeecff92 100644 --- a/kornia/augmentation/callbacks/base.py +++ b/kornia/augmentation/callbacks/base.py @@ -267,12 +267,15 @@ def __init__( self.num_to_log = num_to_log def _make_mask_data(self, mask: Tensor) -> Tensor: - raise NotImplementedError + return mask - def _make_bbox_data(self, bbox: Boxes) -> Boxes: - raise NotImplementedError + def _make_bbox_data(self, bbox: Boxes) -> Tensor: + return cast(Tensor, bbox.to_tensor("xyxy", as_padded_sequence=True)) + + def _make_keypoints_data(self, keypoints: Keypoints) -> Tensor: + return cast(Tensor, keypoints.to_tensor(as_padded_sequence=True)) - def _log_data(self, data: List["K.container.data_types.DataType"]) -> None: + def _log_data(self, data: List[Tensor]) -> None: raise NotImplementedError def on_sequential_forward_end( @@ -283,8 +286,8 @@ def on_sequential_forward_end( data_keys: Union[List[str], List[int], List[DataKey]], ) -> None: """Called when `forward` ends for `AugmentationSequential`.""" - output_data: List["K.container.data_types.DataType"] = [] - + output_data: List[Tensor] = [] + # Log all the indices if self.log_indices is None: self.log_indices = list(range(len(data_keys))) @@ -296,15 +299,18 @@ def on_sequential_forward_end( postproc = None if self.postprocessing is not None: postproc = self.postprocessing[self.log_indices[i]] - data = arg[: self.num_to_log] + data = arg[:self.num_to_log] + if postproc is not None: data = postproc(data) if data_key in [DataKey.INPUT]: - output_data.append(data) + output_data.append(cast(Tensor, data)) if data_key in [DataKey.MASK]: output_data.append(self._make_mask_data(cast(Tensor, data))) if data_key in [DataKey.BBOX, DataKey.BBOX_XYWH, DataKey.BBOX_XYXY]: output_data.append(self._make_bbox_data(cast(Boxes, data))) + if data_key in [DataKey.KEYPOINTS]: + output_data.append(self._make_keypoints_data(cast(Keypoints, data))) self._log_data(output_data) diff --git a/kornia/augmentation/callbacks/local_logger.py b/kornia/augmentation/callbacks/local_logger.py index 19476bad67..2bf889ff60 100644 --- a/kornia/augmentation/callbacks/local_logger.py +++ b/kornia/augmentation/callbacks/local_logger.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import cast, List, Optional, Union from kornia.augmentation.container.ops import DataType from kornia.constants import DataKey @@ -41,11 +41,5 @@ def __init__( ) self.log_dir = log_dir - def _make_mask_data(self, mask: Tensor) -> Tensor: - raise NotImplementedError - - def _make_bbox_data(self, bbox: Boxes) -> Boxes: - raise NotImplementedError - - def _log_data(self, data: List[DataType]) -> None: + def _log_data(self, data: List[Tensor]) -> None: raise NotImplementedError diff --git a/kornia/augmentation/callbacks/wandb_logger.py b/kornia/augmentation/callbacks/wandb_logger.py index eccae44bcf..f18596cb9d 100644 --- a/kornia/augmentation/callbacks/wandb_logger.py +++ b/kornia/augmentation/callbacks/wandb_logger.py @@ -51,13 +51,7 @@ def contains_duplicated_keys(self, data_keys: Optional[Union[List[str], List[int # WANDB only supports visualization without duplication ... - def _make_mask_data(self, mask: Tensor) -> Tensor: - raise NotImplementedError - - def _make_bbox_data(self, bbox: Boxes) -> Boxes: - raise NotImplementedError - - def _log_data(self, data: List[DataType]) -> None: + def _log_data(self, data: List[Tensor]) -> None: ... # assert self.data_keys no duplication, ... # for i, (value, key) in enumerate(zip(data, self.data_keys)): diff --git a/kornia/utils/draw.py b/kornia/utils/draw.py index 7e7db6e521..f60419c3db 100644 --- a/kornia/utils/draw.py +++ b/kornia/utils/draw.py @@ -35,7 +35,7 @@ def draw_point2d(image: Tensor, points: Tensor, color: Tensor) -> Tensor: return image -def _draw_pixel(image: torch.Tensor, x: int, y: int, color: torch.Tensor) -> None: +def _draw_pixel(image: Tensor, x: int, y: int, color: Tensor) -> None: r"""Draws a pixel into an image. Args: @@ -50,7 +50,7 @@ def _draw_pixel(image: torch.Tensor, x: int, y: int, color: torch.Tensor) -> Non image[:, y, x] = color -def draw_line(image: torch.Tensor, p1: torch.Tensor, p2: torch.Tensor, color: torch.Tensor) -> torch.Tensor: +def draw_line(image: Tensor, p1: Tensor, p2: Tensor, color: Tensor) -> Tensor: r"""Draw a single line into an image. Args: @@ -176,8 +176,8 @@ def draw_line(image: torch.Tensor, p1: torch.Tensor, p2: torch.Tensor, color: to def draw_rectangle( - image: torch.Tensor, rectangle: torch.Tensor, color: Optional[torch.Tensor] = None, fill: Optional[bool] = None -) -> torch.Tensor: + image: Tensor, rectangle: Tensor, color: Optional[Tensor] = None, fill: Optional[bool] = None +) -> Tensor: r"""Draw N rectangles on a batch of image tensors. Args: From 2eca35af8247a70dfbfe3335df3929a2ae2ff4f9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Apr 2024 18:56:32 +0000 Subject: [PATCH 16/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- kornia/augmentation/callbacks/base.py | 4 ++-- kornia/augmentation/callbacks/local_logger.py | 4 +--- kornia/augmentation/callbacks/wandb_logger.py | 2 -- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/kornia/augmentation/callbacks/base.py b/kornia/augmentation/callbacks/base.py index cfaeecff92..accf9547ae 100644 --- a/kornia/augmentation/callbacks/base.py +++ b/kornia/augmentation/callbacks/base.py @@ -287,7 +287,7 @@ def on_sequential_forward_end( ) -> None: """Called when `forward` ends for `AugmentationSequential`.""" output_data: List[Tensor] = [] - + # Log all the indices if self.log_indices is None: self.log_indices = list(range(len(data_keys))) @@ -299,7 +299,7 @@ def on_sequential_forward_end( postproc = None if self.postprocessing is not None: postproc = self.postprocessing[self.log_indices[i]] - data = arg[:self.num_to_log] + data = arg[: self.num_to_log] if postproc is not None: data = postproc(data) diff --git a/kornia/augmentation/callbacks/local_logger.py b/kornia/augmentation/callbacks/local_logger.py index 2bf889ff60..8192719f6f 100644 --- a/kornia/augmentation/callbacks/local_logger.py +++ b/kornia/augmentation/callbacks/local_logger.py @@ -1,9 +1,7 @@ -from typing import cast, List, Optional, Union +from typing import List, Optional, Union -from kornia.augmentation.container.ops import DataType from kornia.constants import DataKey from kornia.core import Module, Tensor -from kornia.geometry.boxes import Boxes from .base import AugmentationCallback diff --git a/kornia/augmentation/callbacks/wandb_logger.py b/kornia/augmentation/callbacks/wandb_logger.py index f18596cb9d..11e5d283d8 100644 --- a/kornia/augmentation/callbacks/wandb_logger.py +++ b/kornia/augmentation/callbacks/wandb_logger.py @@ -1,10 +1,8 @@ import importlib from typing import List, Optional, Union -from kornia.augmentation.container.ops import DataType from kornia.constants import DataKey from kornia.core import Module, Tensor -from kornia.geometry.boxes import Boxes from .base import AugmentationCallback