Skip to content

Commit

Permalink
Set directory location is SSD TBE benchmarks (#2579)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2579

Set ssd location options in training benchmark.
Some Other small nits, improving log messages.

Reviewed By: sryap

Differential Revision: D57219714

fbshipit-source-id: 19dc975000f53e123c8062c674b9764e5c13cf16
  • Loading branch information
pranjalssh authored and facebook-github-bot committed May 21, 2024
1 parent 66efb75 commit b7b12ba
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def benchmark_ssd_function(
end = time.time_ns()
total_time_read_ns += read_end - start
total_time_write_ns += end - read_end
if i % 10 == 0:
if i % 100 == 0:
logging.info(
f"{i}, {(read_end - start) / 10**6}, {(end - read_end) / 10**6}"
f"{i}, {(read_end - start) / 10**3} us, {(end - read_end) / 10**3} us"
)
return (total_time_read_ns / iters, total_time_write_ns / iters)

Expand Down Expand Up @@ -154,8 +154,7 @@ def benchmark_read_write(
gibps_wr = byte_seconds_per_ns / (write_lat_ns * 2**30)
gibps_tot = 2 * byte_seconds_per_ns / ((read_lat_ns + write_lat_ns) * 2**30)
logging.info(
f"Batch Size: {batch_size}, "
f"Bag_size: {bag_size:3d}, "
f"Total bytes: {total_bytes/1e9:0.2f} GB, "
f"Read_us: {read_lat_ns / 1000:8.0f}, "
f"Write_us: {write_lat_ns / 1000:8.0f}, "
f"Total_us: {(read_lat_ns + write_lat_ns) / 1000:8.0f}, "
Expand All @@ -171,9 +170,9 @@ def benchmark_read_write(
# @click.option("--num-tables", default=64)
@click.option("--num-embeddings", default=int(1.5e9))
@click.option("--embedding-dim", default=128)
@click.option("--batch-size", default=1024)
@click.option("--bag-size", default=1)
@click.option("--iters", default=1000)
@click.option("--batch-size", default=4096)
@click.option("--bag-size", default=10)
@click.option("--iters", default=400)
@click.option("--warmup-iters", default=100)
@click.option(
"--ssd-prefix", default="/tmp/ssd_benchmark_embedding"
Expand Down Expand Up @@ -212,7 +211,7 @@ def ssd_read_write(
@click.option("--embedding-dim", default=128)
@click.option("--weights-precision", type=SparseType, default=SparseType.FP32)
@click.option("--stoc", is_flag=True, default=False)
@click.option("--iters", default=100)
@click.option("--iters", default=500)
@click.option("--warmup-runs", default=0)
@click.option("--managed", default="device")
@click.option("--mixed", is_flag=True, default=False)
Expand All @@ -227,6 +226,7 @@ def ssd_read_write(
@click.option("--output-dtype", type=SparseType, default=SparseType.FP32)
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
@click.option("--ssd-prefix", type=str, default="/tmp/ssd_benchmark")
def ssd_training( # noqa C901
alpha: float,
bag_size: int,
Expand All @@ -249,6 +249,7 @@ def ssd_training( # noqa C901
output_dtype: SparseType,
requests_data_file: Optional[str],
tables: Optional[str],
ssd_prefix: Optional[str],
) -> None:
np.random.seed(42)
torch.manual_seed(42)
Expand Down Expand Up @@ -343,8 +344,9 @@ def gen_split_tbe_generator(
"SSD": lambda: SSDTableBatchedEmbeddingBags(
embedding_specs=[(E, d) for d in Ds],
cache_sets=cache_set,
ssd_storage_directory=tempfile.mkdtemp(),
ssd_storage_directory=tempfile.mkdtemp(prefix=ssd_prefix),
ssd_cache_location=EmbeddingLocation.MANAGED,
ssd_shards=8,
**common_args,
),
}
Expand Down Expand Up @@ -384,6 +386,8 @@ def gen_split_tbe_generator(
+ param_size_multiplier * B * sum(Ds) * L
)

logging.info(f"Batch read write bytes: {read_write_bytes/1.0e9: .2f} GB")

# Compute width of test name and bandwidth widths to improve report
# readability
name_width = 0
Expand Down Expand Up @@ -411,13 +415,15 @@ def gen_forward_func(
emb = generator().to(get_device())

# Forward
test_name = f"{prefix} Forward"
logging.info(f"Running benchmark: {test_name}")
time_per_iter = benchmark_requests(
requests,
gen_forward_func(emb, feature_requires_grad),
flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
num_warmups=warmup_runs,
)
test_name = f"{prefix} Forward,"

bw = f"{read_write_bytes / time_per_iter / 1.0e9: .2f}"
report.append(
f"{test_name: <{name_width}} B: {B}, "
Expand All @@ -427,6 +433,8 @@ def gen_forward_func(
)

# Backward
test_name = f"{prefix} Backward,"
logging.info(f"Running benchmark: {test_name}")
time_per_iter = benchmark_requests(
requests,
gen_forward_func(emb, feature_requires_grad),
Expand All @@ -436,7 +444,6 @@ def gen_forward_func(
num_warmups=warmup_runs,
)

test_name = f"{prefix} Backward,"
bw = f"{2 * read_write_bytes / time_per_iter / 1.0e9: .2f}"
report.append(
f"{test_name: <{name_width}} B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
Expand Down

0 comments on commit b7b12ba

Please sign in to comment.