Skip to content

Commit

Permalink
Add --dump-ir option
Browse files Browse the repository at this point in the history
Summary:
This will dump the Triton IR if available (i.e. if the benchmarked kernel is a
Triton one). Otherwise it is a no-op.

Reviewed By: chenyang78

Differential Revision: D59138284

fbshipit-source-id: 20e2424b3873b3fd9e918b50af82def2a6f4ef84
  • Loading branch information
int3 authored and facebook-github-bot committed Jul 2, 2024
1 parent f5e3f03 commit 9410a0f
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,11 @@ def parse_args(
action="store_true",
help="Run this under test mode, potentially skipping expensive steps like autotuning."
)
parser.add_argument(
"--dump-ir",
action="store_true",
help="Dump Triton IR",
)
return parser.parse_known_args(args)

class PostInitProcessor(type):
Expand Down Expand Up @@ -893,6 +898,8 @@ def _init_extra_metrics() -> Dict[str, Any]:
continue
func = getattr(self, metric_name)
metrics.extra_metrics[metric_name] = func(fn, self.example_inputs, metrics)
if self.tb_args.dump_ir:
self.dump_ir(input_id, fn)
except torch.cuda.OutOfMemoryError:
metrics.error_msg = "CUDA OOM"
except Exception as e:
Expand Down Expand Up @@ -1092,3 +1099,30 @@ def work_func():
self._op_flops[fn] = _get_flops(self, fn)
op_flops = self._op_flops[fn]
return op_flops / metrics.latency / 1e12 * 1e3

def dump_ir(self, input_id, fn):
from unittest import mock
from triton.runtime.jit import JITFunction

original_run = JITFunction.run
compiled_kernels = []

# There isn't really a great way to get the compiled kernels without monkeypatching
def run_and_capture(self, *args, **kwargs):
compiled_kernel = original_run(self, *args, **kwargs)
compiled_kernels.append(compiled_kernel)
return compiled_kernel

with mock.patch.object(JITFunction, "run", run_and_capture):
fn()

if len(compiled_kernels) > 0:
ir_dir = self.get_temp_path("ir")
ir_dir.mkdir(parents=True, exist_ok=True)
logger.info("Writing Triton IR to %s", ir_dir)

for kernel in compiled_kernels:
for ir in ["ttir", "ttgir", "llir", "ptx", "amdgcn"]:
if ir in kernel.asm:
with open(ir_dir / f"{fn._name}_{kernel.name}_{input_id}.{ir}", "w") as f:
f.write(kernel.asm[ir])

0 comments on commit 9410a0f

Please sign in to comment.