diff --git a/torchbenchmark/operators/gemm/operator.py b/torchbenchmark/operators/gemm/operator.py index fc231cc3fd..7f51707b16 100644 --- a/torchbenchmark/operators/gemm/operator.py +++ b/torchbenchmark/operators/gemm/operator.py @@ -13,6 +13,7 @@ BenchmarkOperatorMetrics, register_benchmark, register_metric, + register_x_val, ) from .data_io import parse_args, read_shapes_from_csv @@ -141,6 +142,7 @@ def colfax_cutlass_matmul(self, a, b, bias) -> Callable: else: return lambda: colfax_gemm(a, b, alpha=1.0, beta=1.0) + @register_x_val(label="(M, N, K)") def get_x_val(self, example_inputs) -> Tuple[int, int, int]: # x-value: computation intensity a, w, bias = example_inputs diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py index c79f60e40e..bbd7f60f64 100644 --- a/torchbenchmark/util/triton_op.py +++ b/torchbenchmark/util/triton_op.py @@ -29,6 +29,7 @@ DEFAULT_QUANTILES = [0.5, 0.1, 0.9] REGISTERED_BENCHMARKS: Dict[str, List[str]] = {} REGISTERED_METRICS: Dict[str, List[str]] = {} +REGISTERED_X_VALS: Dict[str, str] = {} BASELINE_BENCHMARKS: Dict[str, str] = {} BUILTIN_METRICS = [ "latency", @@ -156,7 +157,7 @@ class BenchmarkOperatorResult: def _table(self): table = [] # generate headers - headers = ["x_val"] + headers = [REGISTERED_X_VALS[self.op_name]] y_val = self.result[0][1] y_val_keys = list(y_val.keys()) # move the baseline benchmark to the front of the list if exists @@ -263,6 +264,14 @@ def __str__(self): table = tabulate.tabulate(table, headers=headers, stralign="right") return table +def register_x_val(label: str="x_val"): + def decorator(function): + operator_name = _find_op_name_from_module_path(function.__module__) + REGISTERED_X_VALS[operator_name] = label + def _inner(self, *args, **kwargs): + return function(self, *args, **kwargs) + return _inner + return decorator def register_benchmark(baseline: bool = False, enabled: bool = True): def decorator(function): @@ -378,6 +387,8 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None) ), f"We only accept 3 test modes: fwd(eval), fwd_bwd(train), or bwd." self.mode = Mode.BWD self.dargs, unprocessed_args = parse_decoration_args(self, extra_args) + if self.name not in REGISTERED_X_VALS: + REGISTERED_X_VALS[self.name] = "x_val" # This will be changed by the time we apply the decoration args self.dtype = PRECISION_DTYPE_MAPPING.get(self.dargs.precision, None) self.DEFAULT_METRICS.extend(