Skip to content

Commit

Permalink
Add decorator to rename the x_val label
Browse files Browse the repository at this point in the history
Summary: Add decorator `register_x_val(label: str)` to rename the the `x_val` label.

Reviewed By: bertmaher

Differential Revision: D57166423

fbshipit-source-id: 61851eb421e64fe4f30b34eb8f72e84be33a54db
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed May 10, 2024
1 parent 3f8bb6a commit 06eb98c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
2 changes: 2 additions & 0 deletions torchbenchmark/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
BenchmarkOperatorMetrics,
register_benchmark,
register_metric,
register_x_val,
)

from .data_io import parse_args, read_shapes_from_csv
Expand Down Expand Up @@ -141,6 +142,7 @@ def colfax_cutlass_matmul(self, a, b, bias) -> Callable:
else:
return lambda: colfax_gemm(a, b, alpha=1.0, beta=1.0)

@register_x_val(label="(M, N, K)")
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
# x-value: computation intensity
a, w, bias = example_inputs
Expand Down
13 changes: 12 additions & 1 deletion torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
DEFAULT_QUANTILES = [0.5, 0.1, 0.9]
REGISTERED_BENCHMARKS: Dict[str, List[str]] = {}
REGISTERED_METRICS: Dict[str, List[str]] = {}
REGISTERED_X_VALS: Dict[str, str] = {}
BASELINE_BENCHMARKS: Dict[str, str] = {}
BUILTIN_METRICS = [
"latency",
Expand Down Expand Up @@ -156,7 +157,7 @@ class BenchmarkOperatorResult:
def _table(self):
table = []
# generate headers
headers = ["x_val"]
headers = [REGISTERED_X_VALS[self.op_name]]
y_val = self.result[0][1]
y_val_keys = list(y_val.keys())
# move the baseline benchmark to the front of the list if exists
Expand Down Expand Up @@ -263,6 +264,14 @@ def __str__(self):
table = tabulate.tabulate(table, headers=headers, stralign="right")
return table

def register_x_val(label: str="x_val"):
def decorator(function):
operator_name = _find_op_name_from_module_path(function.__module__)
REGISTERED_X_VALS[operator_name] = label
def _inner(self, *args, **kwargs):
return function(self, *args, **kwargs)
return _inner
return decorator

def register_benchmark(baseline: bool = False, enabled: bool = True):
def decorator(function):
Expand Down Expand Up @@ -378,6 +387,8 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None)
), f"We only accept 3 test modes: fwd(eval), fwd_bwd(train), or bwd."
self.mode = Mode.BWD
self.dargs, unprocessed_args = parse_decoration_args(self, extra_args)
if self.name not in REGISTERED_X_VALS:
REGISTERED_X_VALS[self.name] = "x_val"
# This will be changed by the time we apply the decoration args
self.dtype = PRECISION_DTYPE_MAPPING.get(self.dargs.precision, None)
self.DEFAULT_METRICS.extend(
Expand Down

0 comments on commit 06eb98c

Please sign in to comment.