diff --git a/userbenchmark/triton/run.py b/userbenchmark/triton/run.py index f8ab0821a..ad2c58c2a 100644 --- a/userbenchmark/triton/run.py +++ b/userbenchmark/triton/run.py @@ -29,7 +29,12 @@ def get_parser(args=None): parser = argparse.ArgumentParser(allow_abbrev=False) - parser.add_argument("--op", type=str, required=False, help="Operator to benchmark.") + parser.add_argument( + "--op", + type=str, + required=False, + help="Operators to benchmark. Split with comma if multiple.", + ) parser.add_argument( "--mode", choices=["fwd", "bwd", "fwd_bwd", "fwd_no_grad"], @@ -188,5 +193,11 @@ def run(args: List[str] = []): run_ci() return + if args.op: + ops = args.op.split(",") + else: + ops = [] with gpu_lockdown(args.gpu_lockdown): - _run(args, extra_args) + for op in ops: + args.op = op + _run(args, extra_args)