Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] init callbacks #23

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion kornia/augmentation/auto/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from kornia.augmentation.auto.operations.base import OperationBase
from kornia.augmentation.auto.operations.policy import PolicySequential
from kornia.augmentation.container.base import ImageSequentialBase, TransformMatrixMinIn
from kornia.augmentation.container.base import ImageSequentialBase
from kornia.augmentation.container.mixins import TransformMatrixMinIn
from kornia.augmentation.container.ops import InputSequentialOps
from kornia.augmentation.container.params import ParamItem
from kornia.core import Module, Tensor
Expand Down
11 changes: 9 additions & 2 deletions kornia/augmentation/auto/operations/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import kornia.augmentation as K
from kornia.augmentation.auto.operations import OperationBase
from kornia.augmentation.container.base import ImageSequentialBase, TransformMatrixMinIn
from kornia.augmentation.callbacks import AugmentationCallbackBase
from kornia.augmentation.container.base import ImageSequentialBase
from kornia.augmentation.container.mixins import TransformMatrixMinIn
from kornia.augmentation.container.ops import InputSequentialOps
from kornia.augmentation.container.params import ParamItem
from kornia.augmentation.utils import _transform_input, override_parameters
Expand All @@ -19,10 +21,15 @@ class PolicySequential(TransformMatrixMinIn, ImageSequentialBase):
operations: a list of operations to perform.
"""

def __init__(self, *operations: OperationBase) -> None:
def __init__(
self,
*operations: OperationBase,
callbacks: List[AugmentationCallbackBase] = [],
) -> None:
self.validate_operations(*operations)
super().__init__(*operations)
self._valid_ops_for_transform_computation: Tuple[Any, ...] = (OperationBase,)
self.register_callbacks(callbacks)

def _update_transform_matrix_for_valid_op(self, module: Module) -> None:
self._transform_matrices.append(module.transform_matrix)
Expand Down
12 changes: 11 additions & 1 deletion kornia/augmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ def forward(self, input: Tensor, params: Optional[Dict[str, Tensor]] = None, **k
params, flags = self._process_kwargs_to_params_and_flags(params, self.flags, **kwargs)

output = self.apply_func(in_tensor, params, flags)
return self.transform_output_tensor(output, input_shape) if self.keepdim else output
output = self.transform_output_tensor(output, input_shape) if self.keepdim else output
return output


class _AugmentationBase(_BasicAugmentationBase):
Expand All @@ -249,6 +250,15 @@ class _AugmentationBase(_BasicAugmentationBase):
to the batch form ``False``.
"""

def __init__(
self,
p: float = 0.5,
p_batch: float = 1.0,
same_on_batch: bool = False,
keepdim: bool = False,
) -> None:
super().__init__(p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim)

def apply_transform(
self,
input: Tensor,
Expand Down
1 change: 1 addition & 0 deletions kornia/augmentation/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .base import AugmentationCallbackBase
5 changes: 5 additions & 0 deletions kornia/augmentation/callbacks/_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .base import AugmentationCallbackBase


class Logger(AugmentationCallbackBase):
"""Generic logging module."""
316 changes: 316 additions & 0 deletions kornia/augmentation/callbacks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
from typing import Dict, List, Optional, Union, cast

# NOTE: fix circular import
import kornia.augmentation as K
from kornia.constants import DataKey
from kornia.core import Module, Tensor
from kornia.geometry.boxes import Boxes
from kornia.geometry.keypoints import Keypoints

__all__ = [
"AugmentationCallbackBase",
]


class AugmentationCallbackBase(Module):
"""A Meta Callback base class."""

def on_transform_inputs_start(
self,
input: Tensor,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `transform_inputs` begins."""
...

def on_transform_inputs_end(
self,
input: Tensor,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `transform_inputs` ends."""
...

def on_transform_masks_start(
self,
input: Tensor,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `transform_masks` begins."""
...

def on_transform_masks_end(
self,
input: Tensor,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `transform_masks` ends."""
...

def on_transform_classes_start(
self,
input: Tensor,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `transform_classes` begins."""
...

def on_transform_classes_end(
self,
input: Tensor,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `transform_classes` ends."""
...

def on_transform_boxes_start(
self,
input: Boxes,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `transform_boxes` begins."""
...

def on_transform_boxes_end(
self,
input: Boxes,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `transform_boxes` ends."""
...

def on_transform_keypoints_start(
self,
input: Keypoints,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `transform_keypoints` begins."""
...

def on_transform_keypoints_end(
self,
input: Keypoints,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `transform_keypoints` ends."""
...

def on_inverse_inputs_start(
self,
input: Tensor,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `inverse_input` begins."""
...

def on_inverse_inputs_end(
self,
input: Tensor,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `inverse_inputs` ends."""
...

def on_inverse_masks_start(
self,
input: Tensor,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `inverse_masks` begins."""
...

def on_inverse_masks_end(
self,
input: Tensor,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `inverse_masks` ends."""
...

def on_inverse_classes_start(
self,
input: Tensor,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `inverse_classes` begins."""
...

def on_inverse_classes_end(
self,
input: Tensor,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `inverse_classes` ends."""
...

def on_inverse_boxes_start(
self,
input: Boxes,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `inverse_boxes` begins."""
...

def on_inverse_boxes_end(
self,
input: Boxes,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `inverse_boxes` ends."""
...

def on_inverse_keypoints_start(
self,
input: Keypoints,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `inverse_keypoints` begins."""
...

def on_inverse_keypoints_end(
self,
input: Keypoints,
params: Dict[str, Tensor],
module: object,
) -> None:
"""Called when `inverse_keypoints` ends."""
...

def on_sequential_forward_start(
self,
*args: "K.container.data_types.DataType",
module: "K.AugmentationSequential",
params: List["K.container.params.ParamItem"],
data_keys: Union[List[str], List[int], List[DataKey]],
) -> None:
"""Called when `forward` begins for `AugmentationSequential`."""
...

def on_sequential_forward_end(
self,
*args: "K.container.data_types.DataType",
module: "K.AugmentationSequential",
params: List["K.container.params.ParamItem"],
data_keys: Union[List[str], List[int], List[DataKey]],
) -> None:
"""Called when `forward` ends for `AugmentationSequential`."""
...

def on_sequential_inverse_start(
self,
*args: "K.container.data_types.DataType",
module: "K.AugmentationSequential",
params: List["K.container.params.ParamItem"],
data_keys: Union[List[str], List[int], List[DataKey]],
) -> None:
"""Called when `inverse` begins for `AugmentationSequential`."""
...

def on_sequential_inverse_end(
self,
*args: "K.container.data_types.DataType",
module: "K.AugmentationSequential",
params: List["K.container.params.ParamItem"],
data_keys: Union[List[str], List[int], List[DataKey]],
) -> None:
"""Called when `inverse` ends for `AugmentationSequential`."""
...


class AugmentationCallback(AugmentationCallbackBase):
"""Logging images for `AugmentationSequential`.

Args:
batches_to_save: the number of batches to be logged. -1 is to save all batches.
num_to_log: number of images to log in a batch.
log_indices: only selected input types are logged. If `log_indices=[0, 2]` and
`data_keys=["input", "bbox", "mask"]`, only the images and masks
will be logged.
data_keys: the input type sequential. Accepts "input", "image", "mask",
"bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
postprocessing: add postprocessing for images if needed. If not None, the length
must match `data_keys`.
"""

def __init__(
self,
batches_to_save: int = 10,
num_to_log: int = 4,
log_indices: Optional[List[int]] = None,
data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None,
postprocessing: Optional[List[Optional[Module]]] = None,
):
super().__init__()
self.batches_to_save = batches_to_save
self.log_indices = log_indices
self.data_keys = data_keys
self.postprocessing = postprocessing
self.num_to_log = num_to_log

def _make_mask_data(self, mask: Tensor) -> Tensor:
return mask

def _make_bbox_data(self, bbox: Boxes) -> Tensor:
return cast(Tensor, bbox.to_tensor("xyxy", as_padded_sequence=True))

def _make_keypoints_data(self, keypoints: Keypoints) -> Tensor:
return cast(Tensor, keypoints.to_tensor(as_padded_sequence=True))

def _log_data(self, data: List[Tensor]) -> None:
raise NotImplementedError

def on_sequential_forward_end(
self,
*args: "K.container.data_types.DataType",
module: "K.AugmentationSequential",
params: List["K.container.params.ParamItem"],
data_keys: Union[List[str], List[int], List[DataKey]],
) -> None:
"""Called when `forward` ends for `AugmentationSequential`."""
output_data: List[Tensor] = []

# Log all the indices
if self.log_indices is None:
self.log_indices = list(range(len(data_keys)))

for i, (arg, data_key) in enumerate(zip(args, data_keys)):
if i not in self.log_indices:
continue

postproc = None
if self.postprocessing is not None:
postproc = self.postprocessing[self.log_indices[i]]
data = arg[: self.num_to_log]

if postproc is not None:
data = postproc(data)

if data_key in [DataKey.INPUT]:
output_data.append(cast(Tensor, data))
if data_key in [DataKey.MASK]:
output_data.append(self._make_mask_data(cast(Tensor, data)))
if data_key in [DataKey.BBOX, DataKey.BBOX_XYWH, DataKey.BBOX_XYXY]:
output_data.append(self._make_bbox_data(cast(Boxes, data)))
if data_key in [DataKey.KEYPOINTS]:
output_data.append(self._make_keypoints_data(cast(Keypoints, data)))

self._log_data(output_data)
Loading
Loading