Skip to content

Commit

Permalink
Use tqdm to trace the progress
Browse files Browse the repository at this point in the history
Summary: Sometimes the benchmark may take a long time. It would be very helpful if we could know the current progress 🙂

Reviewed By: xuzhao9

Differential Revision: D56342473

fbshipit-source-id: a0c993d12a2419d3f9b88ffbdb865f6371fe78a3
  • Loading branch information
sijiac authored and facebook-github-bot committed Apr 20, 2024
1 parent 1d39d82 commit 3f146bb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
3 changes: 1 addition & 2 deletions torchbenchmark/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class Operator(BenchmarkOperator):
def __init__(self, mode: str, device: str, extra_args: List[str] = []):
super().__init__(mode=mode, device=device, extra_args=extra_args)
if not self.extra_args:
self.DEFAULT_NUM_BATCH = len(BUILDIN_SHAPES)
self.shapes = BUILDIN_SHAPES
else:
self.tbargs = parse_args(self.extra_args)
Expand All @@ -95,7 +94,7 @@ def __init__(self, mode: str, device: str, extra_args: List[str] = []):
self.shapes = SPLIT_K_SHAPES
else:
self.shapes = [(self.tb_args.m, self.tbargs.k, self.tbargs.n)]
self.DEFAULT_NUM_BATCH = len(self.shapes)
self.dargs.num_batch = len(self.shapes)

@register_benchmark()
def triton_tutorial_matmul(self, a, b, bias) -> Callable:
Expand Down
7 changes: 7 additions & 0 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
from torchbenchmark.util.extra_args import apply_decoration_args, parse_decoration_args
from torchbenchmark.util.input import input_cast

try:
from tqdm import tqdm
except ImportError:
tqdm = None

DEFAULT_WARMUP = 25
DEFAULT_RUN_ITERS = 100
DEFAULT_QUANTILES = [0.5, 0.1, 0.9]
Expand Down Expand Up @@ -361,6 +366,8 @@ def run(
batch_range = range(self._batch_id + 1)
else:
batch_range = range(self.dargs.num_batch)
if tqdm is not None:
batch_range = tqdm(batch_range)
for batch_id in batch_range:
if self._batch_id and batch_id < self._batch_id:
continue
Expand Down

0 comments on commit 3f146bb

Please sign in to comment.