From bf981d98acddd34e9140dc7b976e624e6618ecdd Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 19 Oct 2023 15:52:08 -0700 Subject: [PATCH] [userbenchmark] Fix JSON file format Torch-TRT - Remove error message strings in JSON, opting for -1.0 as the error code instead - A future PR will add separate `.txt` files for each erroring model --- userbenchmark/torch_trt/run.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/userbenchmark/torch_trt/run.py b/userbenchmark/torch_trt/run.py index 0cec4725fa..a539902352 100644 --- a/userbenchmark/torch_trt/run.py +++ b/userbenchmark/torch_trt/run.py @@ -102,6 +102,7 @@ def run_single_model( precision = model.get_model_attribute("dargs", "precision") else: pt2_compilation_time = getattr(model, "pt2_compilation_time", None) + pt2_graph_breaks = getattr(model, "pt2_graph_breaks", None) name = getattr(model, "name", None) batch_size = getattr(model, "batch_size", None) precision = getattr(model, "precision", None) @@ -111,6 +112,12 @@ def run_single_model( f"{name}.bs_{batch_size}.precision_{precision}." + f"ir_{selected_ir}.pt2_compilation_time" ] = pt2_compilation_time + + if pt2_graph_breaks is not None and pt2_graph_breaks: + metrics[ + f"{name}.bs_{batch_size}.precision_{precision}." + + f"ir_{selected_ir}.pt2_graph_breaks" + ] = pt2_graph_breaks except: pass @@ -277,9 +284,7 @@ def run(args: List[str]): print( f"\nBenchmarking model {model_name} failed with:\n{e}\nSkipping the model.\n" ) - metrics = { - model_name: f"Failed to run benchmark: {traceback.format_exc()}" - } + metrics = {model_name: -1.0} # Halt further model runs on KeyboardInterrupt except KeyboardInterrupt: @@ -290,9 +295,7 @@ def run(args: List[str]): print( f"\nBenchmarking model {model_name} failed.\nSkipping the model.\n" ) - metrics = { - model_name: f"Failed to run benchmark: Error" - } + metrics = {model_name: -1.0} all_metrics = {**all_metrics, **metrics}