diff --git a/minerva/data/readers/__init__.py b/minerva/data/readers/__init__.py index ab97212..093600f 100644 --- a/minerva/data/readers/__init__.py +++ b/minerva/data/readers/__init__.py @@ -3,11 +3,13 @@ from .reader import _Reader from .tiff_reader import TiffReader from .zarr_reader import PatchedZarrReader +from .multi_reader import MultiReader __all__ = [ "PatchedArrayReader", "PatchedZarrReader", "PNGReader", "TiffReader", + "MultiRead", "_Reader", ] diff --git a/minerva/data/readers/multi_reader.py b/minerva/data/readers/multi_reader.py new file mode 100644 index 0000000..683be23 --- /dev/null +++ b/minerva/data/readers/multi_reader.py @@ -0,0 +1,69 @@ +from typing import Any, Callable, Optional, Sequence +import numpy as np + +from minerva.data.readers import _Reader + + +class MultiReader(_Reader): + """Reader that composes items from other readers. + + Its i-th item is the i-th item of each of the child-readers merged + together according to a collate_fn function.""" + + def __init__( + self, + readers: Sequence[_Reader], + preprocess: Optional[Callable] = None, + collate_fn: Optional[Callable] = np.stack + ): + """Collects data from multiple readers and collates them + + Parameters + ---------- + readers: Sequence[_Reader] + The readers from which the data will be collected. At least one must be + provided. If the readers have different lengths, data will only be + collected up until the length of the smallest child-reader. + preprocess: Optional[Callable] + A function to be applied individually to each item read from the child-readers. + Defaults to an identity function (i.e. no changes to the data). + collate_fn: Optional[Callable] + A function that recieves a list of items read from the child-readers after + preprocessing and returns a single item for this reader. + Defaults to numpy.stack, which means it must be provided if the preprocessing + function does not always return same-shape numpy arrays. + """ + assert len(readers) > 0, "MultiReader expects at least one reader as argument." + + self._readers = readers + self.preprocess = preprocess or (lambda x: x) + self.collate_fn = collate_fn + + def __len__(self) -> int: + """Returns the length the reader, defined as the length of the smallest + child-reader + + Returns + ------- + int + The length of the reader.""" + return min(len(reader) for reader in self._readers) + + def __getitem__(self, index: int) -> Any: + """Retrieves the items from each reader at the specified index and collates them + accordingly. + + Parameters + ---------- + index : int + Index of the item to retrieve. + + Returns + ------- + Any + An item from the reader. + """ + + return self.collate_fn([self.preprocess(reader[index]) for reader in self._readers]) + + diff --git a/tests/data/readers/test_multi_reader.py b/tests/data/readers/test_multi_reader.py new file mode 100644 index 0000000..a7dc629 --- /dev/null +++ b/tests/data/readers/test_multi_reader.py @@ -0,0 +1,62 @@ +import numpy as np + +from minerva.data.readers import MultiReader, PatchedArrayReader + + +def test_multi_reader_identity(): + + reader1 = PatchedArrayReader( + np.arange(15**2).reshape(1, 15, 15), + data_shape=(1, 5, 5), + ) + + reader2 = PatchedArrayReader( + np.arange(10**2).reshape(1, 10, 10), data_shape=(1, 5, 5) + ) + + multireader = MultiReader([reader1, reader2]) + + assert len(multireader) == min( + len(reader1), len(reader2) + ), "Reader has incorrect length" + + assert np.all( + multireader[0] == np.stack([reader1[0], reader2[0]]) + ), "Reader's first element is incorrect" + + assert np.all( + multireader[len(multireader) - 1] + == np.stack([reader1[len(multireader) - 1], reader2[len(multireader) - 1]]) + ), "Reader's last element is incorrect" + + +def test_multi_reader_squeeze(): + + reader1 = PatchedArrayReader( + np.arange(15**2).reshape(1, 15, 15), + data_shape=(1, 5, 5), + ) + + reader2 = PatchedArrayReader( + np.arange(10**2).reshape(1, 10, 10), data_shape=(1, 5, 5) + ) + + multireader = MultiReader([reader1, reader2], np.squeeze) + + assert len(multireader) == min( + len(reader1), len(reader2) + ), "Reader has incorrect length" + + assert np.all( + multireader[0] == np.stack([np.squeeze(reader1[0]), np.squeeze(reader2[0])]) + ), "Reader's first element is incorrect" + + assert np.all( + multireader[len(multireader) - 1] + == np.stack( + [ + np.squeeze(reader1[len(multireader) - 1]), + np.squeeze(reader2[len(multireader) - 1]), + ] + ) + ), "Reader's last element is incorrect"