Skip to content

Commit

Permalink
Fix an issue with recent batch fusion update
Browse files Browse the repository at this point in the history
Summary:
Fix two bugs introduced by D49793793 and D49871774.

1. [framework] We still need to import the model by hardcoding the prefix ".models", because the name of the model is in the format of "dots", not "slashes". Calling `os.path.dirname` will generate empty string.
2. [inductor_speedup] Remove the 2 "_optimus" tests that are already deprecated in the Ads test set.

Reviewed By: aaronenyeshi

Differential Revision: D49889169

fbshipit-source-id: 968aef2d8c217859ce8c89af2142df64d7a0e8a4
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Oct 4, 2023
1 parent c3d8280 commit 159c010
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions torchbenchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def __init__(
if _is_internal_model(model_path):
model_path = f"{internal_model_dir}.{model_path}"
self._worker = Worker(timeout=timeout, extra_env=extra_env)
self.worker.run("import torch")

self.worker.run("import torch")
self._details: ModelDetails = ModelDetails(
**self._maybe_import_model(
package=__name__,
Expand Down Expand Up @@ -299,10 +299,9 @@ def _maybe_import_model(package: str, model_path: str) -> Dict[str, Any]:
import traceback

model_name = os.path.basename(model_path)
model_dir = os.path.basename(os.path.dirname(model_path))
diagnostic_msg = ""
try:
module = importlib.import_module(f'.{model_dir}.{model_name}', package=package)
module = importlib.import_module(f'.models.{model_name}', package=package)
if accelerator_backend := os.getenv("ACCELERATOR_BACKEND"):
setattr(module, accelerator_backend, importlib.import_module(accelerator_backend))
Model = getattr(module, 'Model', None)
Expand Down

0 comments on commit 159c010

Please sign in to comment.