From bb75b05f1927eec7d5298b544da0a4329e3a5f2d Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 6 Jul 2023 19:06:45 -0700 Subject: [PATCH] [backends] Add functionality to TRT backend - Add argument parsing for backend arguments to pass to TRT - Add capability to specify IR via command line CLI - Add functionality to compilation path and clean up code --- torchbenchmark/util/backends/trt.py | 52 ++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/torchbenchmark/util/backends/trt.py b/torchbenchmark/util/backends/trt.py index 98bbaf1942..78f0c78826 100644 --- a/torchbenchmark/util/backends/trt.py +++ b/torchbenchmark/util/backends/trt.py @@ -1,9 +1,24 @@ from typing import List import torch +import argparse from torchbenchmark.util.backends import create_backend from torchbenchmark.util.env_check import is_hf_model +def parse_torch_trt_args(backend_args: List[str]): + """Parses CLI-provided backend arguments to extract Torch-TRT keywords + + Returns kwargs dictionary and remainder arguments which were unrecognized + """ + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument("--truncate_long_and_double", default=False, action="store_true") + arg_parser.add_argument("--workspace_size", type=int) + arg_parser.add_argument("--min_block_size", type=int) + arg_parser.add_argument("--ir", type=str, default="ts") + args, unknown = arg_parser.parse_known_args(backend_args) + + return vars(args), unknown + @create_backend def fx2trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]): FP16 = True if model.dargs.precision == "fp16" else False @@ -40,18 +55,37 @@ def _fx2trt(): @create_backend def torch_trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]): + """Backend for Torch-TRT + + Can be directly invoked from the command line, for example via: + python run.py resnet18 -d cuda -t eval --backend torch_trt --precision fp32 --truncate_long_and_double + + Options include: + --truncate_long_and_double: Whether to automatically truncate long and double operations + --min_block_size: Minimum number of operations in an accelerated TRT block + --workspace_size: Size of workspace allotted to TensorRT + --ir: Which internal representation to use: {"ts", "dynamo_compile", "fx_ts_compat", ...} + """ FP16 = True if model.dargs.precision == "fp16" else False - assert model.device == "cuda" and model.test == "eval", f"fx2trt only works on CUDA inference tests." + assert model.device == "cuda" and model.test == "eval", f"Torch-TRT only works on CUDA inference tests." + + # Extract relevant Torch-TRT arguments from the provided CLI arguments + torch_trt_kwargs, backend_args = parse_torch_trt_args(backend_args) + def _torch_trt(): + """Helper function for invoking Torch-TRT + """ import torch_tensorrt module, example_inputs = model.get_module() - if FP16: - torchtrt_dtype = torch_tensorrt.dtype.half - torch_dtype = torch.half - else: - torchtrt_dtype = torch_tensorrt.dtype.float - torch_dtype = torch.float32 - trt_input = [torch_tensorrt.Input(shape=example_inputs[0].shape, dtype=torch_dtype)] - trt_module = torch_tensorrt.compile(module, inputs=trt_input, enabled_precisions=torchtrt_dtype) + torch_dtype_precision = torch.half if FP16 else torch.float32 + + trt_input = [torch_tensorrt.Input(shape=input_.shape, dtype=input_.dtype) + for input_ in example_inputs] + + trt_module = torch_tensorrt.compile(module, + inputs=trt_input, + enabled_precisions={torch_dtype_precision}, + **torch_trt_kwargs) model.set_module(trt_module) + return _torch_trt, backend_args