Skip to content

Commit

Permalink
permute_multi_embs
Browse files Browse the repository at this point in the history
Differential Revision: D52354486
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jul 20, 2024
1 parent 56a3f45 commit 5203613
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 0 deletions.
54 changes: 54 additions & 0 deletions torchrec/sparse/tests/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult
from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import (
_desugar_keyed_tensors,
_fbgemm_permute_pooled_embs,
_regroup_keyed_tensors,
KeyedJaggedTensor,
Expand All @@ -26,6 +27,37 @@
regroup_kts,
)
from torchrec.sparse.tests.utils import build_groups, build_kts
from torchrec.sparse.triton_ops import (
triton_permute_multi_embs,
triton_permute_pooled_embs,
)


@torch.fx.wrap
def _triton_permute_pooled_embs(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
) -> List[torch.Tensor]:
keys, lengths, values = _desugar_keyed_tensors(keyed_tensors)
permuted_values, splits = triton_permute_pooled_embs(
values,
keys,
lengths,
groups,
)
return list(torch.split(permuted_values, splits, dim=1))


@torch.fx.wrap
def _triton_permute_multi_embs(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
) -> List[torch.Tensor]:
keys, lengths, values = _desugar_keyed_tensors(keyed_tensors)
return triton_permute_multi_embs(
values,
keys,
lengths,
groups,
)


class DummyModel(torch.nn.Module):
Expand Down Expand Up @@ -283,6 +315,28 @@ def main(
{"keyed_tensors": kts, "groups": groups},
profile,
)
bench(
"[Triton] permute_pooled_embs",
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
_triton_permute_pooled_embs,
{"keyed_tensors": kts, "groups": groups},
profile,
)
bench(
"[Triton] permute_multi_embs",
labels,
batch_size,
n_dense + n_sparse,
device_type,
run_backward,
_triton_permute_multi_embs,
{"keyed_tensors": kts, "groups": groups},
profile,
)


if __name__ == "__main__":
Expand Down
69 changes: 69 additions & 0 deletions torchrec/sparse/tests/test_triton_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict


import unittest

import torch

from torchrec.sparse.jagged_tensor import _desugar_keyed_tensors, _regroup_keyed_tensors
from torchrec.sparse.tests.utils import build_groups, build_kts
from torchrec.sparse.triton_ops import (
triton_permute_multi_embs,
triton_permute_pooled_embs,
)


class TestPermutePooledEmbs(unittest.TestCase):
# pyre-ignore[56]
@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available")
def test_triton_permute_pooled_embs_forward(self) -> None:
kts = build_kts(
dense_features=2,
sparse_features=2,
dim_dense=16,
dim_sparse=16,
batch_size=8,
device=torch.device("cuda"),
run_backward=False,
)
groups = build_groups(
kts,
4,
)
keys, lengths, values = _desugar_keyed_tensors(kts)
output, splits = triton_permute_pooled_embs(values, keys, lengths, groups)
refs = _regroup_keyed_tensors(kts, groups)
outputs = torch.split(output, splits, dim=1)
for ref, output in zip(refs, outputs):
torch.testing.assert_close(ref, output)


class TestPermuteMultiEmbs(unittest.TestCase):
# pyre-ignore[56]
@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available")
def test_triton_permute_multi_embs_forward(self) -> None:
kts = build_kts(
dense_features=2,
sparse_features=2,
dim_dense=16,
dim_sparse=16,
batch_size=8,
device=torch.device("cuda"),
run_backward=False,
)
groups = build_groups(
kts,
4,
)
keys, lengths, values = _desugar_keyed_tensors(kts)
outputs = triton_permute_multi_embs(values, keys, lengths, groups)
refs = _regroup_keyed_tensors(kts, groups)
for ref, output in zip(refs, outputs):
torch.testing.assert_close(ref, output)
219 changes: 219 additions & 0 deletions torchrec/sparse/triton_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import Dict, List, Tuple

import torch

# @manual=//triton:triton
import triton

# @manual=//triton:triton
import triton.language as tl


def triton_permute_pooled_embs(
values: List[torch.Tensor],
keys: List[List[str]],
lengths: List[List[int]],
groups: List[List[str]],
) -> Tuple[torch.Tensor, List[int]]:
"""
Permute the values of a KeyedTensor based on the groups.
"""
assert len(values) == len(keys)
assert len(values) == len(lengths)
P = sum(len(g) for g in groups)
B = values[0].shape[0]
device = values[0].device
in_length: int = 0
out_length: int = 0
splits: List[int] = [0] * len(groups)

# permute: [in_offset, out_offset, length, next]
permutes: List[List[int]] = [[0] * 4 for _ in range(P)]
# key -> (in_tensor, in_offset, length)
lookup: Dict[str, Tuple[int, int, int]] = {}
for i, (key, length) in enumerate(zip(keys, lengths)):
for k, l in zip(key, length):
lookup[k] = (i, in_length, l)
in_length += l

curr = 0
for j, group in enumerate(groups):
for k in group:
in_tensor, in_offset, length = lookup[k]
permutes[curr][:] = [in_offset, out_length, length, 0]
out_length += length
splits[j] += length
curr += 1

permute_tensor = torch.tensor(permutes, dtype=torch.int32).to(
device, non_blocking=True
)
output: torch.Tensor = torch.empty(B, out_length, device=device)
permute_pooled_embeddings_kernel[(B, P)](
torch.concat(values, dim=1),
output,
permute_tensor,
in_length,
out_length,
)
return output, splits


@triton.jit
def permute_pooled_embeddings_kernel(
values,
outputs,
permutes,
in_length,
out_length,
):
batch_id = tl.program_id(0)
pid = tl.program_id(1)
in_offset = tl.load(permutes + 4 * pid)
out_offset = tl.load(permutes + 4 * pid + 1)
length = tl.load(permutes + 4 * pid + 2)
BLOCK_SIZE: tl.constexpr = 32

idx = tl.arange(0, BLOCK_SIZE)
in_ptr = values + batch_id * in_length + in_offset + idx
out_ptr = outputs + batch_id * out_length + out_offset + idx

for k in range(0, length, BLOCK_SIZE):
inputs = tl.load(in_ptr + k, mask=idx < length - k)
tl.store(out_ptr + k, inputs, mask=idx < length - k)


def triton_permute_multi_embs(
values: List[torch.Tensor],
keys: List[List[str]],
lengths: List[List[int]],
groups: List[List[str]],
) -> List[torch.Tensor]:
"""
Permute the values of a KeyedTensor based on the groups.
"""
assert len(values) == len(keys)
assert len(values) == len(lengths)
P = sum(len(g) for g in groups)
B = values[0].shape[0]
device = values[0].device
in_lengths: List[int] = [0] * len(values)
out_lengths: List[int] = [0] * len(groups)

inputs: torch.Tensor = torch.tensor(
[v.data_ptr() for v in values], dtype=torch.int64
).to(device, non_blocking=True)

# permute: [in_tensor, out_tensor, in_offset, out_offset, length, next]
permutes: List[List[int]] = [[0] * 6 for _ in range(P)]
# key -> (in_tensor, in_offset, length)
lookup: Dict[str, Tuple[int, int, int]] = {}
for i, (key, length) in enumerate(zip(keys, lengths)):
for k, l in zip(key, length):
lookup[k] = (i, in_lengths[i], l)
in_lengths[i] += l

curr = 0
for out_tensor, group in enumerate(groups):
for k in group:
in_tensor, in_offset, length = lookup[k]
permutes[curr][:] = [
in_tensor,
out_tensor,
in_offset,
out_lengths[out_tensor],
length,
0,
]
out_lengths[out_tensor] += length
curr += 1

permute_tensor = torch.tensor(permutes, dtype=torch.int64).to(
device, non_blocking=True
)
outputs: List[torch.Tensor] = [
torch.empty(B, L, device=device) for L in out_lengths
]
output: torch.Tensor = torch.tensor(
[o.data_ptr() for o in outputs], dtype=torch.int64
).to(device, non_blocking=True)
in_lengths_ptr: torch.Tensor = torch.tensor(in_lengths, dtype=torch.int64).to(
device, non_blocking=True
)
out_lengths_ptr: torch.Tensor = torch.tensor(out_lengths, dtype=torch.int64).to(
device, non_blocking=True
)
permute_multi_embeddings_kernel[(B, P)](
values[0],
inputs,
output,
permute_tensor,
in_lengths_ptr,
out_lengths_ptr,
)
return outputs


@triton.jit
def permute_multi_embeddings_kernel(
example,
inputs,
output,
permutes,
in_lengths,
out_lengths,
):
batch_id = tl.program_id(0)
pid = tl.program_id(1)
in_tensor = tl.load(permutes + 6 * pid)
out_tensor = tl.load(permutes + 6 * pid + 1)
in_offset = tl.load(permutes + 6 * pid + 2)
out_offset = tl.load(permutes + 6 * pid + 3)
length = tl.load(permutes + 6 * pid + 4)

in_length = tl.load(in_lengths + in_tensor)
out_length = tl.load(out_lengths + out_tensor)

BLOCK_SIZE: tl.constexpr = 32
idx = tl.arange(0, BLOCK_SIZE)

in_ptr = (
tl.load(inputs + in_tensor).to(example.dtype, bitcast=True)
+ batch_id * in_length
+ in_offset
+ idx
)
out_ptr = (
tl.load(output + out_tensor).to(example.dtype, bitcast=True)
+ batch_id * out_length
+ out_offset
+ idx
)

for k in range(0, length, BLOCK_SIZE):
in_data = tl.load(in_ptr + k, mask=idx < length - k)
tl.store(out_ptr + k, in_data, mask=idx < length - k)


# @custom_impl("torchrec::permute_multi_embeddings", "CUDA")
# @custom_impl("torchrec::permute_multi_embeddings", "AutogradCUDA")
# def permute_multi_embeddings(
# values: List[torch.Tensor],
# keys: List[List[str]],
# lengths: List[List[int]],
# groups: List[List[str]],
# ) -> List[torch.Tensor]:
# """
# Permute the values of a KeyedTensor based on the groups.
# """
# assert len(values) == len(keys)
# assert len(values) == len(lengths)

0 comments on commit 5203613

Please sign in to comment.