diff --git a/torchbenchmark/util/env_check.py b/torchbenchmark/util/env_check.py index f282c2b7e5..62809c6d65 100644 --- a/torchbenchmark/util/env_check.py +++ b/torchbenchmark/util/env_check.py @@ -367,9 +367,9 @@ def forward_and_backward_pass(mod, inputs, contexts, optimizer, collect_outputs= def run_n_iterations(mod, inputs, contexts, optimizer=None, is_training=False, iterations=STABLENESS_CHECK_ROUNDS): def _model_iter_fn(mod, inputs, contexts, optimizer, collect_outputs): if is_training: - forward_and_backward_pass(mod, inputs, contexts, optimizer, collect_outputs) + return forward_and_backward_pass(mod, inputs, contexts, optimizer, collect_outputs) else: - forward_pass(mod, inputs, contexts, collect_outputs) + return forward_pass(mod, inputs, contexts, collect_outputs) for _ in range(iterations - 1): _model_iter_fn(mod, inputs, contexts, optimizer, collect_outputs=False) return _model_iter_fn(mod, inputs, contexts, optimizer, collect_outputs=True)