diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 3629659db1..0c36da6d3d 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -12,7 +12,7 @@ from __future__ import annotations import warnings -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence, Sized from typing import TYPE_CHECKING, Any import torch @@ -121,24 +121,24 @@ def __init__( to_kwargs: dict | None = None, amp_kwargs: dict | None = None, ) -> None: - if iteration_update is not None: - super().__init__(iteration_update) - else: - super().__init__(self._iteration) + super().__init__(self._iteration if iteration_update is None else iteration_update) if isinstance(data_loader, DataLoader): - sampler = data_loader.__dict__["sampler"] + sampler = getattr(data_loader, "sampler", None) + + # set the epoch value for DistributedSampler objects when an epoch starts if isinstance(sampler, DistributedSampler): @self.on(Events.EPOCH_STARTED) def set_sampler_epoch(engine: Engine) -> None: sampler.set_epoch(engine.state.epoch) - if epoch_length is None: + # if the epoch_length isn't given, attempt to get it from the length of the data loader + if epoch_length is None and isinstance(data_loader, Sized): + try: epoch_length = len(data_loader) - else: - if epoch_length is None: - raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.") + except TypeError: # raised when data_loader has an iterable dataset with no length, or is some other type + pass # deliberately leave epoch_length as None # set all sharable data for the workflow based on Ignite engine.state self.state: Any = State( @@ -147,7 +147,7 @@ def set_sampler_epoch(engine: Engine) -> None: iteration=0, epoch=0, max_epochs=max_epochs, - epoch_length=epoch_length, + epoch_length=epoch_length, # None when the dataset is iterable and so has no length output=None, batch=None, metrics={}, diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index cfa711e4c0..fb554e391c 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -18,8 +18,10 @@ import nibabel as nib import numpy as np +import torch.nn as nn from monai.data import DataLoader, Dataset, IterableDataset +from monai.engines import SupervisedEvaluator from monai.transforms import Compose, LoadImaged, SimulateDelayd @@ -59,6 +61,17 @@ def test_shape(self): for d in dataloader: self.assertTupleEqual(d["image"].shape[1:], expected_shape) + def test_supervisedevaluator(self): + """ + Test that a SupervisedEvaluator is compatible with IterableDataset in conjunction with DataLoader. + """ + data = list(range(10)) + dl = DataLoader(IterableDataset(data)) + evaluator = SupervisedEvaluator(device="cpu", val_data_loader=dl, network=nn.Identity()) + evaluator.run() # fails if the epoch length or other internal setup is not done correctly + + self.assertEqual(evaluator.state.iteration, len(data)) + if __name__ == "__main__": unittest.main()