From 842f11db9a83a1cfc22048462980c07a7dfa8909 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 22 Sep 2023 13:00:33 -0700 Subject: [PATCH] Add interface to support multiple batches (#1914) Summary: The interface `self.gen_input_iter()` returns an input generator which yields the next batch of available input. We support dynamic shapes on 3 categories of models: - huggingface (hf_*): we randomize the bucket_len of the input data - timm: we randomize the batch_size of the input data - torchvision: we randomize the batch_size of the input data This PR also introduces a new type of train test, `train_dynamic`, to run the train test with a list of batches. By default, we will run 10 batches of inputs returned from the generator. Pull Request resolved: https://github.com/pytorch/benchmark/pull/1914 Reviewed By: davidberard98 Differential Revision: D49504281 Pulled By: xuzhao9 fbshipit-source-id: 322015ff85725869fe8573dc709a04796332dfed --- run.py | 16 ++-- .../models/detectron2_maskrcnn/__init__.py | 17 ++-- .../models/fambench_xlmr/__init__.py | 4 +- torchbenchmark/util/extra_args.py | 1 + .../framework/detectron2/model_factory.py | 28 +++--- .../framework/huggingface/model_factory.py | 49 +++++------ .../util/framework/timm/model_factory.py | 23 ++--- .../util/framework/vision/model_factory.py | 43 +++++----- .../util/{tensor_cast.py => input.py} | 10 ++- torchbenchmark/util/model.py | 86 ++++++++++++------- 10 files changed, 155 insertions(+), 122 deletions(-) rename torchbenchmark/util/{tensor_cast.py => input.py} (74%) diff --git a/run.py b/run.py index d35c3c918a..27d8f31fe3 100644 --- a/run.py +++ b/run.py @@ -87,17 +87,13 @@ def printResultSummaryTime(result_summary, metrics_needed=[], model=None, flops_ if args.device == "cuda": gpu_time = np.median(list(map(lambda x: x[0], result_summary))) cpu_walltime = np.median(list(map(lambda x: x[1], result_summary))) - if hasattr(model, "NUM_BATCHES"): - print('{:<20} {:>20}'.format("GPU Time per batch:", "%.3f milliseconds" % - (gpu_time / model.NUM_BATCHES), sep='')) - print('{:<20} {:>20}'.format("CPU Wall Time per batch:", "%.3f milliseconds" % - (cpu_walltime / model.NUM_BATCHES), sep='')) - else: - print('{:<20} {:>20}'.format("GPU Time:", "%.3f milliseconds" % gpu_time, sep='')) - print('{:<20} {:>20}'.format("CPU Total Wall Time:", "%.3f milliseconds" % cpu_walltime, sep='')) + print('{:<20} {:>20}'.format("GPU Time per batch:", "%.3f milliseconds" % + (gpu_time / model.num_batch), sep='')) + print('{:<20} {:>20}'.format("CPU Wall Time per batch:", "%.3f milliseconds" % + (cpu_walltime / model.num_batch), sep='')) else: cpu_walltime = np.median(list(map(lambda x: x[0], result_summary))) - print('{:<20} {:>20}'.format("CPU Total Wall Time:", "%.3f milliseconds" % cpu_walltime, sep='')) + print('{:<20} {:>20}'.format("CPU Wall Time per batch:", "%.3f milliseconds" % (cpu_walltime / model.num_batch), sep='')) # if model_flops is not None, output the TFLOPs per sec if 'flops' in metrics_needed: if flops_model_analyzer.metrics_backend_mapping['flops'] == 'dcgm': @@ -298,7 +294,7 @@ def _validate_profile_options(profile_options: str): parser.add_argument( "-t", "--test", - choices=["eval", "train"], + choices=["eval", "train", "train_dynamic"], default="eval", help="Which test to run.", ) diff --git a/torchbenchmark/models/detectron2_maskrcnn/__init__.py b/torchbenchmark/models/detectron2_maskrcnn/__init__.py index a350e33226..e641f6ac7b 100644 --- a/torchbenchmark/models/detectron2_maskrcnn/__init__.py +++ b/torchbenchmark/models/detectron2_maskrcnn/__init__.py @@ -69,26 +69,25 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): self.model.eval() test_loader = instantiate(data_cfg.test) self.example_inputs = prefetch(itertools.islice(test_loader, 100), self.device) - self.NUM_BATCHES = len(self.example_inputs) def get_module(self): return self.model, (self.example_inputs[0], ) def train(self): self.model.train() + idx = 0 with EventStorage(): - for idx in range(self.NUM_BATCHES): - losses = self.model(self.example_inputs[idx]) - loss = sum(losses.values()) - loss.backward() - self.optimizer.step() - self.optimizer.zero_grad() + losses = self.model(self.example_inputs[idx]) + loss = sum(losses.values()) + loss.backward() + self.optimizer.step() + self.optimizer.zero_grad() def eval(self) -> Tuple[torch.Tensor]: self.model.eval() + idx = 0 with torch.no_grad(): - for idx in range(self.NUM_BATCHES): - out = self.model(self.example_inputs[idx]) + out = self.model(self.example_inputs[idx]) # retrieve output tensors outputs = [] for item in out: diff --git a/torchbenchmark/models/fambench_xlmr/__init__.py b/torchbenchmark/models/fambench_xlmr/__init__.py index 07a029a381..faf0489495 100644 --- a/torchbenchmark/models/fambench_xlmr/__init__.py +++ b/torchbenchmark/models/fambench_xlmr/__init__.py @@ -46,7 +46,6 @@ class Model(BenchmarkModel): # We use the same batch size for train and inference (96), ... # ... but runs only 1 batch DEFAULT_FAM_CONFIG = "fb-1dev-A" - DEFAULT_NUM_BATCHES = 1 DEFAULT_TRAIN_BSIZE = 96 DEFAULT_EVAL_BSIZE = 96 DEFAULT_SEQ_LENGTH = 64 @@ -58,10 +57,11 @@ class Model(BenchmarkModel): def __init__(self, test, device, batch_size=None, extra_args=[]): super().__init__(test=test, device=device, batch_size=batch_size, extra_args=extra_args) + num_batches = 1 self.xlmr = fairseq.models.roberta.XLMRModel.from_pretrained("xlmr.large") parser = init_argparse() args = parser.parse_args([f"--famconfig={self.DEFAULT_FAM_CONFIG}", - f"--num-batches={self.DEFAULT_NUM_BATCHES}", f"--batch-size={self.batch_size} ", \ + f"--num-batches={num_batches}", f"--batch-size={self.batch_size} ", \ f"--sequence-length={self.DEFAULT_SEQ_LENGTH}", f"--vocab-size={self.DEFAULT_VOCAB_SIZE}"]) if self.device == "cuda": args.use_gpu = True diff --git a/torchbenchmark/util/extra_args.py b/torchbenchmark/util/extra_args.py index 30d5b6ebd6..0d1b1f7754 100644 --- a/torchbenchmark/util/extra_args.py +++ b/torchbenchmark/util/extra_args.py @@ -72,6 +72,7 @@ def parse_decoration_args(model: 'torchbenchmark.util.model.BenchmarkModel', ext parser.add_argument("--accuracy", action="store_true", help="Check accuracy of the model only instead of running the performance test.") parser.add_argument("--use_cosine_similarity", action='store_true', help="use cosine similarity for correctness check") parser.add_argument("--quant-engine", choices=QUANT_ENGINES, default='x86', help=f"choose quantization engine for fx_int8 precision from {QUANT_ENGINES}") + parser.add_argument("--num-batch", type=int, help="Number of batches if running the train_dynamic test.") dargs, opt_args = parser.parse_known_args(extra_args) if not check_precision(model, dargs.precision): raise NotImplementedError(f"precision value: {dargs.precision}, " diff --git a/torchbenchmark/util/framework/detectron2/model_factory.py b/torchbenchmark/util/framework/detectron2/model_factory.py index d31d96b023..c11e9f308e 100644 --- a/torchbenchmark/util/framework/detectron2/model_factory.py +++ b/torchbenchmark/util/framework/detectron2/model_factory.py @@ -100,7 +100,7 @@ def __init__(self, variant, test, device, batch_size=None, extra_args=[]): self.model = instantiate(cfg.model).to(self.device) # setup model and return the dataloader - if self.test == "train": + if self.test == "train" or self.test == "train_dynamic": if hasattr(self, "FCOS_USE_BN") and self.FCOS_USE_BN: raise NotImplementedError("FCOS train is not supported by upstream detectron2. " \ "See GH Issue: https://github.com/facebookresearch/detectron2/issues/4369.") @@ -110,8 +110,6 @@ def __init__(self, variant, test, device, batch_size=None, extra_args=[]): loader = self.setup_eval(cfg, args) self.example_inputs = prefetch(itertools.islice(loader, 100), self.device) - # torchbench: only run 1 batch - self.NUM_BATCHES = 1 def setup_train(self): if hasattr(self, "FCOS_USE_BN") and self.FCOS_USE_BN: @@ -163,22 +161,22 @@ def enable_fp16(self): self.example_inputs = prefetch(self.example_inputs, self.device, self.dargs.precision) def train(self): + batch_id = 0 with EventStorage(): - for batch_id in range(self.NUM_BATCHES): - loss_dict = self.model(self.example_inputs[batch_id]) - if isinstance(loss_dict, torch.Tensor): - losses = loss_dict - loss_dict = {"total_loss": loss_dict} - else: - losses = sum(loss_dict.values()) - self.optimizer.zero_grad() - losses.backward() - self.optimizer.step() + loss_dict = self.model(self.example_inputs[batch_id]) + if isinstance(loss_dict, torch.Tensor): + losses = loss_dict + loss_dict = {"total_loss": loss_dict} + else: + losses = sum(loss_dict.values()) + self.optimizer.zero_grad() + losses.backward() + self.optimizer.step() def eval(self) -> Tuple[torch.Tensor]: + batch_id = 0 with torch.no_grad(): - for batch_id in range(self.NUM_BATCHES): - out = self.model(self.example_inputs[batch_id]) + out = self.model(self.example_inputs[batch_id]) # retrieve output tensors outputs = [] for item in out: diff --git a/torchbenchmark/util/framework/huggingface/model_factory.py b/torchbenchmark/util/framework/huggingface/model_factory.py index 4ec3f7ebfe..93661b688a 100644 --- a/torchbenchmark/util/framework/huggingface/model_factory.py +++ b/torchbenchmark/util/framework/huggingface/model_factory.py @@ -56,6 +56,8 @@ class HuggingFaceModel(BenchmarkModel): HF_MODEL = True # Default eval precision on CUDA device is fp16(half mode) DEFAULT_EVAL_CUDA_PRECISION = "fp16" + # When running the train_dynamic test, run 100 batches of input + DEFAULT_NUM_BATCH = 10 # If you suffix a model with '_generate', we will instead wrap the # unsuffixed model with GenerationWrapper which will make it do @@ -72,7 +74,7 @@ def __init__(self, name, test, device, batch_size=None, extra_args=[]): self.is_generate = False self.unqual_name = name name = self.unqual_name # we don't want to refer to the qualified name anymore - if test == "train": + if test == "train" or test == "train_dynamic": self.max_length = class_models[name][0] elif test == "eval": self.max_length = class_models[name][1] @@ -97,7 +99,6 @@ def __init__(self, name, test, device, batch_size=None, extra_args=[]): # populate these on-demand to avoid wasting memory when not used self.vocab_size = config.vocab_size - self.dynamic_example_inputs = None if test == "train": input_ids = torch.randint(0, config.vocab_size, (self.batch_size, self.max_length)).to(device) @@ -113,7 +114,6 @@ def __init__(self, name, test, device, batch_size=None, extra_args=[]): if class_models[name][3] == 'AutoModelForSeq2SeqLM': self.example_inputs['decoder_input_ids'] = eval_context self.model.eval() - self.amp_context = nullcontext def get_module(self, wrap_model=True): @@ -126,30 +126,31 @@ def get_module(self, wrap_model=True): self.example_inputs['input_ids'], self.example_inputs[k]) return self.model, (self.example_inputs["input_ids"], ) - def get_dynamic_shapes_module(self): - if self.dynamic_example_inputs is None: - nbuckets = 8 - nsamples = 32 - n = int(math.log2(self.max_length)) - buckets = [2**n for n in range(n - nbuckets, n)] - self.dynamic_example_inputs = [ - { - 'input_ids': torch.randint(0, self.vocab_size, (self.batch_size, bucket_len)).to(self.device), - 'labels': torch.randint(0, self.vocab_size, (self.batch_size, bucket_len)).to(self.device)} - for bucket_len in random.choices(buckets, k=nsamples) - ] - + def get_input_iter(self): + """Yield randomized bucket length of inputs.""" + nbuckets = 8 + n = int(math.log2(self.max_length)) + buckets = [2**n for n in range(n - nbuckets, n)] if class_models[self.unqual_name][3] == 'AutoModelForSeq2SeqLM': - raise NotImplementedError("Not yet supported") - - # TODO(whc) why is labels not passed through? - return self.model, [(i['input_ids'],) for i in self.dynamic_example_inputs] - - def train(self): + raise NotImplementedError("AutoModelForSeq2SeqLM is not yet supported") + while True: + # randomize bucket_len + bucket_len = random.choice(buckets) + dict_input = { + 'input_ids': torch.randint(0, self.vocab_size, (self.batch_size, bucket_len)).to(self.device), + 'labels': torch.randint(0, self.vocab_size, (self.batch_size, bucket_len)).to(self.device), + } + yield dict_input + + def forward(self): with self.amp_context(): outputs = self.model(**self.example_inputs) - loss = outputs.loss - loss.backward() + return outputs.loss + + def backward(self, losses): + losses.backward() + + def optimizer_step(self): self.optimizer.step() def eval(self) -> Tuple[torch.Tensor]: diff --git a/torchbenchmark/util/framework/timm/model_factory.py b/torchbenchmark/util/framework/timm/model_factory.py index 17010349fe..6e31b2b613 100644 --- a/torchbenchmark/util/framework/timm/model_factory.py +++ b/torchbenchmark/util/framework/timm/model_factory.py @@ -14,6 +14,8 @@ class TimmModel(BenchmarkModel): DEFAULT_EVAL_BSIZE = None # Default eval precision on CUDA device is fp16 DEFAULT_EVAL_CUDA_PRECISION = "fp16" + # When running the train_dynamic test, run 100 batches of input + DEFAULT_NUM_BATCH = 10 def __init__(self, model_name, test, device, batch_size=None, extra_args=[]): super().__init__(test=test, device=device, batch_size=batch_size, extra_args=extra_args) @@ -27,22 +29,21 @@ def __init__(self, model_name, test, device, batch_size=None, extra_args=[]): self.model.to( device=self.device ) - if test == "train": + if test == "train" or test == "train_dynamic": self.model.train() elif test == "eval": self.model.eval() self.amp_context = suppress - def gen_inputs(self, num_batches:int=1) -> Tuple[Generator, Optional[int]]: - def _gen_inputs(): - while True: - result = [] - for _i in range(num_batches): - result.append((self._gen_input(self.batch_size), )) - if self.dargs.precision == "fp16": - result = list(map(lambda x: (x[0].half(), ), result)) - yield result - return (_gen_inputs(), None) + def get_input_iter(self): + """Yield randomized batch size of inputs.""" + import math, random + n = int(math.log2(self.batch_size)) + buckets = [2**n for n in range(n)] + while True: + random_batch_size = random.choice(buckets) + example_input = (self._gen_input(random_batch_size), ) + yield example_input def _gen_input(self, batch_size): return torch.randn((batch_size,) + self.cfg.input_size, device=self.device) diff --git a/torchbenchmark/util/framework/vision/model_factory.py b/torchbenchmark/util/framework/vision/model_factory.py index 4e442c54d3..5a49ebe455 100644 --- a/torchbenchmark/util/framework/vision/model_factory.py +++ b/torchbenchmark/util/framework/vision/model_factory.py @@ -5,7 +5,6 @@ import torchvision.models as models from contextlib import nullcontext from torchbenchmark.util.model import BenchmarkModel -from typing import Tuple, Generator, Optional class TorchVisionModel(BenchmarkModel): # To recognize this is a torchvision model @@ -17,6 +16,8 @@ class TorchVisionModel(BenchmarkModel): DEFAULT_EVAL_CUDA_PRECISION = "fp16" # Whether to skip the opt zero grad SKIP_ZERO_GRAD = False + # When running the train_dynamic test, run 100 batches of input + DEFAULT_NUM_BATCH = 10 def __init__(self, model_name, test, device, batch_size=None, weights=None, extra_args=[]): super().__init__(test=test, device=device, batch_size=batch_size, extra_args=extra_args) @@ -28,7 +29,7 @@ def __init__(self, model_name, test, device, batch_size=None, weights=None, extr else: self.model = getattr(models, model_name)(weights=weights).to(self.device) self.example_inputs = (torch.randn((self.batch_size, 3, 224, 224)).to(self.device), ) - if test == "train": + if test == "train" or test == "train_dynamic": # compute loss with torch.no_grad(): self.example_outputs = (torch.rand_like(self.model(*self.example_inputs)), ) @@ -58,29 +59,31 @@ def get_flops(self): self.flops = self.flops * FLOPS_FMA return self.flops - def gen_inputs(self, num_batches:int=1) -> Tuple[Generator, Optional[int]]: - def _gen_inputs(): - while True: - result = [] - for _i in range(num_batches): - result.append((torch.randn((self.batch_size, 3, 224, 224)).to(self.device),)) - if self.dargs.precision == "fp16": - result = list(map(lambda x: (x[0].half(), ), result)) - yield result - return (_gen_inputs(), None) + def get_input_iter(self): + """Yield randomized batch size of inputs.""" + import math, random + n = int(math.log2(self.batch_size)) + buckets = [2**n for n in range(n)] + while True: + random_batch_size = random.choice(buckets) + example_input = (torch.randn((random_batch_size, 3, 224, 224)).to(self.device), ) + yield example_input def get_module(self): return self.model, self.example_inputs - def train(self): - if self.opt and not self.SKIP_ZERO_GRAD: - self.opt.zero_grad() + def forward(self): + with torch.no_grad(): + self.example_outputs = (torch.rand_like(self.model(*self.example_inputs)), ) for data, target in zip(self.example_inputs, self.example_outputs): - with self.amp_context(): - pred = self.model(data) - self.loss_fn(pred, target).backward() - if self.opt: - self.opt.step() + pred = self.model(data) + return self.loss_fn(pred, target) + + def backward(self, loss): + loss.backward() + + def optimizer_step(self): + self.opt.step() def cudagraph_train(self): for data, target in zip(self.real_input, self.real_output): diff --git a/torchbenchmark/util/tensor_cast.py b/torchbenchmark/util/input.py similarity index 74% rename from torchbenchmark/util/tensor_cast.py rename to torchbenchmark/util/input.py index d2026f4af7..7e44fa1bfe 100644 --- a/torchbenchmark/util/tensor_cast.py +++ b/torchbenchmark/util/input.py @@ -1,12 +1,18 @@ import torch from torch.utils._pytree import tree_map +from dataclasses import dataclass, field +from typing import Dict, Any, Optional -def inputs_cast(cond, action, example_inputs): +@dataclass +class ModelInputDescriptor: + pass + +def input_cast(cond, action, example_inputs): """Traverse the input batch pytree, and cast tensor with `action` if it satisfies `cond`.""" if isinstance(example_inputs, torch.Tensor) and cond(example_inputs): return action(example_inputs) elif isinstance(example_inputs, (tuple, list, dict)): - return tree_map(lambda x: inputs_cast(cond, action, x), example_inputs) + return tree_map(lambda x: input_cast(cond, action, x), example_inputs) elif example_inputs is None or \ isinstance(example_inputs, str) or \ isinstance(example_inputs, int) or \ diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py index b3312f34fe..c85f466539 100644 --- a/torchbenchmark/util/model.py +++ b/torchbenchmark/util/model.py @@ -12,7 +12,7 @@ from torchbenchmark.util.env_check import set_random_seed, is_hf_model, \ save_deterministic_dict, load_deterministic_dict, check_accuracy from torchbenchmark.util.fx_int8 import get_sub_module, prepare_sub_module, convert_sub_module -from torchbenchmark.util.tensor_cast import inputs_cast +from torchbenchmark.util.input import input_cast, ModelInputDescriptor SPECIAL_DEVICE_MAPPING = { "AMD Instinct MI210": "NVIDIA A100-SXM4-40GB" @@ -81,14 +81,17 @@ class BenchmarkModel(metaclass=PostInitProcessor): def __init__(self, test: str, device: str, batch_size: Optional[int]=None, extra_args: List[str]=[]): self.metadata = self._load_metadata() self.test = test - assert self.test == "train" or self.test == "eval", \ - f"Test must be 'train' or 'eval', but get {self.test}. Please submit a bug report." + # sanity checks of the options + assert self.test == "train" or self.test == "eval" or self.test == "train_dynamic", \ + f"Test must be 'train', 'train_dynamic', or 'eval', but provided {self.test}." + assert self.test == "train_dynamic" and is_staged_train_test(self) or (not self.test == "train_dynamic"), \ + f"Dynamic shapes must be implemented with staged train test." self.device = device self.extra_args = extra_args self.opt = None self._skip_by_device_name() # contexts to run in the test function - if self.test == "train": + if self.test == "train" or self.test == "train_dynamic": # In train test, there are run contexts that should only be applied for forward/backward/optimizer stage # For example, amp only applies for the forward stage self.forward_contexts = [] @@ -99,8 +102,6 @@ def __init__(self, test: str, device: str, batch_size: Optional[int]=None, extra ] set_random_seed() - # sanity checks of the options - assert self.test == "train" or self.test == "eval", f"Test must be 'train' or 'eval', but provided {self.test}." # parse the args self.dargs, opt_args = parse_decoration_args(self, self.extra_args) if self.dargs.accuracy and not self.DISABLE_DETERMINISM: @@ -114,6 +115,7 @@ def __init__(self, test: str, device: str, batch_size: Optional[int]=None, extra self.dynamo = False self.opt_args, self.extra_args = parse_opt_args(self, opt_args) self._determine_batch_size(batch_size) + self.num_batch = self._determine_dynamic_num_batches(self.dargs.num_batch) # Run the post processing for model acceleration def __post__init__(self): @@ -161,6 +163,14 @@ def _skip_by_device_name(self): if skip_device == current_device_name and (not skip_test or skip_test == self.test): raise NotImplementedError(f"The current device {current_device_name} is skipped by its `{self.name}/metadata.yaml`.") + def _determine_dynamic_num_batches(self, user_specified_num_batches: Optional[int]) -> int: + if self.test == "train" or self.test == "eval": + return 1 + if user_specified_num_batches: + return user_specified_num_batches + assert hasattr(self, 'DEFAULT_NUM_BATCH'), f"We expect all models with dynamic shapes specify field `DEFAULT_NUM_BATCHES`." + return self.DEFAULT_NUM_BATCH + def _determine_batch_size(self, batch_size=None): # batch size priority for eval tests: not ALLOW_CUSTOMIZE_BSIZE > user specified > device specified > default # batch size priority for train tests: not ALLOW_CUSTOMIZE_BSIZE > user specified > default @@ -250,34 +260,50 @@ def set_module(self, new_model): else: raise NotImplementedError("The instance variable 'model' does not exist or is not type 'torch.nn.Module', implement your own `set_module()` function.") - def gen_inputs(self, num_batches: int=1) -> Tuple[Generator, Optional[int]]: - """Generate a tuple of (iterator of model input, the size of the iterator). - If size is None, the input is randomly generated and has infinite size.""" - raise NotImplementedError("Default input generation function is not implemented. " - "Please submit an issue if you need input iterator implementation for the model.") + def get_input_iter(self) -> Generator: + """Return the dynamic input iterator for the model.""" + raise NotImplementedError(f"Default dynamic input iterator is not implemented. " + "Please submit an issue if you need a dynamic shape input iterator implementation for the model {self.name}.") + + def get_input_descriptor(self) -> ModelInputDescriptor: + if hasattr(self, 'input_descriptor') and isinstance(self.input_descriptor, ModelInputDescriptor): + return self.input_descriptor + raise NotImplementedError(f"Default dynamic input descriptor is not implemented. " + "Please submit an issue if you need a dynamic shape input iterator implementation for the model {self.name}.") + + def set_input_descriptor(self, descriptor: ModelInputDescriptor) -> None: + """Set the customized dynamic input descriptor for the model.""" + if hasattr(self, 'input_descriptor') and isinstance(self.input_descriptor, ModelInputDescriptor): + self.input_descriptor = descriptor + return + raise NotImplementedError(f"Default dynamic input descriptor is not implemented." + "Please submit an issue if you need a dynamic shape input descriptor implementation for the model {self.name}.") - def invoke_staged_train_test(self) -> None: + def _invoke_staged_train_test(self, num_batch: int) -> None: optimizer = self.get_optimizer() - if optimizer is not None: - optimizer.zero_grad() - - with nested(*self.forward_contexts): - losses = self.forward() - - with nested(*self.backward_contexts): - self.backward(losses) - - if optimizer is not None: - with nested(*self.optimizer_contexts): - self.optimizer_step() - + input_generator = self.get_input_iter() if not num_batch == 1 else None + for _batch_num in range(num_batch): + if input_generator: + self.example_inputs = next(input_generator) + # cast inputs if needed + apply_decoration_args(self, self.dargs) + if optimizer is not None: + optimizer.zero_grad() + with nested(*self.forward_contexts): + losses = self.forward() + with nested(*self.backward_contexts): + self.backward(losses) + if optimizer is not None: + with nested(*self.optimizer_contexts): + self.optimizer_step() return None def invoke(self) -> Optional[Tuple[torch.Tensor]]: - out = None if self.test == "train" and is_staged_train_test(self): - self.invoke_staged_train_test() - return out + return self._invoke_staged_train_test(num_batch=self.num_batch) + if self.test == "train_dynamic": + return self._invoke_staged_train_test(num_batch=self.num_batch) + out = None with nested(*self.run_contexts): if self.test == "train": self.train() @@ -317,7 +343,7 @@ def _cast_to(self, cond, action): return self.set_module(model) if hasattr(self, 'example_inputs'): - self.example_inputs = inputs_cast(cond, action, self.example_inputs) + self.example_inputs = input_cast(cond, action, self.example_inputs) else: warnings.warn(UserWarning(f"{model_name} example inputs doesn't cast to {action} yet!")) @@ -345,6 +371,8 @@ def enable_amp(self): self.amp_context = lambda: torch.cpu.amp.autocast() elif self.device == "cuda": self.amp_context = lambda: torch.cuda.amp.autocast() + if is_staged_train_test(self): + self.forward_contexts.append(self.amp_context) @property def pt2_compilation_time(self):