diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 9ec6c6e75e0bd..e1ede580e8324 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -601,9 +601,9 @@ def test_fit_loop_reset(tmp_path): assert epoch_loop.batch_progress.total.ready == 4 assert epoch_loop.batch_progress.total.processed == 4 assert epoch_loop.batch_progress.total.completed == 4 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 4 # currents get set to the completed value - assert epoch_loop.batch_progress.current.processed == 4 - assert epoch_loop.batch_progress.current.completed == 4 + assert epoch_loop.batch_progress.current.ready == 0 # currents get set to the completed value + assert epoch_loop.batch_progress.current.processed == 0 + assert epoch_loop.batch_progress.current.completed == 0 @pytest.mark.parametrize(