From 159c010c2cedc2eba5839244905b99d042458854 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 4 Oct 2023 07:03:17 -0700 Subject: [PATCH] Fix an issue with recent batch fusion update Summary: Fix two bugs introduced by D49793793 and D49871774. 1. [framework] We still need to import the model by hardcoding the prefix ".models", because the name of the model is in the format of "dots", not "slashes". Calling `os.path.dirname` will generate empty string. 2. [inductor_speedup] Remove the 2 "_optimus" tests that are already deprecated in the Ads test set. Reviewed By: aaronenyeshi Differential Revision: D49889169 fbshipit-source-id: 968aef2d8c217859ce8c89af2142df64d7a0e8a4 --- torchbenchmark/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchbenchmark/__init__.py b/torchbenchmark/__init__.py index 2a186723b2..4d157e248f 100644 --- a/torchbenchmark/__init__.py +++ b/torchbenchmark/__init__.py @@ -264,8 +264,8 @@ def __init__( if _is_internal_model(model_path): model_path = f"{internal_model_dir}.{model_path}" self._worker = Worker(timeout=timeout, extra_env=extra_env) - self.worker.run("import torch") + self.worker.run("import torch") self._details: ModelDetails = ModelDetails( **self._maybe_import_model( package=__name__, @@ -299,10 +299,9 @@ def _maybe_import_model(package: str, model_path: str) -> Dict[str, Any]: import traceback model_name = os.path.basename(model_path) - model_dir = os.path.basename(os.path.dirname(model_path)) diagnostic_msg = "" try: - module = importlib.import_module(f'.{model_dir}.{model_name}', package=package) + module = importlib.import_module(f'.models.{model_name}', package=package) if accelerator_backend := os.getenv("ACCELERATOR_BACKEND"): setattr(module, accelerator_backend, importlib.import_module(accelerator_backend)) Model = getattr(module, 'Model', None)