Skip to content

Commit

Permalink
Enable timm and huggingface models internally
Browse files Browse the repository at this point in the history
Summary: To generate a list of OSS models available internally, we need to enable the timm and huggingface models.

Reviewed By: aaronenyeshi

Differential Revision: D56584318

fbshipit-source-id: 018d8a621b7569d02b33ab5aae9550bb4b8b6f2d
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Apr 27, 2024
1 parent d2e6075 commit fc0c752
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 72 deletions.
54 changes: 10 additions & 44 deletions torchbenchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def dir_contains_file(dir, file_name) -> bool:
return file_name in names


def _list_model_paths() -> List[str]:
def _list_model_paths(internal=True) -> List[str]:
p = pathlib.Path(__file__).parent.joinpath(model_dir)
# Only load the model directories that contain a "__init.py__" file
models = sorted(
Expand All @@ -111,7 +111,7 @@ def _list_model_paths() -> List[str]:
and dir_contains_file(child, "__init__.py")
)
p = p.joinpath(internal_model_dir)
if p.exists():
if p.exists() and internal:
m = sorted(
str(child.absolute())
for child in p.iterdir()
Expand Down Expand Up @@ -332,29 +332,11 @@ def _maybe_import_model(package: str, model_path: str) -> Dict[str, Any]:
import importlib
import os
import traceback
from torchbenchmark import load_model_by_name

model_name = os.path.basename(model_path)
diagnostic_msg = ""
try:
module = importlib.import_module(f".models.{model_name}", package=package)
if accelerator_backend := os.getenv("ACCELERATOR_BACKEND"):
setattr(
module,
accelerator_backend,
importlib.import_module(accelerator_backend),
)
Model = getattr(module, "Model", None)
if Model is None:
diagnostic_msg = (
f"Warning: {module} does not define attribute Model, skip it"
)

elif not hasattr(Model, "name"):
Model.name = model_name

except ModuleNotFoundError as e:
traceback.print_exc()
exit(-1)
Model = load_model_by_name(model_name)

# Populate global namespace so subsequent calls to worker.run can access `Model`
globals()["Model"] = Model
Expand Down Expand Up @@ -401,27 +383,6 @@ def make_model_instance(
}
)

# =========================================================================
# == Replace the `invoke()` function in `model` instance ==================
# =========================================================================
@base_task.run_in_worker(scoped=True)
@staticmethod
def replace_invoke(module_name: str, func_name: str) -> None:
import importlib

# import function from pkg
model = globals()["model"]
try:
module = importlib.import_module(module_name)
inject_func = getattr(module, func_name, None)
if inject_func is None:
diagnostic_msg = (
f"Warning: {module} does not define attribute {func_name}, skip it"
)
except ModuleNotFoundError as e:
diagnostic_msg = f"Warning: Could not find dependent module {e.name} for Model {model.name}, skip it"
model.invoke = inject_func.__get__(model)

# =========================================================================
# == Get Model attribute in the child process =============================
# =========================================================================
Expand Down Expand Up @@ -706,7 +667,12 @@ def load_model_by_name(model_name: str):
), f"Found more than one models {models} with the exact name: {model_name}"

module = importlib.import_module(module_path, package=__name__)

if accelerator_backend := os.getenv("ACCELERATOR_BACKEND"):
setattr(
module,
accelerator_backend,
importlib.import_module(accelerator_backend),
)
Model = getattr(module, cls_name, None)
if Model is None:
print(f"Warning: {module} does not define attribute Model, skip it")
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/util/experiment/instantiator.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ def list_tests() -> List[str]:
return ["train", "eval"]


def list_models() -> List[str]:
def list_models(internal=True) -> List[str]:
"""Return a list of names of all TorchBench models"""
model_paths = _list_model_paths()
model_paths = _list_model_paths(internal=internal)
model_names = list(map(lambda x: os.path.basename(x), model_paths))
return model_names

Expand Down
21 changes: 12 additions & 9 deletions torchbenchmark/util/framework/huggingface/extended_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@
# Only load the extended models in OSS
if hasattr(torch.version, "git_version"):
MODELS_FILENAME = os.path.join(DYNAMOBENCH_PATH, "huggingface_models_list.txt")
assert os.path.exists(MODELS_FILENAME)
with open(MODELS_FILENAME, "r") as fh:
lines = fh.readlines()
lines = [line.rstrip() for line in lines]
for line in lines:
model_name, batch_size = line.split(",")
batch_size = int(batch_size)
BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size
assert len(BATCH_SIZE_KNOWN_MODELS)
else:
from libfb.py import parutil
MODELS_FILENAME = parutil.get_file_path("caffe2/benchmarks/dynamo/huggingface_models_list.txt")
assert os.path.exists(MODELS_FILENAME)
with open(MODELS_FILENAME, "r") as fh:
lines = fh.readlines()
lines = [line.rstrip() for line in lines]
for line in lines:
model_name, batch_size = line.split(",")
batch_size = int(batch_size)
BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size
assert len(BATCH_SIZE_KNOWN_MODELS)


def is_extended_huggingface_models(model_name: str) -> bool:
Expand Down
18 changes: 11 additions & 7 deletions torchbenchmark/util/framework/timm/extended_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
TIMM_MODELS = dict()
# Only load the extended models in OSS
if hasattr(torch.version, "git_version"):
filename = os.path.join(DYNAMOBENCH_PATH, "timm_models_list.txt")
with open(filename) as fh:
lines = fh.readlines()
lines = [line.rstrip() for line in lines]
for line in lines:
model_name, batch_size = line.split(" ")
TIMM_MODELS[model_name] = int(batch_size)
MODELS_FILENAME = os.path.join(DYNAMOBENCH_PATH, "timm_models_list.txt")
else:
from libfb.py import parutil
MODELS_FILENAME = parutil.get_file_path("caffe2/benchmarks/dynamo/timm_models_list.txt")
assert os.path.exists(MODELS_FILENAME)
with open(MODELS_FILENAME) as fh:
lines = fh.readlines()
lines = [line.rstrip() for line in lines]
for line in lines:
model_name, batch_size = line.split(" ")
TIMM_MODELS[model_name] = int(batch_size)


def is_extended_timm_models(model_name: str) -> bool:
Expand Down
44 changes: 34 additions & 10 deletions userbenchmark/test_bench/run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Run PyTorch nightly benchmarking.
Run TorchBench test benchmarking.
"""
import argparse
import itertools
Expand All @@ -12,7 +12,7 @@
import ast
import numpy

from typing import List, Dict, Optional, Union
from typing import List, Dict, Optional, Union, Set
from ..utils import (
REPO_PATH,
add_path,
Expand Down Expand Up @@ -85,8 +85,6 @@ def generate_model_configs(
extra_args: List[str],
) -> List[TorchBenchModelConfig]:
"""Use the default batch size and default mode."""
if not model_names:
model_names = list_models()
cfgs = itertools.product(*[devices, tests, batch_sizes, model_names])
result = [
TorchBenchModelConfig(
Expand Down Expand Up @@ -133,7 +131,7 @@ def get_metrics(config: TorchBenchModelConfig) -> List[str]:
return ["latencies", "cpu_peak_mem", "gpu_peak_mem"]


def validate(candidates: List[str], choices: List[str]) -> List[str]:
def validate(candidates: List[str], choices: Union[Set[str], List[str]]) -> List[str]:
"""Validate the candidates provided by the user is valid"""
for candidate in candidates:
assert (
Expand Down Expand Up @@ -272,13 +270,14 @@ def parse_known_args(args):
parser.add_argument(
"--models",
"-m",
nargs="*",
default=None,
help="Name of models to run, split by comma.",
)
parser.add_argument(
"--device",
"-d",
default=default_device,
choices=list_devices(),
help="Devices to run, splited by comma.",
)
parser.add_argument(
Expand All @@ -293,6 +292,18 @@ def parse_known_args(args):
parser.add_argument(
"--run-bisect", help="Run with the output of regression detector."
)
parser.add_argument(
"--oss", action="store_true", help="[Meta-Internal Only] Run only the oss models."
)
parser.add_argument(
"--timm", action="store_true", help="Run with extended timm models."
)
parser.add_argument(
"--huggingface", action="store_true", help="Run with extended huggingface models."
)
parser.add_argument(
"--all", action="store_true", help="Run with all available models."
)
parser.add_argument("--dryrun", action="store_true", help="Dryrun the command.")
parser.add_argument(
"--output",
Expand All @@ -308,14 +319,27 @@ def run(args: List[str]):
if args.run_bisect:
configs = generate_model_configs_from_bisect_yaml(args.run_bisect)
else:
# If not specified, use the entire model + extended model set
modelset = list_models() + list_extended_models()
modelset = set(list_models(internal=(not args.oss)))
timm_set = set(list_extended_models(suite_name="timm"))
huggingface_set = set(list_extended_models(suite_name="huggingface"))
modelset = modelset.union(timm_set).union(huggingface_set)
if not args.models:
args.models = modelset
args.models = []
args.models = parse_str_to_list(args.models)
if args.timm:
args.models.extend(timm_set)
if args.huggingface:
args.models.extend(huggingface_set)
if args.all:
args.models.extend(modelset)
# If nothing is specified, run all built-in models by default.
if not args.models:
args.models = set(list_models(internal=(not args.oss)))

devices = validate(parse_str_to_list(args.device), list_devices())
tests = validate(parse_str_to_list(args.test), list_tests())
batch_sizes = parse_str_to_list(args.bs)
models = validate(parse_str_to_list(args.models), modelset)
models = validate(args.models, modelset)
configs = generate_model_configs(
devices, tests, batch_sizes, model_names=models, extra_args=extra_args
)
Expand Down

0 comments on commit fc0c752

Please sign in to comment.