diff --git a/torchbenchmark/_components/ncu/__init__.py b/torchbenchmark/_components/ncu/__init__.py index 4111e0327..f3b6a8c42 100644 --- a/torchbenchmark/_components/ncu/__init__.py +++ b/torchbenchmark/_components/ncu/__init__.py @@ -1,11 +1,28 @@ from typing import Callable +import torch -def do_bench_ncu_in_task( + +class cuda_profiler_range: + def __init__(self, use_cuda_profiler_range): + self.use_cuda_profiler_range = use_cuda_profiler_range + + def __enter__(self): + if self.use_cuda_profiler_range: + torch.cuda.cudart().cudaProfilerStart() + + def __exit__(self, *exc_info): + if self.use_cuda_profiler_range: + torch.cuda.cudart().cudaProfilerStop() + + +def do_bench_in_task( fn: Callable, grad_to_none=None, - fast_flush=True, range_name: str = "", + warmup: bool = False, + warmup_time: int = 25, + use_cuda_profiler_range: bool = False, ) -> None: """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with @@ -15,23 +32,30 @@ def do_bench_ncu_in_task( :type fn: Callable :param grad_to_none: Reset the gradient of the provided tensor to None :type grad_to_none: torch.tensor, optional - :param fast_flush: Use faster kernel to flush L2 between measurements - :type fast_flush: bool - :param output_dir: Output directory to store the trace - :type output_dir: str, optional """ - import torch fn() torch.cuda.synchronize() - # We maintain a buffer of 256 MB that we clear - # before each kernel call to make sure that the L2 - # doesn't contain any input data before the run - if fast_flush: - cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") - else: - cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda") + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") + + if warmup: + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + # Warm-up + for _ in range(n_warmup): + fn() # we don't want `fn` to accumulate gradient values # if it contains a backward pass. So we clear the @@ -41,5 +65,7 @@ def do_bench_ncu_in_task( x.grad = None # we clear the L2 cache before run cache.zero_() - with torch.cuda.nvtx.range(range_name): + with cuda_profiler_range(use_cuda_profiler_range), torch.cuda.nvtx.range( + range_name + ): fn() diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py index f9b5d647a..d28fdb43e 100644 --- a/torchbenchmark/util/triton_op.py +++ b/torchbenchmark/util/triton_op.py @@ -177,6 +177,8 @@ class BenchmarkOperatorMetrics: ncu_rep: Optional[str] = None # ncu replay file with TTGIR line numbers ncu_rep_ir: Optional[str] = None + # nsys replay file + nsys_rep: Optional[str] = None # kineto trace file kineto_trace: Optional[str] = None # cpu peak memory @@ -859,6 +861,8 @@ def _init_extra_metrics() -> Dict[str, Any]: metrics.ncu_rep_ir = self.ncu_trace( input_id, fn_name, replay=True, profile_ir=True ) + if "nsys_rep" in self.required_metrics: + metrics.nsys_rep = self.nsys_rep(input_id, fn_name) if "kineto_trace" in self.required_metrics: metrics.kineto_trace = self.kineto_trace(input_id, fn) if "best_config" in self.required_metrics: @@ -886,14 +890,33 @@ def _init_extra_metrics() -> Dict[str, Any]: "_ncu_trace_in_task must be measured by itself. " f"required_metrics: {self.required_metrics}, _only: {self._only}, _input_id: {self._input_id}" ) - from torchbenchmark._components.ncu import do_bench_ncu_in_task + from torchbenchmark._components.ncu import do_bench_in_task - do_bench_ncu_in_task( + do_bench_in_task( fn=fn, grad_to_none=self.get_grad_to_none(self.example_inputs), range_name=_RANGE_NAME, ) metrics.extra_metrics["_ncu_trace_in_task"] = "success" + if "_nsys_rep_in_task" in self.required_metrics: + assert ( + self.required_metrics == ["_nsys_rep_in_task"] + and len(self._only) == 1 + and (self._input_id is not None) + ), ( + "_nsys_rep_in_task must be measured by itself. " + f"required_metrics: {self.required_metrics}, _only: {self._only}, _input_id: {self._input_id}" + ) + from torchbenchmark._components.ncu import do_bench_in_task + + do_bench_in_task( + fn=fn, + grad_to_none=self.get_grad_to_none(self.example_inputs), + range_name=_RANGE_NAME, + warmup=True, + use_cuda_profiler_range=True, + ) + metrics.extra_metrics["_nsys_rep_in_task"] = "success" # generate customized metrics if self.name in REGISTERED_METRICS: for metric_name in REGISTERED_METRICS[self.name]: @@ -925,6 +948,54 @@ def get_peak_mem( metrics_gpu_backend="nvml", ) + def nsys_rep(self, input_id: int, fn_name: str) -> str: + import subprocess + import sys + + op_task_args = [] if IS_FBCODE else [sys.executable] + op_task_args.extend(copy.deepcopy(sys.argv)) + for override_option in ["--only", "--input-id", "--num-inputs", "--metrics"]: + op_task_args = _remove_params( + op_task_args, _find_param_loc(op_task_args, override_option) + ) + op_task_args.extend( + [ + "--only", + fn_name, + "--num-inputs", + str(1), + "--input-id", + str(input_id), + "--metrics", + "_nsys_rep_in_task", + ] + ) + nsys_output_dir = self.get_temp_path(f"nsys_traces/{fn_name}_{input_id}") + nsys_output_dir.mkdir(parents=True, exist_ok=True) + ext = ".nsys-rep" + nsys_output_file = nsys_output_dir.joinpath(f"nsys_output{ext}").resolve() + nsys_trace_cmd = [ + "nsys", + "profile", + "-c", + "cudaProfilerApi", + "-t", + "nvtx,osrt,cuda,cudnn,cublas", + "-w", + "true", + "-f", + "true", + "-o", + nsys_output_file, + ] + nsys_trace_cmd.extend(op_task_args) + try: + subprocess.check_call(nsys_trace_cmd) + except subprocess.CalledProcessError: + # FIXME: calling nsys on Tritonbench will throw SIGTERM with error code 143 + pass + return str(nsys_output_file.resolve()) + def ncu_trace( self, input_id: int, fn_name: str, replay: bool = False, profile_ir=False ) -> str: