Skip to content

Commit

Permalink
Bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed May 10, 2024
1 parent 0c984c5 commit 8ab53be
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions torchbenchmark/util/experiment/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def work_func():

def get_model_test_metrics(
model: Union[BenchmarkModel, ModelTask],
metrics=[],
required_metrics=[],
export_metrics_file=False,
metrics_gpu_backend="nvml",
nwarmup=WARMUP_ROUNDS,
Expand Down Expand Up @@ -172,36 +172,36 @@ def get_model_test_metrics(
if isinstance(model, BenchmarkModel)
else model.get_model_attribute("device")
)
if "latencies" in metrics or "throughputs" in metrics:
if "latencies" in required_metrics or "throughputs" in required_metrics:
metrics.latencies = get_latencies(
model.invoke, device, nwarmup=nwarmup, num_iter=num_iter
)
if "cpu_peak_mem" in metrics or "gpu_peak_mem" in metrics:
if "cpu_peak_mem" in required_metrics or "gpu_peak_mem" in required_metrics:
metrics.cpu_peak_mem, _device_id, metrics.gpu_peak_mem = get_peak_memory(
model.invoke,
device,
export_metrics_file=export_metrics_file,
metrics_needed=metrics,
metrics_needed=required_metrics,
metrics_gpu_backend=metrics_gpu_backend,
cpu_monitored_pid=model_pid,
)
if "throughputs" in metrics:
if "throughputs" in required_metrics:
metrics.throughputs = [model.batch_size * 1000 / latency for latency in metrics.latencies]
if "pt2_compilation_time" in metrics:
if "pt2_compilation_time" in required_metrics:
metrics.pt2_compilation_time = (
model.get_model_attribute("pt2_compilation_time")
if isinstance(model, ModelTask)
else model.pt2_compilation_time
)
if "pt2_graph_breaks" in metrics:
if "pt2_graph_breaks" in required_metrics:
metrics.pt2_graph_breaks = (
model.get_model_attribute("pt2_graph_breaks")
if isinstance(model, ModelTask)
else model.pt2_graph_breaks
)
if "model_flops" in metrics:
if "model_flops" in required_metrics:
metrics.model_flops = get_model_flops(model)
if "ttfb" in metrics:
if "ttfb" in required_metrics:
metrics.ttfb = (
model.get_model_attribute("ttfb")
if isinstance(model, ModelTask)
Expand Down Expand Up @@ -270,7 +270,7 @@ def run_config(config: TorchBenchModelConfig,
from torchbenchmark.util.experiment.instantiator import (
load_model_isolated,
)
model_task = load_model_isolated(config.name)
model_task = load_model_isolated(config)
metrics = get_model_test_metrics(model_task, metrics=required_metrics)
if "accuracy" in required_metrics:
metrics.accuracy = accuracy
Expand Down

0 comments on commit 8ab53be

Please sign in to comment.