diff --git a/torchbenchmark/util/env_check.py b/torchbenchmark/util/env_check.py index 9784e2a040..fcacf4c399 100644 --- a/torchbenchmark/util/env_check.py +++ b/torchbenchmark/util/env_check.py @@ -359,7 +359,7 @@ def forward_pass(mod, inputs, contexts, _collect_outputs=True): return mod(**inputs) else: return mod(*inputs) - + def forward_and_backward_pass(mod, inputs, contexts, optimizer, collect_outputs=True): cloned_inputs = clone_inputs(inputs) diff --git a/torchbenchmark/util/experiment/metrics.py b/torchbenchmark/util/experiment/metrics.py index 587abc8980..b5e910a9e5 100644 --- a/torchbenchmark/util/experiment/metrics.py +++ b/torchbenchmark/util/experiment/metrics.py @@ -149,4 +149,6 @@ def get_model_accuracy(model_config: TorchBenchModelConfig, isolated: bool=True) return accuracy else: model = load_model(accuracy_model_config) - return model.accuracy + accuracy = model.accuracy + del model + return accuracy