diff --git a/torchbenchmark/__init__.py b/torchbenchmark/__init__.py index 2a186723b2..4d157e248f 100644 --- a/torchbenchmark/__init__.py +++ b/torchbenchmark/__init__.py @@ -264,8 +264,8 @@ def __init__( if _is_internal_model(model_path): model_path = f"{internal_model_dir}.{model_path}" self._worker = Worker(timeout=timeout, extra_env=extra_env) - self.worker.run("import torch") + self.worker.run("import torch") self._details: ModelDetails = ModelDetails( **self._maybe_import_model( package=__name__, @@ -299,10 +299,9 @@ def _maybe_import_model(package: str, model_path: str) -> Dict[str, Any]: import traceback model_name = os.path.basename(model_path) - model_dir = os.path.basename(os.path.dirname(model_path)) diagnostic_msg = "" try: - module = importlib.import_module(f'.{model_dir}.{model_name}', package=package) + module = importlib.import_module(f'.models.{model_name}', package=package) if accelerator_backend := os.getenv("ACCELERATOR_BACKEND"): setattr(module, accelerator_backend, importlib.import_module(accelerator_backend)) Model = getattr(module, 'Model', None)