Skip to content

Commit

Permalink
[backends] Add functionality to TRT backend
Browse files Browse the repository at this point in the history
- Add argument parsing for backend arguments to pass to TRT
- Add capability to specify IR via command line CLI
- Add functionality to compilation path and clean up code
  • Loading branch information
gs-olive committed Jul 7, 2023
1 parent ae3534a commit bb75b05
Showing 1 changed file with 43 additions and 9 deletions.
52 changes: 43 additions & 9 deletions torchbenchmark/util/backends/trt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
from typing import List
import torch
import argparse

from torchbenchmark.util.backends import create_backend
from torchbenchmark.util.env_check import is_hf_model

def parse_torch_trt_args(backend_args: List[str]):
"""Parses CLI-provided backend arguments to extract Torch-TRT keywords
Returns kwargs dictionary and remainder arguments which were unrecognized
"""
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--truncate_long_and_double", default=False, action="store_true")
arg_parser.add_argument("--workspace_size", type=int)
arg_parser.add_argument("--min_block_size", type=int)
arg_parser.add_argument("--ir", type=str, default="ts")
args, unknown = arg_parser.parse_known_args(backend_args)

return vars(args), unknown

@create_backend
def fx2trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]):
FP16 = True if model.dargs.precision == "fp16" else False
Expand Down Expand Up @@ -40,18 +55,37 @@ def _fx2trt():

@create_backend
def torch_trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]):
"""Backend for Torch-TRT
Can be directly invoked from the command line, for example via:
python run.py resnet18 -d cuda -t eval --backend torch_trt --precision fp32 --truncate_long_and_double
Options include:
--truncate_long_and_double: Whether to automatically truncate long and double operations
--min_block_size: Minimum number of operations in an accelerated TRT block
--workspace_size: Size of workspace allotted to TensorRT
--ir: Which internal representation to use: {"ts", "dynamo_compile", "fx_ts_compat", ...}
"""
FP16 = True if model.dargs.precision == "fp16" else False
assert model.device == "cuda" and model.test == "eval", f"fx2trt only works on CUDA inference tests."
assert model.device == "cuda" and model.test == "eval", f"Torch-TRT only works on CUDA inference tests."

# Extract relevant Torch-TRT arguments from the provided CLI arguments
torch_trt_kwargs, backend_args = parse_torch_trt_args(backend_args)

def _torch_trt():
"""Helper function for invoking Torch-TRT
"""
import torch_tensorrt
module, example_inputs = model.get_module()
if FP16:
torchtrt_dtype = torch_tensorrt.dtype.half
torch_dtype = torch.half
else:
torchtrt_dtype = torch_tensorrt.dtype.float
torch_dtype = torch.float32
trt_input = [torch_tensorrt.Input(shape=example_inputs[0].shape, dtype=torch_dtype)]
trt_module = torch_tensorrt.compile(module, inputs=trt_input, enabled_precisions=torchtrt_dtype)
torch_dtype_precision = torch.half if FP16 else torch.float32

trt_input = [torch_tensorrt.Input(shape=input_.shape, dtype=input_.dtype)
for input_ in example_inputs]

trt_module = torch_tensorrt.compile(module,
inputs=trt_input,
enabled_precisions={torch_dtype_precision},
**torch_trt_kwargs)
model.set_module(trt_module)

return _torch_trt, backend_args

0 comments on commit bb75b05

Please sign in to comment.