Skip to content

Commit

Permalink
Backward kernel for layernorm
Browse files Browse the repository at this point in the history
Summary: as title

Reviewed By: chenyang78, sijiac

Differential Revision: D55819296

fbshipit-source-id: dbaacbee61d342d6f2ce38b02d78cd7ed0198b75
  • Loading branch information
bertmaher authored and facebook-github-bot committed Apr 8, 2024
1 parent 37dc5a4 commit 1b2a583
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions torchbenchmark/operators/layer_norm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,39 @@
from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Mode,
register_benchmark,
register_metric,
)

from typing import Callable

from . import tutorial


class Operator(BenchmarkOperator):
@register_benchmark()
def layer_norm_fwd(self, *args):
def triton_layer_norm(self, *args):
return lambda: tutorial.layer_norm(*args)

@register_benchmark(baseline=True)
def layer_norm_fwd_baseline(self, *args):
def torch_layer_norm(self, *args):
return lambda: F.layer_norm(*args)

def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
y = fwd_fn()
dy = 0.1 * torch.randn_like(y)
return lambda: y.backward(dy, retain_graph=True)

def get_input_iter(self):
M = 4096
eps = 1e-5
for N in [512 * i for i in range(2, 32)]:
x_shape = (M, N)
w_shape = (x_shape[-1],)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=self.dtype, device="cuda")
x = -2.3 + 0.5 * torch.randn(
x_shape, dtype=self.dtype, device="cuda", requires_grad=True
)
weight = torch.rand(
w_shape, dtype=self.dtype, device="cuda", requires_grad=True
)
Expand All @@ -44,7 +54,12 @@ def gbps(self, fn_name, args, metrics: BenchmarkOperatorMetrics) -> float:
x = args[0]

def gbps(ms):
return 2 * x.numel() * x.element_size() / ms * 1e-6
base = x.numel() * x.element_size() / ms * 1e-6
return {
Mode.FWD: 2 * base,
Mode.BWD: 3 * base,
Mode.FWD_BWD: 5 * base,
}[self.mode]

return list(map(gbps, metrics.latency))

Expand All @@ -55,12 +70,12 @@ def plot(self):
x_vals=self.output.x_vals,
line_arg="provider",
line_vals=[
"layer_norm_fwd",
"layer_norm_fwd_baseline",
"triton_layer_norm",
"torch_layer_norm",
],
line_names=[
"Triton",
"Torch",
"triton_layer_norm",
"torch_layer_norm",
],
styles=[("blue", "-"), ("green", "-")],
ylabel="GB/s",
Expand Down

0 comments on commit 1b2a583

Please sign in to comment.