Skip to content

Commit

Permalink
[backends] Add functionality to TRT backend
Browse files Browse the repository at this point in the history
- Add argument parsing for backend arguments to pass to TRT
- Add capability to specify IR via command line CLI
- Add functionality to compilation path and clean up code
  • Loading branch information
gs-olive committed Jul 12, 2023
1 parent 2ea018e commit 695842c
Showing 1 changed file with 50 additions and 9 deletions.
59 changes: 50 additions & 9 deletions torchbenchmark/util/backends/trt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,31 @@
from typing import List
import torch
import argparse

from torchbenchmark.util.backends import create_backend
from torchbenchmark.util.env_check import is_hf_model

def parse_torch_trt_args(backend_args: List[str]):
"""Parses CLI-provided backend arguments to extract Torch-TRT keywords
Returns kwargs dictionary and remainder arguments which were unrecognized
"""
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--truncate_long_and_double", default=None, action="store_true")
arg_parser.add_argument("--workspace_size", type=int)
arg_parser.add_argument("--min_block_size", type=int)
arg_parser.add_argument("--ir", type=str)
args, unknown = arg_parser.parse_known_args(backend_args)

# Remove unspecified arguments from the args dictionary
# (Only pass through user-specified args)
parsed_args = vars(args)
for key in list(parsed_args.keys()):
if parsed_args[key] is None:
del parsed_args[key]

return parsed_args, unknown

@create_backend
def fx2trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]):
FP16 = True if model.dargs.precision == "fp16" else False
Expand Down Expand Up @@ -40,18 +62,37 @@ def _fx2trt():

@create_backend
def torch_trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]):
"""Backend for Torch-TRT
Can be directly invoked from the command line, for example via:
python run.py resnet18 -d cuda -t eval --backend torch_trt --precision fp32 --truncate_long_and_double
Options include:
--truncate_long_and_double: Whether to automatically truncate long and double operations
--min_block_size: Minimum number of operations in an accelerated TRT block
--workspace_size: Size of workspace allotted to TensorRT
--ir: Which internal representation to use: {"ts", "dynamo_compile", "fx_ts_compat", ...}
"""
FP16 = True if model.dargs.precision == "fp16" else False
assert model.device == "cuda" and model.test == "eval", f"fx2trt only works on CUDA inference tests."
assert model.device == "cuda" and model.test == "eval", f"Torch-TRT only works on CUDA inference tests."

# Extract relevant Torch-TRT arguments from the provided CLI arguments
torch_trt_kwargs, backend_args = parse_torch_trt_args(backend_args)

def _torch_trt():
"""Helper function for invoking Torch-TRT
"""
import torch_tensorrt
module, example_inputs = model.get_module()
if FP16:
torchtrt_dtype = torch_tensorrt.dtype.half
torch_dtype = torch.half
else:
torchtrt_dtype = torch_tensorrt.dtype.float
torch_dtype = torch.float32
trt_input = [torch_tensorrt.Input(shape=example_inputs[0].shape, dtype=torch_dtype)]
trt_module = torch_tensorrt.compile(module, inputs=trt_input, enabled_precisions=torchtrt_dtype)
torch_dtype_precision = torch.half if FP16 else torch.float32

trt_input = [torch_tensorrt.Input(shape=input_.shape, dtype=input_.dtype)
for input_ in example_inputs]

trt_module = torch_tensorrt.compile(module,
inputs=trt_input,
enabled_precisions={torch_dtype_precision},
**torch_trt_kwargs)
model.set_module(trt_module)

return _torch_trt, backend_args

0 comments on commit 695842c

Please sign in to comment.