From 5d781e250696a215e0411306d6ce79e863c3ed52 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 23 Aug 2024 13:48:18 +0800 Subject: [PATCH 1/4] Add `symbol.Symbol.remove_prefix` to remove prefixes from constexpr and meta symbols --- src/ninetoothed/symbol.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/src/ninetoothed/symbol.py b/src/ninetoothed/symbol.py index 7e48ef7..2fded55 100644 --- a/src/ninetoothed/symbol.py +++ b/src/ninetoothed/symbol.py @@ -137,19 +137,39 @@ def visit_Call(self, node): @staticmethod def is_constexpr(name): - return name.startswith("_ninetoothed_constexpr_") or Symbol.is_meta(name) + return name.startswith(Symbol._constexpr_prefix()) or Symbol.is_meta(name) @staticmethod def is_meta(name): - return name.startswith("_ninetoothed_meta_") + return name.startswith(Symbol._meta_prefix()) + + @staticmethod + def remove_prefix(name): + if name.startswith(Symbol._constexpr_prefix()): + return name.removeprefix(Symbol._constexpr_prefix()) + + if name.startswith(Symbol._meta_prefix()): + return name.removeprefix(Symbol._meta_prefix()) @staticmethod def _create_constexpr(name): - return f"_ninetoothed_constexpr_{name}" + return f"{Symbol._constexpr_prefix()}{name}" @staticmethod def _create_meta(name): - return f"_ninetoothed_meta_{name}" + return f"{Symbol._meta_prefix()}{name}" + + @staticmethod + def _constexpr_prefix(): + return f"{Symbol._ninetoothed_prefix()}constexpr_" + + @staticmethod + def _meta_prefix(): + return f"{Symbol._ninetoothed_prefix()}meta_" + + @staticmethod + def _ninetoothed_prefix(): + return "_ninetoothed_" class _FindAndReplacer(ast.NodeTransformer): From f5651fb178ece0b8a79b88da1069184e8f1f86f5 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 23 Aug 2024 13:51:57 +0800 Subject: [PATCH 2/4] Add support for handling constexpr parameters --- src/ninetoothed/jit.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/ninetoothed/jit.py b/src/ninetoothed/jit.py index 128e90b..6092ed9 100644 --- a/src/ninetoothed/jit.py +++ b/src/ninetoothed/jit.py @@ -137,7 +137,7 @@ def visit_arguments(self, node): node.args = [ ast.arg(arg=name) if not Symbol.is_constexpr(name) - else ast.arg(arg=name, annotation=attribute("constexpr")) + else ast.arg(arg=name, annotation=attribute("constexpr").node) for name in non_meta_names ] + [ ast.arg(arg=name, annotation=attribute("constexpr").node) @@ -287,15 +287,30 @@ def _generate_autotune(self, params, meta): ) def _generate_launch(self, params, meta): + constexpr_params = [param for param in params if Symbol.is_constexpr(param)] + constexpr_params_without_prefixes = [ + Symbol.remove_prefix(param) for param in constexpr_params + ] + launch = ast.FunctionDef( name=f"launch_{self._func_def.name}", args=ast.arguments( posonlyargs=[], - args=[ast.arg(arg.original.name) for arg in self._args], + args=[ast.arg(arg=arg.original.name) for arg in self._args] + + [ast.arg(arg=param) for param in constexpr_params_without_prefixes], kwonlyargs=[], defaults=[], ), body=[ + ast.Assign( + targets=[ast.Name(id=param, ctx=ast.Store())], + value=ast.Name(id=param_without_prefix, ctx=ast.Load()), + ) + for param, param_without_prefix in zip( + constexpr_params, constexpr_params_without_prefixes + ) + ] + + [ ast.Expr( ast.Call( func=ast.Subscript( From 27e73576105887f5a16b304ecc58d545995b9449 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 23 Aug 2024 13:58:02 +0800 Subject: [PATCH 3/4] Replace the original `block_size` variable in `test_softmax.py` with the constexpr symbol `BLOCK_SIZE` --- tests/test_softmax.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_softmax.py b/tests/test_softmax.py index c926bd4..a3e2a12 100644 --- a/tests/test_softmax.py +++ b/tests/test_softmax.py @@ -3,26 +3,26 @@ import ninetoothed import ninetoothed.language as ntl -from ninetoothed import Tensor +from ninetoothed import Symbol, Tensor from tests.skippers import skip_if_cuda_not_available def softmax(input): - output = torch.empty_like(input) - - block_size = triton.next_power_of_2(input.shape[-1]) + BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True) @ninetoothed.jit def softmax_kernel( - input_row: Tensor(2, other=float("-inf")).tile((1, block_size)), - output_row: Tensor(2).tile((1, block_size)), + input_row: Tensor(2, other=float("-inf")).tile((1, BLOCK_SIZE)), + output_row: Tensor(2).tile((1, BLOCK_SIZE)), ): row_minus_max = input_row - ntl.max(input_row) numerator = ntl.exp(row_minus_max) denominator = ntl.sum(numerator) output_row = numerator / denominator # noqa: F841 - softmax_kernel(input, output) + output = torch.empty_like(input) + + softmax_kernel(input, output, BLOCK_SIZE=triton.next_power_of_2(input.shape[-1])) return output From 3d996c84749975d89f8ac52f4c6dfd7ddc52c47c Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 23 Aug 2024 14:03:26 +0800 Subject: [PATCH 4/4] Increment the version number from 0.3.0 to 0.4.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ca57066..42afa67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "ninetoothed" -version = "0.3.0" +version = "0.4.0" authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }] description = "A domain-specific language based on Triton but providing higher-level abstraction." readme = "README.md"