diff --git a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py index 6d354900a..363060cf9 100644 --- a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py @@ -94,7 +94,7 @@ def forward( @click.option("--num-tasks", default=3) @click.option("--repeats", default=100) # pyre-fixme[2]: Parameter must be annotated. -def main(batch_size, num_tables, num_tasks, repeats) -> None: +def cli(batch_size, num_tables, num_tasks, repeats) -> None: device = torch.device("cuda", 0) torch.cuda.set_device(device) hash_sizes = list(np.random.choice(range(50, 250), size=(num_tables))) @@ -169,4 +169,4 @@ def main(batch_size, num_tables, num_tasks, repeats) -> None: if __name__ == "__main__": - main() + cli() diff --git a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py index e43106b8c..e089e0888 100644 --- a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py +++ b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py @@ -49,7 +49,7 @@ def benchmark_hbc_function( @click.command() @click.option("--iters", default=100) @click.option("--warmup-runs", default=2) -def main( +def cli( iters: int, warmup_runs: int, ) -> None: @@ -276,4 +276,4 @@ def fbgemm_generic_hbc_by_feature_gpu( if __name__ == "__main__": - main() + cli() diff --git a/fbgemm_gpu/bench/merge_embeddings_benchmark.py b/fbgemm_gpu/bench/merge_embeddings_benchmark.py index 4b0b48f70..6ed8ebb26 100644 --- a/fbgemm_gpu/bench/merge_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/merge_embeddings_benchmark.py @@ -486,7 +486,7 @@ def pool_func_with_quantization( @click.option("--num_of_embeddings", default=100000, type=int) @click.option("--pooling_factor", default=25, type=int) @click.option("--sweep", is_flag=True, default=False) -def main( +def cli( all_to_one_only: bool, sum_reduce_to_one_only: bool, num_ads: int, @@ -573,4 +573,4 @@ def handler(signum, frame): if __name__ == "__main__": - main() + cli()