Skip to content

Commit

Permalink
Fix the CI by using model_name to initialize a ModelTask (#2252)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2252

Test Plan: OSS CI

Reviewed By: aaronenyeshi

Differential Revision: D56755572

Pulled By: xuzhao9

fbshipit-source-id: 5e3f74edf66994cd3ed998933cf6e027c737ac8a
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Apr 30, 2024
1 parent 2e67641 commit 7decf17
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 21 deletions.
4 changes: 2 additions & 2 deletions gen_summary_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _extract_detail(path: str) -> Dict[str, Any]:
t_detail = None
e_detail = None
# Separate train and eval to isolated processes.
task_t = ModelTask(path, timeout=TIMEOUT)
task_t = ModelTask(name, timeout=TIMEOUT)
try:
task_t.make_model_instance(device=device)
task_t.set_train()
Expand All @@ -72,7 +72,7 @@ def _extract_detail(path: str) -> Dict[str, Any]:
print(f"Model {name} train is not fully implemented. skipping...")
del task_t

task_e = ModelTask(path, timeout=TIMEOUT)
task_e = ModelTask(name, timeout=TIMEOUT)
try:
task_e.make_model_instance(device=device)
task_e.set_eval()
Expand Down
14 changes: 8 additions & 6 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def _create_example_model_instance(task: ModelTask, device: str):


def _load_test(path, device):

model_name = os.path.basename(path)

def _skip_cuda_memory_check_p(metadata):
if device != "cuda":
return True
Expand All @@ -63,7 +66,7 @@ def _skip_cuda_memory_check_p(metadata):
return False

def example_fn(self):
task = ModelTask(path, timeout=TIMEOUT)
task = ModelTask(model_name, timeout=TIMEOUT)
with task.watch_cuda_memory(
skip=_skip_cuda_memory_check_p(metadata), assert_equal=self.assertEqual
):
Expand All @@ -83,7 +86,7 @@ def example_fn(self):

def train_fn(self):
metadata = get_metadata_from_yaml(path)
task = ModelTask(path, timeout=TIMEOUT)
task = ModelTask(model_name, timeout=TIMEOUT)
allow_customize_batch_size = task.get_model_attribute(
"ALLOW_CUSTOMIZE_BSIZE", classattr=True
)
Expand All @@ -106,7 +109,7 @@ def train_fn(self):

def eval_fn(self):
metadata = get_metadata_from_yaml(path)
task = ModelTask(path, timeout=TIMEOUT)
task = ModelTask(model_name, timeout=TIMEOUT)
allow_customize_batch_size = task.get_model_attribute(
"ALLOW_CUSTOMIZE_BSIZE", classattr=True
)
Expand All @@ -129,7 +132,7 @@ def eval_fn(self):
)

def check_device_fn(self):
task = ModelTask(path, timeout=TIMEOUT)
task = ModelTask(model_name, timeout=TIMEOUT)
with task.watch_cuda_memory(
skip=_skip_cuda_memory_check_p(metadata), assert_equal=self.assertEqual
):
Expand All @@ -142,7 +145,6 @@ def check_device_fn(self):
f'Method check_device on {device} is not implemented because "{e}", skipping...'
)

name = os.path.basename(path)
metadata = get_metadata_from_yaml(path)
for fn, fn_name in zip(
[example_fn, train_fn, eval_fn, check_device_fn],
Expand All @@ -151,7 +153,7 @@ def check_device_fn(self):
# set exclude list based on metadata
setattr(
TestBenchmark,
f"test_{name}_{fn_name}_{device}",
f"test_{model_name}_{fn_name}_{device}",
(
unittest.skipIf(
skip_by_metadata(
Expand Down
23 changes: 12 additions & 11 deletions test_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ def pytest_generate_tests(metafunc):

if metafunc.cls and metafunc.cls.__name__ == "TestBenchNetwork":
paths = _list_model_paths()
model_names = [os.path.basename(path) for path in paths]
metafunc.parametrize(
"model_path",
paths,
ids=[os.path.basename(path) for path in paths],
"model_name",
model_names,
ids=model_names,
scope="class",
)

Expand All @@ -61,20 +62,20 @@ def pytest_generate_tests(metafunc):
)
class TestBenchNetwork:

def test_train(self, model_path, device, compiler, benchmark):
def test_train(self, model_name, device, compiler, benchmark):
try:
if skip_by_metadata(
test="train",
device=device,
extra_args=[],
metadata=get_metadata_from_yaml(model_path),
metadata=get_metadata_from_yaml(model_name),
):
raise NotImplementedError("Test skipped by its metadata.")
# TODO: skipping quantized tests for now due to BC-breaking changes for prepare
# api, enable after PyTorch 1.13 release
if "quantized" in model_path:
if "quantized" in model_name:
return
task = ModelTask(model_path)
task = ModelTask(model_name)
if not task.model_details.exists:
return # Model is not supported.

Expand All @@ -90,20 +91,20 @@ def test_train(self, model_path, device, compiler, benchmark):
except NotImplementedError:
print(f"Test train on {device} is not implemented, skipping...")

def test_eval(self, model_path, device, compiler, benchmark, pytestconfig):
def test_eval(self, model_name, device, compiler, benchmark, pytestconfig):
try:
if skip_by_metadata(
test="eval",
device=device,
extra_args=[],
metadata=get_metadata_from_yaml(model_path),
metadata=get_metadata_from_yaml(model_name),
):
raise NotImplementedError("Test skipped by its metadata.")
# TODO: skipping quantized tests for now due to BC-breaking changes for prepare
# api, enable after PyTorch 1.13 release
if "quantized" in model_path:
if "quantized" in model_name:
return
task = ModelTask(model_path)
task = ModelTask(model_name)
if not task.model_details.exists:
return # Model is not supported.

Expand Down
2 changes: 1 addition & 1 deletion torchbenchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def watch_cuda_memory(


def list_models_details(workers: int = 1) -> List[ModelDetails]:
return [ModelTask(model_path).model_details for model_path in _list_model_paths()]
return [ModelTask(os.path.basename(model_path)).model_details for model_path in _list_model_paths()]


def list_models(model_match=None):
Expand Down
2 changes: 1 addition & 1 deletion userbenchmark/test_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def assertEqual(x, y):
model_name = config.name
model_path = os.path.join(REPO_PATH, "torchbenchmark", "models", model_name)
metadata = get_metadata_from_yaml(model_path)
task = ModelTask(model_path, timeout=TIMEOUT)
task = ModelTask(model_name, timeout=TIMEOUT)
allow_customize_batch_size = task.get_model_attribute(
"ALLOW_CUSTOMIZE_BSIZE", classattr=True
)
Expand Down

0 comments on commit 7decf17

Please sign in to comment.