diff --git a/components/_impl/workers/subprocess_worker.py b/components/_impl/workers/subprocess_worker.py index 9652b70b7b..c8f727909b 100644 --- a/components/_impl/workers/subprocess_worker.py +++ b/components/_impl/workers/subprocess_worker.py @@ -3,18 +3,15 @@ import io import os import marshal -import pathlib import shutil import signal import subprocess import sys import tempfile import textwrap -import time -import typing from pathlib import Path +from typing import Optional, Dict, List, Any, Tuple -import components from components._impl.workers import base from components._impl.workers import subprocess_rpc @@ -49,15 +46,21 @@ def args(self) -> typing.List[str]: """ _working_dir: str + _user_save_output: bool = False _alive: bool = False _bootstrap_timeout: int = 10 # seconds - def __init__(self, timeout: typing.Optional[float] = None, extra_env: typing.Optional[typing.Dict[str, str]]=None) -> None: + def __init__(self, timeout: Optional[float] = None, + extra_env: Optional[Dict[str, str]] = None, + save_output_dir: Optional[Path] = None) -> None: super().__init__() + if save_output_dir and save_output_dir.is_dir(): + self._working_dir = save_output_dir.absolute() + self._user_save_output = True # Log inputs and outputs for debugging. self._command_log = os.path.join(self.working_dir, "commands.log") - pathlib.Path(self._command_log).touch() + Path(self._command_log).touch() self._stdout_f: io.FileIO = io.FileIO( os.path.join(self.working_dir, "stdout.txt"), mode="w", @@ -148,13 +151,13 @@ def working_dir(self) -> str: return self._working_dir @property - def args(self) -> typing.List[str]: + def args(self) -> List[str]: return [sys.executable, "-i", "-u"] def run(self, snippet: str) -> None: self._run(snippet) - def store(self, name: str, value: typing.Any, in_memory: bool = False) -> None: + def store(self, name: str, value: Any, in_memory: bool = False) -> None: if in_memory: raise NotImplementedError("SubprocessWorker does not support `in_memory`") @@ -165,7 +168,7 @@ def store(self, name: str, value: typing.Any, in_memory: bool = False) -> None: ) """) - def load(self, name: str) -> typing.Any: + def load(self, name: str) -> Any: self._run(f""" {subprocess_rpc.WORKER_IMPL_NAMESPACE}["load_pipe"].write( {subprocess_rpc.WORKER_IMPL_NAMESPACE}["marshal"].dumps({name}) @@ -278,7 +281,7 @@ def watch_stdout_stderr(self): stdout_stat = os.stat(self._stdout_f.name) stderr_stat = os.stat(self._stderr_f.name) - def get() -> typing.Tuple[str, str]: + def get() -> Tuple[str, str]: with open(self._stdout_f.name, "rb") as f: _ = f.seek(stdout_stat.st_size) stdout = f.read().decode("utf-8").strip() @@ -371,5 +374,8 @@ def __del__(self) -> None: self._stdout_f.close() self._stderr_f.close() - # Finally, make sure we don't leak any files. - shutil.rmtree(self._working_dir, ignore_errors=True) + # We deliberately keep the output files when user explicitly + # specifies the output file dir. + # Otherwise, delete all the subprocess files. + if not self._user_save_output: + shutil.rmtree(self._working_dir, ignore_errors=True) diff --git a/torchbenchmark/__init__.py b/torchbenchmark/__init__.py index bc7dd88dab..1445b05122 100644 --- a/torchbenchmark/__init__.py +++ b/torchbenchmark/__init__.py @@ -256,6 +256,7 @@ def __init__( model_path: str, timeout: Optional[float] = None, extra_env: Optional[Dict[str, str]] = None, + save_output_dir: Optional[pathlib.Path] = None, ) -> None: gc.collect() # Make sure previous task has a chance to release the lock assert self._lock.acquire(blocking=False), "Failed to acquire lock." @@ -263,7 +264,7 @@ def __init__( self._model_path = model_path if _is_internal_model(model_path): model_path = f"{internal_model_dir}.{model_path}" - self._worker = Worker(timeout=timeout, extra_env=extra_env) + self._worker = Worker(timeout=timeout, extra_env=extra_env, save_output_dir=save_output_dir) self.worker.run("import torch") self._details: ModelDetails = ModelDetails( diff --git a/torchbenchmark/util/experiment/instantiator.py b/torchbenchmark/util/experiment/instantiator.py index 9c53c59d6c..34c074e6b4 100644 --- a/torchbenchmark/util/experiment/instantiator.py +++ b/torchbenchmark/util/experiment/instantiator.py @@ -4,7 +4,7 @@ They expect callers handle all exceptions. """ import os -import importlib +import pathlib import dataclasses from typing import Optional, List, Dict from torchbenchmark.util.model import BenchmarkModel @@ -21,6 +21,7 @@ class TorchBenchModelConfig: batch_size: Optional[int] extra_args: List[str] extra_env: Optional[Dict[str, str]] = None + output_dir: Optional[pathlib.Path] = None def _set_extra_env(extra_env): if not extra_env: @@ -32,8 +33,8 @@ def inject_model_invoke(model_task: ModelTask, inject_function): model_task.replace_invoke(inject_function.__module__, inject_function.__name__) def load_model_isolated(config: TorchBenchModelConfig, timeout: float=WORKER_TIMEOUT) -> ModelTask: - """ Load and return the model in a subprocess. """ - task = ModelTask(config.name, timeout=timeout, extra_env=config.extra_env) + """ Load and return the model in a subprocess. Optionally, save its stdout and stderr to the specified directory. """ + task = ModelTask(config.name, timeout=timeout, extra_env=config.extra_env, save_output_dir=config.output_dir) if not task.model_details.exists: raise ValueError(f"Failed to import model task: {config.name}. Please run the model manually to make sure it succeeds, or report a bug.") task.make_model_instance(test=config.test, device=config.device, batch_size=config.batch_size, extra_args=config.extra_args) diff --git a/torchbenchmark/util/experiment/metrics.py b/torchbenchmark/util/experiment/metrics.py index 881e9c7ef2..c3cbd1e7c3 100644 --- a/torchbenchmark/util/experiment/metrics.py +++ b/torchbenchmark/util/experiment/metrics.py @@ -3,6 +3,7 @@ """ import torch import time +import pathlib import dataclasses from torchbenchmark.util.model import BenchmarkModel from torchbenchmark.util.experiment.instantiator import TorchBenchModelConfig @@ -140,7 +141,7 @@ def get_model_test_metrics(model: Union[BenchmarkModel, ModelTask], metrics=[], if isinstance(model, ModelTask) else model.ttfb return TorchBenchModelMetrics(latencies, throughputs, cpu_peak_mem, gpu_peak_mem, ttfb, pt2_compilation_time, pt2_graph_breaks, model_flops) -def get_model_accuracy(model_config: TorchBenchModelConfig, isolated: bool=True) -> str: +def get_model_accuracy(model_config: TorchBenchModelConfig, isolated: bool=True, save_output_dir: Optional[pathlib.Path]=None) -> str: import copy from torchbenchmark.util.experiment.instantiator import load_model_isolated, load_model # Try load minimal batch size, if fail, load the default batch size diff --git a/userbenchmark/test_bench/__init__.py b/userbenchmark/test_bench/__init__.py new file mode 100644 index 0000000000..01e78c979b --- /dev/null +++ b/userbenchmark/test_bench/__init__.py @@ -0,0 +1 @@ +BM_NAME = "test_bench" diff --git a/userbenchmark/test_bench/regression_detector.py b/userbenchmark/test_bench/regression_detector.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/userbenchmark/test_bench/run.py b/userbenchmark/test_bench/run.py new file mode 100644 index 0000000000..02d8229148 --- /dev/null +++ b/userbenchmark/test_bench/run.py @@ -0,0 +1,164 @@ +""" +Run PyTorch nightly benchmarking. +""" +import argparse +import itertools +import pathlib +import json +import os +import shutil +import numpy + +from typing import List, Dict, Optional, Any, Union +from ..utils import REPO_PATH, add_path, get_output_json, get_default_output_json_path, get_default_debug_output_dir +from . import BM_NAME + +with add_path(REPO_PATH): + from torchbenchmark.util.experiment.instantiator import list_models, load_model_isolated, TorchBenchModelConfig, \ + list_devices, list_tests + from torchbenchmark.util.experiment.metrics import TorchBenchModelMetrics, get_model_test_metrics, get_model_accuracy + +CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) + +def config_to_str(config: TorchBenchModelConfig) -> str: + metrics_base = f"model={config.name}, test={config.test}, device={config.device}," + \ + f" bs={config.batch_size}, extra_args={config.extra_args}" + return metrics_base + +def generate_model_configs(devices: List[str], tests: List[str], batch_sizes: List[str], model_names: List[str], extra_args: List[str]) -> List[TorchBenchModelConfig]: + """Use the default batch size and default mode.""" + if not model_names: + model_names = list_models() + cfgs = itertools.product(*[devices, tests, batch_sizes, model_names]) + result = [TorchBenchModelConfig( + name=model_name, + device=device, + test=test, + batch_size=None if not batch_size else int(batch_size), + extra_args=extra_args, + extra_env=None, + ) for device, test, batch_size, model_name in cfgs] + return result + +def init_output_dir(configs: List[TorchBenchModelConfig], output_dir: pathlib.Path) -> List[TorchBenchModelConfig]: + result = [] + for config in configs: + config_str = config_to_str(config) + config.output_dir = output_dir.joinpath(config_str) + if config.output_dir.exists(): + shutil.rmtree(config.output_dir) + config.output_dir.mkdir(parents=True) + result.append(config) + return result + +def get_metrics(config: TorchBenchModelConfig) -> List[str]: + if "--accuracy" in config.extra_args: + return ["accuracy"] + return ["latencies", "cpu_peak_mem", "gpu_peak_mem"] + +def validate(candidates: List[str], choices: List[str]) -> List[str]: + """Validate the candidates provided by the user is valid""" + for candidate in candidates: + assert candidate in choices, f"Specified {candidate}, but not in available list: {choices}." + return candidates + +def parse_str_to_list(candidates: Optional[str]) -> List[str]: + if isinstance(candidates, list): + return candidates + elif candidates == None: + return [""] + candidates = list(map(lambda x: x.strip(), candidates.split(","))) + return candidates + +def metrics_to_dict(metrics: Union[TorchBenchModelMetrics, Dict[str, str]]) -> Dict[str, Union[str, float]]: + if isinstance(metrics, TorchBenchModelMetrics): + pass + return metrics + +def run_config(config: TorchBenchModelConfig, metrics: List[str], dryrun: bool=False) -> Dict[str, Union[str, float]]: + """This function only handles NotImplementedError, all other errors will fail.""" + print(f"Running {config} ...", end='', flush=True) + if dryrun: + print(" [skip_by_dryrun]", flush=True) + return dict.fromkeys(metrics, "skip_by_dryrun") + # We do not allow RuntimeError in this test + try: + # load the model instance in subprocess + model = load_model_isolated(config) + # get the model test metrics + metrics_output: TorchBenchModelMetrics = get_model_test_metrics(model, metrics=metrics) + result = {} + for metric in metrics: + if metric == "latency" and metrics_output.latencies: + result[metric] = numpy.median(metrics_output.latencies) + if not result[metric]: + result[metric] = "failed" + print(" [done]", flush=True) + return result + except NotImplementedError as e: + print(" [not_implemented]", flush=True) + return dict.fromkeys(metrics, "not_implemented") + +def run_config_accuracy(config: TorchBenchModelConfig, metrics: List[str], dryrun: bool=False) -> Dict[str, str]: + assert metrics == ["accuracy"], f"When running accuracy test, others metrics are not supported: {metrics}." + print(f"Running {config} ...", end='', flush=True) + if dryrun: + print(" [skip_by_dryrun]", flush=True) + return {"accuracy": "skip_by_dryrun"} + try: + accuracy = get_model_accuracy(config) + print(" [done]", flush=True) + return {"accuracy": accuracy} + except NotImplementedError: + print(" [not_implemented]", flush=True) + return {"accuracy": "not_implemented"} + +def parse_known_args(args): + parser = argparse.ArgumentParser() + default_device = "cuda" if "cuda" in list_devices() else "cpu" + parser.add_argument( + "models", + help="Name of models to run, split by comma.", + ) + parser.add_argument("--device", "-d", default=default_device, help="Devices to run, splited by comma.") + parser.add_argument("--test", "-t", default="eval", help="Tests to run, splited by comma.") + parser.add_argument("--bs", default=None, help="Optionally, specify the batch size.") + parser.add_argument("--config", "-c", default=None, help="YAML config to specify tests to run.") + parser.add_argument("--run-bisect", help="Run with the output of regression detector.") + parser.add_argument("--dryrun", action="store_true", help="Dryrun the command.") + parser.add_argument("--output", default=get_default_output_json_path(BM_NAME), help="Specify the path of the output file") + parser.add_argument("--debug", action="store_true", help="Save the debug output.") + return parser.parse_known_args(args) + +def run(args: List[str]): + args, extra_args = parse_known_args(args) + # If not specified, use the entire model set + if not args.models: + args.models = list_models() + debug_output_dir = get_default_debug_output_dir(args.output) if args.debug else None + devices = validate(parse_str_to_list(args.device), list_devices()) + tests = validate(parse_str_to_list(args.test), list_tests()) + batch_sizes = parse_str_to_list(args.bs) + models = validate(parse_str_to_list(args.models), list_models()) + configs = generate_model_configs(devices, tests, batch_sizes, model_names=models, extra_args=extra_args) + configs = init_output_dir(configs, debug_output_dir) if debug_output_dir else configs + results = {} + try: + for config in configs: + metrics = get_metrics(config) + if "accuracy" in metrics: + metrics_dict = run_config_accuracy(config, metrics, dryrun=args.dryrun) + else: + metrics_dict = run_config(config, metrics, dryrun=args.dryrun) + config_str = config_to_str(config) + for metric in metrics_dict: + results[f"{config_str}, metric={metric}"] = metrics_dict[metric] + except KeyboardInterrupt: + print("User keyboard interrupted!") + if not args.dryrun: + result = get_output_json(BM_NAME, results) + if args.device == 'cuda': + import torch + result["environ"]["device"] = torch.cuda.get_device_name() + with open(args.output, 'w') as f: + json.dump(result, f, indent=4) diff --git a/userbenchmark/utils.py b/userbenchmark/utils.py index a5154cc408..25b3cf3664 100644 --- a/userbenchmark/utils.py +++ b/userbenchmark/utils.py @@ -109,6 +109,12 @@ def get_default_output_json_path(bm_name: str, target_dir: Path=None) -> str: full_fname = os.path.join(target_dir, fname) return full_fname +def get_default_debug_output_dir(metrics_json: str) -> Path: + metrics_json_path = Path(metrics_json) + metrics_json_dir = metrics_json_path.parent + metrics_datetime = datetime.strptime(metrics_json_path.name, "metrics-%Y%m%d%H%M%S.json") + debug_output_dir = metrics_json_dir.joinpath("output-" + datetime.strftime(metrics_datetime, "%Y%m%d%H%M%S")) + return debug_output_dir def dump_output(bm_name: str, output: Any, target_dir: Path=None) -> None: full_fname = get_default_output_json_path(bm_name, target_dir=target_dir)