Skip to content

Commit

Permalink
Rename the input samples and number of input samples
Browse files Browse the repository at this point in the history
Summary:
Rename the term "batch" to "input". We will stick to the following behavior:
1. "num_inputs" refers to the number of sample inputs of the benchmark. For example, "--num-inputs 4" runs 4 different inputs in the benchmark.
2. "input_id" refers to the i-th sample input. For example, "--input-id 5" starts from the 5-th input returned by the input generator.
3. When combined together, e.g., "--input-id 5 --num-inputs 4", we start from the 5-th input and runs 4 different inputs from there. i.e., we run the 5,6,7,8-th inputs returned by the input generator.

We should clarify the usage of term "batch", it only refers to the "B" dimension of the input tensors in batched opeartors, such as batched matmul or batched attention.

Reviewed By: sijiac

Differential Revision: D56534046

fbshipit-source-id: 4a7ea4b48881679b18d6137545b0e3c81f717abf
  • Loading branch information
xuzhao9 committed Apr 25, 2024
1 parent 70e476b commit 79c11d9
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 48 deletions.
2 changes: 0 additions & 2 deletions torchbenchmark/operators/addmm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,9 @@ class Operator(BenchmarkOperator):
def __init__(self, mode: str, device: str, extra_args: List[str] = []):
super().__init__(mode=mode, device=device, extra_args=extra_args)
if not self.extra_args:
self.DEFAULT_NUM_BATCH = len(BUILDIN_SHAPES)
self.shapes = BUILDIN_SHAPES
else:
self.shapes = [(self.tb_args.m, self.tbargs.k, self.tbargs.n)]
self.DEFAULT_NUM_BATCH = len(self.shapes)

@register_benchmark()
def triton_addmm(self, a, mat1, mat2) -> Callable:
Expand Down
1 change: 0 additions & 1 deletion torchbenchmark/operators/grouped_gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

class Operator(BenchmarkOperator):
DEFAULT_PRECISION = "fp16"
DEFAULT_NUM_BATCH = 4
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]

@register_benchmark(baseline=True)
Expand Down
121 changes: 76 additions & 45 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,22 +298,43 @@ def _inner(self, *args, **kwargs):


def parse_args(
default_metrics: List[str], args: List[str]
default_metrics: List[str],
args: List[str],
) -> Tuple[argparse.Namespace, List[str]]:
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument(
"--metrics",
default=",".join(default_metrics),
help="Metrics to collect, split with comma. E.g., --metrics latency,tflops,speedup.",
)
parser.add_argument("--only", default=None, help="Specify one or multiple operator implementations to run.")
parser.add_argument(
"--batch-id", type=int, default=None, help="Run only the specific batch id."
"--only",
default=None,
help="Specify one or multiple operator implementations to run."
)
parser.add_argument(
"--num-inputs",
type=int,
help="Number of example inputs.",
)
parser.add_argument(
"--input-id",
type=int,
default=0,
help="Specify the start input id to run. " \
"For example, --input-id 0 runs only the first available input sample." \
"When used together like --input-id <X> --num-inputs <Y>, start from the input id <X> " \
"and run <Y> different inputs."
)
return parser.parse_known_args(args)

class PostInitProcessor(type):
def __call__(cls, *args, **kwargs):
obj = type.__call__(cls, *args, **kwargs)
obj.__post__init__()
return obj

class BenchmarkOperator:
class BenchmarkOperator(metaclass=PostInitProcessor):
mode: Mode = Mode.FWD
test: str = "eval"
device: str = "cuda"
Expand All @@ -324,8 +345,6 @@ class BenchmarkOperator:
# By default, only collect latency metrics
# Each operator can override to define their own default metrics
DEFAULT_METRICS = ["latency"]
# By default, generate 100 data points
DEFAULT_NUM_BATCH = 100

"""
A base class for adding operators to torch benchmark.
Expand All @@ -349,16 +368,24 @@ def __init__(self, mode: str, device: str, extra_args: List[str] = []):
self.dargs, unprocessed_args = parse_decoration_args(self, extra_args)
# This will be changed by the time we apply the decoration args
self.dtype = PRECISION_DTYPE_MAPPING.get(self.dargs.precision, None)
if self.dargs.num_batch is None:
self.dargs.num_batch = self.DEFAULT_NUM_BATCH
self.DEFAULT_METRICS.extend(REGISTERED_METRICS.get(self.name, []))
self.DEFAULT_METRICS.extend(
[x for x in REGISTERED_METRICS.get(self.name, []) if x not in BUILTIN_METRICS]
)
self.DEFAULT_METRICS = list(set(self.DEFAULT_METRICS))
self.tb_args, self.extra_args = parse_args(
self.DEFAULT_METRICS, unprocessed_args
self.DEFAULT_METRICS,
unprocessed_args
)
self.required_metrics = list(set(self.tb_args.metrics.split(",")))
self._only = _split_params_by_comma(self.tb_args.only)
self._batch_id = self.tb_args.batch_id
self._input_id = self.tb_args.input_id
self._num_inputs = self.tb_args.num_inputs

# Run the post initialization
def __post__init__(self):
self._available_num_inputs = self.count_example_inputs()
if self._num_inputs is None:
self._num_inputs = self._available_num_inputs - self._input_id + 1

def _get_bm_func(self, bm_func_name: str):
fwd_fn_lambda = getattr(self, bm_func_name, None)
Expand Down Expand Up @@ -388,21 +415,18 @@ def run(
) -> BenchmarkOperatorResult:
"""Benchmarking the operator and returning its metrics."""
metrics = []
if self._batch_id is not None:
# Run only the user-specific batch id
batch_range = range(self._batch_id + 1)
else:
batch_range = range(self.dargs.num_batch)
input_id_range = range(self._input_id, self._input_id+self._num_inputs)
if tqdm is not None:
batch_range = tqdm(batch_range)
for batch_id in batch_range:
if self._batch_id and batch_id < self._batch_id:
continue
input_id_range = tqdm(input_id_range)
if self._input_id:
for _dryrun_input_id in range(self._input_id):
self.example_inputs = self.get_example_inputs()
for input_id in input_id_range:
self.example_inputs = self.get_example_inputs()
if self.example_inputs is None:
warnings.warn(
UserWarning(
f"The input generator get_input_iter() has depleted. Maximum input batches {batch_id}."
f"The input generator get_input_iter() has depleted at id {input_id}. Available number of inputs: {self._available_num_inputs}."
)
)
break
Expand Down Expand Up @@ -444,7 +468,7 @@ def _reduce_benchmarks(acc, bm_name: str):
else False
)
acc[bm_name] = self._do_bench(
batch_id=batch_id,
input_id=input_id,
fn_name=bm_name,
warmup=warmup,
rep=rep,
Expand Down Expand Up @@ -559,6 +583,9 @@ def enable_channels_last(self):
tensor_cond, tensor_action, self.example_inputs
)

def count_example_inputs(self):
return sum(1 for _ in self.get_input_iter())

def get_example_inputs(self):
if self._input_iter is None:
self._input_iter = self.get_input_iter()
Expand Down Expand Up @@ -589,7 +616,7 @@ def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:

def _do_bench(
self,
batch_id: int,
input_id: int,
fn_name: str,
warmup=DEFAULT_WARMUP,
rep=DEFAULT_RUN_ITERS,
Expand Down Expand Up @@ -666,32 +693,32 @@ def _do_bench(
if "tflops" in self.required_metrics:
metric.tflops = self.tflops(fn_name, self.example_inputs, metric)
if "compile_time" in self.required_metrics:
metric.compile_time = self.compile_time(batch_id, fn_name, metric)
metric.compile_time = self.compile_time(input_id, fn_name, metric)
if "ncu_trace" in self.required_metrics:
metric.ncu_trace = self.ncu_trace(batch_id, fn_name)
metric.ncu_trace = self.ncu_trace(input_id, fn_name)
if "kineto_trace" in self.required_metrics:
metric.kineto_trace = self.kineto_trace(batch_id, fn)
metric.kineto_trace = self.kineto_trace(input_id, fn)
extra_metrics = {}
# run the hidden metric "_compile_time_in_task"
# to get the compile time in parent process
if "_compile_time_in_task" in self.required_metrics:
assert (
self.required_metrics == ["_compile_time_in_task"]
and self._only
and (self._batch_id is not None)
and len(self._only) == 1
and (self._input_id is not None)
), (
"_compile_time_in_task must be measured by itself. "
f"required_metrics: {self.required_metrics}, _only: {self._only}, _batch_id: {self._batch_id}"
f"required_metrics: {self.required_metrics}, _only: {self._only}, _input_id: {self._input_id}"
)
extra_metrics["_compile_time_in_task"] = self._compile_time_in_task(fn)
if "_ncu_trace_in_task" in self.required_metrics:
assert (
self.required_metrics == ["_ncu_trace_in_task"]
and self._only
and (self._batch_id is not None)
and len(self._only) == 1
and (self._input_id is not None)
), (
"_ncu_trace_in_task must be measured by itself. "
f"required_metrics: {self.required_metrics}, _only: {self._only}, _batch_id: {self._batch_id}"
f"required_metrics: {self.required_metrics}, _only: {self._only}, _input_id: {self._input_id}"
)
from torchbenchmark._components.ncu import do_bench_ncu_in_task

Expand Down Expand Up @@ -720,7 +747,7 @@ def _do_bench(
walltime=None,
compile_time=None,
ncu_trace=None,
hw_roofline=self.hw_roofline(),
hw_roofline=self.hw_roofline() if "hw_roofline" in self.required_metrics else None,
kineto_trace=None,
cpu_peak_mem=None,
gpu_peak_mem=None,
Expand All @@ -740,23 +767,25 @@ def get_peak_mem(
)

@register_metric()
def ncu_trace(self, batch_id: int, fn_name: str) -> str:
def ncu_trace(self, input_id: int, fn_name: str) -> str:
# collect the ncu trace
import sys
import subprocess
from pathlib import Path

op_task_args = copy.deepcopy(sys.argv)
for override_option in ["--only", "--batch-id", "--metrics"]:
for override_option in ["--only", "--input-id", "--num-inputs", "--metrics"]:
op_task_args = _remove_params(
op_task_args, _find_param_loc(op_task_args, override_option)
)
op_task_args.extend(
[
"--only",
fn_name,
"--batch-id",
str(batch_id),
"--num-inputs",
str(1),
"--input-id",
str(input_id),
"--metrics",
"_ncu_trace_in_task",
]
Expand All @@ -775,7 +804,7 @@ def ncu_trace(self, batch_id: int, fn_name: str) -> str:
warnings.warn(
"Cannot find dyno to disable DCGM. Proceed to collect NCU Trace."
)
ncu_output_dir = Path(f"/tmp/tritonbench_{self.name}_{fn_name}_{batch_id}")
ncu_output_dir = Path(f"/tmp/tritonbench_{self.name}_{fn_name}_{input_id}")
ncu_output_dir.mkdir(parents=True, exist_ok=True)
ncu_output_file = ncu_output_dir.joinpath("ncu_output.csv").resolve()
ncu_args = [
Expand All @@ -796,11 +825,11 @@ def ncu_trace(self, batch_id: int, fn_name: str) -> str:
return str(ncu_output_file.resolve())

@register_metric()
def kineto_trace(self, batch_id: int, fn: Callable) -> str:
def kineto_trace(self, input_id: int, fn: Callable) -> str:
from pathlib import Path
from torchbenchmark._components.kineto import do_bench_kineto

kineto_output_dir = Path(f"/tmp/tritonbench_{self.name}_{fn._name}_{batch_id}")
kineto_output_dir = Path(f"/tmp/tritonbench_{self.name}_{fn._name}_{input_id}")
kineto_output_dir.mkdir(parents=True, exist_ok=True)
return do_bench_kineto(
fn=fn,
Expand All @@ -810,23 +839,25 @@ def kineto_trace(self, batch_id: int, fn: Callable) -> str:

@register_metric()
def compile_time(
self, batch_id: int, fn_name: str, metrics: BenchmarkOperatorMetrics
self, input_id: int, fn_name: str, metrics: BenchmarkOperatorMetrics
) -> float:
# We need to spawn a subprocess when user wants to measure the compile time
# of multiple batches and backends.
# of multiple sample inputs and backends.
from torchbenchmark.operators.op_task import OpTask

op_task_args = copy.deepcopy(self._raw_extra_args)
for override_option in ["--only", "--batch-id", "--metrics"]:
for override_option in ["--only", "--input-id", "--num-inputs", "--metrics"]:
op_task_args = _remove_params(
op_task_args, _find_param_loc(op_task_args, override_option)
)
op_task_args.extend(
[
"--only",
fn_name,
"--batch-id",
str(batch_id),
"--num-inputs",
str(1),
"--input-id",
str(input_id),
"--metrics",
"_compile_time_in_task",
]
Expand Down

0 comments on commit 79c11d9

Please sign in to comment.