Skip to content

Commit

Permalink
Add timm and huggingface model suites support (#2197)
Browse files Browse the repository at this point in the history
Summary:
Dynamobench supports extra huggingface and timm models beyond the existing model set in TorchBench.
This PR will add support to those models as well, and they can be invoked with `run.py` or in the `group_bench` userbenchmarks.

Pull Request resolved: #2197

Test Plan:
TIMM model example:
```
$ python run.py convit_base -d cpu -t eval
Running eval method from convit_base on cpu in eager mode with input batch size 64 and precision fp32.
CPU Wall Time per batch: 4419.601 milliseconds
CPU Wall Time:       4419.601 milliseconds
Time to first batch:         2034.6840 ms
CPU Peak Memory:                0.6162 GB
```

```
$ python run.py convit_base -d cpu -t train
Running train method from convit_base on cpu in eager mode with input batch size 64 and precision fp32.
CPU Wall Time per batch: 17044.825 milliseconds
CPU Wall Time:       17044.825 milliseconds
Time to first batch:         1616.9790 ms
CPU Peak Memory:                7.3408 GB
```

Huggingface model example:
```
python run.py MBartForCausalLM -d cuda -t train
Running train method from MBartForCausalLM on cuda in eager mode with input batch size 4 and precision fp32.
GPU Time per batch:  839.994 milliseconds
CPU Wall Time per batch: 842.323 milliseconds
CPU Wall Time:       842.323 milliseconds
Time to first batch:         5390.2949 ms
GPU 0 Peak Memory:             19.7418 GB
CPU Peak Memory:                0.9121 GB
```

Fixes #2170

Reviewed By: HDCharles

Differential Revision: D54953131

Pulled By: xuzhao9

fbshipit-source-id: e63e5d5ed7fc36e4500439fbc8d6a7825b7514bf
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Mar 20, 2024
1 parent 91a1d32 commit 2196021
Show file tree
Hide file tree
Showing 23 changed files with 931 additions and 198 deletions.
33 changes: 6 additions & 27 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
"""
import argparse
import time

import traceback
from functools import partial

import numpy as np
Expand All @@ -23,6 +21,7 @@
load_model_by_name,
ModelNotFoundError,
)
from torchbenchmark.util.experiment.instantiator import load_model, TorchBenchModelConfig
from torchbenchmark.util.experiment.metrics import get_model_flops, get_peak_memory


Expand Down Expand Up @@ -498,42 +497,22 @@ def main() -> None:
# Log the tool usage
usage_report_logger()

found = False
Model = None

try:
Model = load_model_by_name(args.model)
except ModuleNotFoundError:
traceback.print_exc()
exit(-1)
except ModelNotFoundError:
print(f"Warning: The model {args.model} cannot be found at core set.")
if not Model:
try:
Model = load_canary_model_by_name(args.model)
except ModuleNotFoundError:
traceback.print_exc()
exit(-1)
except ModelNotFoundError:
print(
f"Error: The model {args.model} cannot be found at either core or canary model set."
)
exit(-1)

m = Model(
device=args.device,
config = TorchBenchModelConfig(
name=args.model,
test=args.test,
device=args.device,
batch_size=args.bs,
extra_args=extra_args,
)
m = load_model(config)
if m.dynamo:
mode = f"dynamo {m.opt_args.torchdynamo}"
elif m.opt_args.backend:
mode = f"{m.opt_args.backend}"
else:
mode = "eager"
print(
f"Running {args.test} method from {Model.name} on {args.device} in {mode} mode with input batch size {m.batch_size} and precision {m.dargs.precision}."
f"Running {args.test} method from {m.name} on {args.device} in {mode} mode with input batch size {m.batch_size} and precision {m.dargs.precision}."
)
if "--accuracy" in extra_args:
print("{:<20} {:>20}".format("Accuracy: ", str(m.accuracy)), sep="")
Expand Down
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def example_fn(self):
task.del_model_instance()
except NotImplementedError as e:
self.skipTest(
f'Method `get_module()` on {device} is not implemented because "{e}", skipping...'
f'Accuracy check on {device} is not implemented because "{e}", skipping...'
)

def train_fn(self):
Expand Down
44 changes: 32 additions & 12 deletions torchbenchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,26 +667,46 @@ def list_models(model_match=None):
return models


def load_model_by_name(model):
def load_model_by_name(model_name: str):
models = filter(
lambda x: model.lower() == x.lower(),
lambda x: model_name.lower() == x.lower(),
map(lambda y: os.path.basename(y), _list_model_paths()),
)
models = list(models)
cls_name = "Model"
if not models:
raise ModelNotFoundError(f"{model} is not found in the core model list.")
# If the model is in TIMM or Huggingface extended model list
from torchbenchmark.util.framework.huggingface.extended_configs import (
list_extended_huggingface_models
)
from torchbenchmark.util.framework.timm.extended_configs import (
list_extended_timm_models
)
if model_name in list_extended_huggingface_models():
cls_name = "ExtendedHuggingFaceModel"
module_path = ".util.framework.huggingface.model_factory"
models.append(model_name)
elif model_name in list_extended_timm_models():
cls_name = "ExtendedTimmModel"
module_path = ".util.framework.timm.model_factory"
models.append(model_name)
else:
raise ModelNotFoundError(f"{model_name} is not found in the core model list.")
else:
model_name = models[0]
model_pkg = (
model_name
if not _is_internal_model(model_name)
else f"{internal_model_dir}.{model_name}"
)
module_path = f".models.{model_pkg}"
assert (
len(models) == 1
), f"Found more than one models {models} with the exact name: {model}"
model_name = models[0]
model_pkg = (
model_name
if not _is_internal_model(model_name)
else f"{internal_model_dir}.{model_name}"
)
module = importlib.import_module(f".models.{model_pkg}", package=__name__)
), f"Found more than one models {models} with the exact name: {model_name}"

module = importlib.import_module(module_path, package=__name__)

Model = getattr(module, "Model", None)
Model = getattr(module, cls_name, None)
if Model is None:
print(f"Warning: {module} does not define attribute Model, skip it")
return None
Expand Down
2 changes: 1 addition & 1 deletion torchbenchmark/canary_models/hf_MPT_7b_instruct/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ def pip_install_requirements():
pip_install_requirements()
patch_transformers()
model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
cache_model(model_name, trust_remote_code=True)
cache_model(model_name)
2 changes: 1 addition & 1 deletion torchbenchmark/canary_models/hf_Yi/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ def pip_install_requirements():
pip_install_requirements()
patch_transformers()
model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
cache_model(model_name, trust_remote_code=True)
cache_model(model_name)
2 changes: 1 addition & 1 deletion torchbenchmark/canary_models/hf_mixtral/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ def pip_install_requirements():
pip_install_requirements()
patch_transformers()
model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
cache_model(model_name, trust_remote_code=True)
cache_model(model_name)
2 changes: 1 addition & 1 deletion torchbenchmark/canary_models/phi_1_5/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ def pip_install_requirements():
pip_install_requirements()
patch_transformers()
model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
cache_model(model_name, trust_remote_code=True)
cache_model(model_name)
2 changes: 1 addition & 1 deletion torchbenchmark/canary_models/phi_2/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ def pip_install_requirements():
pip_install_requirements()
patch_transformers()
model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
cache_model(model_name, trust_remote_code=True)
cache_model(model_name)
5 changes: 4 additions & 1 deletion torchbenchmark/models/hf_Whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):

def train(self):
raise NotImplementedError("Training is not implemented.")


def get_module(self):
return self.model, (self.example_inputs["input_ids"], )

def eval(self):
self.model.eval()
with torch.no_grad():
Expand Down
5 changes: 4 additions & 1 deletion torchbenchmark/models/hf_distil_whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
self.example_inputs = {"input_features": self.input_features.to(self.device), "input_ids" : self.input_features.to(self.device)}
self.model.to(self.device)

def get_module(self):
return self.model, (self.example_inputs["input_ids"], )

def train(self):
raise NotImplementedError("Training is not implemented.")

def eval(self):
self.model.eval()
with torch.no_grad():
Expand Down
12 changes: 12 additions & 0 deletions torchbenchmark/util/experiment/instantiator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,15 @@ def list_models() -> List[str]:
model_paths = _list_model_paths()
model_names = list(map(lambda x: os.path.basename(x), model_paths))
return model_names

def list_extended_models(suite_name: str="all") -> List[str]:
from torchbenchmark.util.framework.huggingface.extended_configs import list_extended_huggingface_models
from torchbenchmark.util.framework.timm.extended_configs import list_extended_timm_models
if suite_name == "huggingface":
return list_extended_huggingface_models()
elif suite_name == "timm":
return list_extended_timm_models()
elif suite_name == "all":
return list_extended_huggingface_models() + list_extended_timm_models()
else:
assert False, "Currently, we only support extended model set huggingface or timm."
149 changes: 149 additions & 0 deletions torchbenchmark/util/framework/huggingface/basic_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import transformers
import os
import re
import torch
from typing import List

HUGGINGFACE_MODELS = {
# 'name': (train_max_length, eval_max_length, config, model)
'hf_GPT2': (512, 1024, 'AutoConfig.from_pretrained("gpt2")', 'AutoModelForCausalLM'),
'hf_GPT2_large': (512, 1024, 'AutoConfig.from_pretrained("gpt2-large")', 'AutoModelForCausalLM'),
'hf_T5': (1024, 2048, 'AutoConfig.from_pretrained("t5-small")', 'AutoModelForSeq2SeqLM'),
'hf_T5_base': (1024, 2048, 'AutoConfig.from_pretrained("t5-base")', 'AutoModelForSeq2SeqLM'),
'hf_T5_large': (512, 512, 'AutoConfig.from_pretrained("t5-large")', 'AutoModelForSeq2SeqLM'),
'hf_Bart': (512, 512, 'AutoConfig.from_pretrained("facebook/bart-base")', 'AutoModelForSeq2SeqLM'),
'hf_Reformer': (4096, 4096, 'ReformerConfig(num_buckets=128)', 'AutoModelForMaskedLM'),
'hf_BigBird': (1024, 4096, 'BigBirdConfig(attention_type="block_sparse",)', 'AutoModelForMaskedLM'),
'hf_Albert': (512, 512, 'AutoConfig.from_pretrained("albert-base-v2")', 'AutoModelForMaskedLM'),
'hf_DistilBert': (512, 512, 'AutoConfig.from_pretrained("distilbert-base-uncased")', 'AutoModelForMaskedLM'),
'hf_Longformer': (1024, 4096, 'AutoConfig.from_pretrained("allenai/longformer-base-4096")', 'AutoModelForMaskedLM'),
'hf_Bert': (512, 512, 'BertConfig()', 'AutoModelForMaskedLM'),
# see https://huggingface.co/bert-large-cased
'hf_Bert_large': (512, 512, 'BertConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16)', 'AutoModelForMaskedLM'),
'hf_Whisper': (1024, 1024, 'WhisperConfig()', 'AutoModelForAudioClassification'),
'hf_distil_whisper': (1024, 1024, 'AutoConfig.from_pretrained("distil-whisper/distil-medium.en")', 'AutoModelForAudioClassification'),
'hf_mixtral' : (512,512, 'AutoConfig.from_pretrained("mistralai/Mixtral-8x7B-v0.1")', 'AutoModelForCausalLM'),
# default num_hidden_layers=32 but that OOMs, feel free to change this config to something more real
'llama_v2_7b_16h' : (128,512, 'LlamaConfig(num_hidden_layers=16)', 'AutoModelForCausalLM'),
'hf_MPT_7b_instruct': (512, 512, 'AutoConfig.from_pretrained("mosaicml/mpt-7b-instruct", trust_remote_code=True)', 'AutoModelForCausalLM'),
'llava' : (512,512, 'AutoConfig.from_pretrained("liuhaotian/llava-v1.5-13b")', 'LlavaForConditionalGeneration'),
'llama_v2_7b' : (512,512, 'AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")', 'AutoModelForCausalLM'),
'llama_v2_13b' : (512,512, 'AutoConfig.from_pretrained("meta-llama/Llama-2-13b-hf")', 'AutoModelForCausalLM'),
'llama_v2_70b' : (512, 512, 'AutoConfig.from_pretrained("meta-llama/Llama-2-70b-hf")', 'AutoModelForMaskedLM'),
'codellama' : (512,512, 'AutoConfig.from_pretrained("codellama/CodeLlama-7b-hf")', 'AutoModelForCausalLM'),
'phi_1_5' : (512, 512, 'AutoConfig.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)', 'AutoModelForCausalLM'),
'phi_2' : (512, 512, 'AutoConfig.from_pretrained("microsoft/phi-2", trust_remote_code=True)', 'AutoModelForCausalLM'),
'moondream' : (512, 512, 'PhiConfig.from_pretrained("vikhyatk/moondream1")', 'PhiForCausalLM'),
# as per this page https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 trust_remote_code=True is not required
'mistral_7b_instruct' : (128, 128, 'AutoConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")', 'AutoModelForCausalLM'),
'hf_Yi' : (512, 512, 'AutoConfig.from_pretrained("01-ai/Yi-6B", trust_remote_code=True)', 'AutoModelForCausalLM'),
'orca_2' : (512, 512, 'AutoConfig.from_pretrained("microsoft/Orca-2-13b")', 'AutoModelForCausalLM'),
}

CPU_INPUT_SLICE = {
'hf_BigBird': 5,
'hf_Longformer': 8,
'hf_T5': 4,
'hf_GPT2': 4,
'hf_Reformer': 2,
}

HUGGINGFACE_MODELS_REQUIRING_TRUST_REMOTE_CODE = [
"hf_Falcon_7b",
"hf_MPT_7b_instruct",
"phi_1_5",
"phi_2",
"hf_Yi",
"hf_mixtral",
]

HUGGINGFACE_MODELS_SGD_OPTIMIZER = [
"llama_v2_7b_16h",
]


def is_basic_huggingface_models(model_name: str) -> bool:
return model_name in HUGGINGFACE_MODELS

def list_basic_huggingface_models() -> List[str]:
return HUGGINGFACE_MODELS.keys()

def generate_inputs_for_model(
model_cls, model, model_name, bs, device, is_training=False,
):
if is_training:
max_length = HUGGINGFACE_MODELS[model_name][0]
else:
max_length = HUGGINGFACE_MODELS[model_name][1]
# populate these on-demand to avoid wasting memory when not used
if is_training:
input_ids = torch.randint(0, model.config.vocab_size, (bs, max_length)).to(device)
decoder_ids = torch.randint(0, model.config.vocab_size, (bs, max_length)).to(device)
example_inputs = {'input_ids': input_ids, 'labels': decoder_ids}
else:
# Cut the length of sentence when running on CPU, to reduce test time
if device == "cpu" and model_name in CPU_INPUT_SLICE:
max_length = int(max_length / CPU_INPUT_SLICE[model_name])
eval_context = torch.randint(0, model.config.vocab_size, (bs, max_length)).to(device)
example_inputs = {'input_ids': eval_context, }
if model_cls.__name__ in [
"AutoModelForSeq2SeqLM"
]:
example_inputs['decoder_input_ids'] = eval_context
return example_inputs

def generate_input_iter_for_model(
model_cls, model, model_name, bs, device, is_training=False,
):
import math
import random
nbuckets = 8
if is_training:
max_length = HUGGINGFACE_MODELS[model_name][0]
else:
max_length = HUGGINGFACE_MODELS[model_name][1]
n = int(math.log2(max_length))
buckets = [2**n for n in range(n - nbuckets, n)]
if model_cls.__name__ == 'AutoModelForSeq2SeqLM':
raise NotImplementedError("AutoModelForSeq2SeqLM is not yet supported")
while True:
# randomize bucket_len
bucket_len = random.choice(buckets)
dict_input = {
'input_ids': torch.randint(0, model.config.vocab_size, (bs, bucket_len)).to(device),
'labels': torch.randint(0, model.config.vocab_size, (bs, bucket_len)).to(device),
}
yield dict_input

def download_model(model_name):
def _extract_config_cls_name(config_cls_ctor: str) -> str:
"""Extract the class name from the given string of config object creation.
For example,
if the constructor runs like `AutoConfig.from_pretrained("gpt2")`, return "AutoConfig".
if the constructor runs like `LlamaConfig(num_hidden_layers=16)`, return "LlamaConfig"."""
pattern = r'([A-Za-z0-9_]*)[\(\.].*'
m = re.match(pattern, config_cls_ctor)
return m.groups()[0]
config_cls_name = _extract_config_cls_name(HUGGINGFACE_MODELS[model_name][2])
exec(f"from transformers import {config_cls_name}")
config = eval(HUGGINGFACE_MODELS[model_name][2])
model_cls = getattr(transformers, HUGGINGFACE_MODELS[model_name][3])
kwargs = {}
if model_name in HUGGINGFACE_MODELS_REQUIRING_TRUST_REMOTE_CODE:
kwargs["trust_remote_code"] = True
if hasattr(model_cls, "from_config"):
model = model_cls.from_config(config, **kwargs)
else:
model = model_cls(config, **kwargs)
return model_cls, model

def generate_optimizer_for_model(model, model_name):
from torch import optim
if model_name in HUGGINGFACE_MODELS_SGD_OPTIMIZER:
return optim.SGD(model.parameters(), lr= 0.001)
return optim.Adam(
model.parameters(),
lr=0.001,
# TODO resolve https://github.com/pytorch/torchdynamo/issues/1083
capturable=bool(int(os.getenv("ADAM_CAPTURABLE", 0)
)))
Loading

0 comments on commit 2196021

Please sign in to comment.