Skip to content

Commit

Permalink
Add the test_bench userbenchmark (#2052)
Browse files Browse the repository at this point in the history
Summary:
The plan is to use the `test_bench` userbenchmark to deprecate `test_bench.py`.

It supports:
1. Running multiple models, each model in its subprocess.
2. Running with extra args such as `--torchdynamo inductor`.
3. `--debug` option: save the output of subprocess in the `output-%Y%m%d%H%M%S` directory.

Pull Request resolved: #2052

Test Plan:
```
$ python run_benchmark.py test_bench llama_v2_7b_16h -d cuda -t eval --accuracy --debug
Running TorchBenchModelConfig(name='llama_v2_7b_16h', test='eval', device='cuda', batch_size=None, extra_args=['--accuracy'], extra_env=None, output_dir=PosixPath("/data/users/xzhao9/git/benchmark/.userbenchmark/test_bench/output-20231122193846/model=llama_v2_7b_16h, test=eval, device=cuda, bs=None, extra_args=['--accuracy']")) ...[done]
```

The error log is saved to the log file: `test_bench/output-20231122193846/model=llama_v2_7b_16h, test=eval, device=cuda, bs=None, extra_args=['--accuracy']/stderr.log`:

```
  File "/home/xzhao9/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/xzhao9/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xzhao9/.conda/envs/py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 424, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/xzhao9/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/xzhao9/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xzhao9/.conda/envs/py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 321, in forward
    query_states = self.q_proj(hidden_states)
  File "/home/xzhao9/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/xzhao9/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xzhao9/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`
```

Reviewed By: aaronenyeshi

Differential Revision: D51542761

Pulled By: xuzhao9

fbshipit-source-id: acf0616c791a72c3d7f015a1b77cba4a017d915d
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 23, 2023
1 parent ec23124 commit 7a8b39c
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 17 deletions.
30 changes: 18 additions & 12 deletions components/_impl/workers/subprocess_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,15 @@
import io
import os
import marshal
import pathlib
import shutil
import signal
import subprocess
import sys
import tempfile
import textwrap
import time
import typing
from pathlib import Path
from typing import Optional, Dict, List, Any, Tuple

import components
from components._impl.workers import base
from components._impl.workers import subprocess_rpc

Expand Down Expand Up @@ -49,15 +46,21 @@ def args(self) -> typing.List[str]:
"""

_working_dir: str
_user_save_output: bool = False
_alive: bool = False
_bootstrap_timeout: int = 10 # seconds

def __init__(self, timeout: typing.Optional[float] = None, extra_env: typing.Optional[typing.Dict[str, str]]=None) -> None:
def __init__(self, timeout: Optional[float] = None,
extra_env: Optional[Dict[str, str]] = None,
save_output_dir: Optional[Path] = None) -> None:
super().__init__()

if save_output_dir and save_output_dir.is_dir():
self._working_dir = save_output_dir.absolute()
self._user_save_output = True
# Log inputs and outputs for debugging.
self._command_log = os.path.join(self.working_dir, "commands.log")
pathlib.Path(self._command_log).touch()
Path(self._command_log).touch()

self._stdout_f: io.FileIO = io.FileIO(
os.path.join(self.working_dir, "stdout.txt"), mode="w",
Expand Down Expand Up @@ -148,13 +151,13 @@ def working_dir(self) -> str:
return self._working_dir

@property
def args(self) -> typing.List[str]:
def args(self) -> List[str]:
return [sys.executable, "-i", "-u"]

def run(self, snippet: str) -> None:
self._run(snippet)

def store(self, name: str, value: typing.Any, in_memory: bool = False) -> None:
def store(self, name: str, value: Any, in_memory: bool = False) -> None:
if in_memory:
raise NotImplementedError("SubprocessWorker does not support `in_memory`")

Expand All @@ -165,7 +168,7 @@ def store(self, name: str, value: typing.Any, in_memory: bool = False) -> None:
)
""")

def load(self, name: str) -> typing.Any:
def load(self, name: str) -> Any:
self._run(f"""
{subprocess_rpc.WORKER_IMPL_NAMESPACE}["load_pipe"].write(
{subprocess_rpc.WORKER_IMPL_NAMESPACE}["marshal"].dumps({name})
Expand Down Expand Up @@ -278,7 +281,7 @@ def watch_stdout_stderr(self):
stdout_stat = os.stat(self._stdout_f.name)
stderr_stat = os.stat(self._stderr_f.name)

def get() -> typing.Tuple[str, str]:
def get() -> Tuple[str, str]:
with open(self._stdout_f.name, "rb") as f:
_ = f.seek(stdout_stat.st_size)
stdout = f.read().decode("utf-8").strip()
Expand Down Expand Up @@ -371,5 +374,8 @@ def __del__(self) -> None:
self._stdout_f.close()
self._stderr_f.close()

# Finally, make sure we don't leak any files.
shutil.rmtree(self._working_dir, ignore_errors=True)
# We deliberately keep the output files when user explicitly
# specifies the output file dir.
# Otherwise, delete all the subprocess files.
if not self._user_save_output:
shutil.rmtree(self._working_dir, ignore_errors=True)
3 changes: 2 additions & 1 deletion torchbenchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,15 @@ def __init__(
model_path: str,
timeout: Optional[float] = None,
extra_env: Optional[Dict[str, str]] = None,
save_output_dir: Optional[pathlib.Path] = None,
) -> None:
gc.collect() # Make sure previous task has a chance to release the lock
assert self._lock.acquire(blocking=False), "Failed to acquire lock."

self._model_path = model_path
if _is_internal_model(model_path):
model_path = f"{internal_model_dir}.{model_path}"
self._worker = Worker(timeout=timeout, extra_env=extra_env)
self._worker = Worker(timeout=timeout, extra_env=extra_env, save_output_dir=save_output_dir)

self.worker.run("import torch")
self._details: ModelDetails = ModelDetails(
Expand Down
7 changes: 4 additions & 3 deletions torchbenchmark/util/experiment/instantiator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
They expect callers handle all exceptions.
"""
import os
import importlib
import pathlib
import dataclasses
from typing import Optional, List, Dict
from torchbenchmark.util.model import BenchmarkModel
Expand All @@ -21,6 +21,7 @@ class TorchBenchModelConfig:
batch_size: Optional[int]
extra_args: List[str]
extra_env: Optional[Dict[str, str]] = None
output_dir: Optional[pathlib.Path] = None

def _set_extra_env(extra_env):
if not extra_env:
Expand All @@ -32,8 +33,8 @@ def inject_model_invoke(model_task: ModelTask, inject_function):
model_task.replace_invoke(inject_function.__module__, inject_function.__name__)

def load_model_isolated(config: TorchBenchModelConfig, timeout: float=WORKER_TIMEOUT) -> ModelTask:
""" Load and return the model in a subprocess. """
task = ModelTask(config.name, timeout=timeout, extra_env=config.extra_env)
""" Load and return the model in a subprocess. Optionally, save its stdout and stderr to the specified directory. """
task = ModelTask(config.name, timeout=timeout, extra_env=config.extra_env, save_output_dir=config.output_dir)
if not task.model_details.exists:
raise ValueError(f"Failed to import model task: {config.name}. Please run the model manually to make sure it succeeds, or report a bug.")
task.make_model_instance(test=config.test, device=config.device, batch_size=config.batch_size, extra_args=config.extra_args)
Expand Down
3 changes: 2 additions & 1 deletion torchbenchmark/util/experiment/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import torch
import time
import pathlib
import dataclasses
from torchbenchmark.util.model import BenchmarkModel
from torchbenchmark.util.experiment.instantiator import TorchBenchModelConfig
Expand Down Expand Up @@ -140,7 +141,7 @@ def get_model_test_metrics(model: Union[BenchmarkModel, ModelTask], metrics=[],
if isinstance(model, ModelTask) else model.ttfb
return TorchBenchModelMetrics(latencies, throughputs, cpu_peak_mem, gpu_peak_mem, ttfb, pt2_compilation_time, pt2_graph_breaks, model_flops)

def get_model_accuracy(model_config: TorchBenchModelConfig, isolated: bool=True) -> str:
def get_model_accuracy(model_config: TorchBenchModelConfig, isolated: bool=True, save_output_dir: Optional[pathlib.Path]=None) -> str:
import copy
from torchbenchmark.util.experiment.instantiator import load_model_isolated, load_model
# Try load minimal batch size, if fail, load the default batch size
Expand Down
1 change: 1 addition & 0 deletions userbenchmark/test_bench/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
BM_NAME = "test_bench"
Empty file.
164 changes: 164 additions & 0 deletions userbenchmark/test_bench/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""
Run PyTorch nightly benchmarking.
"""
import argparse
import itertools
import pathlib
import json
import os
import shutil
import numpy

from typing import List, Dict, Optional, Any, Union
from ..utils import REPO_PATH, add_path, get_output_json, get_default_output_json_path, get_default_debug_output_dir
from . import BM_NAME

with add_path(REPO_PATH):
from torchbenchmark.util.experiment.instantiator import list_models, load_model_isolated, TorchBenchModelConfig, \
list_devices, list_tests
from torchbenchmark.util.experiment.metrics import TorchBenchModelMetrics, get_model_test_metrics, get_model_accuracy

CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))

def config_to_str(config: TorchBenchModelConfig) -> str:
metrics_base = f"model={config.name}, test={config.test}, device={config.device}," + \
f" bs={config.batch_size}, extra_args={config.extra_args}"
return metrics_base

def generate_model_configs(devices: List[str], tests: List[str], batch_sizes: List[str], model_names: List[str], extra_args: List[str]) -> List[TorchBenchModelConfig]:
"""Use the default batch size and default mode."""
if not model_names:
model_names = list_models()
cfgs = itertools.product(*[devices, tests, batch_sizes, model_names])
result = [TorchBenchModelConfig(
name=model_name,
device=device,
test=test,
batch_size=None if not batch_size else int(batch_size),
extra_args=extra_args,
extra_env=None,
) for device, test, batch_size, model_name in cfgs]
return result

def init_output_dir(configs: List[TorchBenchModelConfig], output_dir: pathlib.Path) -> List[TorchBenchModelConfig]:
result = []
for config in configs:
config_str = config_to_str(config)
config.output_dir = output_dir.joinpath(config_str)
if config.output_dir.exists():
shutil.rmtree(config.output_dir)
config.output_dir.mkdir(parents=True)
result.append(config)
return result

def get_metrics(config: TorchBenchModelConfig) -> List[str]:
if "--accuracy" in config.extra_args:
return ["accuracy"]
return ["latencies", "cpu_peak_mem", "gpu_peak_mem"]

def validate(candidates: List[str], choices: List[str]) -> List[str]:
"""Validate the candidates provided by the user is valid"""
for candidate in candidates:
assert candidate in choices, f"Specified {candidate}, but not in available list: {choices}."
return candidates

def parse_str_to_list(candidates: Optional[str]) -> List[str]:
if isinstance(candidates, list):
return candidates
elif candidates == None:
return [""]
candidates = list(map(lambda x: x.strip(), candidates.split(",")))
return candidates

def metrics_to_dict(metrics: Union[TorchBenchModelMetrics, Dict[str, str]]) -> Dict[str, Union[str, float]]:
if isinstance(metrics, TorchBenchModelMetrics):
pass
return metrics

def run_config(config: TorchBenchModelConfig, metrics: List[str], dryrun: bool=False) -> Dict[str, Union[str, float]]:
"""This function only handles NotImplementedError, all other errors will fail."""
print(f"Running {config} ...", end='', flush=True)
if dryrun:
print(" [skip_by_dryrun]", flush=True)
return dict.fromkeys(metrics, "skip_by_dryrun")
# We do not allow RuntimeError in this test
try:
# load the model instance in subprocess
model = load_model_isolated(config)
# get the model test metrics
metrics_output: TorchBenchModelMetrics = get_model_test_metrics(model, metrics=metrics)
result = {}
for metric in metrics:
if metric == "latency" and metrics_output.latencies:
result[metric] = numpy.median(metrics_output.latencies)
if not result[metric]:
result[metric] = "failed"
print(" [done]", flush=True)
return result
except NotImplementedError as e:
print(" [not_implemented]", flush=True)
return dict.fromkeys(metrics, "not_implemented")

def run_config_accuracy(config: TorchBenchModelConfig, metrics: List[str], dryrun: bool=False) -> Dict[str, str]:
assert metrics == ["accuracy"], f"When running accuracy test, others metrics are not supported: {metrics}."
print(f"Running {config} ...", end='', flush=True)
if dryrun:
print(" [skip_by_dryrun]", flush=True)
return {"accuracy": "skip_by_dryrun"}
try:
accuracy = get_model_accuracy(config)
print(" [done]", flush=True)
return {"accuracy": accuracy}
except NotImplementedError:
print(" [not_implemented]", flush=True)
return {"accuracy": "not_implemented"}

def parse_known_args(args):
parser = argparse.ArgumentParser()
default_device = "cuda" if "cuda" in list_devices() else "cpu"
parser.add_argument(
"models",
help="Name of models to run, split by comma.",
)
parser.add_argument("--device", "-d", default=default_device, help="Devices to run, splited by comma.")
parser.add_argument("--test", "-t", default="eval", help="Tests to run, splited by comma.")
parser.add_argument("--bs", default=None, help="Optionally, specify the batch size.")
parser.add_argument("--config", "-c", default=None, help="YAML config to specify tests to run.")
parser.add_argument("--run-bisect", help="Run with the output of regression detector.")
parser.add_argument("--dryrun", action="store_true", help="Dryrun the command.")
parser.add_argument("--output", default=get_default_output_json_path(BM_NAME), help="Specify the path of the output file")
parser.add_argument("--debug", action="store_true", help="Save the debug output.")
return parser.parse_known_args(args)

def run(args: List[str]):
args, extra_args = parse_known_args(args)
# If not specified, use the entire model set
if not args.models:
args.models = list_models()
debug_output_dir = get_default_debug_output_dir(args.output) if args.debug else None
devices = validate(parse_str_to_list(args.device), list_devices())
tests = validate(parse_str_to_list(args.test), list_tests())
batch_sizes = parse_str_to_list(args.bs)
models = validate(parse_str_to_list(args.models), list_models())
configs = generate_model_configs(devices, tests, batch_sizes, model_names=models, extra_args=extra_args)
configs = init_output_dir(configs, debug_output_dir) if debug_output_dir else configs
results = {}
try:
for config in configs:
metrics = get_metrics(config)
if "accuracy" in metrics:
metrics_dict = run_config_accuracy(config, metrics, dryrun=args.dryrun)
else:
metrics_dict = run_config(config, metrics, dryrun=args.dryrun)
config_str = config_to_str(config)
for metric in metrics_dict:
results[f"{config_str}, metric={metric}"] = metrics_dict[metric]
except KeyboardInterrupt:
print("User keyboard interrupted!")
if not args.dryrun:
result = get_output_json(BM_NAME, results)
if args.device == 'cuda':
import torch
result["environ"]["device"] = torch.cuda.get_device_name()
with open(args.output, 'w') as f:
json.dump(result, f, indent=4)
6 changes: 6 additions & 0 deletions userbenchmark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ def get_default_output_json_path(bm_name: str, target_dir: Path=None) -> str:
full_fname = os.path.join(target_dir, fname)
return full_fname

def get_default_debug_output_dir(metrics_json: str) -> Path:
metrics_json_path = Path(metrics_json)
metrics_json_dir = metrics_json_path.parent
metrics_datetime = datetime.strptime(metrics_json_path.name, "metrics-%Y%m%d%H%M%S.json")
debug_output_dir = metrics_json_dir.joinpath("output-" + datetime.strftime(metrics_datetime, "%Y%m%d%H%M%S"))
return debug_output_dir

def dump_output(bm_name: str, output: Any, target_dir: Path=None) -> None:
full_fname = get_default_output_json_path(bm_name, target_dir=target_dir)
Expand Down

0 comments on commit 7a8b39c

Please sign in to comment.