Skip to content

Commit

Permalink
Add FusedLinearCrossEntropy (#2485)
Browse files Browse the repository at this point in the history
Summary:
As discussed in pytorch/pytorch#136168, I'm going to migrate implementations of operator benchmarking. This PR adds different implementations for FusedLinearCrossEntropy as a starting example.

Execution command:
```
python run_benchmark.py triton --op FusedLinearCrossEntropy
```
Example output:
```
x_val    LMHeadCE-latency    LigerLMHeadCE-latency    inductor_fused_linear_cross_entropy-latency
-------  ------------------  -----------------------  ---------------------------------------------
      0             98.0041                  389.87                                         95.0412
      1            196.12                    652.619                                       193.219
      2            417.242                  1248.75                                        416.725
      3            824.906                  2356.25                                        809.56
```

Pull Request resolved: #2485

Reviewed By: xuzhao9

Differential Revision: D63859871

Pulled By: FindHao

fbshipit-source-id: 4b73a2144702c1f8f3ae5ed15e76112d03f12b87
  • Loading branch information
FindHao authored and facebook-github-bot committed Oct 4, 2024
1 parent a1f4b2e commit dde8528
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 0 deletions.
12 changes: 12 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[build-system]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"


[tool.black]
line-length = 88
target-version = ["py38"]
exclude = '''/submodules/.*'''

[tool.usort]
excludes = ["**/submodules/**"]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .operator import Operator
108 changes: 108 additions & 0 deletions torchbenchmark/operators/FusedLinearCrossEntropy/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import argparse
from typing import Callable, Generator, List, Optional

import torch

from torchbenchmark.util.triton_op import BenchmarkOperator, register_benchmark

try:
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
except ModuleNotFoundError:
LigerFusedLinearCrossEntropyLoss = None

# Reference: https://github.com/linkedin/Liger-Kernel/blob/\
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py


def parse_op_args(args: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument("--hidden-size", type=int, default=4096, help="hidden size")
parser.add_argument("--vocab-size", type=int, default=128256, help="vocab size")
return parser.parse_args(args)


class TorchLMHeadCE(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based cross entropy loss.
:param H: hidden size
:param V: vocab size
:param ignore_index: index to ignore
:param reduction: reduction method
"""

def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.ce_loss = torch.nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction="mean"
)

def forward(self, input, target):
logits = self.lin(input)
return self.ce_loss(logits, target)


class LigerLMHeadCE(torch.nn.Module):
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.ce_loss = LigerFusedLinearCrossEntropyLoss(
ignore_index=ignore_index, reduction="mean"
)

def forward(self, input, target):
return self.ce_loss(self.lin.weight, input, target)


class Operator(BenchmarkOperator):
def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
op_args = parse_op_args(self.extra_args)
self.hidden_size = op_args.hidden_size
self.vocab_size = op_args.vocab_size
self.baseline_model = TorchLMHeadCE(
H=self.hidden_size, V=self.vocab_size, dtype=self.dtype
).to(self.device)
self.liger_model = LigerLMHeadCE(
H=self.hidden_size, V=self.vocab_size, dtype=self.dtype
).to(self.device)
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for BT in [2**i for i in range(12, 16)]:
_input = torch.randn(
BT,
self.hidden_size,
requires_grad=True,
dtype=self.dtype,
device=self.device,
)
target = torch.randint(
self.vocab_size, (BT, 1), dtype=torch.long, device=self.device
).squeeze(1)
yield _input, target

@register_benchmark(baseline=True)
def LMHeadCE(self, input, target) -> Callable:
return lambda: self.baseline_model(input, target)

@register_benchmark()
def LigerLMHeadCE(self, input, target) -> Callable:
return lambda: self.liger_model(input, target)

@register_benchmark()
def inductor_fused_linear_cross_entropy(self, input, target) -> Callable:
compiled = torch.compile(self.baseline_model, dynamic=False)
return lambda: compiled(input, target)

def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
y = fwd_fn()
return lambda: y.backward(retain_graph=True)
10 changes: 10 additions & 0 deletions userbenchmark/triton/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def install_fa3():
subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve()))


def install_liger():
# Liger-kernel has a conflict dependency `triton` with pytorch,
# so we need to install it without dependencies
cmd = ["pip", "install", "liger-kernel", "--no-deps"]
subprocess.check_call(cmd)


def install_tk():
try:
from .tk.install import install_tk
Expand All @@ -88,6 +95,7 @@ def install_tk():
)
parser.add_argument("--jax", action="store_true", help="Install jax nightly")
parser.add_argument("--tk", action="store_true", help="Install ThunderKittens")
parser.add_argument("--liger", action="store_true", help="Install Liger-kernel")
parser.add_argument("--test", action="store_true", help="Run test")
args = parser.parse_args()

Expand All @@ -105,3 +113,5 @@ def install_tk():
install_jax()
if args.tk and not args.test:
install_tk()
if args.liger and not args.test:
install_liger()

0 comments on commit dde8528

Please sign in to comment.