Skip to content

Commit

Permalink
Add NCU Trace generation
Browse files Browse the repository at this point in the history
Summary: Generate NCU Trace for the triton kernel and input batch

Reviewed By: chenyang78

Differential Revision: D56047231

fbshipit-source-id: a0a18f12daeeeae9f5c9e8adc1568f3be98bd9b1
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Apr 12, 2024
1 parent 1a1a1f8 commit 6bff330
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 11 deletions.
60 changes: 60 additions & 0 deletions torchbenchmark/_components/ncu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

from typing import Callable

def do_bench_ncu_in_task(fn: Callable, warmup=25, grad_to_none=None, fast_flush=True, output_dir=None) -> None:
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
:param fn: Function to benchmark
:type fn: Callable
:param warmup: Warmup time (in ms)
:type warmup: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param fast_flush: Use faster kernel to flush L2 between measurements
:type fast_flush: bool
:param output_dir: Output directory to store the trace
:type output_dir: str, optional
"""
import torch

fn()
torch.cuda.synchronize()

# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
if fast_flush:
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
else:
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')

# Estimate the runtime of the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5

# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
# Warm-up
for _ in range(n_warmup):
fn()
# Start ncu profiling
torch.cuda.cudart().cudaProfilerStart()
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
# we clear the L2 cache before run
cache.zero_()
fn()
torch.cuda.cudart().cudaProfilerStop()
61 changes: 50 additions & 11 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
REGISTERED_BENCHMARKS: Dict[str, List[str]] = {}
REGISTERED_METRICS: Dict[str, List[str]] = {}
BASELINE_BENCHMARKS: Dict[str, str] = {}
BUILTIN_METRICS = ["latency", "tflops", "speedup", "accuracy", "compile_time"]
BUILTIN_METRICS = ["latency", "tflops", "speedup", "accuracy", "compile_time", "ncu_trace"]
BASELINE_SKIP_METRICS = ["speedup", "accuracy"]
PRECISION_DTYPE_MAPPING = {
"fp32": torch.float32,
Expand Down Expand Up @@ -70,6 +70,15 @@ def do_bench_walltime(fn, warmup=25, rep=100):
wall_time_ms = (end_time - start_time) * 1e3 / n_repeat
return wall_time_ms

def _find_param_loc(l, key: str) -> int:
try:
return l.index(key)
except ValueError:
return -1
def _remove_params(l, loc):
if loc == -1:
return l
return l[:loc] + l[loc+2:]

@dataclass
class BenchmarkOperatorMetrics:
Expand All @@ -85,6 +94,8 @@ class BenchmarkOperatorMetrics:
walltime: Optional[float]
# compile time
compile_time: Optional[float]
# ncu trace file
ncu_trace: Optional[str]
# error message
error_msg: Optional[str]
# extra metrics
Expand Down Expand Up @@ -544,13 +555,16 @@ def _do_bench(
accuracy=accuracy,
walltime=walltime,
compile_time=None,
ncu_trace=None,
error_msg=error_msg,
extra_metrics={},
)
if "tflops" in self.required_metrics:
metric.tflops = self.tflops(fn_name, self.example_inputs, metric)
if "compile_time" in self.required_metrics:
metric.compile_time = self.compile_time(batch_id, fn_name, metric)
if "ncu_trace" in self.required_metrics:
metric.ncu_trace = self.ncu_trace(batch_id, fn_name)
extra_metrics = {}
# run the hidden metric "_compile_time_in_task"
# to get the compile time in parent process
Expand All @@ -559,6 +573,13 @@ def _do_bench(
"_compile_time_in_task must be measured by itself. " \
f"required_metrics: {self.required_metrics}, _only: {self._only}, _batch_id: {self._batch_id}"
extra_metrics["_compile_time_in_task"] = self._compile_time_in_task(fn)
if "_ncu_trace_in_task" in self.required_metrics:
assert self.required_metrics == ["_ncu_trace_in_task"] and self._only and (self._batch_id is not None), \
"_ncu_trace_in_task must be measured by itself. " \
f"required_metrics: {self.required_metrics}, _only: {self._only}, _batch_id: {self._batch_id}"
from torchbenchmark._components.ncu import do_bench_ncu_in_task
do_bench_ncu_in_task(fn=fn, warmup=warmup, grad_to_none=self.get_grad_to_none(self.example_inputs))
extra_metrics["_ncu_trace_in_task"] = "success"
# generate customized metrics
if self.name in REGISTERED_METRICS:
for metric_name in REGISTERED_METRICS[self.name]:
Expand All @@ -577,29 +598,47 @@ def _do_bench(
accuracy=None,
walltime=None,
compile_time=None,
ncu_trace=None,
error_msg="CUDA OOM",
extra_metrics={},
)
return metric


@register_metric()
def ncu_trace(self, batch_id: int, fn_name: str) -> str:
# collect the ncu trace
import sys
import subprocess
from pathlib import Path
op_task_args = copy.deepcopy(sys.argv)
for override_option in ["--only", "--batch-id", "--metrics"]:
op_task_args = _remove_params(op_task_args, _find_param_loc(op_task_args, override_option))
op_task_args.extend(["--only", fn_name, "--batch-id", str(batch_id), "--metrics", "_ncu_trace_in_task"])
# Disable DCGM
try:
disable_dcgm = ["sudo", "dyno", "dcgm_profiling", "--mute=true", "--duration=1000_s"]
subprocess.run(disable_dcgm, check=True)
except subprocess.SubprocessError:
warnings.warn("Cannot find dyno to disable DCGM. Proceed to collect NCU Trace.")
ncu_output_dir = Path(f"/tmp/tritonbench_{self.name}_{fn_name}_{batch_id}")
ncu_output_dir.mkdir(parents=True, exist_ok=True)
ncu_output_file = ncu_output_dir.joinpath("ncu_output.csv").resolve()
ncu_args = ["ncu", "--set", "full", "--replay-mode", "range", "--target-processes", "all", \
"--csv", "-f", "--log-file", str(ncu_output_file.resolve())]
ncu_args.extend(op_task_args)
subprocess.check_call(ncu_args)
return str(ncu_output_file.resolve())


@register_metric()
def compile_time(self, batch_id: int, fn_name: str, metrics: BenchmarkOperatorMetrics) -> float:
# We need to spawn a subprocess when user wants to measure the compile time
# of multiple batches and backends.
def _find_loc(l, key: str) -> int:
try:
return l.index(key)
except ValueError:
return -1
def _remove_element(l, loc):
if loc == -1:
return l
return l[:loc] + l[loc+2:]
from torchbenchmark.operators.op_task import OpTask
op_task_args = copy.deepcopy(self._raw_extra_args)
for override_option in ["--only", "--batch-id", "--metrics"]:
op_task_args = _remove_element(op_task_args, _find_loc(op_task_args, override_option))
op_task_args = _remove_params(op_task_args, _find_param_loc(op_task_args, override_option))
op_task_args.extend(["--only", fn_name, "--batch-id", str(batch_id), "--metrics", "_compile_time_in_task"])
op_task = OpTask(name=self.name)
op_task.make_operator_instance(mode=self.mode.value, device=self.device, extra_args=op_task_args)
Expand Down

0 comments on commit 6bff330

Please sign in to comment.