From 88979410bdd3c3cf08237b9c09413ebf7aa9d384 Mon Sep 17 00:00:00 2001 From: shijianjian Date: Tue, 10 Sep 2024 22:57:58 +0300 Subject: [PATCH] update --- kornia/enhance/normalize.py | 101 ++++++++++++++++++------------------ kornia/utils/__init__.py | 2 + kornia/utils/sample.py | 81 +++++++++++++++++++++++++++++ 3 files changed, 134 insertions(+), 50 deletions(-) create mode 100644 kornia/utils/sample.py diff --git a/kornia/enhance/normalize.py b/kornia/enhance/normalize.py index 423d6ee3bd..6e758527f1 100644 --- a/kornia/enhance/normalize.py +++ b/kornia/enhance/normalize.py @@ -100,28 +100,34 @@ def normalize(data: Tensor, mean: Tensor, std: Tensor) -> Tensor: torch.Size([1, 4, 3, 3]) """ shape = data.shape - if len(mean.shape) == 0 or mean.shape[0] == 1: - mean = mean.expand(shape[1]) - if len(std.shape) == 0 or std.shape[0] == 1: - std = std.expand(shape[1]) - # Allow broadcast on channel dimension - if mean.shape and mean.shape[0] != 1: - if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]: - raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.") + if torch.onnx.is_in_onnx_export(): + if not isinstance(mean, Tensor) or not isinstance(std, Tensor): + raise ValueError("Only tensor is accepted when converting to ONNX.") + if mean.shape[0] != 1 or std.shape[0] != 1: + raise ValueError("Batch dimension must be one for broadcasting when converting to ONNX.") + else: + if isinstance(mean, float): + mean = torch.tensor([mean] * shape[1], device=data.device, dtype=data.dtype) - # Allow broadcast on channel dimension - if std.shape and std.shape[0] != 1: - if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]: - raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.") + if isinstance(std, float): + std = torch.tensor([std] * shape[1], device=data.device, dtype=data.dtype) - mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype) - std = torch.as_tensor(std, device=data.device, dtype=data.dtype) + # Allow broadcast on channel dimension + if mean.shape and mean.shape[0] != 1: + if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]: + raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.") - if mean.shape: - mean = mean[..., :, None] - if std.shape: - std = std[..., :, None] + # Allow broadcast on channel dimension + if std.shape and std.shape[0] != 1: + if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]: + raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.") + + mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype) + std = torch.as_tensor(std, device=data.device, dtype=data.dtype) + + mean = mean[..., :, None] + std = std[..., :, None] out: Tensor = (data.view(shape[0], shape[1], -1) - mean) / std @@ -203,38 +209,33 @@ def denormalize(data: Tensor, mean: Union[Tensor, float], std: Union[Tensor, flo """ shape = data.shape - if isinstance(mean, float): - mean = torch.tensor([mean] * shape[1], device=data.device, dtype=data.dtype) - - if isinstance(std, float): - std = torch.tensor([std] * shape[1], device=data.device, dtype=data.dtype) - - if not isinstance(data, Tensor): - raise TypeError(f"data should be a tensor. Got {type(data)}") - - if not isinstance(mean, Tensor): - raise TypeError(f"mean should be a tensor or a float. Got {type(mean)}") - - if not isinstance(std, Tensor): - raise TypeError(f"std should be a tensor or float. Got {type(std)}") - - # Allow broadcast on channel dimension - if mean.shape and mean.shape[0] != 1: - if mean.shape[0] != data.shape[-3] and mean.shape[:2] != data.shape[:2]: - raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.") - - # Allow broadcast on channel dimension - if std.shape and std.shape[0] != 1: - if std.shape[0] != data.shape[-3] and std.shape[:2] != data.shape[:2]: - raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.") - - mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype) - std = torch.as_tensor(std, device=data.device, dtype=data.dtype) - - if mean.shape: - mean = mean[..., :, None] - if std.shape: - std = std[..., :, None] + if torch.onnx.is_in_onnx_export(): + if not isinstance(mean, Tensor) or not isinstance(std, Tensor): + raise ValueError("Only tensor is accepted when converting to ONNX.") + if mean.shape[0] != 1 or std.shape[0] != 1: + raise ValueError("Batch dimension must be one for broadcasting when converting to ONNX.") + else: + if isinstance(mean, float): + mean = torch.tensor([mean] * shape[1], device=data.device, dtype=data.dtype) + + if isinstance(std, float): + std = torch.tensor([std] * shape[1], device=data.device, dtype=data.dtype) + + # Allow broadcast on channel dimension + if mean.shape and mean.shape[0] != 1: + if mean.shape[0] != data.shape[-3] and mean.shape[:2] != data.shape[:2]: + raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.") + + # Allow broadcast on channel dimension + if std.shape and std.shape[0] != 1: + if std.shape[0] != data.shape[-3] and std.shape[:2] != data.shape[:2]: + raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.") + + mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype) + std = torch.as_tensor(std, device=data.device, dtype=data.dtype) + + mean = mean[..., :, None] + std = std[..., :, None] out: Tensor = (data.view(shape[0], shape[1], -1) * std) + mean diff --git a/kornia/utils/__init__.py b/kornia/utils/__init__.py index 8fc674f6fb..95cfe1016e 100644 --- a/kornia/utils/__init__.py +++ b/kornia/utils/__init__.py @@ -28,6 +28,7 @@ ) from .one_hot import one_hot from .pointcloud_io import load_pointcloud_ply, save_pointcloud_ply +from .sample import get_sample_images __all__ = [ "batched_forward", @@ -62,4 +63,5 @@ "is_mps_tensor_safe", "dataclass_to_dict", "dict_to_dataclass", + "get_sample_images" ] diff --git a/kornia/utils/sample.py b/kornia/utils/sample.py new file mode 100644 index 0000000000..ee0d2d7351 --- /dev/null +++ b/kornia/utils/sample.py @@ -0,0 +1,81 @@ +import os +import logging +import torch +import kornia +import requests +from io import BytesIO +from typing import Optional, Union + +from kornia.io import load_image +from kornia.core import Tensor, stack +from kornia.core.external import PILImage as Image + +__all__ = [ + "get_sample_images", +] + +IMAGE_URLS: list[str] = [ + "https://raw.githubusercontent.com/kornia/data/main/panda.jpg", + "https://raw.githubusercontent.com/kornia/data/main/simba.png", + "https://raw.githubusercontent.com/kornia/data/main/girona.png", + "https://raw.githubusercontent.com/kornia/data/main/baby_giraffe.png", + "https://raw.githubusercontent.com/kornia/data/main/persistencia_memoria.jpg", + "https://raw.githubusercontent.com/kornia/data/main/delorean.png", +] + + +def download_image(url: str, save_to: str) -> None: + """Download an image from a given URL and save it to a specified file path. + + Args: + url: The URL of the image to download. + save_to: The file path where the downloaded image will be saved. + """ + im = Image.open(requests.get(url, stream=True).raw) # type:ignore + im.save(save_to) # type: ignore + + +def get_sample_images( + resize: tuple[int, int] = None, paths: list[str] = IMAGE_URLS, download: bool = True, + cache_dir: Optional[str] = None +) -> Union[Tensor, list[Tensor]]: + """Loads multiple images from the given URLs. + + Optionally download them, resize them if specified, and return them as a batch of tensors or a list of tensors. + + Args: + paths: A list of path or URL from which to load or download images. + Defaults to a pre-defined constant `IMAGE_URLS` if not provided. + resize: Optional target size for resizing all images as a tuple (height, width). + If not provided, the images will not be resized, and their original sizes will be retained. + download (bool): Whether to download the images if they are not already cached. Defaults to True. + cache_dir (Optional[str]): The directory where the downloaded images will be cached. + Defaults to ".kornia_hub/images". + + Returns: + torch.Tensor | list[torch.Tensor]: + If `resize` is provided, returns a single stacked tensor with shape (B, C, H, W). + Otherwise, returns a list of tensors, each with its original shape (C, H, W). + """ + if cache_dir is None: + cache_dir = ".kornia_hub/images" + os.makedirs(cache_dir, exist_ok=True) + tensors = [] + for path in paths: + if path.startswith("http"): + name = os.path.basename(path) + fname = os.path.join(cache_dir, name) + if not os.path.exists(fname): + logging.info(f"Downloading `{path}` to `{fname}`.") + download_image(path, fname) + else: + fname = path + img_tensor = load_image(fname) + if resize is not None: + img_tensor = kornia.geometry.resize(img_tensor, resize) + tensors.append(img_tensor) + + if resize is not None: + return stack(tensors) + else: + return tensors