From 79928279722c54b33d791455d9ba14170872f91b Mon Sep 17 00:00:00 2001 From: Facebook Community Bot Date: Thu, 6 Jun 2024 10:16:15 -0700 Subject: [PATCH] Re-sync with internal repository (#2281) --- .../dynamo/dynamobench/huggingface.py | 685 ++++++++++++++++++ .../dynamo/dynamobench/timm_models.py | 387 ++++++++++ .../dynamo/dynamobench/torchbench.yaml | 259 +++++++ 3 files changed, 1331 insertions(+) create mode 100755 userbenchmark/dynamo/dynamobench/huggingface.py create mode 100755 userbenchmark/dynamo/dynamobench/timm_models.py create mode 100644 userbenchmark/dynamo/dynamobench/torchbench.yaml diff --git a/userbenchmark/dynamo/dynamobench/huggingface.py b/userbenchmark/dynamo/dynamobench/huggingface.py new file mode 100755 index 0000000000..dca2915a07 --- /dev/null +++ b/userbenchmark/dynamo/dynamobench/huggingface.py @@ -0,0 +1,685 @@ +#!/usr/bin/env python3 +import importlib +import logging +import os +import re +import subprocess +import sys +import warnings + +try: + from .common import BenchmarkRunner, download_retry_decorator, main, reset_rng_state +except ImportError: + from common import BenchmarkRunner, download_retry_decorator, main, reset_rng_state + +import torch + +from torch._dynamo.testing import collect_results +from torch._dynamo.utils import clone_inputs + +log = logging.getLogger(__name__) + +# Enable FX graph caching +if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: + torch._inductor.config.fx_graph_cache = True + + +def pip_install(package): + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + + +# Disable the flake warnings for the imports. Flake8 does not provide a way to +# disable just warning for the entire file. Disabling flake8 entirely. +# flake8: noqa +imports = [ + "AlbertForPreTraining", + "AutoConfig", + "AutoModelForCausalLM", + "AutoModelForMaskedLM", + "AutoModelForSeq2SeqLM", + "BigBirdConfig", + "BlenderbotForConditionalGeneration", + "BlenderbotModel", + "BlenderbotSmallForConditionalGeneration", + "BlenderbotSmallModel", + "CLIPModel", + "CLIPVisionModel", + "ElectraForPreTraining", + "GPT2ForSequenceClassification", + "GPTJForSequenceClassification", + "GPTNeoForSequenceClassification", + "HubertForSequenceClassification", + "LxmertForPreTraining", + "LxmertForQuestionAnswering", + "MarianForCausalLM", + "MarianModel", + "MarianMTModel", + "PegasusForConditionalGeneration", + "PegasusModel", + "ReformerConfig", + "ViTForImageClassification", + "ViTForMaskedImageModeling", + "ViTModel", +] + + +def process_hf_reformer_output(out): + assert isinstance(out, list) + # second output is unstable + return [elem for i, elem in enumerate(out) if i != 1] + + +try: + mod = importlib.import_module("transformers") + for cls in imports: + if not hasattr(mod, cls): + raise ModuleNotFoundError +except ModuleNotFoundError: + print("Installing HuggingFace Transformers...") + pip_install("git+https://github.com/huggingface/transformers.git#egg=transformers") +finally: + for cls in imports: + exec(f"from transformers import {cls}") + + +# These models contain the models present in huggingface_models_list. It is a +# combination of models supported by HF Fx parser and some manually supplied +# models. For these models, we already know the largest batch size that can fit +# on A100 GPUs - 40 GB. +BATCH_SIZE_KNOWN_MODELS = dict() + + +# Get the list of models and their batch sizes +MODELS_FILENAME = os.path.join(os.path.dirname(__file__), "huggingface_models_list.txt") +assert os.path.exists(MODELS_FILENAME) +with open(MODELS_FILENAME, "r") as fh: + lines = fh.readlines() + lines = [line.rstrip() for line in lines] + for line in lines: + model_name, batch_size = line.split(",") + batch_size = int(batch_size) + BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size +assert len(BATCH_SIZE_KNOWN_MODELS) + + +SKIP = { + # Difficult to setup accuracy test because .eval() not supported + "Reformer", + # Fails deepcopy + "BlenderbotForConditionalGeneration", + "GPTNeoForCausalLM", + "GPTNeoForSequenceClassification", + # Fails with even batch size = 1 + "GPTJForCausalLM", + "GPTJForQuestionAnswering", +} + +# TODO - Fails even after fake tensors +BATCH_SIZE_DIVISORS = { + "AlbertForMaskedLM": 2, + "AlbertForQuestionAnswering": 2, + "AllenaiLongformerBase": 2, + "BartForCausalLM": 2, + "BartForConditionalGeneration": 2, + "BertForMaskedLM": 2, + "BertForQuestionAnswering": 2, + "BlenderbotForCausalLM": 8, + # "BlenderbotForConditionalGeneration" : 16, + "BlenderbotSmallForCausalLM": 4, + "BlenderbotSmallForConditionalGeneration": 2, + "CamemBert": 2, + "DebertaForMaskedLM": 4, + "DebertaForQuestionAnswering": 2, + "DebertaV2ForMaskedLM": 4, + "DebertaV2ForQuestionAnswering": 8, + "DistilBertForMaskedLM": 2, + "DistilBertForQuestionAnswering": 2, + "DistillGPT2": 2, + "ElectraForCausalLM": 2, + "ElectraForQuestionAnswering": 2, + "GPT2ForSequenceClassification": 2, + # "GPTJForCausalLM" : 2, + # "GPTJForQuestionAnswering" : 2, + # "GPTNeoForCausalLM" : 32, + # "GPTNeoForSequenceClassification" : 2, + "GoogleFnet": 2, + "LayoutLMForMaskedLM": 2, + "LayoutLMForSequenceClassification": 2, + "M2M100ForConditionalGeneration": 4, + "MBartForCausalLM": 2, + "MBartForConditionalGeneration": 2, + "MT5ForConditionalGeneration": 2, + "MegatronBertForCausalLM": 4, + "MegatronBertForQuestionAnswering": 2, + "MobileBertForMaskedLM": 2, + "MobileBertForQuestionAnswering": 2, + "OPTForCausalLM": 2, + "PLBartForCausalLM": 2, + "PLBartForConditionalGeneration": 2, + "PegasusForCausalLM": 4, + "PegasusForConditionalGeneration": 2, + "RobertaForCausalLM": 2, + "RobertaForQuestionAnswering": 2, + "Speech2Text2ForCausalLM": 4, + "T5ForConditionalGeneration": 2, + "T5Small": 2, + "TrOCRForCausalLM": 2, + "XGLMForCausalLM": 4, + "XLNetLMHeadModel": 2, + "YituTechConvBert": 2, +} + +SKIP_ACCURACY_CHECK_MODELS = { + # Models too large to have eager, dynamo and fp64_numbers simultaneosuly + # even for 40 GB machine. + "DebertaV2ForMaskedLM", + "BlenderbotForCausalLM", +} + +SKIP_DUE_TO_CONTROL_FLOW = {"AllenaiLongformerBase"} + + +REQUIRE_HIGHER_TOLERANCE_TRAINING = { + "MT5ForConditionalGeneration", + # AlbertForQuestionAnswering fails in CI GCP A100 but error does not seem + # harmful. + "AlbertForQuestionAnswering", +} +REQUIRE_HIGHER_TOLERANCE_INFERENCE = { + "GPT2ForSequenceClassification", + "RobertaForQuestionAnswering", +} + + +SKIP_FOR_CPU = { + "OPTForCausalLM", # OOMs +} + +ONLY_EVAL_MODE = { + "M2M100ForConditionalGeneration", # Fails with dynamo for train mode +} + +FP32_ONLY_MODELS = { + "GoogleFnet", +} + + +def get_module_cls_by_model_name(model_cls_name): + _module_by_model_name = { + "Speech2Text2Decoder": "transformers.models.speech_to_text_2.modeling_speech_to_text_2", + "TrOCRDecoder": "transformers.models.trocr.modeling_trocr", + } + module_name = _module_by_model_name.get(model_cls_name, "transformers") + module = importlib.import_module(module_name) + return getattr(module, model_cls_name) + + +def get_sequence_length(model_cls, model_name): + if model_name.startswith(("Blenderbot",)): + seq_length = 128 + elif model_name.startswith(("GPT2", "Bart", "T5", "PLBart", "MBart")): + seq_length = 1024 + elif model_name in ("AllenaiLongformerBase", "BigBird"): + seq_length = 1024 + elif model_name.startswith("OPT"): + seq_length = 2048 + elif "Reformer" in model_name: + seq_length = 4096 + elif model_name.startswith( + ( + "Albert", + "Deberta", + "Layout", + "Electra", + "XLNet", + "MegatronBert", + "Bert", + "Roberta", + ) + ) or model_name in ("DistillGPT2", "GoogleFnet", "YituTechConvBert", "CamemBert"): + seq_length = 512 + elif model_name in ("TrOCRForCausalLM"): + seq_length = 256 + elif model_name.startswith("MobileBert"): + seq_length = 128 + elif model_name.startswith("Wav2Vec2"): + # If too short, will fail with something like + # ValueError: `mask_length` has to be smaller than `sequence_length`, + # but got `mask_length`: 10 and `sequence_length`: 9` + seq_length = 10000 # NB: a more realistic size is 155136 + else: + log.info( + f"Sequence Length not defined for {model_name}. Choosing 128 arbitrarily" + ) + seq_length = 128 + return seq_length + + +def generate_inputs_for_model( + model_cls, model, model_name, bs, device, include_loss_args=False +): + # TODO - Check if following values are representative + num_choices = 3 + num_visual_features = 42 + seq_length = get_sequence_length(model_cls, model_name) + vocab_size = model.config.vocab_size + + if model_name.startswith("Wav2Vec2"): + # TODO: If we add more input_values style models, try to work this + # into the overall control flow + target_length = 100 + return { + "input_values": torch.randn((bs, seq_length), device=device), + # Added because that's what the example training script has + "attention_mask": rand_int_tensor(device, 0, 2, (bs, seq_length)), + "labels": rand_int_tensor(device, 0, vocab_size, (bs, target_length)), + } + + if model_name.endswith("MultipleChoice"): + input = rand_int_tensor(device, 0, vocab_size, (bs, num_choices, seq_length)) + elif model_name.startswith("Roberta"): + input = rand_int_tensor(device, 0, 1, (bs, seq_length)) + else: + input = rand_int_tensor(device, 0, vocab_size, (bs, seq_length)) + + if "Bart" in model_name: + input[:, -1] = model.config.eos_token_id + + input_dict = {"input_ids": input} + + if ( + model_name.startswith("T5") + or model_name.startswith("M2M100") + or model_name.startswith("MT5") + or model_cls + in [ + BlenderbotModel, + BlenderbotSmallModel, + BlenderbotForConditionalGeneration, + BlenderbotSmallForConditionalGeneration, + PegasusModel, + PegasusForConditionalGeneration, + MarianModel, + MarianMTModel, + ] + ): + input_dict["decoder_input_ids"] = input + + if model_name.startswith("Lxmert"): + visual_feat_dim, visual_pos_dim = ( + model.config.visual_feat_dim, + model.config.visual_pos_dim, + ) + input_dict["visual_feats"] = torch.randn( + bs, num_visual_features, visual_feat_dim + ) + input_dict["visual_pos"] = torch.randn(bs, num_visual_features, visual_pos_dim) + + if include_loss_args: + if model_name.endswith("PreTraining"): + if model_cls in [ElectraForPreTraining, LxmertForPreTraining]: + input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs, seq_length)) + else: + label_name = ( + "sentence_order_label" + if model_cls in [AlbertForPreTraining] + else "next_sentence_label" + ) + input_dict["labels"] = ( + rand_int_tensor(device, 0, vocab_size, (bs, seq_length)), + ) + input_dict[label_name] = rand_int_tensor(device, 0, 1, (bs,)) + elif model_name.endswith("QuestionAnswering"): + input_dict["start_positions"] = rand_int_tensor( + device, 0, seq_length, (bs,) + ) + input_dict["end_positions"] = rand_int_tensor(device, 0, seq_length, (bs,)) + elif ( + model_name.endswith("MaskedLM") + or model_name.endswith("HeadModel") + or model_name.endswith("CausalLM") + or model_name.endswith("DoubleHeadsModel") + ): + input_dict["labels"] = rand_int_tensor( + device, 0, vocab_size, (bs, seq_length) + ) + elif model_name.endswith("TokenClassification"): + input_dict["labels"] = rand_int_tensor( + device, 0, model.config.num_labels - 1, (bs, seq_length) + ) + elif model_name.endswith("MultipleChoice"): + input_dict["labels"] = rand_int_tensor(device, 0, num_choices, (bs,)) + elif model_name.endswith("SequenceClassification"): + input_dict["labels"] = rand_int_tensor( + device, 0, model.config.num_labels - 1, (bs,) + ) + elif model_name.endswith("NextSentencePrediction"): + input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs,)) + elif model_name.endswith("ForConditionalGeneration"): + input_dict["labels"] = rand_int_tensor( + device, 0, vocab_size - 1, (bs, seq_length) + ) + elif model_name in EXTRA_MODELS: + input_dict["labels"] = rand_int_tensor( + device, 0, vocab_size, (bs, seq_length) + ) + else: + raise NotImplementedError( + f"Class {model_name} unsupported for training test " + ) + + return input_dict + + +def rand_int_tensor(device, low, high, shape): + return torch.randint( + low, + high, + shape, + device=device, + dtype=torch.int64, + requires_grad=False, + ) + + +EXTRA_MODELS = { + "AllenaiLongformerBase": ( + AutoConfig.from_pretrained("allenai/longformer-base-4096"), + AutoModelForMaskedLM, + ), + "Reformer": ( + ReformerConfig(), + AutoModelForMaskedLM, + ), + "T5Small": ( + AutoConfig.from_pretrained("t5-small"), + AutoModelForSeq2SeqLM, + ), + # "BigBird": ( + # BigBirdConfig(attention_type="block_sparse"), + # AutoModelForMaskedLM, + # ), + "DistillGPT2": ( + AutoConfig.from_pretrained("distilgpt2"), + AutoModelForCausalLM, + ), + "GoogleFnet": ( + AutoConfig.from_pretrained("google/fnet-base"), + AutoModelForMaskedLM, + ), + "YituTechConvBert": ( + AutoConfig.from_pretrained("YituTech/conv-bert-base"), + AutoModelForMaskedLM, + ), + "CamemBert": ( + AutoConfig.from_pretrained("camembert-base"), + AutoModelForMaskedLM, + ), +} + + +class HuggingfaceRunner(BenchmarkRunner): + def __init__(self): + super().__init__() + self.suite_name = "huggingface" + + @property + def skip_models_for_cpu(self): + return SKIP_FOR_CPU + + @property + def fp32_only_models(self): + return FP32_ONLY_MODELS + + @property + def skip_models_due_to_control_flow(self): + return SKIP_DUE_TO_CONTROL_FLOW + + def _get_model_cls_and_config(self, model_name): + if model_name not in EXTRA_MODELS: + model_cls = get_module_cls_by_model_name(model_name) + config_cls = model_cls.config_class + config = config_cls() + + # NB: some models need a pad token defined to handle BS > 1 + if ( + model_cls + in [ + GPT2ForSequenceClassification, + GPTNeoForSequenceClassification, + GPTJForSequenceClassification, + ] + or model_cls.__name__.startswith("Roberta") + or model_cls.__name__.startswith("Marian") + ): + config.pad_token_id = 0 + + else: + config, model_cls = EXTRA_MODELS[model_name] + + return model_cls, config + + @download_retry_decorator + def _download_model(self, model_name): + model_cls, config = self._get_model_cls_and_config(model_name) + if "auto" in model_cls.__module__: + # Handle auto classes + model = model_cls.from_config(config) + else: + model = model_cls(config) + return model + + def load_model( + self, + device, + model_name, + batch_size=None, + extra_args=None, + ): + is_training = self.args.training + use_eval_mode = self.args.use_eval_mode + dtype = torch.float32 + reset_rng_state() + model_cls, config = self._get_model_cls_and_config(model_name) + model = self._download_model(model_name) + model = model.to(device, dtype=dtype) + if self.args.enable_activation_checkpointing: + model.gradient_checkpointing_enable() + if model_name in BATCH_SIZE_KNOWN_MODELS: + batch_size_default = BATCH_SIZE_KNOWN_MODELS[model_name] + elif batch_size is None: + batch_size_default = 16 + log.info( + f"Batch size not specified for {model_name}. Setting batch_size=16" + ) + + if batch_size is None: + batch_size = batch_size_default + if model_name in BATCH_SIZE_DIVISORS: + batch_size = max(int(batch_size / BATCH_SIZE_DIVISORS[model_name]), 1) + log.info( + f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}" + ) + + example_inputs = generate_inputs_for_model( + model_cls, model, model_name, batch_size, device, include_loss_args=True + ) + + # So we can check for correct gradients without eliminating the dropout computation + for attr in dir(config): + if "drop" in attr and isinstance(getattr(config, attr), float): + setattr(config, attr, 1e-30) + + if ( + is_training + and not use_eval_mode + and not (self.args.accuracy and model_name in ONLY_EVAL_MODE) + ): + model.train() + else: + model.eval() + + self.validate_model(model, example_inputs) + return device, model_name, model, example_inputs, batch_size + + def iter_model_names(self, args): + model_names = list(BATCH_SIZE_KNOWN_MODELS.keys()) + list(EXTRA_MODELS.keys()) + model_names = set(model_names) + model_names = sorted(model_names) + + start, end = self.get_benchmark_indices(len(model_names)) + for index, model_name in enumerate(model_names): + if index < start or index >= end: + continue + if ( + not re.search("|".join(args.filter), model_name, re.I) + or re.search("|".join(args.exclude), model_name, re.I) + or model_name in args.exclude_exact + or model_name in SKIP + ): + continue + yield model_name + + @property + def skip_accuracy_checks_large_models_dashboard(self): + if self.args.dashboard or self.args.accuracy: + return SKIP_ACCURACY_CHECK_MODELS + return set() + + @property + def get_output_amp_train_process_func(self): + return {} + + def pick_grad(self, name, is_training): + if is_training: + return torch.enable_grad() + else: + return torch.no_grad() + + def get_tolerance_and_cosine_flag(self, is_training, current_device, name): + cosine = self.args.cosine + if is_training: + if name in REQUIRE_HIGHER_TOLERANCE_TRAINING: + return 2e-2, cosine + else: + return 1e-2, cosine + else: + if name in REQUIRE_HIGHER_TOLERANCE_INFERENCE: + return 4e-3, cosine + return 1e-3, cosine + + def compute_loss(self, pred): + return pred[0] + + def forward_pass(self, mod, inputs, collect_outputs=True): + with self.autocast(**self.autocast_arg): + return mod(**inputs) + + def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): + cloned_inputs = clone_inputs(inputs) + self.optimizer_zero_grad(mod) + with self.autocast(**self.autocast_arg): + pred = mod(**cloned_inputs) + loss = self.compute_loss(pred) + self.grad_scaler.scale(loss).backward() + self.optimizer_step() + if collect_outputs: + return collect_results(mod, pred, loss, cloned_inputs) + return None + + +def refresh_model_names_and_batch_sizes(): + """ + This function reads the HF Fx tracer supported models and finds the largest + batch size that could fit on the GPU with PyTorch eager. + + The resulting data is written in huggingface_models_list.txt. + + Note - We only need to run this function if we believe that HF Fx tracer now + supports more models. + """ + import transformers.utils.fx as hf_fx + + family = dict() + lm_seen = set() + family_seen = set() + for cls_name in hf_fx._SUPPORTED_MODELS: + if "For" not in cls_name: + continue + + model_cls = get_module_cls_by_model_name(cls_name) + + # TODO: AttributeError: '*Config' object has no attribute 'vocab_size' + if model_cls in [ + CLIPModel, + CLIPVisionModel, + # SwinForImageClassification, + # SwinForImageClassification, + # SwinForMaskedImageModeling, + # SwinModel, + ViTForImageClassification, + ViTForMaskedImageModeling, + ViTModel, + ]: + continue + + # TODO: AssertionError: Padding_idx must be within num_embeddings + if model_cls in [MarianForCausalLM, MarianMTModel, MarianModel]: + continue + + # TODO: "model is not supported yet" from HFTracer + if model_cls in [HubertForSequenceClassification]: + continue + + # TODO: shape mismatch in loss calculation + if model_cls in [LxmertForQuestionAnswering]: + continue + + family_name = cls_name.split("For")[0] + if family_name not in family: + family[family_name] = [] + if cls_name.endswith(("MaskedLM", "CausalLM")) and family_name not in lm_seen: + family[family_name].append(cls_name) + lm_seen.add(family_name) + elif ( + cls_name.endswith( + ("SequenceClassification", "ConditionalGeneration", "QuestionAnswering") + ) + and family_name not in family_seen + ): + family[family_name].append(cls_name) + family_seen.add(family_name) + elif cls_name.endswith("ImageClassification"): + family[family_name].append(cls_name) + + chosen_models = set() + for members in family.values(): + chosen_models.update(set(members)) + + # Add the EXTRA_MODELS + chosen_models.update(set(EXTRA_MODELS.keys())) + + for model_name in sorted(chosen_models): + try: + subprocess.check_call( + [sys.executable] + + sys.argv + + ["--find-batch-sizes"] + + [f"--only={model_name}"] + + [f"--output={MODELS_FILENAME}"] + ) + except subprocess.SubprocessError: + log.warning(f"Failed to find suitable batch size for {model_name}") + + +def huggingface_main(): + # Code to refresh model names and batch sizes + # if "--find-batch-sizes" not in sys.argv: + # refresh_model_names_and_batch_sizes() + logging.basicConfig(level=logging.WARNING) + warnings.filterwarnings("ignore") + main(HuggingfaceRunner()) + + +if __name__ == "__main__": + huggingface_main() diff --git a/userbenchmark/dynamo/dynamobench/timm_models.py b/userbenchmark/dynamo/dynamobench/timm_models.py new file mode 100755 index 0000000000..75a1251769 --- /dev/null +++ b/userbenchmark/dynamo/dynamobench/timm_models.py @@ -0,0 +1,387 @@ +#!/usr/bin/env python3 +import importlib +import logging +import os +import re +import subprocess +import sys +import warnings + +try: + from .common import BenchmarkRunner, download_retry_decorator, main +except ImportError: + from common import BenchmarkRunner, download_retry_decorator, main + +import torch + +from torch._dynamo.testing import collect_results, reduce_to_scalar_loss +from torch._dynamo.utils import clone_inputs + +# Enable FX graph caching +if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: + torch._inductor.config.fx_graph_cache = True + + +def pip_install(package): + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + + +try: + importlib.import_module("timm") +except ModuleNotFoundError: + print("Installing PyTorch Image Models...") + pip_install("git+https://github.com/rwightman/pytorch-image-models") +finally: + from timm import __version__ as timmversion + from timm.data import resolve_data_config + from timm.models import create_model + +TIMM_MODELS = dict() +filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt") + +with open(filename) as fh: + lines = fh.readlines() + lines = [line.rstrip() for line in lines] + for line in lines: + model_name, batch_size = line.split(" ") + TIMM_MODELS[model_name] = int(batch_size) + + +# TODO - Figure out the reason of cold start memory spike + +BATCH_SIZE_DIVISORS = { + "beit_base_patch16_224": 2, + "convit_base": 2, + "convmixer_768_32": 2, + "convnext_base": 2, + "cspdarknet53": 2, + "deit_base_distilled_patch16_224": 2, + "gluon_xception65": 2, + "mobilevit_s": 2, + "pnasnet5large": 2, + "poolformer_m36": 2, + "resnest101e": 2, + "swin_base_patch4_window7_224": 2, + "swsl_resnext101_32x16d": 2, + "vit_base_patch16_224": 2, + "volo_d1_224": 2, + "jx_nest_base": 4, +} + +REQUIRE_HIGHER_TOLERANCE = { + "fbnetv3_b", + "gmixer_24_224", + "hrnet_w18", + "inception_v3", + "mixer_b16_224", + "sebotnet33ts_256", + "selecsls42b", +} + +REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = { + "adv_inception_v3", + "botnet26t_256", + "gluon_inception_v3", + "selecsls42b", + "swsl_resnext101_32x16d", +} + +SCALED_COMPUTE_LOSS = { + "ese_vovnet19b_dw", + "fbnetc_100", + "mnasnet_100", + "mobilevit_s", + "sebotnet33ts_256", +} + +FORCE_AMP_FOR_FP16_BF16_MODELS = { + "convit_base", + "xcit_large_24_p8_224", +} + +SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = { + "xcit_large_24_p8_224", +} + + +def refresh_model_names(): + import glob + + from timm.models import list_models + + def read_models_from_docs(): + models = set() + # TODO - set the path to pytorch-image-models repo + for fn in glob.glob("../pytorch-image-models/docs/models/*.md"): + with open(fn) as f: + while True: + line = f.readline() + if not line: + break + if not line.startswith("model = timm.create_model("): + continue + + model = line.split("'")[1] + # print(model) + models.add(model) + return models + + def get_family_name(name): + known_families = [ + "darknet", + "densenet", + "dla", + "dpn", + "ecaresnet", + "halo", + "regnet", + "efficientnet", + "deit", + "mobilevit", + "mnasnet", + "convnext", + "resnet", + "resnest", + "resnext", + "selecsls", + "vgg", + "xception", + ] + + for known_family in known_families: + if known_family in name: + return known_family + + if name.startswith("gluon_"): + return "gluon_" + name.split("_")[1] + return name.split("_")[0] + + def populate_family(models): + family = dict() + for model_name in models: + family_name = get_family_name(model_name) + if family_name not in family: + family[family_name] = [] + family[family_name].append(model_name) + return family + + docs_models = read_models_from_docs() + all_models = list_models(pretrained=True, exclude_filters=["*in21k"]) + + all_models_family = populate_family(all_models) + docs_models_family = populate_family(docs_models) + + for key in docs_models_family: + del all_models_family[key] + + chosen_models = set() + chosen_models.update(value[0] for value in docs_models_family.values()) + + chosen_models.update(value[0] for key, value in all_models_family.items()) + + filename = "timm_models_list.txt" + if os.path.exists("benchmarks"): + filename = "benchmarks/" + filename + with open(filename, "w") as fw: + for model_name in sorted(chosen_models): + fw.write(model_name + "\n") + + +class TimmRunner(BenchmarkRunner): + def __init__(self): + super().__init__() + self.suite_name = "timm_models" + + @property + def force_amp_for_fp16_bf16_models(self): + return FORCE_AMP_FOR_FP16_BF16_MODELS + + @property + def force_fp16_for_bf16_models(self): + return set() + + @property + def get_output_amp_train_process_func(self): + return {} + + @property + def skip_accuracy_check_as_eager_non_deterministic(self): + if self.args.accuracy and self.args.training: + return SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS + return set() + + @property + def guard_on_nn_module_models(self): + return { + "convit_base", + } + + @download_retry_decorator + def _download_model(self, model_name): + model = create_model( + model_name, + in_chans=3, + scriptable=False, + num_classes=None, + drop_rate=0.0, + drop_path_rate=None, + drop_block_rate=None, + pretrained=True, + ) + return model + + def load_model( + self, + device, + model_name, + batch_size=None, + extra_args=None, + ): + if self.args.enable_activation_checkpointing: + raise NotImplementedError( + "Activation checkpointing not implemented for Timm models" + ) + + is_training = self.args.training + use_eval_mode = self.args.use_eval_mode + + channels_last = self._args.channels_last + model = self._download_model(model_name) + + if model is None: + raise RuntimeError(f"Failed to load model '{model_name}'") + model.to( + device=device, + memory_format=torch.channels_last if channels_last else None, + ) + + self.num_classes = model.num_classes + + data_config = resolve_data_config( + vars(self._args) if timmversion >= "0.8.0" else self._args, + model=model, + use_test_size=not is_training, + ) + input_size = data_config["input_size"] + recorded_batch_size = TIMM_MODELS[model_name] + + if model_name in BATCH_SIZE_DIVISORS: + recorded_batch_size = max( + int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1 + ) + batch_size = batch_size or recorded_batch_size + + torch.manual_seed(1337) + input_tensor = torch.randint( + 256, size=(batch_size,) + input_size, device=device + ).to(dtype=torch.float32) + mean = torch.mean(input_tensor) + std_dev = torch.std(input_tensor) + example_inputs = (input_tensor - mean) / std_dev + + if channels_last: + example_inputs = example_inputs.contiguous( + memory_format=torch.channels_last + ) + example_inputs = [ + example_inputs, + ] + self.target = self._gen_target(batch_size, device) + + self.loss = torch.nn.CrossEntropyLoss().to(device) + + if model_name in SCALED_COMPUTE_LOSS: + self.compute_loss = self.scaled_compute_loss + + if is_training and not use_eval_mode: + model.train() + else: + model.eval() + + self.validate_model(model, example_inputs) + + return device, model_name, model, example_inputs, batch_size + + def iter_model_names(self, args): + # for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]): + model_names = sorted(TIMM_MODELS.keys()) + start, end = self.get_benchmark_indices(len(model_names)) + for index, model_name in enumerate(model_names): + if index < start or index >= end: + continue + if ( + not re.search("|".join(args.filter), model_name, re.I) + or re.search("|".join(args.exclude), model_name, re.I) + or model_name in args.exclude_exact + or model_name in self.skip_models + ): + continue + + yield model_name + + def pick_grad(self, name, is_training): + if is_training: + return torch.enable_grad() + else: + return torch.no_grad() + + def get_tolerance_and_cosine_flag(self, is_training, current_device, name): + cosine = self.args.cosine + tolerance = 1e-3 + + if self.args.freezing and name in REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING: + # the conv-batchnorm fusion used under freezing may cause relatively + # large numerical difference. We need are larger tolerance. + # Check https://github.com/pytorch/pytorch/issues/120545 for context + tolerance = 8 * 1e-2 + + if is_training: + if name in ["levit_128"]: + tolerance = 8 * 1e-2 + elif name in REQUIRE_HIGHER_TOLERANCE: + tolerance = 4 * 1e-2 + else: + tolerance = 1e-2 + return tolerance, cosine + + def _gen_target(self, batch_size, device): + return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_( + self.num_classes + ) + + def compute_loss(self, pred): + # High loss values make gradient checking harder, as small changes in + # accumulation order upsets accuracy checks. + return reduce_to_scalar_loss(pred) + + def scaled_compute_loss(self, pred): + # Loss values need zoom out further. + return reduce_to_scalar_loss(pred) / 1000.0 + + def forward_pass(self, mod, inputs, collect_outputs=True): + with self.autocast(**self.autocast_arg): + return mod(*inputs) + + def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): + cloned_inputs = clone_inputs(inputs) + self.optimizer_zero_grad(mod) + with self.autocast(**self.autocast_arg): + pred = mod(*cloned_inputs) + if isinstance(pred, tuple): + pred = pred[0] + loss = self.compute_loss(pred) + self.grad_scaler.scale(loss).backward() + self.optimizer_step() + if collect_outputs: + return collect_results(mod, pred, loss, cloned_inputs) + return None + + +def timm_main(): + logging.basicConfig(level=logging.WARNING) + warnings.filterwarnings("ignore") + main(TimmRunner()) + + +if __name__ == "__main__": + timm_main() diff --git a/userbenchmark/dynamo/dynamobench/torchbench.yaml b/userbenchmark/dynamo/dynamobench/torchbench.yaml new file mode 100644 index 0000000000..bf848e81b3 --- /dev/null +++ b/userbenchmark/dynamo/dynamobench/torchbench.yaml @@ -0,0 +1,259 @@ +# Some models have large dataset that doesn't fit in memory. Lower the batch +# size to test the accuracy. +batch_size: + training: + demucs: 4 + dlrm: 1024 + densenet121: 4 + hf_Reformer: 4 + hf_T5_base: 4 + timm_efficientdet: 1 + llama_v2_7b_16h: 1 + # reduced from 16 due to cudagraphs OOM in TorchInductor dashboard + yolov3: 8 + + inference: + timm_efficientdet: 32 + + +dont_change_batch_size: + - demucs + - pytorch_struct + - pyhpc_turbulent_kinetic_energy + # https://github.com/pytorch/benchmark/pull/1656 + - vision_maskrcnn + + +tolerance: + # Need lower tolerance on GPU. GPU kernels have non deterministic kernels for these models. + higher: + - alexnet + - attention_is_all_you_need_pytorch + - densenet121 + - hf_Albert + - vgg16 + - mobilenet_v3_large + - nvidia_deeprecommender + - timm_efficientdet + + # These models need >1e-3 tolerance + even_higher: + - soft_actor_critic + - tacotron2 + + higher_fp16: + - doctr_reco_predictor + - drq + - hf_Whisper + + higher_bf16: + - doctr_reco_predictor + - drq + - hf_Whisper + + cosine: [] + + +# These benchmarks took >600s on an i9-11900K CPU +very_slow: &VERY_SLOW_MODELS + # 3339s + - hf_BigBird + # 3062s + - hf_Longformer + # 930s + - hf_T5 + + +# These benchmarks took >60s on an i9-11900K CPU +slow: + - *VERY_SLOW_MODELS + # 137s + - BERT_pytorch + # 116s + - demucs + # 242s + - fastNLP_Bert + # 221s + - hf_Albert + # 400s + - hf_Bart + # 334s + - hf_Bert + # 187s + - hf_DistilBert + # 470s + - hf_GPT2 + # 141s + - hf_Reformer + # 317s + - speech_transformer + # 99s + - vision_maskrcnn + + +non_deterministic: + # https://github.com/pytorch/pytorch/issues/98355 + - mobilenet_v3_large + - sam_fast + + +dtype: + force_amp_for_fp16_bf16_models: + - DALLE2_pytorch + - doctr_det_predictor + - doctr_reco_predictor + - Super_SloMo + - tts_angular + - pyhpc_turbulent_kinetic_energy + - detectron2_fcos_r_50_fpn + + force_fp16_for_bf16_models: + - vision_maskrcnn + + +# models in canary_models that we should run anyway +canary_models: + - torchrec_dlrm + + +detectron2_models: &DETECTRON2_MODELS + - detectron2_fasterrcnn_r_101_c4 + - detectron2_fasterrcnn_r_101_dc5 + - detectron2_fasterrcnn_r_101_fpn + - detectron2_fasterrcnn_r_50_c4 + - detectron2_fasterrcnn_r_50_dc5 + - detectron2_fasterrcnn_r_50_fpn + - detectron2_maskrcnn_r_101_c4 + - detectron2_maskrcnn_r_101_fpn + - detectron2_maskrcnn_r_50_fpn + + +# These models support only train mode. So accuracy checking can't be done in +# eval mode. +only_training: + - *DETECTRON2_MODELS + - tts_angular + - tacotron2 + - demucs + - hf_Reformer + - pytorch_struct + - yolov3 + + +trt_not_yet_working: + - alexnet + - resnet18 + - resnet50 + - mobilenet_v2 + - mnasnet1_0 + - squeezenet1_1 + - shufflenetv2_x1_0 + - vgg16 + - resnext50_32x4d + + +skip: + all: + # OOMs (A100 40G) + - detectron2_maskrcnn + # TIMEOUT, https://github.com/pytorch/pytorch/issues/98467 + - tacotron2 + # Failing in eager mode + - hf_clip + # multi gpu not always available in benchmark runners + - simple_gpt_tp_manual + + device: + cpu: + # OOMs + - hf_T5_generate + # model is CUDA only + - cm3leon_generate + # timeout + - nanogpt + # timeout + - sam + # model is CUDA only + - sam_fast + # model is CUDA only + - llama_v2_7b_16h + # flaky + - stable_diffusion + # requires FBGEMM, CUDA only + - torchrec_dlrm + - simple_gpt + # works on cuda, accuracy failure on cpu + - hf_Whisper + - stable_diffusion_text_encoder + - llava + + cuda: [] + + test: + training: + - *DETECTRON2_MODELS + # not designed for training + - pyhpc_equation_of_state + - pyhpc_isoneutral_mixing + - pyhpc_turbulent_kinetic_energy + - maml + - llama + - llama_v2_7b_16h + - simple_gpt + - sam_fast + # Model's DEFAULT_TRAIN_BSIZE is not implemented + - cm3leon_generate + - hf_T5_generate + - doctr_det_predictor + - doctr_reco_predictor + - moondream + # doesnt fit in memory + - phi_1_5 + - detectron2_fcos_r_50_fpn + + control_flow: + - cm3leon_generate + - detectron2_fcos_r_50_fpn + - fastNLP_Bert + - hf_Longformer + - hf_Reformer + - hf_T5_generate + - opacus_cifar10 + - speech_transformer + + # Models that should only run in --multiprocess mode + multiprocess: + - simple_gpt + + # for these models, conv-batchnorm fusing causes big numerical churn. + # Skip them + freezing: + - mnasnet1_0 + - moco + - shufflenet_v2_x1_0 + + + + +accuracy: + skip: + large_models: + # Models too large to have eager, dynamo and fp64_numbers simultaneosuly + # even for 40 GB machine. We have tested accuracy for smaller version of + # these models + - hf_GPT2_large + - hf_T5_large + - timm_vision_transformer_large + # accuracy https://github.com/pytorch/pytorch/issues/93847 + - maml + - llama_v2_7b_16h + - Background_Matting + - stable_diffusion_unet + eager_not_deterministic: + # Models that deterministic algorithms can not be turned on for eager mode. + - Background_Matting + - pytorch_unet + + max_batch_size: + hf_GPT2: 2 + pytorch_unet: 2