diff --git a/userbenchmark/group_bench/run.py b/userbenchmark/group_bench/run.py index f490d13036..6adb61563f 100644 --- a/userbenchmark/group_bench/run.py +++ b/userbenchmark/group_bench/run.py @@ -219,9 +219,9 @@ def run(args: List[str]): except KeyboardInterrupt: print("User keyboard interrupted!") result = get_output_json(BM_NAME, results) - if group_config.baseline_config.device == 'cuda': + if group_config.baseline_configs[0].device == 'cuda': import torch result["environ"]["device"] = torch.cuda.get_device_name() + print(json.dumps(result)) with open(args.output, 'w') as f: json.dump(result, f, indent=4) - print(json.dumps(result))