Skip to content

Commit

Permalink
Add ttfb to available metrics
Browse files Browse the repository at this point in the history
Summary: Add ttfb to available metrics

Reviewed By: dshi7

Differential Revision: D50225322

fbshipit-source-id: e1814224332bcd2c9345653dceb325fca3571646
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Oct 13, 2023
1 parent c744f0c commit 7dc8e81
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 3 deletions.
6 changes: 4 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def printResultSummaryTime(result_summary, model, metrics_needed=[], flops_model
flops = model.get_flops()
tflops = flops / (cpu_walltime / 1.0e3) / 1.0e12
print('{:<20} {:>20}'.format("GPU FLOPS:", "%.4f TFLOPs per second" % tflops, sep=''))
if 'ttfb' in metrics_needed:
print('{:<20} {:>20}'.format("Time to first batch:", "%.4f ms" % model.ttfb, sep=''))
if model_flops is not None:
tflops = model_flops / (cpu_walltime / 1.0e3) / 1.0e12
print('{:<20} {:>20}'.format("Model Flops:", "%.4f TFLOPs per second" % tflops, sep=''))
Expand Down Expand Up @@ -356,8 +358,8 @@ def _validate_profile_options(profile_options: str):
parser.add_argument(
"--metrics",
type=str,
default="cpu_peak_mem,gpu_peak_mem",
help="Specify metrics [cpu_peak_mem,gpu_peak_mem,flops,model_flops]to be collected. You can also set `none` to disable all metrics. The metrics are separated by comma such as cpu_peak_mem,gpu_peak_mem.",
default="cpu_peak_mem,gpu_peak_mem,ttfb",
help="Specify metrics [cpu_peak_mem,gpu_peak_mem,ttfb,flops,model_flops]to be collected. You can also set `none` to disable all metrics. The metrics are separated by comma such as cpu_peak_mem,gpu_peak_mem.",
)
parser.add_argument(
"--metrics-gpu-backend",
Expand Down
7 changes: 6 additions & 1 deletion torchbenchmark/util/experiment/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class TorchBenchModelMetrics:
throughputs: List[float]
cpu_peak_mem: Optional[float]
gpu_peak_mem: Optional[float]
ttfb: Optional[float] # time-to-first-batch
pt2_compilation_time: Optional[float]
pt2_graph_breaks: Optional[float]
model_flops: Optional[float]
Expand Down Expand Up @@ -112,6 +113,7 @@ def get_model_test_metrics(model: Union[BenchmarkModel, ModelTask], metrics=[],
throughputs = None
cpu_peak_mem = None
gpu_peak_mem = None
ttfb = None
pt2_compilation_time = None
pt2_graph_breaks = None
model_flops = None
Expand All @@ -133,7 +135,10 @@ def get_model_test_metrics(model: Union[BenchmarkModel, ModelTask], metrics=[],
if isinstance(model, ModelTask) else model.pt2_graph_breaks
if 'model_flops' in metrics:
model_flops = get_model_flops(model)
return TorchBenchModelMetrics(latencies, throughputs, cpu_peak_mem, gpu_peak_mem, pt2_compilation_time, pt2_graph_breaks, model_flops)
if 'ttfb' in metrics:
ttfb = model.get_model_attribute('ttfb') \
if isinstance(model, ModelTask) else model.ttfb
return TorchBenchModelMetrics(latencies, throughputs, cpu_peak_mem, gpu_peak_mem, ttfb, pt2_compilation_time, pt2_graph_breaks, model_flops)

def get_model_accuracy(model_config: TorchBenchModelConfig, isolated: bool=True) -> str:
import copy
Expand Down
8 changes: 8 additions & 0 deletions torchbenchmark/util/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import contextmanager, ExitStack
import warnings
import yaml
import time
from pathlib import Path
from typing import ContextManager, Optional, List, Tuple, Generator
from torchbenchmark import REPO_PATH
Expand Down Expand Up @@ -79,6 +80,7 @@ class BenchmarkModel(metaclass=PostInitProcessor):
See [Adding Models](#../models/ADDING_MODELS.md)
"""
def __init__(self, test: str, device: str, batch_size: Optional[int]=None, extra_args: List[str]=[]):
self._start_init_time = time.time_ns()
self.metadata = self._load_metadata()
self.test = test
# sanity checks of the options
Expand Down Expand Up @@ -149,6 +151,7 @@ def __post__init__(self):
# Need to clean up the cache because we run deep copy within correceness check
if self.device == "cuda":
torch.cuda.empty_cache()
self._end_init_time = time.time_ns()

def _skip_by_device_name(self):
if not self.device == "cuda":
Expand Down Expand Up @@ -390,3 +393,8 @@ def pt2_graph_breaks(self):
from torch._dynamo.utils import counters
num_graph_breaks = len(counters["graph_break"].keys())
return num_graph_breaks

@property
def ttfb(self):
"""Return the time taken to the first batch in ms."""
return (self._end_init_time - self._start_init_time) / 1_000_000

0 comments on commit 7dc8e81

Please sign in to comment.