diff --git a/CHANGELOG.md b/CHANGELOG.md index 0954d213b7..99f5a00651 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 () - Pass Keyword Argument to TabularDataBase () +- Enable dtype argument when calling media.data + () ### Bug fixes - Preserve end_frame information of a video when it is zero. diff --git a/src/datumaro/components/media.py b/src/datumaro/components/media.py index 0445a95417..d348c03f87 100644 --- a/src/datumaro/components/media.py +++ b/src/datumaro/components/media.py @@ -11,6 +11,7 @@ import shutil from copy import deepcopy from enum import IntEnum +from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -39,6 +40,7 @@ copyto_image, decode_image, lazy_image, + load_image, save_image, ) @@ -224,6 +226,7 @@ def __init__( f"{self.__class__.__name__}.from_numpy(), {self.__class__.__name__}.from_bytes())." ) super().__init__(*args, **kwargs) + self._dtype = np.uint8 if ext is not None: if not ext.startswith("."): @@ -322,6 +325,8 @@ def data(self) -> Optional[np.ndarray]: if not self.has_data: return None + if self.__data._dtype != self._dtype: + self.__data._loader = partial(load_image, dtype=self._dtype) data = self.__data() if self._size is None and data is not None: @@ -368,6 +373,11 @@ def set_crypter(self, crypter: Crypter): if isinstance(self.__data, lazy_image): self.__data._crypter = crypter + def get_data_as_dtype(self, dtype: Optional[np.dtype] = np.uint8) -> Optional[np.ndarray]: + """Get image data with a specific data type""" + self._dtype = dtype + return self.data + class ImageFromData(FromDataMixin, Image): def save( @@ -400,8 +410,8 @@ def data(self) -> Optional[np.ndarray]: data = super().data - if isinstance(data, np.ndarray) and data.dtype != np.uint8: - data = np.clip(data, 0.0, 255.0).astype(np.uint8) + if isinstance(data, np.ndarray) and data.dtype != self._dtype: + data = np.clip(data, 0.0, 255.0).astype(self._dtype) if self._size is None and data is not None: if not 2 <= data.ndim <= 3: raise MediaShapeError("An image should have 2 (gray) or 3 (bgra) dims.") @@ -413,6 +423,11 @@ def has_size(self) -> bool: """Indicates that size info is cached and won't require image loading""" return self._size is not None or isinstance(self._data, np.ndarray) + def get_data_as_dtype(self, dtype: Optional[np.dtype] = np.uint8) -> Optional[np.ndarray]: + """Get image data with a specific data type""" + self._dtype = dtype + return self.data + class ImageFromBytes(ImageFromData): _FORMAT_MAGICS = ( @@ -446,13 +461,21 @@ def data(self) -> Optional[np.ndarray]: data = super().data if isinstance(data, bytes): - data = decode_image(data, dtype=np.uint8) + data = decode_image(data, dtype=self._dtype) if self._size is None and data is not None: if not 2 <= data.ndim <= 3: raise MediaShapeError("An image should have 2 (gray) or 3 (bgra) dims.") self._size = tuple(map(int, data.shape[:2])) return data + def get_data_as_dtype(self, dtype: Optional[np.dtype] = np.uint8) -> Optional[np.ndarray]: + """Get image data with a specific data type""" + + if dtype != np.uint8: + raise ValueError("ImageFromBytes only support `dtype=np.uint8`.") + self._dtype = dtype + return self.data + class VideoFrame(ImageFromNumpy): _type = MediaType.VIDEO_FRAME diff --git a/src/datumaro/util/image.py b/src/datumaro/util/image.py index 4a31d35645..4b94447726 100644 --- a/src/datumaro/util/image.py +++ b/src/datumaro/util/image.py @@ -60,9 +60,9 @@ class ImageColorChannel(Enum): COLOR_BGR = 1 COLOR_RGB = 2 - def decode_by_cv2(self, image_bytes: bytes) -> np.ndarray: + def decode_by_cv2(self, image_bytes: bytes, dtype: DTypeLike = np.uint8) -> np.ndarray: """Convert image color channel for OpenCV image (np.ndarray).""" - image_buffer = np.frombuffer(image_bytes, dtype=np.uint8) + image_buffer = np.frombuffer(image_bytes, dtype=dtype) if self == ImageColorChannel.UNCHANGED: return cv2.imdecode(image_buffer, cv2.IMREAD_UNCHANGED) @@ -283,15 +283,26 @@ def encode_image(image: np.ndarray, ext: str, dtype: DTypeLike = np.uint8, **kwa raise NotImplementedError() -def decode_image(image_bytes: bytes, dtype: DTypeLike = np.uint8) -> np.ndarray: +def decode_image(image_bytes: bytes, dtype: np.dtype = np.uint8) -> np.ndarray: ctx_color_scale = IMAGE_COLOR_CHANNEL.get() - if IMAGE_BACKEND.get() == ImageBackend.cv2: - image = ctx_color_scale.decode_by_cv2(image_bytes) - elif IMAGE_BACKEND.get() == ImageBackend.PIL: - image = ctx_color_scale.decode_by_pil(image_bytes) + if np.issubdtype(dtype, np.floating): + # PIL doesn't support floating point image loading + # CV doesn't support floating point image with color channel setting (IMREAD_COLOR) + with decode_image_context( + image_backend=ImageBackend.cv2, image_color_channel=ImageColorChannel.UNCHANGED + ): + image = ctx_color_scale.decode_by_cv2(image_bytes, dtype=dtype) + image = image[..., ::-1] + if ctx_color_scale == ImageColorChannel.COLOR_BGR: + image = image[..., ::-1] else: - raise NotImplementedError() + if IMAGE_BACKEND.get() == ImageBackend.cv2: + image = ctx_color_scale.decode_by_cv2(image_bytes) + elif IMAGE_BACKEND.get() == ImageBackend.PIL: + image = ctx_color_scale.decode_by_pil(image_bytes) + else: + raise NotImplementedError() image = image.astype(dtype) @@ -376,6 +387,7 @@ def __init__( assert isinstance(cache, (ImageCache, bool)) self._cache = cache self._crypter = crypter + self._dtype = dtype def __call__(self) -> np.ndarray: image = None diff --git a/tests/requirements.txt b/tests/requirements.txt index 4a0346e718..b58a3f3492 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -7,3 +7,5 @@ pytest-stress pytest-html coverage pytest-csv + +tifffile diff --git a/tests/unit/test_images.py b/tests/unit/test_images.py index bb4a56b5bb..0862abd268 100644 --- a/tests/unit/test_images.py +++ b/tests/unit/test_images.py @@ -193,6 +193,26 @@ def test_ext_detection_failure(self): image = Image.from_bytes(data=image_bytes) self.assertEqual(image.ext, None) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_floating_image_from_numpy(self): + image_float = np.random.rand(32, 32, 3).astype(np.float16) * 255.0 + media = Image.from_numpy(image_float) + data = media.get_data_as_dtype(dtype=np.float16) + self.assertTrue(np.all(image_float == data)) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_floating_image_from_file(self): + import tifffile + + with TestDir() as test_dir: + image_float = np.random.rand(32, 32, 3).astype(np.float32) * 255.0 + image_path = osp.join(test_dir, "floating_image.tiff") + tifffile.imwrite(image_path, image_float) + + media = Image.from_file(image_path) + data = media.get_data_as_dtype(dtype=np.float32) + self.assertTrue(np.all(image_float == data)) + class RoIImageTest(TestCase): def _test_ctors(self, img_ctor, args_list, test_dir, is_bytes=False):