diff --git a/userbenchmark/torch_trt/run.py b/userbenchmark/torch_trt/run.py index b2f087d773..651b30b48c 100644 --- a/userbenchmark/torch_trt/run.py +++ b/userbenchmark/torch_trt/run.py @@ -195,7 +195,10 @@ def run(args: List[str]): ir_idx = unknown_args.index("--ir") selected_ir = unknown_args[ir_idx + 1] except (ValueError, IndexError): + # If no IR was specified, default to torch.compile selected_ir = "torch_compile" + unknown_args.append("--ir") + unknown_args.append(selected_ir) # Parse model string if specified, otherwise run all models # Adapted from benchmark/run.py