From 2196021e9bc0b72a547121bbf298ae854a85a21a Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 20 Mar 2024 08:56:20 -0700 Subject: [PATCH] Add timm and huggingface model suites support (#2197) Summary: Dynamobench supports extra huggingface and timm models beyond the existing model set in TorchBench. This PR will add support to those models as well, and they can be invoked with `run.py` or in the `group_bench` userbenchmarks. Pull Request resolved: https://github.com/pytorch/benchmark/pull/2197 Test Plan: TIMM model example: ``` $ python run.py convit_base -d cpu -t eval Running eval method from convit_base on cpu in eager mode with input batch size 64 and precision fp32. CPU Wall Time per batch: 4419.601 milliseconds CPU Wall Time: 4419.601 milliseconds Time to first batch: 2034.6840 ms CPU Peak Memory: 0.6162 GB ``` ``` $ python run.py convit_base -d cpu -t train Running train method from convit_base on cpu in eager mode with input batch size 64 and precision fp32. CPU Wall Time per batch: 17044.825 milliseconds CPU Wall Time: 17044.825 milliseconds Time to first batch: 1616.9790 ms CPU Peak Memory: 7.3408 GB ``` Huggingface model example: ``` python run.py MBartForCausalLM -d cuda -t train Running train method from MBartForCausalLM on cuda in eager mode with input batch size 4 and precision fp32. GPU Time per batch: 839.994 milliseconds CPU Wall Time per batch: 842.323 milliseconds CPU Wall Time: 842.323 milliseconds Time to first batch: 5390.2949 ms GPU 0 Peak Memory: 19.7418 GB CPU Peak Memory: 0.9121 GB ``` Fixes https://github.com/pytorch/benchmark/issues/2170 Reviewed By: HDCharles Differential Revision: D54953131 Pulled By: xuzhao9 fbshipit-source-id: e63e5d5ed7fc36e4500439fbc8d6a7825b7514bf --- run.py | 33 +- test.py | 2 +- torchbenchmark/__init__.py | 44 +- .../hf_MPT_7b_instruct/install.py | 2 +- torchbenchmark/canary_models/hf_Yi/install.py | 2 +- .../canary_models/hf_mixtral/install.py | 2 +- .../canary_models/phi_1_5/install.py | 2 +- torchbenchmark/canary_models/phi_2/install.py | 2 +- torchbenchmark/models/hf_Whisper/__init__.py | 5 +- .../models/hf_distil_whisper/__init__.py | 5 +- .../util/experiment/instantiator.py | 12 + .../framework/huggingface/basic_configs.py | 149 ++++++ .../framework/huggingface/extended_configs.py | 498 ++++++++++++++++++ .../framework/huggingface/model_factory.py | 196 +++---- .../util/framework/huggingface/patch_hf.py | 15 +- .../util/framework/timm/extended_configs.py | 82 +++ .../util/framework/timm/model_factory.py | 39 +- torchbenchmark/util/model.py | 3 +- userbenchmark/dynamo/__init__.py | 3 + userbenchmark/dynamo/run.py | 5 +- .../group_bench/configs/torch_ao.yaml | 3 + userbenchmark/group_bench/run.py | 16 +- userbenchmark/test_bench/run.py | 9 +- 23 files changed, 931 insertions(+), 198 deletions(-) create mode 100644 torchbenchmark/util/framework/huggingface/basic_configs.py create mode 100644 torchbenchmark/util/framework/huggingface/extended_configs.py create mode 100644 torchbenchmark/util/framework/timm/extended_configs.py diff --git a/run.py b/run.py index 1d0fb6dbb0..a89b31de84 100644 --- a/run.py +++ b/run.py @@ -10,8 +10,6 @@ """ import argparse import time - -import traceback from functools import partial import numpy as np @@ -23,6 +21,7 @@ load_model_by_name, ModelNotFoundError, ) +from torchbenchmark.util.experiment.instantiator import load_model, TorchBenchModelConfig from torchbenchmark.util.experiment.metrics import get_model_flops, get_peak_memory @@ -498,34 +497,14 @@ def main() -> None: # Log the tool usage usage_report_logger() - found = False - Model = None - - try: - Model = load_model_by_name(args.model) - except ModuleNotFoundError: - traceback.print_exc() - exit(-1) - except ModelNotFoundError: - print(f"Warning: The model {args.model} cannot be found at core set.") - if not Model: - try: - Model = load_canary_model_by_name(args.model) - except ModuleNotFoundError: - traceback.print_exc() - exit(-1) - except ModelNotFoundError: - print( - f"Error: The model {args.model} cannot be found at either core or canary model set." - ) - exit(-1) - - m = Model( - device=args.device, + config = TorchBenchModelConfig( + name=args.model, test=args.test, + device=args.device, batch_size=args.bs, extra_args=extra_args, ) + m = load_model(config) if m.dynamo: mode = f"dynamo {m.opt_args.torchdynamo}" elif m.opt_args.backend: @@ -533,7 +512,7 @@ def main() -> None: else: mode = "eager" print( - f"Running {args.test} method from {Model.name} on {args.device} in {mode} mode with input batch size {m.batch_size} and precision {m.dargs.precision}." + f"Running {args.test} method from {m.name} on {args.device} in {mode} mode with input batch size {m.batch_size} and precision {m.dargs.precision}." ) if "--accuracy" in extra_args: print("{:<20} {:>20}".format("Accuracy: ", str(m.accuracy)), sep="") diff --git a/test.py b/test.py index ecac74cc94..65991df875 100644 --- a/test.py +++ b/test.py @@ -77,7 +77,7 @@ def example_fn(self): task.del_model_instance() except NotImplementedError as e: self.skipTest( - f'Method `get_module()` on {device} is not implemented because "{e}", skipping...' + f'Accuracy check on {device} is not implemented because "{e}", skipping...' ) def train_fn(self): diff --git a/torchbenchmark/__init__.py b/torchbenchmark/__init__.py index 5156e81ae6..32879a0b46 100644 --- a/torchbenchmark/__init__.py +++ b/torchbenchmark/__init__.py @@ -667,26 +667,46 @@ def list_models(model_match=None): return models -def load_model_by_name(model): +def load_model_by_name(model_name: str): models = filter( - lambda x: model.lower() == x.lower(), + lambda x: model_name.lower() == x.lower(), map(lambda y: os.path.basename(y), _list_model_paths()), ) models = list(models) + cls_name = "Model" if not models: - raise ModelNotFoundError(f"{model} is not found in the core model list.") + # If the model is in TIMM or Huggingface extended model list + from torchbenchmark.util.framework.huggingface.extended_configs import ( + list_extended_huggingface_models + ) + from torchbenchmark.util.framework.timm.extended_configs import ( + list_extended_timm_models + ) + if model_name in list_extended_huggingface_models(): + cls_name = "ExtendedHuggingFaceModel" + module_path = ".util.framework.huggingface.model_factory" + models.append(model_name) + elif model_name in list_extended_timm_models(): + cls_name = "ExtendedTimmModel" + module_path = ".util.framework.timm.model_factory" + models.append(model_name) + else: + raise ModelNotFoundError(f"{model_name} is not found in the core model list.") + else: + model_name = models[0] + model_pkg = ( + model_name + if not _is_internal_model(model_name) + else f"{internal_model_dir}.{model_name}" + ) + module_path = f".models.{model_pkg}" assert ( len(models) == 1 - ), f"Found more than one models {models} with the exact name: {model}" - model_name = models[0] - model_pkg = ( - model_name - if not _is_internal_model(model_name) - else f"{internal_model_dir}.{model_name}" - ) - module = importlib.import_module(f".models.{model_pkg}", package=__name__) + ), f"Found more than one models {models} with the exact name: {model_name}" + + module = importlib.import_module(module_path, package=__name__) - Model = getattr(module, "Model", None) + Model = getattr(module, cls_name, None) if Model is None: print(f"Warning: {module} does not define attribute Model, skip it") return None diff --git a/torchbenchmark/canary_models/hf_MPT_7b_instruct/install.py b/torchbenchmark/canary_models/hf_MPT_7b_instruct/install.py index 64e5b1127e..1a49905932 100644 --- a/torchbenchmark/canary_models/hf_MPT_7b_instruct/install.py +++ b/torchbenchmark/canary_models/hf_MPT_7b_instruct/install.py @@ -10,4 +10,4 @@ def pip_install_requirements(): pip_install_requirements() patch_transformers() model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__))) - cache_model(model_name, trust_remote_code=True) \ No newline at end of file + cache_model(model_name) \ No newline at end of file diff --git a/torchbenchmark/canary_models/hf_Yi/install.py b/torchbenchmark/canary_models/hf_Yi/install.py index 64e5b1127e..1a49905932 100644 --- a/torchbenchmark/canary_models/hf_Yi/install.py +++ b/torchbenchmark/canary_models/hf_Yi/install.py @@ -10,4 +10,4 @@ def pip_install_requirements(): pip_install_requirements() patch_transformers() model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__))) - cache_model(model_name, trust_remote_code=True) \ No newline at end of file + cache_model(model_name) \ No newline at end of file diff --git a/torchbenchmark/canary_models/hf_mixtral/install.py b/torchbenchmark/canary_models/hf_mixtral/install.py index 64e5b1127e..1a49905932 100644 --- a/torchbenchmark/canary_models/hf_mixtral/install.py +++ b/torchbenchmark/canary_models/hf_mixtral/install.py @@ -10,4 +10,4 @@ def pip_install_requirements(): pip_install_requirements() patch_transformers() model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__))) - cache_model(model_name, trust_remote_code=True) \ No newline at end of file + cache_model(model_name) \ No newline at end of file diff --git a/torchbenchmark/canary_models/phi_1_5/install.py b/torchbenchmark/canary_models/phi_1_5/install.py index 64e5b1127e..1a49905932 100644 --- a/torchbenchmark/canary_models/phi_1_5/install.py +++ b/torchbenchmark/canary_models/phi_1_5/install.py @@ -10,4 +10,4 @@ def pip_install_requirements(): pip_install_requirements() patch_transformers() model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__))) - cache_model(model_name, trust_remote_code=True) \ No newline at end of file + cache_model(model_name) \ No newline at end of file diff --git a/torchbenchmark/canary_models/phi_2/install.py b/torchbenchmark/canary_models/phi_2/install.py index 64e5b1127e..1a49905932 100644 --- a/torchbenchmark/canary_models/phi_2/install.py +++ b/torchbenchmark/canary_models/phi_2/install.py @@ -10,4 +10,4 @@ def pip_install_requirements(): pip_install_requirements() patch_transformers() model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__))) - cache_model(model_name, trust_remote_code=True) \ No newline at end of file + cache_model(model_name) \ No newline at end of file diff --git a/torchbenchmark/models/hf_Whisper/__init__.py b/torchbenchmark/models/hf_Whisper/__init__.py index b77d9ae5d6..e5a2c42abe 100644 --- a/torchbenchmark/models/hf_Whisper/__init__.py +++ b/torchbenchmark/models/hf_Whisper/__init__.py @@ -18,7 +18,10 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): def train(self): raise NotImplementedError("Training is not implemented.") - + + def get_module(self): + return self.model, (self.example_inputs["input_ids"], ) + def eval(self): self.model.eval() with torch.no_grad(): diff --git a/torchbenchmark/models/hf_distil_whisper/__init__.py b/torchbenchmark/models/hf_distil_whisper/__init__.py index 62be026bfd..1c56a4a91a 100644 --- a/torchbenchmark/models/hf_distil_whisper/__init__.py +++ b/torchbenchmark/models/hf_distil_whisper/__init__.py @@ -17,9 +17,12 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): self.example_inputs = {"input_features": self.input_features.to(self.device), "input_ids" : self.input_features.to(self.device)} self.model.to(self.device) + def get_module(self): + return self.model, (self.example_inputs["input_ids"], ) + def train(self): raise NotImplementedError("Training is not implemented.") - + def eval(self): self.model.eval() with torch.no_grad(): diff --git a/torchbenchmark/util/experiment/instantiator.py b/torchbenchmark/util/experiment/instantiator.py index 34c074e6b4..f86080e876 100644 --- a/torchbenchmark/util/experiment/instantiator.py +++ b/torchbenchmark/util/experiment/instantiator.py @@ -76,3 +76,15 @@ def list_models() -> List[str]: model_paths = _list_model_paths() model_names = list(map(lambda x: os.path.basename(x), model_paths)) return model_names + +def list_extended_models(suite_name: str="all") -> List[str]: + from torchbenchmark.util.framework.huggingface.extended_configs import list_extended_huggingface_models + from torchbenchmark.util.framework.timm.extended_configs import list_extended_timm_models + if suite_name == "huggingface": + return list_extended_huggingface_models() + elif suite_name == "timm": + return list_extended_timm_models() + elif suite_name == "all": + return list_extended_huggingface_models() + list_extended_timm_models() + else: + assert False, "Currently, we only support extended model set huggingface or timm." diff --git a/torchbenchmark/util/framework/huggingface/basic_configs.py b/torchbenchmark/util/framework/huggingface/basic_configs.py new file mode 100644 index 0000000000..61064fb9bc --- /dev/null +++ b/torchbenchmark/util/framework/huggingface/basic_configs.py @@ -0,0 +1,149 @@ +import transformers +import os +import re +import torch +from typing import List + +HUGGINGFACE_MODELS = { + # 'name': (train_max_length, eval_max_length, config, model) + 'hf_GPT2': (512, 1024, 'AutoConfig.from_pretrained("gpt2")', 'AutoModelForCausalLM'), + 'hf_GPT2_large': (512, 1024, 'AutoConfig.from_pretrained("gpt2-large")', 'AutoModelForCausalLM'), + 'hf_T5': (1024, 2048, 'AutoConfig.from_pretrained("t5-small")', 'AutoModelForSeq2SeqLM'), + 'hf_T5_base': (1024, 2048, 'AutoConfig.from_pretrained("t5-base")', 'AutoModelForSeq2SeqLM'), + 'hf_T5_large': (512, 512, 'AutoConfig.from_pretrained("t5-large")', 'AutoModelForSeq2SeqLM'), + 'hf_Bart': (512, 512, 'AutoConfig.from_pretrained("facebook/bart-base")', 'AutoModelForSeq2SeqLM'), + 'hf_Reformer': (4096, 4096, 'ReformerConfig(num_buckets=128)', 'AutoModelForMaskedLM'), + 'hf_BigBird': (1024, 4096, 'BigBirdConfig(attention_type="block_sparse",)', 'AutoModelForMaskedLM'), + 'hf_Albert': (512, 512, 'AutoConfig.from_pretrained("albert-base-v2")', 'AutoModelForMaskedLM'), + 'hf_DistilBert': (512, 512, 'AutoConfig.from_pretrained("distilbert-base-uncased")', 'AutoModelForMaskedLM'), + 'hf_Longformer': (1024, 4096, 'AutoConfig.from_pretrained("allenai/longformer-base-4096")', 'AutoModelForMaskedLM'), + 'hf_Bert': (512, 512, 'BertConfig()', 'AutoModelForMaskedLM'), + # see https://huggingface.co/bert-large-cased + 'hf_Bert_large': (512, 512, 'BertConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16)', 'AutoModelForMaskedLM'), + 'hf_Whisper': (1024, 1024, 'WhisperConfig()', 'AutoModelForAudioClassification'), + 'hf_distil_whisper': (1024, 1024, 'AutoConfig.from_pretrained("distil-whisper/distil-medium.en")', 'AutoModelForAudioClassification'), + 'hf_mixtral' : (512,512, 'AutoConfig.from_pretrained("mistralai/Mixtral-8x7B-v0.1")', 'AutoModelForCausalLM'), + # default num_hidden_layers=32 but that OOMs, feel free to change this config to something more real + 'llama_v2_7b_16h' : (128,512, 'LlamaConfig(num_hidden_layers=16)', 'AutoModelForCausalLM'), + 'hf_MPT_7b_instruct': (512, 512, 'AutoConfig.from_pretrained("mosaicml/mpt-7b-instruct", trust_remote_code=True)', 'AutoModelForCausalLM'), + 'llava' : (512,512, 'AutoConfig.from_pretrained("liuhaotian/llava-v1.5-13b")', 'LlavaForConditionalGeneration'), + 'llama_v2_7b' : (512,512, 'AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")', 'AutoModelForCausalLM'), + 'llama_v2_13b' : (512,512, 'AutoConfig.from_pretrained("meta-llama/Llama-2-13b-hf")', 'AutoModelForCausalLM'), + 'llama_v2_70b' : (512, 512, 'AutoConfig.from_pretrained("meta-llama/Llama-2-70b-hf")', 'AutoModelForMaskedLM'), + 'codellama' : (512,512, 'AutoConfig.from_pretrained("codellama/CodeLlama-7b-hf")', 'AutoModelForCausalLM'), + 'phi_1_5' : (512, 512, 'AutoConfig.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)', 'AutoModelForCausalLM'), + 'phi_2' : (512, 512, 'AutoConfig.from_pretrained("microsoft/phi-2", trust_remote_code=True)', 'AutoModelForCausalLM'), + 'moondream' : (512, 512, 'PhiConfig.from_pretrained("vikhyatk/moondream1")', 'PhiForCausalLM'), + # as per this page https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 trust_remote_code=True is not required + 'mistral_7b_instruct' : (128, 128, 'AutoConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")', 'AutoModelForCausalLM'), + 'hf_Yi' : (512, 512, 'AutoConfig.from_pretrained("01-ai/Yi-6B", trust_remote_code=True)', 'AutoModelForCausalLM'), + 'orca_2' : (512, 512, 'AutoConfig.from_pretrained("microsoft/Orca-2-13b")', 'AutoModelForCausalLM'), +} + +CPU_INPUT_SLICE = { + 'hf_BigBird': 5, + 'hf_Longformer': 8, + 'hf_T5': 4, + 'hf_GPT2': 4, + 'hf_Reformer': 2, +} + +HUGGINGFACE_MODELS_REQUIRING_TRUST_REMOTE_CODE = [ + "hf_Falcon_7b", + "hf_MPT_7b_instruct", + "phi_1_5", + "phi_2", + "hf_Yi", + "hf_mixtral", +] + +HUGGINGFACE_MODELS_SGD_OPTIMIZER = [ + "llama_v2_7b_16h", +] + + +def is_basic_huggingface_models(model_name: str) -> bool: + return model_name in HUGGINGFACE_MODELS + +def list_basic_huggingface_models() -> List[str]: + return HUGGINGFACE_MODELS.keys() + +def generate_inputs_for_model( + model_cls, model, model_name, bs, device, is_training=False, +): + if is_training: + max_length = HUGGINGFACE_MODELS[model_name][0] + else: + max_length = HUGGINGFACE_MODELS[model_name][1] + # populate these on-demand to avoid wasting memory when not used + if is_training: + input_ids = torch.randint(0, model.config.vocab_size, (bs, max_length)).to(device) + decoder_ids = torch.randint(0, model.config.vocab_size, (bs, max_length)).to(device) + example_inputs = {'input_ids': input_ids, 'labels': decoder_ids} + else: + # Cut the length of sentence when running on CPU, to reduce test time + if device == "cpu" and model_name in CPU_INPUT_SLICE: + max_length = int(max_length / CPU_INPUT_SLICE[model_name]) + eval_context = torch.randint(0, model.config.vocab_size, (bs, max_length)).to(device) + example_inputs = {'input_ids': eval_context, } + if model_cls.__name__ in [ + "AutoModelForSeq2SeqLM" + ]: + example_inputs['decoder_input_ids'] = eval_context + return example_inputs + +def generate_input_iter_for_model( + model_cls, model, model_name, bs, device, is_training=False, +): + import math + import random + nbuckets = 8 + if is_training: + max_length = HUGGINGFACE_MODELS[model_name][0] + else: + max_length = HUGGINGFACE_MODELS[model_name][1] + n = int(math.log2(max_length)) + buckets = [2**n for n in range(n - nbuckets, n)] + if model_cls.__name__ == 'AutoModelForSeq2SeqLM': + raise NotImplementedError("AutoModelForSeq2SeqLM is not yet supported") + while True: + # randomize bucket_len + bucket_len = random.choice(buckets) + dict_input = { + 'input_ids': torch.randint(0, model.config.vocab_size, (bs, bucket_len)).to(device), + 'labels': torch.randint(0, model.config.vocab_size, (bs, bucket_len)).to(device), + } + yield dict_input + +def download_model(model_name): + def _extract_config_cls_name(config_cls_ctor: str) -> str: + """Extract the class name from the given string of config object creation. + For example, + if the constructor runs like `AutoConfig.from_pretrained("gpt2")`, return "AutoConfig". + if the constructor runs like `LlamaConfig(num_hidden_layers=16)`, return "LlamaConfig".""" + pattern = r'([A-Za-z0-9_]*)[\(\.].*' + m = re.match(pattern, config_cls_ctor) + return m.groups()[0] + config_cls_name = _extract_config_cls_name(HUGGINGFACE_MODELS[model_name][2]) + exec(f"from transformers import {config_cls_name}") + config = eval(HUGGINGFACE_MODELS[model_name][2]) + model_cls = getattr(transformers, HUGGINGFACE_MODELS[model_name][3]) + kwargs = {} + if model_name in HUGGINGFACE_MODELS_REQUIRING_TRUST_REMOTE_CODE: + kwargs["trust_remote_code"] = True + if hasattr(model_cls, "from_config"): + model = model_cls.from_config(config, **kwargs) + else: + model = model_cls(config, **kwargs) + return model_cls, model + +def generate_optimizer_for_model(model, model_name): + from torch import optim + if model_name in HUGGINGFACE_MODELS_SGD_OPTIMIZER: + return optim.SGD(model.parameters(), lr= 0.001) + return optim.Adam( + model.parameters(), + lr=0.001, + # TODO resolve https://github.com/pytorch/torchdynamo/issues/1083 + capturable=bool(int(os.getenv("ADAM_CAPTURABLE", 0) + ))) diff --git a/torchbenchmark/util/framework/huggingface/extended_configs.py b/torchbenchmark/util/framework/huggingface/extended_configs.py new file mode 100644 index 0000000000..b50bd07dc3 --- /dev/null +++ b/torchbenchmark/util/framework/huggingface/extended_configs.py @@ -0,0 +1,498 @@ +# Extended huggingface model configs from Dynamobench +import logging +import torch +import os +import importlib +from typing import List +from userbenchmark.dynamo import DYNAMOBENCH_PATH + +# These models contain the models present in huggingface_models_list. It is a +# combination of models supported by HF Fx parser and some manually supplied +# models. For these models, we already know the largest batch size that can fit +# on A100 GPUs - 40 GB. +BATCH_SIZE_KNOWN_MODELS = dict() + +# Get the list of models and their batch sizes +# Only load the extended models in OSS +if hasattr(torch.version, "git_version"): + MODELS_FILENAME = os.path.join(DYNAMOBENCH_PATH, "huggingface_models_list.txt") + assert os.path.exists(MODELS_FILENAME) + with open(MODELS_FILENAME, "r") as fh: + lines = fh.readlines() + lines = [line.rstrip() for line in lines] + for line in lines: + model_name, batch_size = line.split(",") + batch_size = int(batch_size) + BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size + assert len(BATCH_SIZE_KNOWN_MODELS) + +def is_extended_huggingface_models(model_name: str) -> bool: + return model_name in BATCH_SIZE_KNOWN_MODELS + +def list_extended_huggingface_models() -> List[str]: + return list(BATCH_SIZE_KNOWN_MODELS.keys()) + +imports = [ + "AlbertForPreTraining", + "AutoConfig", + "AutoModelForCausalLM", + "AutoModelForMaskedLM", + "AutoModelForSeq2SeqLM", + "BigBirdConfig", + "BlenderbotForConditionalGeneration", + "BlenderbotModel", + "BlenderbotSmallForConditionalGeneration", + "BlenderbotSmallModel", + "CLIPModel", + "CLIPVisionModel", + "ElectraForPreTraining", + "GPT2ForSequenceClassification", + "GPTJForSequenceClassification", + "GPTNeoForSequenceClassification", + "HubertForSequenceClassification", + "LxmertForPreTraining", + "LxmertForQuestionAnswering", + "MarianForCausalLM", + "MarianModel", + "MarianMTModel", + "PegasusForConditionalGeneration", + "PegasusModel", + "ReformerConfig", + "ViTForImageClassification", + "ViTForMaskedImageModeling", + "ViTModel", +] + + +mod = importlib.import_module("transformers") +for cls in imports: + exec(f"from transformers import {cls}") + + +log = logging.getLogger(__name__) + +SKIP = { + # Difficult to setup accuracy test because .eval() not supported + "Reformer", + # Fails deepcopy + "BlenderbotForConditionalGeneration", + "GPTNeoForCausalLM", + "GPTNeoForSequenceClassification", + # Fails with even batch size = 1 + "GPTJForCausalLM", + "GPTJForQuestionAnswering", +} + +# These models currently fail accuracy with eager Adam optimizer +# so we use SGD when running the full benchmarks +# https://github.com/pytorch/pytorch/issues/115966 +BENCHMARK_USE_SGD = { + # TorchBench + "BERT_pytorch", + "LearningToPaint", + "alexnet", + "dcgan", + "demucs", + "densenet121", + "dlrm", + "fastNLP_Bert", + "mobilenet_v2", + "phlippe_densenet", + "phlippe_resnet", + "pytorch_stargan", + "resnet18", + "shufflenet_v2_x1_0", + "speech_transformer", + "squeezenet1_1", + "stable_diffusion_text_encoder", + "timm_efficientdet", + "timm_nfnet", + "timm_regnet", + "timm_vision_transformer", + "timm_vovnet", + "vgg16", + "hf_T5", # Fails dynamic https://github.com/pytorch/pytorch/issues/115968 + # HF + "AlbertForMaskedLM", + "BartForCausalLM", + "BartForConditionalGeneration", + "BlenderbotSmallForCausalLM", + "BlenderbotSmallForConditionalGeneration", + "DebertaV2ForQuestionAnswering", # eager OOM + "ElectraForCausalLM", + "M2M100ForConditionalGeneration", + "MBartForCausalLM", + "MBartForConditionalGeneration", + "OPTForCausalLM", + "PLBartForCausalLM", + "PLBartForConditionalGeneration", + "PegasusForCausalLM", + "Speech2Text2ForCausalLM", + "TrOCRForCausalLM", + "XGLMForCausalLM", + # TIMM + "adv_inception_v3", + "botnet26t_256", + "cait_m36_384", # OOM + "coat_lite_mini", + "convit_base", + "dpn107", + "fbnetv3_b", + "gernet_l", + "lcnet_050", + "mixnet_l", + "res2net101_26w_4s", + "res2net50_14w_8s", + "res2next50", + "resnest101e", + "sebotnet33ts_256", + "swsl_resnext101_32x16d", + "tf_efficientnet_b0", + "ghostnet_100", + "gmixer_24_224", + "tinynet_a", +} + +# TODO - Fails even after fake tensors +BATCH_SIZE_DIVISORS = { + "AlbertForMaskedLM": 2, + "AlbertForQuestionAnswering": 2, + "AllenaiLongformerBase": 2, + "BartForCausalLM": 2, + "BartForConditionalGeneration": 2, + "BertForMaskedLM": 2, + "BertForQuestionAnswering": 2, + "BlenderbotForCausalLM": 8, + # "BlenderbotForConditionalGeneration" : 16, + "BlenderbotSmallForCausalLM": 4, + "BlenderbotSmallForConditionalGeneration": 2, + "CamemBert": 2, + "DebertaForMaskedLM": 4, + "DebertaForQuestionAnswering": 2, + "DebertaV2ForMaskedLM": 4, + "DebertaV2ForQuestionAnswering": 8, + "DistilBertForMaskedLM": 2, + "DistilBertForQuestionAnswering": 2, + "DistillGPT2": 2, + "ElectraForCausalLM": 2, + "ElectraForQuestionAnswering": 2, + "GPT2ForSequenceClassification": 2, + # "GPTJForCausalLM" : 2, + # "GPTJForQuestionAnswering" : 2, + # "GPTNeoForCausalLM" : 32, + # "GPTNeoForSequenceClassification" : 2, + "GoogleFnet": 2, + "LayoutLMForMaskedLM": 2, + "LayoutLMForSequenceClassification": 2, + "M2M100ForConditionalGeneration": 4, + "MBartForCausalLM": 2, + "MBartForConditionalGeneration": 2, + "MT5ForConditionalGeneration": 2, + "MegatronBertForCausalLM": 4, + "MegatronBertForQuestionAnswering": 2, + "MobileBertForMaskedLM": 2, + "MobileBertForQuestionAnswering": 2, + "OPTForCausalLM": 2, + "PLBartForCausalLM": 2, + "PLBartForConditionalGeneration": 2, + "PegasusForCausalLM": 4, + "PegasusForConditionalGeneration": 2, + "RobertaForCausalLM": 2, + "RobertaForQuestionAnswering": 2, + "Speech2Text2ForCausalLM": 4, + "T5ForConditionalGeneration": 2, + "T5Small": 2, + "TrOCRForCausalLM": 2, + "XGLMForCausalLM": 4, + "XLNetLMHeadModel": 2, + "YituTechConvBert": 2, +} + +try: + EXTRA_MODELS = { + "AllenaiLongformerBase": ( + AutoConfig.from_pretrained("allenai/longformer-base-4096"), + AutoModelForMaskedLM, + ), + "Reformer": ( + ReformerConfig(), + AutoModelForMaskedLM, + ), + "T5Small": ( + AutoConfig.from_pretrained("t5-small"), + AutoModelForSeq2SeqLM, + ), + # "BigBird": ( + # BigBirdConfig(attention_type="block_sparse"), + # AutoModelForMaskedLM, + # ), + "DistillGPT2": ( + AutoConfig.from_pretrained("distilgpt2"), + AutoModelForCausalLM, + ), + "GoogleFnet": ( + AutoConfig.from_pretrained("google/fnet-base"), + AutoModelForMaskedLM, + ), + "YituTechConvBert": ( + AutoConfig.from_pretrained("YituTech/conv-bert-base"), + AutoModelForMaskedLM, + ), + "CamemBert": ( + AutoConfig.from_pretrained("camembert-base"), + AutoModelForMaskedLM, + ), + } +except OSError: + # Extra models are only available when Internet access is available + EXTRA_MODELS = {} + +SKIP_ACCURACY_CHECK_MODELS = { + # Models too large to have eager, dynamo and fp64_numbers simultaneosuly + # even for 40 GB machine. + "DebertaV2ForMaskedLM", + "BlenderbotForCausalLM", +} + +SKIP_DUE_TO_CONTROL_FLOW = {"AllenaiLongformerBase"} + + +REQUIRE_HIGHER_TOLERANCE_TRAINING = { + "MT5ForConditionalGeneration", + # AlbertForQuestionAnswering fails in CI GCP A100 but error does not seem + # harmful. + "AlbertForQuestionAnswering", +} +REQUIRE_HIGHER_TOLERANCE_INFERENCE = { + "GPT2ForSequenceClassification", + "RobertaForQuestionAnswering", +} + + +SKIP_FOR_CPU = { + "OPTForCausalLM", # OOMs +} + +ONLY_EVAL_MODE = { + "M2M100ForConditionalGeneration", # Fails with dynamo for train mode +} + +FP32_ONLY_MODELS = { + "GoogleFnet", +} + +def get_sequence_length(model_cls, model_name): + if model_name.startswith(("Blenderbot",)): + seq_length = 128 + elif model_name.startswith(("GPT2", "Bart", "T5", "PLBart", "MBart")): + seq_length = 1024 + elif model_name in ("AllenaiLongformerBase", "BigBird"): + seq_length = 1024 + elif model_name.startswith("OPT"): + seq_length = 2048 + elif "Reformer" in model_name: + seq_length = 4096 + elif model_name.startswith( + ( + "Albert", + "Deberta", + "Layout", + "Electra", + "XLNet", + "MegatronBert", + "Bert", + "Roberta", + ) + ) or model_name in ("DistillGPT2", "GoogleFnet", "YituTechConvBert", "CamemBert"): + seq_length = 512 + elif model_name in ("TrOCRForCausalLM"): + seq_length = 256 + elif model_name.startswith("MobileBert"): + seq_length = 128 + elif model_name.startswith("Wav2Vec2"): + # If too short, will fail with something like + # ValueError: `mask_length` has to be smaller than `sequence_length`, + # but got `mask_length`: 10 and `sequence_length`: 9` + seq_length = 10000 # NB: a more realistic size is 155136 + else: + log.info( + f"Sequence Length not defined for {model_name}. Choosing 128 arbitrarily" + ) + seq_length = 128 + return seq_length + + +def rand_int_tensor(device, low, high, shape): + return torch.randint( + low, + high, + shape, + device=device, + dtype=torch.int64, + requires_grad=False, + ) + + +def generate_inputs_for_model( + model_cls, model, model_name, bs, device, include_loss_args=False +): + # TODO - Check if following values are representative + num_choices = 3 + num_visual_features = 42 + seq_length = get_sequence_length(model_cls, model_name) + vocab_size = model.config.vocab_size + + if model_name.startswith("Wav2Vec2"): + # TODO: If we add more input_values style models, try to work this + # into the overall control flow + target_length = 100 + return { + "input_values": torch.randn((bs, seq_length), device=device), + # Added because that's what the example training script has + "attention_mask": rand_int_tensor(device, 0, 2, (bs, seq_length)), + "labels": rand_int_tensor(device, 0, vocab_size, (bs, target_length)), + } + + if model_name.endswith("MultipleChoice"): + input = rand_int_tensor(device, 0, vocab_size, (bs, num_choices, seq_length)) + elif model_name.startswith("Roberta"): + input = rand_int_tensor(device, 0, 1, (bs, seq_length)) + else: + input = rand_int_tensor(device, 0, vocab_size, (bs, seq_length)) + + if "Bart" in model_name: + input[:, -1] = model.config.eos_token_id + + input_dict = {"input_ids": input} + + if ( + model_name.startswith("T5") + or model_name.startswith("M2M100") + or model_name.startswith("MT5") + or model_cls + in [ + BlenderbotModel, + BlenderbotSmallModel, + BlenderbotForConditionalGeneration, + BlenderbotSmallForConditionalGeneration, + PegasusModel, + PegasusForConditionalGeneration, + MarianModel, + MarianMTModel, + ] + ): + input_dict["decoder_input_ids"] = input + + if model_name.startswith("Lxmert"): + visual_feat_dim, visual_pos_dim = ( + model.config.visual_feat_dim, + model.config.visual_pos_dim, + ) + input_dict["visual_feats"] = torch.randn( + bs, num_visual_features, visual_feat_dim + ) + input_dict["visual_pos"] = torch.randn(bs, num_visual_features, visual_pos_dim) + + if include_loss_args: + if model_name.endswith("PreTraining"): + if model_cls in [ElectraForPreTraining, LxmertForPreTraining]: + input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs, seq_length)) + else: + label_name = ( + "sentence_order_label" + if model_cls in [AlbertForPreTraining] + else "next_sentence_label" + ) + input_dict["labels"] = ( + rand_int_tensor(device, 0, vocab_size, (bs, seq_length)), + ) + input_dict[label_name] = rand_int_tensor(device, 0, 1, (bs,)) + elif model_name.endswith("QuestionAnswering"): + input_dict["start_positions"] = rand_int_tensor( + device, 0, seq_length, (bs,) + ) + input_dict["end_positions"] = rand_int_tensor(device, 0, seq_length, (bs,)) + elif ( + model_name.endswith("MaskedLM") + or model_name.endswith("HeadModel") + or model_name.endswith("CausalLM") + or model_name.endswith("DoubleHeadsModel") + ): + input_dict["labels"] = rand_int_tensor( + device, 0, vocab_size, (bs, seq_length) + ) + elif model_name.endswith("TokenClassification"): + input_dict["labels"] = rand_int_tensor( + device, 0, model.config.num_labels - 1, (bs, seq_length) + ) + elif model_name.endswith("MultipleChoice"): + input_dict["labels"] = rand_int_tensor(device, 0, num_choices, (bs,)) + elif model_name.endswith("SequenceClassification"): + input_dict["labels"] = rand_int_tensor( + device, 0, model.config.num_labels - 1, (bs,) + ) + elif model_name.endswith("NextSentencePrediction"): + input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs,)) + elif model_name.endswith("ForConditionalGeneration"): + input_dict["labels"] = rand_int_tensor( + device, 0, vocab_size - 1, (bs, seq_length) + ) + elif model_name in EXTRA_MODELS: + input_dict["labels"] = rand_int_tensor( + device, 0, vocab_size, (bs, seq_length) + ) + else: + raise NotImplementedError( + f"Class {model_name} unsupported for training test " + ) + + return input_dict + +def get_module_cls_by_model_name(model_cls_name): + _module_by_model_name = { + "Speech2Text2Decoder": "transformers.models.speech_to_text_2.modeling_speech_to_text_2", + "TrOCRDecoder": "transformers.models.trocr.modeling_trocr", + } + module_name = _module_by_model_name.get(model_cls_name, "transformers") + module = importlib.import_module(module_name) + return getattr(module, model_cls_name) + +def _get_model_cls_and_config(model_name): + if model_name not in EXTRA_MODELS: + model_cls = get_module_cls_by_model_name(model_name) + config_cls = model_cls.config_class + config = config_cls() + + # NB: some models need a pad token defined to handle BS > 1 + if ( + model_cls + in [ + GPT2ForSequenceClassification, + GPTNeoForSequenceClassification, + GPTJForSequenceClassification, + ] + or model_cls.__name__.startswith("Roberta") + or model_cls.__name__.startswith("Marian") + ): + config.pad_token_id = 0 + + else: + config, model_cls = EXTRA_MODELS[model_name] + + return model_cls, config + +def download_model(model_name): + model_cls, config = _get_model_cls_and_config(model_name) + if "auto" in model_cls.__module__: + # Handle auto classes + model = model_cls.from_config(config) + else: + model = model_cls(config) + return model_cls, model + +def generate_optimizer_for_model(model, model_name): + if model_name in BENCHMARK_USE_SGD: + return torch.optim.SGD(model.parameters(), lr=0.01, foreach=True) + return torch.optim.Adam( + model.parameters(), lr=0.01, capturable=True, foreach=True + ) diff --git a/torchbenchmark/util/framework/huggingface/model_factory.py b/torchbenchmark/util/framework/huggingface/model_factory.py index e9ace9fe1c..52406a7033 100644 --- a/torchbenchmark/util/framework/huggingface/model_factory.py +++ b/torchbenchmark/util/framework/huggingface/model_factory.py @@ -3,70 +3,15 @@ import os import torch from contextlib import nullcontext -from torch import optim + import torch.nn as nn from torchbenchmark.util.model import BenchmarkModel from torchbenchmark.tasks import NLP -import transformers -from transformers import AutoConfig, ReformerConfig, BertConfig, GenerationConfig, WhisperConfig, LlamaConfig -# PhiConfig is only available in newer version of transformers -try: - from transformers import PhiConfig -except: - pass from typing import Tuple +from transformers import GenerationConfig -class_models = { - # 'name': (train_max_length, eval_max_length, config, model) - 'hf_GPT2': (512, 1024, 'AutoConfig.from_pretrained("gpt2")', 'AutoModelForCausalLM'), - 'hf_GPT2_large': (512, 1024, 'AutoConfig.from_pretrained("gpt2-large")', 'AutoModelForCausalLM'), - 'hf_T5': (1024, 2048, 'AutoConfig.from_pretrained("t5-small")', 'AutoModelForSeq2SeqLM'), - 'hf_T5_base': (1024, 2048, 'AutoConfig.from_pretrained("t5-base")', 'AutoModelForSeq2SeqLM'), - 'hf_T5_large': (512, 512, 'AutoConfig.from_pretrained("t5-large")', 'AutoModelForSeq2SeqLM'), - 'hf_Bart': (512, 512, 'AutoConfig.from_pretrained("facebook/bart-base")', 'AutoModelForSeq2SeqLM'), - 'hf_Reformer': (4096, 4096, 'ReformerConfig()', 'AutoModelForMaskedLM'), - 'hf_BigBird': (1024, 4096, 'BigBirdConfig(attention_type="block_sparse",)', 'AutoModelForMaskedLM'), - 'hf_Albert': (512, 512, 'AutoConfig.from_pretrained("albert-base-v2")', 'AutoModelForMaskedLM'), - 'hf_DistilBert': (512, 512, 'AutoConfig.from_pretrained("distilbert-base-uncased")', 'AutoModelForMaskedLM'), - 'hf_Longformer': (1024, 4096, 'AutoConfig.from_pretrained("allenai/longformer-base-4096")', 'AutoModelForMaskedLM'), - 'hf_Bert': (512, 512, 'BertConfig()', 'AutoModelForMaskedLM'), - # see https://huggingface.co/bert-large-cased - 'hf_Bert_large': (512, 512, 'BertConfig(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16)', 'AutoModelForMaskedLM'), - 'hf_Whisper': (1024, 1024, 'WhisperConfig()', 'AutoModelForAudioClassification'), - 'hf_distil_whisper': (1024, 1024, 'AutoConfig.from_pretrained("distil-whisper/distil-medium.en")', 'AutoModelForAudioClassification'), - 'hf_mixtral' : (512,512, 'AutoConfig.from_pretrained("mistralai/Mixtral-8x7B-v0.1")', 'AutoModelForCausalLM'), - # default num_hidden_layers=32 but that OOMs, feel free to change this config to something more real - 'llama_v2_7b_16h' : (128,512, 'LlamaConfig(num_hidden_layers=16)', 'AutoModelForCausalLM'), - 'hf_MPT_7b_instruct': (512, 512, 'AutoConfig.from_pretrained("mosaicml/mpt-7b-instruct", trust_remote_code=True)', 'AutoModelForCausalLM'), - 'llava' : (512,512, 'AutoConfig.from_pretrained("liuhaotian/llava-v1.5-13b")', 'LlavaForConditionalGeneration'), - 'llama_v2_7b' : (512,512, 'AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")', 'AutoModelForCausalLM'), - 'llama_v2_13b' : (512,512, 'AutoConfig.from_pretrained("meta-llama/Llama-2-13b-hf")', 'AutoModelForCausalLM'), - 'llama_v2_70b' : (512, 512, 'AutoConfig.from_pretrained("meta-llama/Llama-2-70b-hf")', 'AutoModelForMaskedLM'), - 'codellama' : (512,512, 'AutoConfig.from_pretrained("codellama/CodeLlama-7b-hf")', 'AutoModelForCausalLM'), - 'phi_1_5' : (512, 512, 'AutoConfig.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)', 'AutoModelForCausalLM'), - 'phi_2' : (512, 512, 'AutoConfig.from_pretrained("microsoft/phi-2", trust_remote_code=True)', 'AutoModelForCausalLM'), - 'moondream' : (512, 512, 'PhiConfig.from_pretrained("vikhyatk/moondream1")', 'PhiForCausalLM'), - # as per this page https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 trust_remote_code=True is not required - 'mistral_7b_instruct' : (128, 128, 'AutoConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")', 'AutoModelForCausalLM'), - 'hf_Yi' : (512, 512, 'AutoConfig.from_pretrained("01-ai/Yi-6B", trust_remote_code=True)', 'AutoModelForCausalLM'), - 'orca_2' : (512, 512, 'AutoConfig.from_pretrained("microsoft/Orca-2-13b")', 'AutoModelForCausalLM'), -} - -cpu_input_slice = { - 'hf_BigBird': 5, - 'hf_Longformer': 8, - 'hf_T5': 4, - 'hf_GPT2': 4, - 'hf_Reformer': 2, -} - -class ArgsToKwargsWrapper(torch.nn.Module): - def __init__(self, model): - super(ArgsToKwargsWrapper, self).__init__() - self.model = model - - def forward(self, input_ids, decoder_input_ids): - return self.model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) +from .basic_configs import is_basic_huggingface_models +from .extended_configs import is_extended_huggingface_models, BATCH_SIZE_KNOWN_MODELS, BATCH_SIZE_DIVISORS class HuggingFaceModel(BenchmarkModel): HF_MODEL = True @@ -88,80 +33,59 @@ def __init__(self, name, test, device, batch_size=None, extra_args=[]): self.is_generate = False self.unqual_name = name name = self.unqual_name # we don't want to refer to the qualified name anymore - if test == "train": - self.max_length = class_models[name][0] - elif test == "eval": - self.max_length = class_models[name][1] - # workaround the bigbird config import - if name == "hf_BigBird": - from transformers import BigBirdConfig - config = eval(class_models[name][2]) - if class_models[name][2] == "ReformerConfig()" and not config.num_buckets: - # silence "config.num_buckets is not set. Setting config.num_buckets to 128" - config.num_buckets = 128 - class_ctor = getattr(transformers, class_models[name][3]) - kwargs = {} - hugging_face_models_requiring_trust_remote_code = ["hf_Falcon_7b", "hf_MPT_7b_instruct", "phi_1_5", "phi_2", "hf_Yi"] - if name in hugging_face_models_requiring_trust_remote_code: - kwargs["trust_remote_code"] = True - if hasattr(class_ctor, "from_config"): - self.model = class_ctor.from_config(config, **kwargs).to(device) + is_training = self.test == "train" + if is_basic_huggingface_models(name): + from .basic_configs import download_model, generate_inputs_for_model, generate_optimizer_for_model + self.model_cls, self.model = download_model(name) + self.model = self.model.to(self.device) + self.example_inputs = generate_inputs_for_model( + self.model_cls, + self.model, + name, + self.batch_size, + self.device, + is_training, + ) + if is_training: + self.optimizer = generate_optimizer_for_model(self.model, name) + elif is_extended_huggingface_models(name): + from .extended_configs import download_model, generate_inputs_for_model, generate_optimizer_for_model + self.model_cls, self.model = download_model(name) + self.model = self.model.to(self.device) + self.example_inputs = generate_inputs_for_model( + self.model_cls, + self.model, + name, + self.batch_size, + self.device, + include_loss_args=True + ) + if is_training: + self.optimizer = generate_optimizer_for_model(self.model, name) else: - self.model = class_ctor(config, **kwargs).to(device) - self.optimizer = optim.Adam( - self.model.parameters(), - lr=0.001, - # TODO resolve https://github.com/pytorch/torchdynamo/issues/1083 - capturable=bool(int(os.getenv("ADAM_CAPTURABLE", 0) - ))) - - if name in ["llama_v2_7b_16h"]: - self.optimizer = optim.SGD(self.model.parameters(), lr= 0.001) - - # populate these on-demand to avoid wasting memory when not used - self.vocab_size = config.vocab_size - - if test == "train": - input_ids = torch.randint(0, config.vocab_size, (self.batch_size, self.max_length)).to(device) - decoder_ids = torch.randint(0, config.vocab_size, (self.batch_size, self.max_length)).to(device) - self.example_inputs = {'input_ids': input_ids, 'labels': decoder_ids} + assert False, f"Huggingface model {name} is not supported yet." + + if is_training: self.model.train() - elif test == "eval": - # Cut the length of sentence when running on CPU, to reduce test time - if self.device == "cpu" and name in cpu_input_slice: - self.max_length = int(self.max_length / cpu_input_slice[name]) - eval_context = torch.randint(0, config.vocab_size, (self.batch_size, self.max_length)).to(device) - self.example_inputs = {'input_ids': eval_context, } - if class_models[name][3] == 'AutoModelForSeq2SeqLM': - self.example_inputs['decoder_input_ids'] = eval_context + else: self.model.eval() self.amp_context = nullcontext - def get_module(self, wrap_model=True): - if not self.is_generate and class_models[self.unqual_name][3] == 'AutoModelForSeq2SeqLM': - k = 'labels' if self.test == 'train' else 'decoder_input_ids' - if not wrap_model: - return self.model, ( - self.example_inputs['input_ids'], self.example_inputs[k]) - return ArgsToKwargsWrapper(self.model), ( - self.example_inputs['input_ids'], self.example_inputs[k]) - return self.model, (self.example_inputs["input_ids"], ) + def get_module(self): + return self.model, self.example_inputs def get_input_iter(self): """Yield randomized bucket length of inputs.""" - nbuckets = 8 - n = int(math.log2(self.max_length)) - buckets = [2**n for n in range(n - nbuckets, n)] - if class_models[self.unqual_name][3] == 'AutoModelForSeq2SeqLM': - raise NotImplementedError("AutoModelForSeq2SeqLM is not yet supported") - while True: - # randomize bucket_len - bucket_len = random.choice(buckets) - dict_input = { - 'input_ids': torch.randint(0, self.vocab_size, (self.batch_size, bucket_len)).to(self.device), - 'labels': torch.randint(0, self.vocab_size, (self.batch_size, bucket_len)).to(self.device), - } - yield dict_input + from .basic_configs import generate_input_iter_for_model + generator = generate_input_iter_for_model( + self.model_cls, + self.model, + self.unqual_name, + self.batch_size, + self.device, + self.test == "train", + ) + yield next(generator) def forward(self): with self.amp_context(): @@ -188,6 +112,7 @@ def eval(self) -> Tuple[torch.Tensor]: else: return (out["logits"], ) + class HuggingFaceAuthMixin: def __init__(self): if not 'HUGGING_FACE_HUB_TOKEN' in os.environ: @@ -217,6 +142,7 @@ def __init__(self, name, test, device, batch_size=None, extra_args=[]): use_cache=True, ) self.model = GenerationWrapper(self.model, generation_config) + self.example_inputs = (self.example_inputs['input_ids'], ) def train(self): raise NotImplementedError("_generate variant doesn't train") @@ -224,7 +150,7 @@ def train(self): def eval(self) -> Tuple[torch.Tensor]: with torch.no_grad(): with self.amp_context(): - out = self.model(self.example_inputs['input_ids']) + out = self.model(*self.example_inputs) return (out,) @@ -236,3 +162,21 @@ def __init__(self, model, generation_config): def forward(self, inputs): return self.model.generate(inputs, self.generation_config) + +class ExtendedHuggingFaceModel(HuggingFaceModel): + DEFAULT_TRAIN_BSIZE = None + DEFAULT_EVAL_BSIZE = None + def __init__(self, test, device, batch_size=None, extra_args=[]): + recorded_batch_size = BATCH_SIZE_KNOWN_MODELS[self.name] + if self.name in BATCH_SIZE_DIVISORS: + recorded_batch_size = max( + int(recorded_batch_size / BATCH_SIZE_DIVISORS[self.name]), 1 + ) + self.DEFAULT_TRAIN_BSIZE = recorded_batch_size + self.DEFAULT_EVAL_BSIZE = recorded_batch_size + super().__init__( + name=self.name, + test=test, + device=device, + batch_size=batch_size, + extra_args=extra_args) diff --git a/torchbenchmark/util/framework/huggingface/patch_hf.py b/torchbenchmark/util/framework/huggingface/patch_hf.py index 89483e104d..5f777db262 100644 --- a/torchbenchmark/util/framework/huggingface/patch_hf.py +++ b/torchbenchmark/util/framework/huggingface/patch_hf.py @@ -4,21 +4,12 @@ import os import subprocess import sys -from .model_factory import class_models -from transformers import AutoConfig, ReformerConfig, BigBirdConfig, BertConfig, WhisperConfig, LlamaConfig, PhiConfig -import inspect +from .basic_configs import download_model PATCH_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "patches") -def cache_model(name: str, **kwargs): - import transformers - model_config = eval(class_models[name][2]) - model_ctor = getattr(transformers, class_models[name][3]) - if not hasattr(model_ctor, "from_config"): - model_ctor(model_config, **kwargs) - else: - model_ctor.from_config(model_config, **kwargs) - +def cache_model(name: str): + download_model(name) def patch_transformers(): import transformers diff --git a/torchbenchmark/util/framework/timm/extended_configs.py b/torchbenchmark/util/framework/timm/extended_configs.py new file mode 100644 index 0000000000..dbf67fce0f --- /dev/null +++ b/torchbenchmark/util/framework/timm/extended_configs.py @@ -0,0 +1,82 @@ +# Extended timm model configs from Dynamobench +from typing import List +import os +import torch +from userbenchmark.dynamo import DYNAMOBENCH_PATH + +TIMM_MODELS = dict() +# Only load the extended models in OSS +if hasattr(torch.version, "git_version"): + filename = os.path.join(DYNAMOBENCH_PATH, "timm_models_list.txt") + with open(filename) as fh: + lines = fh.readlines() + lines = [line.rstrip() for line in lines] + for line in lines: + model_name, batch_size = line.split(" ") + TIMM_MODELS[model_name] = int(batch_size) + +def is_extended_timm_models(model_name: str) -> bool: + return model_name in TIMM_MODELS + +def list_extended_timm_models() -> List[str]: + return TIMM_MODELS.keys() + +# TODO - Figure out the reason of cold start memory spike +BATCH_SIZE_DIVISORS = { + "beit_base_patch16_224": 2, + "cait_m36_384": 8, + "convit_base": 2, + "convmixer_768_32": 2, + "convnext_base": 2, + "cspdarknet53": 2, + "deit_base_distilled_patch16_224": 2, + "dpn107": 2, + "gluon_xception65": 2, + "mobilevit_s": 2, + "pit_b_224": 2, + "pnasnet5large": 2, + "poolformer_m36": 2, + "res2net101_26w_4s": 2, + "resnest101e": 2, + "sebotnet33ts_256": 2, + "swin_base_patch4_window7_224": 2, + "swsl_resnext101_32x16d": 2, + "twins_pcpvt_base": 2, + "vit_base_patch16_224": 2, + "volo_d1_224": 2, + "jx_nest_base": 4, + "xcit_large_24_p8_224": 4, +} + +REQUIRE_HIGHER_TOLERANCE = { + "fbnetv3_b", + "gmixer_24_224", + "hrnet_w18", + "inception_v3", + "sebotnet33ts_256", + "selecsls42b", +} +REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = { + "adv_inception_v3", + "botnet26t_256", + "gluon_inception_v3", + "selecsls42b", + "swsl_resnext101_32x16d", +} + +SCALED_COMPUTE_LOSS = { + "ese_vovnet19b_dw", + "fbnetc_100", + "mnasnet_100", + "mobilevit_s", + "sebotnet33ts_256", +} + +FORCE_AMP_FOR_FP16_BF16_MODELS = { + "convit_base", + "xcit_large_24_p8_224", +} + +SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = { + "xcit_large_24_p8_224", +} diff --git a/torchbenchmark/util/framework/timm/model_factory.py b/torchbenchmark/util/framework/timm/model_factory.py index 56528db612..8f3d71d213 100644 --- a/torchbenchmark/util/framework/timm/model_factory.py +++ b/torchbenchmark/util/framework/timm/model_factory.py @@ -4,7 +4,13 @@ import timm from torchbenchmark.util.model import BenchmarkModel from .timm_config import TimmConfig -from typing import Generator, Tuple, Optional +from .extended_configs import BATCH_SIZE_DIVISORS, TIMM_MODELS + +# No pretrained weights exist for specific TIMM models +DISABLE_PRETRAINED_WEIGHTS = [ + "vovnet39a", + "vit_giant_patch14_224", +] class TimmModel(BenchmarkModel): # To recognize this is a timm model @@ -20,7 +26,18 @@ def __init__(self, model_name, test, device, batch_size=None, extra_args=[]): torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True - self.model = timm.create_model(model_name, pretrained=False, scriptable=True) + pretrained_weights = True if not model_name in DISABLE_PRETRAINED_WEIGHTS else False + self.model = timm.create_model( + model_name, + in_chans=3, + scriptable=False, + num_classes=None, + drop_rate=0.0, + drop_path_rate=None, + drop_block_rate=None, + pretrained=pretrained_weights, + ) + self.cfg = TimmConfig(model = self.model, device = device) self.example_inputs = self._gen_input(self.batch_size) @@ -86,3 +103,21 @@ def eval(self) -> typing.Tuple[torch.Tensor]: with self.amp_context(): out = self._step_eval() return (out, ) + +class ExtendedTimmModel(TimmModel): + DEFAULT_TRAIN_BSIZE = None + DEFAULT_EVAL_BSIZE = None + def __init__(self, test, device, batch_size=None, extra_args=[]): + recorded_batch_size = TIMM_MODELS[self.name] + if self.name in BATCH_SIZE_DIVISORS: + recorded_batch_size = max( + int(recorded_batch_size / BATCH_SIZE_DIVISORS[self.name]), 1 + ) + self.DEFAULT_EVAL_BSIZE = recorded_batch_size + self.DEFAULT_TRAIN_BSIZE = recorded_batch_size + super().__init__( + model_name=self.name, + test=test, + device=device, + batch_size=batch_size, + extra_args=extra_args) diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py index 79a11df999..e51e142ffb 100644 --- a/torchbenchmark/util/model.py +++ b/torchbenchmark/util/model.py @@ -211,7 +211,8 @@ def _determine_batch_size(self, user_specified_batch_size=None): def _load_metadata(self): relative_path = self.__class__.__module__.split(".") - self.name = relative_path[-1] + if getattr(self, "name", None) == None: + self.name = relative_path[-1] metadata_loc = Path(REPO_PATH).joinpath(*relative_path).joinpath("metadata.yaml") if not metadata_loc.exists(): return None diff --git a/userbenchmark/dynamo/__init__.py b/userbenchmark/dynamo/__init__.py index 8b13789179..8b5bf2e90c 100644 --- a/userbenchmark/dynamo/__init__.py +++ b/userbenchmark/dynamo/__init__.py @@ -1 +1,4 @@ +from torchbenchmark import REPO_PATH +BM_NAME = "dynamo" +DYNAMOBENCH_PATH = REPO_PATH.joinpath("userbenchmark", "dynamo", "dynamobench") \ No newline at end of file diff --git a/userbenchmark/dynamo/run.py b/userbenchmark/dynamo/run.py index 42e98f2d15..8cce828475 100644 --- a/userbenchmark/dynamo/run.py +++ b/userbenchmark/dynamo/run.py @@ -1,9 +1,8 @@ import logging import warnings -from torchbenchmark import add_path, REPO_PATH - -DYNAMOBENCH_PATH = REPO_PATH.joinpath("userbenchmark", "dynamo", "dynamobench") +from torchbenchmark import add_path +from . import DYNAMOBENCH_PATH try: # OSS Import diff --git a/userbenchmark/group_bench/configs/torch_ao.yaml b/userbenchmark/group_bench/configs/torch_ao.yaml index 77cbec12e8..762668ea3f 100644 --- a/userbenchmark/group_bench/configs/torch_ao.yaml +++ b/userbenchmark/group_bench/configs/torch_ao.yaml @@ -1,4 +1,7 @@ model: "*" +extended_models: + - huggingface + - timm test: eval device: cuda extra_args: --precision bf16 --torchdynamo inductor --inductor-compile-mode max-autotune diff --git a/userbenchmark/group_bench/run.py b/userbenchmark/group_bench/run.py index 6adb61563f..16a7f490f3 100644 --- a/userbenchmark/group_bench/run.py +++ b/userbenchmark/group_bench/run.py @@ -152,13 +152,23 @@ def run_config_accuracy(config: TorchBenchModelConfig, metrics: List[str], dryru def models_from_config(config) -> List[str]: assert "model" in config, f"We expect users to define models in config file." + basic_models_list = [] if isinstance(config["model"], str): if config["model"] == "*": - return list_models() + basic_models_list = list_models() else: - return [config["model"]] + basic_models_list = [config["model"]] assert isinstance(config["model", list]), "Config model must be a list or string." - return config["model"] + basic_models_list = config["model"] + extended_models_list = [] + if "extended_models" in config: + from torchbenchmark.util.experiment.instantiator import list_extended_models + for extended_model in config["extended_models"]: + if extended_model == "huggingface" or extended_model == "timm": + extended_models_list.extend(list_extended_models(extended_model)) + else: + extended_models_list.append(extended_model) + return basic_models_list + extended_models_list def load_group_config(config_file: str) -> TorchBenchGroupBenchConfig: if not os.path.exists(config_file): diff --git a/userbenchmark/test_bench/run.py b/userbenchmark/test_bench/run.py index efb9405cfe..74ae69353d 100644 --- a/userbenchmark/test_bench/run.py +++ b/userbenchmark/test_bench/run.py @@ -25,6 +25,7 @@ with add_path(REPO_PATH): from torchbenchmark.util.experiment.instantiator import ( list_models, + list_extended_models, load_model_isolated, TorchBenchModelConfig, list_devices, @@ -251,13 +252,14 @@ def run(args: List[str]): if args.run_bisect: configs = generate_model_configs_from_bisect_yaml(args.run_bisect) else: - # If not specified, use the entire model set + # If not specified, use the entire model + extended model set + modelset = list_models() + list_extended_models() if not args.models: - args.models = list_models() + args.models = modelset devices = validate(parse_str_to_list(args.device), list_devices()) tests = validate(parse_str_to_list(args.test), list_tests()) batch_sizes = parse_str_to_list(args.bs) - models = validate(parse_str_to_list(args.models), list_models()) + models = validate(parse_str_to_list(args.models), modelset) configs = generate_model_configs( devices, tests, batch_sizes, model_names=models, extra_args=extra_args ) @@ -281,7 +283,6 @@ def run(args: List[str]): result = get_output_json(BM_NAME, results) if args.device == "cuda": import torch - result["environ"]["device"] = torch.cuda.get_device_name() with open(args.output, "w") as f: json.dump(result, f, indent=4)