Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implemented MultiReader #68

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions minerva/data/readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from .reader import _Reader
from .tiff_reader import TiffReader
from .zarr_reader import PatchedZarrReader
from .multi_reader import MultiReader

__all__ = [
"PatchedArrayReader",
"PatchedZarrReader",
"PNGReader",
"TiffReader",
"MultiRead",
"_Reader",
]
69 changes: 69 additions & 0 deletions minerva/data/readers/multi_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Any, Callable, Optional, Sequence
import numpy as np

from minerva.data.readers import _Reader


class MultiReader(_Reader):
"""Reader that composes items from other readers.

Its i-th item is the i-th item of each of the child-readers merged
together according to a collate_fn function."""

def __init__(
self,
readers: Sequence[_Reader],
preprocess: Optional[Callable] = None,
collate_fn: Optional[Callable] = np.stack
):
"""Collects data from multiple readers and collates them

Parameters
----------
readers: Sequence[_Reader]
The readers from which the data will be collected. At least one must be
provided. If the readers have different lengths, data will only be
collected up until the length of the smallest child-reader.
preprocess: Optional[Callable]
A function to be applied individually to each item read from the child-readers.
Defaults to an identity function (i.e. no changes to the data).
collate_fn: Optional[Callable]
A function that recieves a list of items read from the child-readers after
preprocessing and returns a single item for this reader.
Defaults to numpy.stack, which means it must be provided if the preprocessing
function does not always return same-shape numpy arrays.
"""
assert len(readers) > 0, "MultiReader expects at least one reader as argument."

self._readers = readers
self.preprocess = preprocess or (lambda x: x)
self.collate_fn = collate_fn

def __len__(self) -> int:
"""Returns the length the reader, defined as the length of the smallest
child-reader

Returns
-------
int
The length of the reader."""
return min(len(reader) for reader in self._readers)

def __getitem__(self, index: int) -> Any:
"""Retrieves the items from each reader at the specified index and collates them
accordingly.

Parameters
----------
index : int
Index of the item to retrieve.

Returns
-------
Any
An item from the reader.
"""

return self.collate_fn([self.preprocess(reader[index]) for reader in self._readers])


62 changes: 62 additions & 0 deletions tests/data/readers/test_multi_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import numpy as np

from minerva.data.readers import MultiReader, PatchedArrayReader


def test_multi_reader_identity():

reader1 = PatchedArrayReader(
np.arange(15**2).reshape(1, 15, 15),
data_shape=(1, 5, 5),
)

reader2 = PatchedArrayReader(
np.arange(10**2).reshape(1, 10, 10), data_shape=(1, 5, 5)
)

multireader = MultiReader([reader1, reader2])

assert len(multireader) == min(
len(reader1), len(reader2)
), "Reader has incorrect length"

assert np.all(
multireader[0] == np.stack([reader1[0], reader2[0]])
), "Reader's first element is incorrect"

assert np.all(
multireader[len(multireader) - 1]
== np.stack([reader1[len(multireader) - 1], reader2[len(multireader) - 1]])
), "Reader's last element is incorrect"


def test_multi_reader_squeeze():

reader1 = PatchedArrayReader(
np.arange(15**2).reshape(1, 15, 15),
data_shape=(1, 5, 5),
)

reader2 = PatchedArrayReader(
np.arange(10**2).reshape(1, 10, 10), data_shape=(1, 5, 5)
)

multireader = MultiReader([reader1, reader2], np.squeeze)

assert len(multireader) == min(
len(reader1), len(reader2)
), "Reader has incorrect length"

assert np.all(
multireader[0] == np.stack([np.squeeze(reader1[0]), np.squeeze(reader2[0])])
), "Reader's first element is incorrect"

assert np.all(
multireader[len(multireader) - 1]
== np.stack(
[
np.squeeze(reader1[len(multireader) - 1]),
np.squeeze(reader2[len(multireader) - 1]),
]
)
), "Reader's last element is incorrect"
Loading