Skip to content

Commit

Permalink
Modify Workflow to Allow IterableDataset Inputs (#8263)
Browse files Browse the repository at this point in the history
### Description

This modifies the behaviour of `Workflow` to permit `IterableDataset` to
be used correctly. A check against the `epoch_length` value is removed,
to allow that value to be `None`, and a test is added to verify this.
The length of a data loader is not defined when using iterable datasets,
so try/raise is added to allow that to be queried safely. This is
related to my work on the streaming support, in my [prototype
gist](https://gist.github.com/ericspod/1904713716b45631260784ac3fcd6fb3)
I had to provide a bogus epoch length value in the then change it to
`None` later once the evaluator object was created. This PR will remove
the need for this hack.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk>
Signed-off-by: Eric Kerfoot <eric.kerfoot@gmail>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <eric.kerfoot@gmail>
  • Loading branch information
3 people authored Dec 19, 2024
1 parent 21920a3 commit e1e3d8e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
22 changes: 11 additions & 11 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from __future__ import annotations

import warnings
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence, Sized
from typing import TYPE_CHECKING, Any

import torch
Expand Down Expand Up @@ -121,24 +121,24 @@ def __init__(
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
) -> None:
if iteration_update is not None:
super().__init__(iteration_update)
else:
super().__init__(self._iteration)
super().__init__(self._iteration if iteration_update is None else iteration_update)

if isinstance(data_loader, DataLoader):
sampler = data_loader.__dict__["sampler"]
sampler = getattr(data_loader, "sampler", None)

# set the epoch value for DistributedSampler objects when an epoch starts
if isinstance(sampler, DistributedSampler):

@self.on(Events.EPOCH_STARTED)
def set_sampler_epoch(engine: Engine) -> None:
sampler.set_epoch(engine.state.epoch)

if epoch_length is None:
# if the epoch_length isn't given, attempt to get it from the length of the data loader
if epoch_length is None and isinstance(data_loader, Sized):
try:
epoch_length = len(data_loader)
else:
if epoch_length is None:
raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.")
except TypeError: # raised when data_loader has an iterable dataset with no length, or is some other type
pass # deliberately leave epoch_length as None

# set all sharable data for the workflow based on Ignite engine.state
self.state: Any = State(
Expand All @@ -147,7 +147,7 @@ def set_sampler_epoch(engine: Engine) -> None:
iteration=0,
epoch=0,
max_epochs=max_epochs,
epoch_length=epoch_length,
epoch_length=epoch_length, # None when the dataset is iterable and so has no length
output=None,
batch=None,
metrics={},
Expand Down
13 changes: 13 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import nibabel as nib
import numpy as np
import torch.nn as nn

from monai.data import DataLoader, Dataset, IterableDataset
from monai.engines import SupervisedEvaluator
from monai.transforms import Compose, LoadImaged, SimulateDelayd


Expand Down Expand Up @@ -59,6 +61,17 @@ def test_shape(self):
for d in dataloader:
self.assertTupleEqual(d["image"].shape[1:], expected_shape)

def test_supervisedevaluator(self):
"""
Test that a SupervisedEvaluator is compatible with IterableDataset in conjunction with DataLoader.
"""
data = list(range(10))
dl = DataLoader(IterableDataset(data))
evaluator = SupervisedEvaluator(device="cpu", val_data_loader=dl, network=nn.Identity())
evaluator.run() # fails if the epoch length or other internal setup is not done correctly

self.assertEqual(evaluator.state.iteration, len(data))


if __name__ == "__main__":
unittest.main()

0 comments on commit e1e3d8e

Please sign in to comment.