Skip to content

Commit

Permalink
[userbenchmark] Fix JSON file format Torch-TRT
Browse files Browse the repository at this point in the history
- Remove error message strings in JSON, opting for -1.0 as the error
code instead
- A future PR will add separate `.txt` files for each erroring model
  • Loading branch information
gs-olive committed Oct 19, 2023
1 parent 5248206 commit bf981d9
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions userbenchmark/torch_trt/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def run_single_model(
precision = model.get_model_attribute("dargs", "precision")
else:
pt2_compilation_time = getattr(model, "pt2_compilation_time", None)
pt2_graph_breaks = getattr(model, "pt2_graph_breaks", None)
name = getattr(model, "name", None)
batch_size = getattr(model, "batch_size", None)
precision = getattr(model, "precision", None)
Expand All @@ -111,6 +112,12 @@ def run_single_model(
f"{name}.bs_{batch_size}.precision_{precision}."
+ f"ir_{selected_ir}.pt2_compilation_time"
] = pt2_compilation_time

if pt2_graph_breaks is not None and pt2_graph_breaks:
metrics[
f"{name}.bs_{batch_size}.precision_{precision}."
+ f"ir_{selected_ir}.pt2_graph_breaks"
] = pt2_graph_breaks
except:
pass

Expand Down Expand Up @@ -277,9 +284,7 @@ def run(args: List[str]):
print(
f"\nBenchmarking model {model_name} failed with:\n{e}\nSkipping the model.\n"
)
metrics = {
model_name: f"Failed to run benchmark: {traceback.format_exc()}"
}
metrics = {model_name: -1.0}

# Halt further model runs on KeyboardInterrupt
except KeyboardInterrupt:
Expand All @@ -290,9 +295,7 @@ def run(args: List[str]):
print(
f"\nBenchmarking model {model_name} failed.\nSkipping the model.\n"
)
metrics = {
model_name: f"Failed to run benchmark: Error"
}
metrics = {model_name: -1.0}

all_metrics = {**all_metrics, **metrics}

Expand Down

0 comments on commit bf981d9

Please sign in to comment.