Skip to content

Commit

Permalink
benchmark on SplitTBE module with different emb-dim
Browse files Browse the repository at this point in the history
Summary:
# context
* benchmark results w/ small batch_size
```
SplitTableBatchedEmbeddingBagsCodegen-1000-16-4-1024-10 | Runtime (P90): 0.289728 ms | Memory (P90): 1 MB
SplitTableBatchedEmbeddingBagsCodegen-1000-32-4-1024-10 | Runtime (P90): 0.304992 ms | Memory (P90): 1 MB
SplitTableBatchedEmbeddingBagsCodegen-1000-64-4-1024-10 | Runtime (P90): 0.315392 ms | Memory (P90): 3 MB
SplitTableBatchedEmbeddingBagsCodegen-1000-128-4-1024-10 | Runtime (P90): 0.295552 ms | Memory (P90): 5 MB
SplitTableBatchedEmbeddingBagsCodegen-1000-256-4-1024-10 | Runtime (P90): 0.30096 ms | Memory (P90): 9 MB
SplitTableBatchedEmbeddingBagsCodegen-1000-512-4-1024-10 | Runtime (P90): 0.317056 ms | Memory (P90): 18 MB
SplitTableBatchedEmbeddingBagsCodegen-1000-1024-4-1024-10 | Runtime (P90): 0.31792 ms | Memory (P90): 36 MB
```
* benchmark results w/ large batch_size
```
SplitTableBatchedEmbeddingBagsCodegen-100000-16-16-131072-10 | Runtime (P90): 0.587584 ms | Memory (P90): 0.31 GB
SplitTableBatchedEmbeddingBagsCodegen-100000-32-16-131072-10 | Runtime (P90): 0.722624 ms | Memory (P90): 0.51 GB
SplitTableBatchedEmbeddingBagsCodegen-100000-64-16-131072-10 | Runtime (P90): 1.29395 ms | Memory (P90): 0.92 GB
SplitTableBatchedEmbeddingBagsCodegen-100000-128-16-131072-10 | Runtime (P90): 2.73472 ms | Memory (P90): 1.7 GB
SplitTableBatchedEmbeddingBagsCodegen-100000-256-16-131072-10 | Runtime (P90): 6.5608 ms | Memory (P90): 3.4 GB
SplitTableBatchedEmbeddingBagsCodegen-100000-512-16-131072-10 | Runtime (P90): 14.8527 ms | Memory (P90): 6.6 GB
SplitTableBatchedEmbeddingBagsCodegen-100000-1024-16-131072-10 | Runtime (P90): 31.055 ms | Memory (P90): 13 GB
```
# traces
* [trace files](https://fburl.com/gdrive/t8qaitoi)
* batch_size - 128
 {F1848010270}
* batch_size - 32
 {F1848011914}

# conclusions
* the kernel contains two major parts run on GPU: 1) a lightweighted `direct_copy_kernel_cuda`, and 2) a heavy-lifting `split_embedding_codegen_forward_unweighted_kernel`.
* details of the `direct_copy_kernel_cuda`
{F1848017719}
* the cpu runtime for launching these two GPU kernels (~3.8ms from the traces), which is the bottleneck of this operator.

Differential Revision: D62254614
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Sep 5, 2024
1 parent 81ae241 commit ec784a3
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#!/usr/bin/env python3

import click

import torch

from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType

from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
BoundsCheckMode,
EmbeddingLocation,
PoolingMode,
)
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
ComputeDevice,
SplitTableBatchedEmbeddingBagsCodegen,
)
from torchrec.distributed.benchmark.benchmark_utils import benchmark_func
from torchrec.distributed.test_utils.test_model import ModelInput
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


@click.command()
@click.option("--num-embeddings", default=100_000)
@click.option("--embedding-dim", default=128)
@click.option("--num-tables", default=4)
@click.option("--batch-size", default=262144)
@click.option("--bag-size", default=10)
def main(
num_embeddings: int,
embedding_dim: int,
num_tables: int,
batch_size: int,
bag_size: int,
):
if embedding_dim == 0:
for embedding_dim in [16, 32, 64, 128, 256, 512, 1024]:
op_bench(num_embeddings, embedding_dim, num_tables, batch_size, bag_size)
else:
op_bench(num_embeddings, embedding_dim, num_tables, batch_size, bag_size)


def op_bench(
num_embeddings: int,
embedding_dim: int,
num_tables: int,
batch_size: int,
bag_size: int,
):
emb = SplitTableBatchedEmbeddingBagsCodegen(
[
(
num_embeddings,
embedding_dim,
EmbeddingLocation.DEVICE,
(
ComputeDevice.CUDA
if torch.cuda.is_available()
else ComputeDevice.CPU
),
)
]
* num_tables,
optimizer=OptimType.EXACT_ADAGRAD,
learning_rate=0.1,
eps=0.1,
weights_precision=SparseType.FP32,
stochastic_rounding=False,
output_dtype=SparseType.FP32,
pooling_mode=PoolingMode.SUM,
bounds_check_mode=BoundsCheckMode.NONE,
)

def _func_to_benchmark(
kjt: KeyedJaggedTensor,
model: torch.nn.Module,
) -> torch.Tensor:
return model.forward(kjt.values(), kjt.offsets())

# breakpoint() # import fbvscode; fbvscode.set_trace()
tables = [
EmbeddingBagConfig(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
name="table_0",
feature_names=["feature_0"],
)
]
inputs = ModelInput.generate(
tables=tables,
weighted_tables=[],
batch_size=batch_size,
world_size=1,
num_float_features=0,
pooling_avg=10,
device=torch.device("cuda"),
)[0].idlist_features

result = benchmark_func(
name=f"SplitTableBatchedEmbeddingBagsCodegen-{num_embeddings}-{embedding_dim}-{num_tables}-{batch_size}-{bag_size}",
bench_inputs=inputs, # pyre-ignore
prof_inputs=inputs, # pyre-ignore
num_benchmarks=10,
num_profiles=10,
profile_dir=".",
world_size=1,
func_to_benchmark=_func_to_benchmark,
benchmark_func_kwargs={"model": emb},
rank=0,
)

print(result)


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ class BenchmarkResult:
max_mem_allocated: List[int] # megabytes
rank: int = -1

def __str__(self) -> str:
return f"{self.short_name: <{35}} | Runtime (P90): {self.runtime_percentile(90):g} ms | Memory (P90): {self.max_mem_percentile(90)/1000:.2g} GB"

def runtime_percentile(
self, percentile: int = 50, interpolation: str = "nearest"
) -> torch.Tensor:
Expand Down

0 comments on commit ec784a3

Please sign in to comment.