diff --git a/torchbenchmark/operators/layer_norm/__init__.py b/torchbenchmark/operators/layer_norm/__init__.py index 35513e5c6d..c4660b8511 100644 --- a/torchbenchmark/operators/layer_norm/__init__.py +++ b/torchbenchmark/operators/layer_norm/__init__.py @@ -4,29 +4,39 @@ from torchbenchmark.util.triton_op import ( BenchmarkOperator, BenchmarkOperatorMetrics, + Mode, register_benchmark, register_metric, ) +from typing import Callable + from . import tutorial class Operator(BenchmarkOperator): @register_benchmark() - def layer_norm_fwd(self, *args): + def triton_layer_norm(self, *args): return lambda: tutorial.layer_norm(*args) @register_benchmark(baseline=True) - def layer_norm_fwd_baseline(self, *args): + def torch_layer_norm(self, *args): return lambda: F.layer_norm(*args) + def get_bwd_fn(self, fwd_fn: Callable) -> Callable: + y = fwd_fn() + dy = 0.1 * torch.randn_like(y) + return lambda: y.backward(dy, retain_graph=True) + def get_input_iter(self): M = 4096 eps = 1e-5 for N in [512 * i for i in range(2, 32)]: x_shape = (M, N) w_shape = (x_shape[-1],) - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=self.dtype, device="cuda") + x = -2.3 + 0.5 * torch.randn( + x_shape, dtype=self.dtype, device="cuda", requires_grad=True + ) weight = torch.rand( w_shape, dtype=self.dtype, device="cuda", requires_grad=True ) @@ -44,7 +54,12 @@ def gbps(self, fn_name, args, metrics: BenchmarkOperatorMetrics) -> float: x = args[0] def gbps(ms): - return 2 * x.numel() * x.element_size() / ms * 1e-6 + base = x.numel() * x.element_size() / ms * 1e-6 + return { + Mode.FWD: 2 * base, + Mode.BWD: 3 * base, + Mode.FWD_BWD: 5 * base, + }[self.mode] return list(map(gbps, metrics.latency)) @@ -55,12 +70,12 @@ def plot(self): x_vals=self.output.x_vals, line_arg="provider", line_vals=[ - "layer_norm_fwd", - "layer_norm_fwd_baseline", + "triton_layer_norm", + "torch_layer_norm", ], line_names=[ - "Triton", - "Torch", + "triton_layer_norm", + "torch_layer_norm", ], styles=[("blue", "-"), ("green", "-")], ylabel="GB/s",