Skip to content

Commit

Permalink
add benchmark function for PT2 compile time (pytorch#2288)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2288

# context
* provide a benchmark function entry point and command-line for running torch.compile with a baby model containing EBC
* current supported arguments:
```
    rank: int = 0,
    world_size: int = 2,
    num_features: int = 5,
    batch_size: int = 10,
```
* run command
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=32
```

# results
* on a DevGPU machine
```
rank: 0, world_size: 2, num_features: 16, batch_size: 10, time: 6.65s
rank: 0, world_size: 2, num_features: 32, batch_size: 10, time: 10.99s
rank: 0, world_size: 2, num_features: 64, batch_size: 10, time: 61.55s
rank: 0, world_size: 2, num_features: 128, batch_size: 10, time: 429.14s
```

Differential Revision: D57501708
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Aug 25, 2024
1 parent 85f994d commit a09ec0a
Showing 1 changed file with 68 additions and 32 deletions.
100 changes: 68 additions & 32 deletions torchrec/distributed/tests/test_pt2_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@

#!/usr/bin/env python3

import timeit
import unittest
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple

import click
import fbgemm_gpu.sparse_ops # noqa: F401, E402

import torch
import torchrec
import torchrec.pt2.checks
Expand Down Expand Up @@ -504,23 +505,22 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor:
def _test_compile_fake_pg_fn(
rank: int,
world_size: int,
num_features: int = 5,
batch_size: int = 10,
num_embeddings: int = 256,
) -> None:
sharding_type = ShardingType.TABLE_WISE.value
input_type = _InputType.SINGLE_BATCH
torch_compile_backend = "eager"
config = _TestConfig()
num_embeddings = 256
# emb_dim must be % 4 == 0 for fbgemm
emb_dim = 12
batch_size = 10
num_features: int = 5

num_float_features: int = 8
num_weighted_features: int = 1

device: torch.Device = torch.device("cuda")
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
dist.init_process_group(backend="fake", rank=rank, world_size=2, store=store)
pg: ProcessGroup = dist.distributed_c10d._get_default_group()

topology: Topology = Topology(world_size=world_size, compute_device="cuda")
Expand Down Expand Up @@ -601,29 +601,17 @@ def _dmp(m: torch.nn.Module) -> DistributedModelParallel: # pyre-ignore
ins = []

for _ in range(1 + n_extra_numerics_checks):
if input_type == _InputType.VARIABLE_BATCH:
(
_,
local_model_inputs,
) = ModelInput.generate_variable_batch_input(
average_batch_size=batch_size,
world_size=world_size,
num_float_features=num_float_features,
# pyre-ignore
tables=mi.tables,
)
else:
(
_,
local_model_inputs,
) = ModelInput.generate(
batch_size=batch_size,
world_size=world_size,
num_float_features=num_float_features,
tables=mi.tables,
weighted_tables=mi.weighted_tables,
variable_batch_size=False,
)
(
_,
local_model_inputs,
) = ModelInput.generate(
batch_size=batch_size,
world_size=world_size,
num_float_features=num_float_features,
tables=mi.tables,
weighted_tables=mi.weighted_tables,
variable_batch_size=False,
)
ins.append(local_model_inputs)

local_model_input = ins[0][rank].to(device)
Expand Down Expand Up @@ -655,9 +643,6 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor:
eager_out = dmp(kjt_ft, ff)
reduce_to_scalar_loss(eager_out).backward()

if torch_compile_backend is None:
return

##### COMPILE #####
with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
Expand Down Expand Up @@ -773,3 +758,54 @@ def test_compile_multiprocess_fake_pg(
rank=0,
world_size=2,
)


@click.command()
@click.option(
"--repeat",
type=int,
default=1,
help="repeat times",
)
@click.option(
"--rank",
type=int,
default=0,
help="rank in the test",
)
@click.option(
"--world-size",
type=int,
default=2,
help="world_size in the test",
)
@click.option(
"--num-features",
type=int,
default=5,
help="num_features in the test",
)
@click.option(
"--batch-size",
type=int,
default=10,
help="batch_size in the test",
)
def compile_benchmark(
rank: int, world_size: int, num_features: int, batch_size: int, repeat: int
) -> None:
run: str = (
f"_test_compile_fake_pg_fn(rank={rank}, world_size={world_size}, "
f"num_features={num_features}, batch_size={batch_size})"
)
print("*" * 20 + " compile_benchmark started " + "*" * 20)
t = timeit.timeit(stmt=run, number=repeat, globals=globals())
print("*" * 20 + " compile_benchmark completed " + "*" * 20)
print(
f"rank: {rank}, world_size: {world_size}, "
f"num_features: {num_features}, batch_size: {batch_size}, time: {t:.2f}s"
)


if __name__ == "__main__":
compile_benchmark()

0 comments on commit a09ec0a

Please sign in to comment.