Skip to content

Commit

Permalink
Add interface to support multiple batches (#1914)
Browse files Browse the repository at this point in the history
Summary:
The interface `self.gen_input_iter()` returns an input generator which yields the next batch of available input.

We support dynamic shapes on 3 categories of models:
- huggingface (hf_*): we randomize the bucket_len of the input data
- timm: we randomize the batch_size of the input data
- torchvision: we randomize the batch_size of the input data

This PR also introduces a new type of train test, `train_dynamic`, to run the train test with a list of batches. By default, we will run 10 batches of inputs returned from the generator.

Pull Request resolved: #1914

Reviewed By: davidberard98

Differential Revision: D49504281

Pulled By: xuzhao9

fbshipit-source-id: 322015ff85725869fe8573dc709a04796332dfed
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Sep 22, 2023
1 parent d00b942 commit 842f11d
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 122 deletions.
16 changes: 6 additions & 10 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,13 @@ def printResultSummaryTime(result_summary, metrics_needed=[], model=None, flops_
if args.device == "cuda":
gpu_time = np.median(list(map(lambda x: x[0], result_summary)))
cpu_walltime = np.median(list(map(lambda x: x[1], result_summary)))
if hasattr(model, "NUM_BATCHES"):
print('{:<20} {:>20}'.format("GPU Time per batch:", "%.3f milliseconds" %
(gpu_time / model.NUM_BATCHES), sep=''))
print('{:<20} {:>20}'.format("CPU Wall Time per batch:", "%.3f milliseconds" %
(cpu_walltime / model.NUM_BATCHES), sep=''))
else:
print('{:<20} {:>20}'.format("GPU Time:", "%.3f milliseconds" % gpu_time, sep=''))
print('{:<20} {:>20}'.format("CPU Total Wall Time:", "%.3f milliseconds" % cpu_walltime, sep=''))
print('{:<20} {:>20}'.format("GPU Time per batch:", "%.3f milliseconds" %
(gpu_time / model.num_batch), sep=''))
print('{:<20} {:>20}'.format("CPU Wall Time per batch:", "%.3f milliseconds" %
(cpu_walltime / model.num_batch), sep=''))
else:
cpu_walltime = np.median(list(map(lambda x: x[0], result_summary)))
print('{:<20} {:>20}'.format("CPU Total Wall Time:", "%.3f milliseconds" % cpu_walltime, sep=''))
print('{:<20} {:>20}'.format("CPU Wall Time per batch:", "%.3f milliseconds" % (cpu_walltime / model.num_batch), sep=''))
# if model_flops is not None, output the TFLOPs per sec
if 'flops' in metrics_needed:
if flops_model_analyzer.metrics_backend_mapping['flops'] == 'dcgm':
Expand Down Expand Up @@ -298,7 +294,7 @@ def _validate_profile_options(profile_options: str):
parser.add_argument(
"-t",
"--test",
choices=["eval", "train"],
choices=["eval", "train", "train_dynamic"],
default="eval",
help="Which test to run.",
)
Expand Down
17 changes: 8 additions & 9 deletions torchbenchmark/models/detectron2_maskrcnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,26 +69,25 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
self.model.eval()
test_loader = instantiate(data_cfg.test)
self.example_inputs = prefetch(itertools.islice(test_loader, 100), self.device)
self.NUM_BATCHES = len(self.example_inputs)

def get_module(self):
return self.model, (self.example_inputs[0], )

def train(self):
self.model.train()
idx = 0
with EventStorage():
for idx in range(self.NUM_BATCHES):
losses = self.model(self.example_inputs[idx])
loss = sum(losses.values())
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
losses = self.model(self.example_inputs[idx])
loss = sum(losses.values())
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()

def eval(self) -> Tuple[torch.Tensor]:
self.model.eval()
idx = 0
with torch.no_grad():
for idx in range(self.NUM_BATCHES):
out = self.model(self.example_inputs[idx])
out = self.model(self.example_inputs[idx])
# retrieve output tensors
outputs = []
for item in out:
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/models/fambench_xlmr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class Model(BenchmarkModel):
# We use the same batch size for train and inference (96), ...
# ... but runs only 1 batch
DEFAULT_FAM_CONFIG = "fb-1dev-A"
DEFAULT_NUM_BATCHES = 1
DEFAULT_TRAIN_BSIZE = 96
DEFAULT_EVAL_BSIZE = 96
DEFAULT_SEQ_LENGTH = 64
Expand All @@ -58,10 +57,11 @@ class Model(BenchmarkModel):

def __init__(self, test, device, batch_size=None, extra_args=[]):
super().__init__(test=test, device=device, batch_size=batch_size, extra_args=extra_args)
num_batches = 1
self.xlmr = fairseq.models.roberta.XLMRModel.from_pretrained("xlmr.large")
parser = init_argparse()
args = parser.parse_args([f"--famconfig={self.DEFAULT_FAM_CONFIG}",
f"--num-batches={self.DEFAULT_NUM_BATCHES}", f"--batch-size={self.batch_size} ", \
f"--num-batches={num_batches}", f"--batch-size={self.batch_size} ", \
f"--sequence-length={self.DEFAULT_SEQ_LENGTH}", f"--vocab-size={self.DEFAULT_VOCAB_SIZE}"])
if self.device == "cuda":
args.use_gpu = True
Expand Down
1 change: 1 addition & 0 deletions torchbenchmark/util/extra_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def parse_decoration_args(model: 'torchbenchmark.util.model.BenchmarkModel', ext
parser.add_argument("--accuracy", action="store_true", help="Check accuracy of the model only instead of running the performance test.")
parser.add_argument("--use_cosine_similarity", action='store_true', help="use cosine similarity for correctness check")
parser.add_argument("--quant-engine", choices=QUANT_ENGINES, default='x86', help=f"choose quantization engine for fx_int8 precision from {QUANT_ENGINES}")
parser.add_argument("--num-batch", type=int, help="Number of batches if running the train_dynamic test.")
dargs, opt_args = parser.parse_known_args(extra_args)
if not check_precision(model, dargs.precision):
raise NotImplementedError(f"precision value: {dargs.precision}, "
Expand Down
28 changes: 13 additions & 15 deletions torchbenchmark/util/framework/detectron2/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, variant, test, device, batch_size=None, extra_args=[]):
self.model = instantiate(cfg.model).to(self.device)

# setup model and return the dataloader
if self.test == "train":
if self.test == "train" or self.test == "train_dynamic":
if hasattr(self, "FCOS_USE_BN") and self.FCOS_USE_BN:
raise NotImplementedError("FCOS train is not supported by upstream detectron2. " \
"See GH Issue: https://github.com/facebookresearch/detectron2/issues/4369.")
Expand All @@ -110,8 +110,6 @@ def __init__(self, variant, test, device, batch_size=None, extra_args=[]):
loader = self.setup_eval(cfg, args)

self.example_inputs = prefetch(itertools.islice(loader, 100), self.device)
# torchbench: only run 1 batch
self.NUM_BATCHES = 1

def setup_train(self):
if hasattr(self, "FCOS_USE_BN") and self.FCOS_USE_BN:
Expand Down Expand Up @@ -163,22 +161,22 @@ def enable_fp16(self):
self.example_inputs = prefetch(self.example_inputs, self.device, self.dargs.precision)

def train(self):
batch_id = 0
with EventStorage():
for batch_id in range(self.NUM_BATCHES):
loss_dict = self.model(self.example_inputs[batch_id])
if isinstance(loss_dict, torch.Tensor):
losses = loss_dict
loss_dict = {"total_loss": loss_dict}
else:
losses = sum(loss_dict.values())
self.optimizer.zero_grad()
losses.backward()
self.optimizer.step()
loss_dict = self.model(self.example_inputs[batch_id])
if isinstance(loss_dict, torch.Tensor):
losses = loss_dict
loss_dict = {"total_loss": loss_dict}
else:
losses = sum(loss_dict.values())
self.optimizer.zero_grad()
losses.backward()
self.optimizer.step()

def eval(self) -> Tuple[torch.Tensor]:
batch_id = 0
with torch.no_grad():
for batch_id in range(self.NUM_BATCHES):
out = self.model(self.example_inputs[batch_id])
out = self.model(self.example_inputs[batch_id])
# retrieve output tensors
outputs = []
for item in out:
Expand Down
49 changes: 25 additions & 24 deletions torchbenchmark/util/framework/huggingface/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class HuggingFaceModel(BenchmarkModel):
HF_MODEL = True
# Default eval precision on CUDA device is fp16(half mode)
DEFAULT_EVAL_CUDA_PRECISION = "fp16"
# When running the train_dynamic test, run 100 batches of input
DEFAULT_NUM_BATCH = 10

# If you suffix a model with '_generate', we will instead wrap the
# unsuffixed model with GenerationWrapper which will make it do
Expand All @@ -72,7 +74,7 @@ def __init__(self, name, test, device, batch_size=None, extra_args=[]):
self.is_generate = False
self.unqual_name = name
name = self.unqual_name # we don't want to refer to the qualified name anymore
if test == "train":
if test == "train" or test == "train_dynamic":
self.max_length = class_models[name][0]
elif test == "eval":
self.max_length = class_models[name][1]
Expand All @@ -97,7 +99,6 @@ def __init__(self, name, test, device, batch_size=None, extra_args=[]):

# populate these on-demand to avoid wasting memory when not used
self.vocab_size = config.vocab_size
self.dynamic_example_inputs = None

if test == "train":
input_ids = torch.randint(0, config.vocab_size, (self.batch_size, self.max_length)).to(device)
Expand All @@ -113,7 +114,6 @@ def __init__(self, name, test, device, batch_size=None, extra_args=[]):
if class_models[name][3] == 'AutoModelForSeq2SeqLM':
self.example_inputs['decoder_input_ids'] = eval_context
self.model.eval()

self.amp_context = nullcontext

def get_module(self, wrap_model=True):
Expand All @@ -126,30 +126,31 @@ def get_module(self, wrap_model=True):
self.example_inputs['input_ids'], self.example_inputs[k])
return self.model, (self.example_inputs["input_ids"], )

def get_dynamic_shapes_module(self):
if self.dynamic_example_inputs is None:
nbuckets = 8
nsamples = 32
n = int(math.log2(self.max_length))
buckets = [2**n for n in range(n - nbuckets, n)]
self.dynamic_example_inputs = [
{
'input_ids': torch.randint(0, self.vocab_size, (self.batch_size, bucket_len)).to(self.device),
'labels': torch.randint(0, self.vocab_size, (self.batch_size, bucket_len)).to(self.device)}
for bucket_len in random.choices(buckets, k=nsamples)
]

def get_input_iter(self):
"""Yield randomized bucket length of inputs."""
nbuckets = 8
n = int(math.log2(self.max_length))
buckets = [2**n for n in range(n - nbuckets, n)]
if class_models[self.unqual_name][3] == 'AutoModelForSeq2SeqLM':
raise NotImplementedError("Not yet supported")

# TODO(whc) why is labels not passed through?
return self.model, [(i['input_ids'],) for i in self.dynamic_example_inputs]

def train(self):
raise NotImplementedError("AutoModelForSeq2SeqLM is not yet supported")
while True:
# randomize bucket_len
bucket_len = random.choice(buckets)
dict_input = {
'input_ids': torch.randint(0, self.vocab_size, (self.batch_size, bucket_len)).to(self.device),
'labels': torch.randint(0, self.vocab_size, (self.batch_size, bucket_len)).to(self.device),
}
yield dict_input

def forward(self):
with self.amp_context():
outputs = self.model(**self.example_inputs)
loss = outputs.loss
loss.backward()
return outputs.loss

def backward(self, losses):
losses.backward()

def optimizer_step(self):
self.optimizer.step()

def eval(self) -> Tuple[torch.Tensor]:
Expand Down
23 changes: 12 additions & 11 deletions torchbenchmark/util/framework/timm/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class TimmModel(BenchmarkModel):
DEFAULT_EVAL_BSIZE = None
# Default eval precision on CUDA device is fp16
DEFAULT_EVAL_CUDA_PRECISION = "fp16"
# When running the train_dynamic test, run 100 batches of input
DEFAULT_NUM_BATCH = 10

def __init__(self, model_name, test, device, batch_size=None, extra_args=[]):
super().__init__(test=test, device=device, batch_size=batch_size, extra_args=extra_args)
Expand All @@ -27,22 +29,21 @@ def __init__(self, model_name, test, device, batch_size=None, extra_args=[]):
self.model.to(
device=self.device
)
if test == "train":
if test == "train" or test == "train_dynamic":
self.model.train()
elif test == "eval":
self.model.eval()
self.amp_context = suppress

def gen_inputs(self, num_batches:int=1) -> Tuple[Generator, Optional[int]]:
def _gen_inputs():
while True:
result = []
for _i in range(num_batches):
result.append((self._gen_input(self.batch_size), ))
if self.dargs.precision == "fp16":
result = list(map(lambda x: (x[0].half(), ), result))
yield result
return (_gen_inputs(), None)
def get_input_iter(self):
"""Yield randomized batch size of inputs."""
import math, random
n = int(math.log2(self.batch_size))
buckets = [2**n for n in range(n)]
while True:
random_batch_size = random.choice(buckets)
example_input = (self._gen_input(random_batch_size), )
yield example_input

def _gen_input(self, batch_size):
return torch.randn((batch_size,) + self.cfg.input_size, device=self.device)
Expand Down
43 changes: 23 additions & 20 deletions torchbenchmark/util/framework/vision/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torchvision.models as models
from contextlib import nullcontext
from torchbenchmark.util.model import BenchmarkModel
from typing import Tuple, Generator, Optional

class TorchVisionModel(BenchmarkModel):
# To recognize this is a torchvision model
Expand All @@ -17,6 +16,8 @@ class TorchVisionModel(BenchmarkModel):
DEFAULT_EVAL_CUDA_PRECISION = "fp16"
# Whether to skip the opt zero grad
SKIP_ZERO_GRAD = False
# When running the train_dynamic test, run 100 batches of input
DEFAULT_NUM_BATCH = 10

def __init__(self, model_name, test, device, batch_size=None, weights=None, extra_args=[]):
super().__init__(test=test, device=device, batch_size=batch_size, extra_args=extra_args)
Expand All @@ -28,7 +29,7 @@ def __init__(self, model_name, test, device, batch_size=None, weights=None, extr
else:
self.model = getattr(models, model_name)(weights=weights).to(self.device)
self.example_inputs = (torch.randn((self.batch_size, 3, 224, 224)).to(self.device), )
if test == "train":
if test == "train" or test == "train_dynamic":
# compute loss
with torch.no_grad():
self.example_outputs = (torch.rand_like(self.model(*self.example_inputs)), )
Expand Down Expand Up @@ -58,29 +59,31 @@ def get_flops(self):
self.flops = self.flops * FLOPS_FMA
return self.flops

def gen_inputs(self, num_batches:int=1) -> Tuple[Generator, Optional[int]]:
def _gen_inputs():
while True:
result = []
for _i in range(num_batches):
result.append((torch.randn((self.batch_size, 3, 224, 224)).to(self.device),))
if self.dargs.precision == "fp16":
result = list(map(lambda x: (x[0].half(), ), result))
yield result
return (_gen_inputs(), None)
def get_input_iter(self):
"""Yield randomized batch size of inputs."""
import math, random
n = int(math.log2(self.batch_size))
buckets = [2**n for n in range(n)]
while True:
random_batch_size = random.choice(buckets)
example_input = (torch.randn((random_batch_size, 3, 224, 224)).to(self.device), )
yield example_input

def get_module(self):
return self.model, self.example_inputs

def train(self):
if self.opt and not self.SKIP_ZERO_GRAD:
self.opt.zero_grad()
def forward(self):
with torch.no_grad():
self.example_outputs = (torch.rand_like(self.model(*self.example_inputs)), )
for data, target in zip(self.example_inputs, self.example_outputs):
with self.amp_context():
pred = self.model(data)
self.loss_fn(pred, target).backward()
if self.opt:
self.opt.step()
pred = self.model(data)
return self.loss_fn(pred, target)

def backward(self, loss):
loss.backward()

def optimizer_step(self):
self.opt.step()

def cudagraph_train(self):
for data, target in zip(self.real_input, self.real_output):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import torch
from torch.utils._pytree import tree_map
from dataclasses import dataclass, field
from typing import Dict, Any, Optional

def inputs_cast(cond, action, example_inputs):
@dataclass
class ModelInputDescriptor:
pass

def input_cast(cond, action, example_inputs):
"""Traverse the input batch pytree, and cast tensor with `action` if it satisfies `cond`."""
if isinstance(example_inputs, torch.Tensor) and cond(example_inputs):
return action(example_inputs)
elif isinstance(example_inputs, (tuple, list, dict)):
return tree_map(lambda x: inputs_cast(cond, action, x), example_inputs)
return tree_map(lambda x: input_cast(cond, action, x), example_inputs)
elif example_inputs is None or \
isinstance(example_inputs, str) or \
isinstance(example_inputs, int) or \
Expand Down
Loading

0 comments on commit 842f11d

Please sign in to comment.