From b3edc8d63696685c82598714801776e1df6899d4 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 21 Oct 2024 16:03:38 +0800 Subject: [PATCH 1/7] Store the generated source code in `jit._Handle` for debugging purposes --- src/ninetoothed/jit.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ninetoothed/jit.py b/src/ninetoothed/jit.py index 70cd8b9..1116b79 100644 --- a/src/ninetoothed/jit.py +++ b/src/ninetoothed/jit.py @@ -59,6 +59,7 @@ def __call__(self): handle = _Handle( namespace[self.func.__name__], namespace[f"launch_{self.func.__name__}"], + unparsed, ) type(self).handles[source_file][source_line] = handle @@ -459,9 +460,10 @@ def visit_Call(self, node): class _Handle: - def __init__(self, kernel, launch): + def __init__(self, kernel, launch, source): self._kernel = kernel self._launch = launch + self._source = source def __call__(self, *args, **kwargs): return self._launch(*args, **kwargs) From cdba87007adaff01ac530247572fb8a1edce89a3 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 22 Oct 2024 10:56:06 +0800 Subject: [PATCH 2/7] Use `importlib.util` instead of `compile` and `exec` to invoke Triton --- src/ninetoothed/jit.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/ninetoothed/jit.py b/src/ninetoothed/jit.py index 1116b79..c0024a4 100644 --- a/src/ninetoothed/jit.py +++ b/src/ninetoothed/jit.py @@ -1,9 +1,11 @@ import ast import collections import functools +import importlib.util import inspect import itertools import math +import sys import tempfile import triton @@ -46,19 +48,12 @@ def __call__(self): temp_file.write(unparsed.encode("utf-8")) temp_file_name = temp_file.name - with open(temp_file_name, "r") as temp_file: - code = compile( - source=temp_file.read(), - filename=temp_file_name, - mode="exec", - ) - - namespace = {} - exec(code, namespace) + module = type(self)._import_from_path(temp_file_name, temp_file_name) + module_vars = vars(module) handle = _Handle( - namespace[self.func.__name__], - namespace[f"launch_{self.func.__name__}"], + module_vars[self.func.__name__], + module_vars[f"launch_{self.func.__name__}"], unparsed, ) @@ -75,6 +70,15 @@ def _get_tree(self): return ast.Module(body=[finder.result], type_ignores=[]) + @staticmethod + def _import_from_path(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + return module + class CodeGenerator(ast.NodeTransformer): def __init__(self, context): From 45ed047bd04885599d88af3d13d010c0792c9dc1 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 22 Oct 2024 11:19:02 +0800 Subject: [PATCH 3/7] Add support for calling Triton code from NineToothed --- src/ninetoothed/jit.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/ninetoothed/jit.py b/src/ninetoothed/jit.py index c0024a4..d35fc91 100644 --- a/src/ninetoothed/jit.py +++ b/src/ninetoothed/jit.py @@ -43,9 +43,11 @@ def __call__(self): ast.fix_missing_locations(tree) unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":") + dependencies = self._find_dependencies() + source = "\n\n".join((unparsed, dependencies)).strip() with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file: - temp_file.write(unparsed.encode("utf-8")) + temp_file.write(source.encode("utf-8")) temp_file_name = temp_file.name module = type(self)._import_from_path(temp_file_name, temp_file_name) @@ -54,7 +56,7 @@ def __call__(self): handle = _Handle( module_vars[self.func.__name__], module_vars[f"launch_{self.func.__name__}"], - unparsed, + source, ) type(self).handles[source_file][source_line] = handle @@ -70,6 +72,15 @@ def _get_tree(self): return ast.Module(body=[finder.result], type_ignores=[]) + def _find_dependencies(self): + dependencies = set() + + for obj in self.func.__globals__.values(): + if isinstance(obj, triton.runtime.JITFunction): + dependencies.add(obj.src) + + return "\n".join(f"@triton.jit\n{dependency}" for dependency in dependencies) + @staticmethod def _import_from_path(module_name, file_path): spec = importlib.util.spec_from_file_location(module_name, file_path) From 99998d7d0452cb76da4862963c908862c044ba9e Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 22 Oct 2024 15:19:52 +0800 Subject: [PATCH 4/7] Treat a 0-dimensional tensor as a scalar --- src/ninetoothed/jit.py | 5 ++++- src/ninetoothed/tensor.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/ninetoothed/jit.py b/src/ninetoothed/jit.py index d35fc91..677d00b 100644 --- a/src/ninetoothed/jit.py +++ b/src/ninetoothed/jit.py @@ -118,7 +118,7 @@ def visit_FunctionDef(self, node): self.generic_visit(node) for arg in self._args: - if not isinstance(arg, Tensor): + if not isinstance(arg, Tensor) or arg.ndim == 0: continue offsets = arg.offsets() @@ -371,6 +371,9 @@ def _generate_grid(self): @staticmethod def _generate_load(tensor, intermediate_indices=()): + if tensor.ndim == 0: + return Symbol(tensor.original.name).node + pointers, mask = CodeGenerator._generate_pointers_and_mask( tensor, intermediate_indices ) diff --git a/src/ninetoothed/tensor.py b/src/ninetoothed/tensor.py index ffd5681..68befa8 100644 --- a/src/ninetoothed/tensor.py +++ b/src/ninetoothed/tensor.py @@ -103,6 +103,9 @@ def squeeze(self, dim): ) def names(self): + if self.ndim == 0: + return {self.original.name} + return ( {self.original.pointer_string()} | { From 5a20e4b28fafe1ed2b6f5973f42464a7bd08e488 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 22 Oct 2024 15:45:50 +0800 Subject: [PATCH 5/7] Prefix tensor names with `_ninetoothed_` --- src/ninetoothed/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ninetoothed/tensor.py b/src/ninetoothed/tensor.py index 68befa8..c5b9203 100644 --- a/src/ninetoothed/tensor.py +++ b/src/ninetoothed/tensor.py @@ -21,7 +21,7 @@ def __init__( self.dtype = dtype - self.name = f"tensor_{type(self).num_instances}" + self.name = f"_ninetoothed_tensor_{type(self).num_instances}" if ndim is not None: self.shape = (Symbol(self.size_string(i)) for i in range(ndim)) From 7edeccfd1b3881defed867e9ee13116d9e0bd1e8 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 22 Oct 2024 16:36:42 +0800 Subject: [PATCH 6/7] Add test for addmm --- tests/test_addmm.py | 104 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 tests/test_addmm.py diff --git a/tests/test_addmm.py b/tests/test_addmm.py new file mode 100644 index 0000000..ab100f2 --- /dev/null +++ b/tests/test_addmm.py @@ -0,0 +1,104 @@ +import random + +import torch + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor +from tests.skippers import skip_if_cuda_not_available, skip_if_float8_e5m2_not_supported + + +def addmm(input, mat1, mat2, beta=1, alpha=1): + BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) + BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) + BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True) + + input_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N)) + + output_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N)) + + mat1_tiled = ( + Tensor(2) + .tile((BLOCK_SIZE_M, BLOCK_SIZE_K)) + .tile((1, -1)) + .expand((-1, output_tiled.shape[1])) + ) + mat1_tiled.dtype = mat1_tiled.dtype.squeeze(0) + + mat2_tiled = ( + Tensor(2) + .tile((BLOCK_SIZE_K, BLOCK_SIZE_N)) + .tile((-1, 1)) + .expand((output_tiled.shape[0], -1)) + ) + mat2_tiled.dtype = mat2_tiled.dtype.squeeze(1) + + @ninetoothed.jit + def addmm_kernel( + input: input_tiled, + mat1: mat1_tiled, + mat2: mat2_tiled, + beta: Tensor(0), + alpha: Tensor(0), + output: output_tiled, + ): + accumulator = ntl.zeros(output.shape, dtype=ntl.float32) + for k in range(mat1.shape[0]): + accumulator += ntl.dot(mat1[k], mat2[k]) + output = beta * input + alpha * accumulator.to(ntl.float16) + + output = torch.empty( + (mat1.shape[0], mat2.shape[1]), device=mat1.device, dtype=torch.float16 + ) + + addmm_kernel(input, mat1, mat2, beta, alpha, output) + + return output + + +@skip_if_cuda_not_available +class TestCUDA: + @classmethod + def setup_class(cls): + torch.manual_seed(0) + + shape = (512, 512) + + cls.input = torch.randn(shape, device="cuda") + cls.mat1 = torch.randn(shape, device="cuda") + cls.mat2 = torch.randn(shape, device="cuda") + cls.beta = random.uniform(0, 1) + cls.alpha = random.uniform(0, 1) + + def test_fp16(self): + input = type(self).input.to(torch.float16) + mat1 = type(self).mat1.to(torch.float16) + mat2 = type(self).mat2.to(torch.float16) + beta = type(self).beta + alpha = type(self).alpha + + assert torch.allclose( + addmm(input, mat1, mat2, beta=beta, alpha=alpha), + torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha), + atol=0.075, + ) + + @skip_if_float8_e5m2_not_supported + def test_fp8(self): + input = type(self).input.to(torch.float8_e5m2) + mat1 = type(self).mat1.to(torch.float8_e5m2) + mat2 = type(self).mat2.T.to(torch.float8_e5m2) + beta = type(self).beta + alpha = type(self).alpha + + assert torch.allclose( + addmm(input, mat1, mat2, beta=beta, alpha=alpha), + torch.addmm( + input.to(torch.float16), + mat1.to(torch.float16), + mat2.to(torch.float16), + beta=beta, + alpha=alpha, + ), + atol=0.125, + ) From 8fe5586ba60c96c8571bd4c6b34f0e77f6a7c9d0 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 22 Oct 2024 17:13:40 +0800 Subject: [PATCH 7/7] Increment the version number from 0.5.0 to 0.6.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e1dd19c..5df9264 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "ninetoothed" -version = "0.5.0" +version = "0.6.0" authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }] description = "A domain-specific language based on Triton but providing higher-level abstraction." readme = "README.md"