From 6bff330cdef9b4301fb6028462a2d7ac3a12942d Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 12 Apr 2024 05:38:17 -0700 Subject: [PATCH] Add NCU Trace generation Summary: Generate NCU Trace for the triton kernel and input batch Reviewed By: chenyang78 Differential Revision: D56047231 fbshipit-source-id: a0a18f12daeeeae9f5c9e8adc1568f3be98bd9b1 --- torchbenchmark/_components/ncu/__init__.py | 60 +++++++++++++++++++++ torchbenchmark/util/triton_op.py | 61 ++++++++++++++++++---- 2 files changed, 110 insertions(+), 11 deletions(-) create mode 100644 torchbenchmark/_components/ncu/__init__.py diff --git a/torchbenchmark/_components/ncu/__init__.py b/torchbenchmark/_components/ncu/__init__.py new file mode 100644 index 0000000000..2b65a9ac40 --- /dev/null +++ b/torchbenchmark/_components/ncu/__init__.py @@ -0,0 +1,60 @@ + +from typing import Callable + +def do_bench_ncu_in_task(fn: Callable, warmup=25, grad_to_none=None, fast_flush=True, output_dir=None) -> None: + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param fast_flush: Use faster kernel to flush L2 between measurements + :type fast_flush: bool + :param output_dir: Output directory to store the trace + :type output_dir: str, optional + """ + import torch + + fn() + torch.cuda.synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + if fast_flush: + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + else: + cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + # Warm-up + for _ in range(n_warmup): + fn() + # Start ncu profiling + torch.cuda.cudart().cudaProfilerStart() + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before run + cache.zero_() + fn() + torch.cuda.cudart().cudaProfilerStop() diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py index 05ea5ba059..f8e5237e81 100644 --- a/torchbenchmark/util/triton_op.py +++ b/torchbenchmark/util/triton_op.py @@ -24,7 +24,7 @@ REGISTERED_BENCHMARKS: Dict[str, List[str]] = {} REGISTERED_METRICS: Dict[str, List[str]] = {} BASELINE_BENCHMARKS: Dict[str, str] = {} -BUILTIN_METRICS = ["latency", "tflops", "speedup", "accuracy", "compile_time"] +BUILTIN_METRICS = ["latency", "tflops", "speedup", "accuracy", "compile_time", "ncu_trace"] BASELINE_SKIP_METRICS = ["speedup", "accuracy"] PRECISION_DTYPE_MAPPING = { "fp32": torch.float32, @@ -70,6 +70,15 @@ def do_bench_walltime(fn, warmup=25, rep=100): wall_time_ms = (end_time - start_time) * 1e3 / n_repeat return wall_time_ms +def _find_param_loc(l, key: str) -> int: + try: + return l.index(key) + except ValueError: + return -1 +def _remove_params(l, loc): + if loc == -1: + return l + return l[:loc] + l[loc+2:] @dataclass class BenchmarkOperatorMetrics: @@ -85,6 +94,8 @@ class BenchmarkOperatorMetrics: walltime: Optional[float] # compile time compile_time: Optional[float] + # ncu trace file + ncu_trace: Optional[str] # error message error_msg: Optional[str] # extra metrics @@ -544,6 +555,7 @@ def _do_bench( accuracy=accuracy, walltime=walltime, compile_time=None, + ncu_trace=None, error_msg=error_msg, extra_metrics={}, ) @@ -551,6 +563,8 @@ def _do_bench( metric.tflops = self.tflops(fn_name, self.example_inputs, metric) if "compile_time" in self.required_metrics: metric.compile_time = self.compile_time(batch_id, fn_name, metric) + if "ncu_trace" in self.required_metrics: + metric.ncu_trace = self.ncu_trace(batch_id, fn_name) extra_metrics = {} # run the hidden metric "_compile_time_in_task" # to get the compile time in parent process @@ -559,6 +573,13 @@ def _do_bench( "_compile_time_in_task must be measured by itself. " \ f"required_metrics: {self.required_metrics}, _only: {self._only}, _batch_id: {self._batch_id}" extra_metrics["_compile_time_in_task"] = self._compile_time_in_task(fn) + if "_ncu_trace_in_task" in self.required_metrics: + assert self.required_metrics == ["_ncu_trace_in_task"] and self._only and (self._batch_id is not None), \ + "_ncu_trace_in_task must be measured by itself. " \ + f"required_metrics: {self.required_metrics}, _only: {self._only}, _batch_id: {self._batch_id}" + from torchbenchmark._components.ncu import do_bench_ncu_in_task + do_bench_ncu_in_task(fn=fn, warmup=warmup, grad_to_none=self.get_grad_to_none(self.example_inputs)) + extra_metrics["_ncu_trace_in_task"] = "success" # generate customized metrics if self.name in REGISTERED_METRICS: for metric_name in REGISTERED_METRICS[self.name]: @@ -577,29 +598,47 @@ def _do_bench( accuracy=None, walltime=None, compile_time=None, + ncu_trace=None, error_msg="CUDA OOM", extra_metrics={}, ) return metric + @register_metric() + def ncu_trace(self, batch_id: int, fn_name: str) -> str: + # collect the ncu trace + import sys + import subprocess + from pathlib import Path + op_task_args = copy.deepcopy(sys.argv) + for override_option in ["--only", "--batch-id", "--metrics"]: + op_task_args = _remove_params(op_task_args, _find_param_loc(op_task_args, override_option)) + op_task_args.extend(["--only", fn_name, "--batch-id", str(batch_id), "--metrics", "_ncu_trace_in_task"]) + # Disable DCGM + try: + disable_dcgm = ["sudo", "dyno", "dcgm_profiling", "--mute=true", "--duration=1000_s"] + subprocess.run(disable_dcgm, check=True) + except subprocess.SubprocessError: + warnings.warn("Cannot find dyno to disable DCGM. Proceed to collect NCU Trace.") + ncu_output_dir = Path(f"/tmp/tritonbench_{self.name}_{fn_name}_{batch_id}") + ncu_output_dir.mkdir(parents=True, exist_ok=True) + ncu_output_file = ncu_output_dir.joinpath("ncu_output.csv").resolve() + ncu_args = ["ncu", "--set", "full", "--replay-mode", "range", "--target-processes", "all", \ + "--csv", "-f", "--log-file", str(ncu_output_file.resolve())] + ncu_args.extend(op_task_args) + subprocess.check_call(ncu_args) + return str(ncu_output_file.resolve()) + + @register_metric() def compile_time(self, batch_id: int, fn_name: str, metrics: BenchmarkOperatorMetrics) -> float: # We need to spawn a subprocess when user wants to measure the compile time # of multiple batches and backends. - def _find_loc(l, key: str) -> int: - try: - return l.index(key) - except ValueError: - return -1 - def _remove_element(l, loc): - if loc == -1: - return l - return l[:loc] + l[loc+2:] from torchbenchmark.operators.op_task import OpTask op_task_args = copy.deepcopy(self._raw_extra_args) for override_option in ["--only", "--batch-id", "--metrics"]: - op_task_args = _remove_element(op_task_args, _find_loc(op_task_args, override_option)) + op_task_args = _remove_params(op_task_args, _find_param_loc(op_task_args, override_option)) op_task_args.extend(["--only", fn_name, "--batch-id", str(batch_id), "--metrics", "_compile_time_in_task"]) op_task = OpTask(name=self.name) op_task.make_operator_instance(mode=self.mode.value, device=self.device, extra_args=op_task_args)