Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
shijianjian committed Sep 10, 2024
1 parent 0a15677 commit 8897941
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 50 deletions.
101 changes: 51 additions & 50 deletions kornia/enhance/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,28 +100,34 @@ def normalize(data: Tensor, mean: Tensor, std: Tensor) -> Tensor:
torch.Size([1, 4, 3, 3])
"""
shape = data.shape
if len(mean.shape) == 0 or mean.shape[0] == 1:
mean = mean.expand(shape[1])
if len(std.shape) == 0 or std.shape[0] == 1:
std = std.expand(shape[1])

# Allow broadcast on channel dimension
if mean.shape and mean.shape[0] != 1:
if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]:
raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")
if torch.onnx.is_in_onnx_export():
if not isinstance(mean, Tensor) or not isinstance(std, Tensor):
raise ValueError("Only tensor is accepted when converting to ONNX.")
if mean.shape[0] != 1 or std.shape[0] != 1:
raise ValueError("Batch dimension must be one for broadcasting when converting to ONNX.")
else:
if isinstance(mean, float):
mean = torch.tensor([mean] * shape[1], device=data.device, dtype=data.dtype)

# Allow broadcast on channel dimension
if std.shape and std.shape[0] != 1:
if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]:
raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")
if isinstance(std, float):
std = torch.tensor([std] * shape[1], device=data.device, dtype=data.dtype)

mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
std = torch.as_tensor(std, device=data.device, dtype=data.dtype)
# Allow broadcast on channel dimension
if mean.shape and mean.shape[0] != 1:
if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]:
raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")

if mean.shape:
mean = mean[..., :, None]
if std.shape:
std = std[..., :, None]
# Allow broadcast on channel dimension
if std.shape and std.shape[0] != 1:
if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]:
raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")

mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
std = torch.as_tensor(std, device=data.device, dtype=data.dtype)

mean = mean[..., :, None]
std = std[..., :, None]

out: Tensor = (data.view(shape[0], shape[1], -1) - mean) / std

Expand Down Expand Up @@ -203,38 +209,33 @@ def denormalize(data: Tensor, mean: Union[Tensor, float], std: Union[Tensor, flo
"""
shape = data.shape

if isinstance(mean, float):
mean = torch.tensor([mean] * shape[1], device=data.device, dtype=data.dtype)

if isinstance(std, float):
std = torch.tensor([std] * shape[1], device=data.device, dtype=data.dtype)

if not isinstance(data, Tensor):
raise TypeError(f"data should be a tensor. Got {type(data)}")

if not isinstance(mean, Tensor):
raise TypeError(f"mean should be a tensor or a float. Got {type(mean)}")

if not isinstance(std, Tensor):
raise TypeError(f"std should be a tensor or float. Got {type(std)}")

# Allow broadcast on channel dimension
if mean.shape and mean.shape[0] != 1:
if mean.shape[0] != data.shape[-3] and mean.shape[:2] != data.shape[:2]:
raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")

# Allow broadcast on channel dimension
if std.shape and std.shape[0] != 1:
if std.shape[0] != data.shape[-3] and std.shape[:2] != data.shape[:2]:
raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")

mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
std = torch.as_tensor(std, device=data.device, dtype=data.dtype)

if mean.shape:
mean = mean[..., :, None]
if std.shape:
std = std[..., :, None]
if torch.onnx.is_in_onnx_export():
if not isinstance(mean, Tensor) or not isinstance(std, Tensor):
raise ValueError("Only tensor is accepted when converting to ONNX.")
if mean.shape[0] != 1 or std.shape[0] != 1:
raise ValueError("Batch dimension must be one for broadcasting when converting to ONNX.")
else:
if isinstance(mean, float):
mean = torch.tensor([mean] * shape[1], device=data.device, dtype=data.dtype)

if isinstance(std, float):
std = torch.tensor([std] * shape[1], device=data.device, dtype=data.dtype)

# Allow broadcast on channel dimension
if mean.shape and mean.shape[0] != 1:
if mean.shape[0] != data.shape[-3] and mean.shape[:2] != data.shape[:2]:
raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")

# Allow broadcast on channel dimension
if std.shape and std.shape[0] != 1:
if std.shape[0] != data.shape[-3] and std.shape[:2] != data.shape[:2]:
raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")

mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
std = torch.as_tensor(std, device=data.device, dtype=data.dtype)

mean = mean[..., :, None]
std = std[..., :, None]

out: Tensor = (data.view(shape[0], shape[1], -1) * std) + mean

Expand Down
2 changes: 2 additions & 0 deletions kornia/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from .one_hot import one_hot
from .pointcloud_io import load_pointcloud_ply, save_pointcloud_ply
from .sample import get_sample_images

__all__ = [
"batched_forward",
Expand Down Expand Up @@ -62,4 +63,5 @@
"is_mps_tensor_safe",
"dataclass_to_dict",
"dict_to_dataclass",
"get_sample_images"
]
81 changes: 81 additions & 0 deletions kornia/utils/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
import logging
import torch
import kornia
import requests
from io import BytesIO
from typing import Optional, Union

from kornia.io import load_image
from kornia.core import Tensor, stack
from kornia.core.external import PILImage as Image

__all__ = [
"get_sample_images",
]

IMAGE_URLS: list[str] = [
"https://raw.githubusercontent.com/kornia/data/main/panda.jpg",
"https://raw.githubusercontent.com/kornia/data/main/simba.png",
"https://raw.githubusercontent.com/kornia/data/main/girona.png",
"https://raw.githubusercontent.com/kornia/data/main/baby_giraffe.png",
"https://raw.githubusercontent.com/kornia/data/main/persistencia_memoria.jpg",
"https://raw.githubusercontent.com/kornia/data/main/delorean.png",
]


def download_image(url: str, save_to: str) -> None:
"""Download an image from a given URL and save it to a specified file path.
Args:
url: The URL of the image to download.
save_to: The file path where the downloaded image will be saved.
"""
im = Image.open(requests.get(url, stream=True).raw) # type:ignore
im.save(save_to) # type: ignore


def get_sample_images(
resize: tuple[int, int] = None, paths: list[str] = IMAGE_URLS, download: bool = True,
cache_dir: Optional[str] = None
) -> Union[Tensor, list[Tensor]]:
"""Loads multiple images from the given URLs.
Optionally download them, resize them if specified, and return them as a batch of tensors or a list of tensors.
Args:
paths: A list of path or URL from which to load or download images.
Defaults to a pre-defined constant `IMAGE_URLS` if not provided.
resize: Optional target size for resizing all images as a tuple (height, width).
If not provided, the images will not be resized, and their original sizes will be retained.
download (bool): Whether to download the images if they are not already cached. Defaults to True.
cache_dir (Optional[str]): The directory where the downloaded images will be cached.
Defaults to ".kornia_hub/images".
Returns:
torch.Tensor | list[torch.Tensor]:
If `resize` is provided, returns a single stacked tensor with shape (B, C, H, W).
Otherwise, returns a list of tensors, each with its original shape (C, H, W).
"""
if cache_dir is None:
cache_dir = ".kornia_hub/images"
os.makedirs(cache_dir, exist_ok=True)
tensors = []
for path in paths:
if path.startswith("http"):
name = os.path.basename(path)
fname = os.path.join(cache_dir, name)
if not os.path.exists(fname):
logging.info(f"Downloading `{path}` to `{fname}`.")
download_image(path, fname)
else:
fname = path
img_tensor = load_image(fname)
if resize is not None:
img_tensor = kornia.geometry.resize(img_tensor, resize)
tensors.append(img_tensor)

if resize is not None:
return stack(tensors)
else:
return tensors

0 comments on commit 8897941

Please sign in to comment.