diff --git a/torchbenchmark/util/experiment/metrics.py b/torchbenchmark/util/experiment/metrics.py index e63431e4a1..9c05ce926b 100644 --- a/torchbenchmark/util/experiment/metrics.py +++ b/torchbenchmark/util/experiment/metrics.py @@ -141,7 +141,7 @@ def work_func(): def get_model_test_metrics( model: Union[BenchmarkModel, ModelTask], - metrics=[], + required_metrics=[], export_metrics_file=False, metrics_gpu_backend="nvml", nwarmup=WARMUP_ROUNDS, @@ -172,36 +172,36 @@ def get_model_test_metrics( if isinstance(model, BenchmarkModel) else model.get_model_attribute("device") ) - if "latencies" in metrics or "throughputs" in metrics: + if "latencies" in required_metrics or "throughputs" in required_metrics: metrics.latencies = get_latencies( model.invoke, device, nwarmup=nwarmup, num_iter=num_iter ) - if "cpu_peak_mem" in metrics or "gpu_peak_mem" in metrics: + if "cpu_peak_mem" in required_metrics or "gpu_peak_mem" in required_metrics: metrics.cpu_peak_mem, _device_id, metrics.gpu_peak_mem = get_peak_memory( model.invoke, device, export_metrics_file=export_metrics_file, - metrics_needed=metrics, + metrics_needed=required_metrics, metrics_gpu_backend=metrics_gpu_backend, cpu_monitored_pid=model_pid, ) - if "throughputs" in metrics: + if "throughputs" in required_metrics: metrics.throughputs = [model.batch_size * 1000 / latency for latency in metrics.latencies] - if "pt2_compilation_time" in metrics: + if "pt2_compilation_time" in required_metrics: metrics.pt2_compilation_time = ( model.get_model_attribute("pt2_compilation_time") if isinstance(model, ModelTask) else model.pt2_compilation_time ) - if "pt2_graph_breaks" in metrics: + if "pt2_graph_breaks" in required_metrics: metrics.pt2_graph_breaks = ( model.get_model_attribute("pt2_graph_breaks") if isinstance(model, ModelTask) else model.pt2_graph_breaks ) - if "model_flops" in metrics: + if "model_flops" in required_metrics: metrics.model_flops = get_model_flops(model) - if "ttfb" in metrics: + if "ttfb" in required_metrics: metrics.ttfb = ( model.get_model_attribute("ttfb") if isinstance(model, ModelTask) @@ -270,7 +270,7 @@ def run_config(config: TorchBenchModelConfig, from torchbenchmark.util.experiment.instantiator import ( load_model_isolated, ) - model_task = load_model_isolated(config.name) + model_task = load_model_isolated(config) metrics = get_model_test_metrics(model_task, metrics=required_metrics) if "accuracy" in required_metrics: metrics.accuracy = accuracy