From 8451c1f8fae21d78d21f5f02595fe15ce7c28fb0 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Thu, 20 Jun 2024 10:49:46 -0700 Subject: [PATCH] Fix test_bench script (#2320) Summary: Fixes https://github.com/pytorch/benchmark/issues/2315 Pull Request resolved: https://github.com/pytorch/benchmark/pull/2320 Test Plan: ``` $ pytest test_bench.py -k "test_eval[BERT_pytorch-cpu]" --ignore_machine_config ========================================================================================== test session starts =========================================================================================== platform linux -- Python 3.11.5, pytest-7.4.3, pluggy-1.0.0 benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000) rootdir: /home/xz/git/benchmark plugins: benchmark-4.0.0, hypothesis-6.98.15 collected 411 items / 410 deselected / 1 selected test_bench.py . [100%] ------------------------------------------------- benchmark 'hub': 1 tests ------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ---------------------------------------------------------------------------------------------------------------------------- test_eval[BERT_pytorch-cpu] 114.2104 117.3853 115.4276 1.0485 115.3054 1.4325 4;0 8.6634 9 1 ---------------------------------------------------------------------------------------------------------------------------- Legend: Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile. OPS: Operations Per Second, computed as 1 / Mean =================================================================================== 1 passed, 410 deselected in 5.68s ==================================================================================== ``` Reviewed By: aaronenyeshi Differential Revision: D58823072 Pulled By: xuzhao9 fbshipit-source-id: 172be1d922b2a51ec2df08b822102dc0a20818ac --- test_bench.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/test_bench.py b/test_bench.py index 8be059f95d..9dc14c68fc 100644 --- a/test_bench.py +++ b/test_bench.py @@ -42,11 +42,10 @@ def pytest_generate_tests(metafunc): if metafunc.cls and metafunc.cls.__name__ == "TestBenchNetwork": paths = _list_model_paths() - model_names = [os.path.basename(path) for path in paths] metafunc.parametrize( - "model_name", - model_names, - ids=model_names, + "model_path", + paths, + ids=[os.path.basename(path) for path in paths], scope="class", ) @@ -62,13 +61,14 @@ def pytest_generate_tests(metafunc): ) class TestBenchNetwork: - def test_train(self, model_name, device, compiler, benchmark): + def test_train(self, model_path, device, benchmark): try: + model_name = os.path.basename(model_path) if skip_by_metadata( test="train", device=device, extra_args=[], - metadata=get_metadata_from_yaml(model_name), + metadata=get_metadata_from_yaml(model_path), ): raise NotImplementedError("Test skipped by its metadata.") # TODO: skipping quantized tests for now due to BC-breaking changes for prepare @@ -91,13 +91,14 @@ def test_train(self, model_name, device, compiler, benchmark): except NotImplementedError: print(f"Test train on {device} is not implemented, skipping...") - def test_eval(self, model_name, device, compiler, benchmark, pytestconfig): + def test_eval(self, model_path, device, benchmark, pytestconfig): try: + model_name = os.path.basename(model_path) if skip_by_metadata( test="eval", device=device, extra_args=[], - metadata=get_metadata_from_yaml(model_name), + metadata=get_metadata_from_yaml(model_path), ): raise NotImplementedError("Test skipped by its metadata.") # TODO: skipping quantized tests for now due to BC-breaking changes for prepare @@ -110,16 +111,15 @@ def test_eval(self, model_name, device, compiler, benchmark, pytestconfig): task.make_model_instance(test="eval", device=device) - with task.no_grad(disable_nograd=pytestconfig.getoption("disable_nograd")): - benchmark(task.invoke) - benchmark.extra_info["machine_state"] = get_machine_state() - benchmark.extra_info["batch_size"] = task.get_model_attribute( - "batch_size" - ) - benchmark.extra_info["precision"] = task.get_model_attribute( - "dargs", "precision" - ) - benchmark.extra_info["test"] = "eval" + benchmark(task.invoke) + benchmark.extra_info["machine_state"] = get_machine_state() + benchmark.extra_info["batch_size"] = task.get_model_attribute( + "batch_size" + ) + benchmark.extra_info["precision"] = task.get_model_attribute( + "dargs", "precision" + ) + benchmark.extra_info["test"] = "eval" except NotImplementedError: print(f"Test eval on {device} is not implemented, skipping...")