Skip to content

Commit

Permalink
Merge pull request #7 from InfiniTensor/dev
Browse files Browse the repository at this point in the history
Add support for handling constexpr parameters
  • Loading branch information
voltjia authored Aug 23, 2024
2 parents 2589556 + 3d996c8 commit 895beec
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "ninetoothed"
version = "0.3.0"
version = "0.4.0"
authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }]
description = "A domain-specific language based on Triton but providing higher-level abstraction."
readme = "README.md"
Expand Down
19 changes: 17 additions & 2 deletions src/ninetoothed/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def visit_arguments(self, node):
node.args = [
ast.arg(arg=name)
if not Symbol.is_constexpr(name)
else ast.arg(arg=name, annotation=attribute("constexpr"))
else ast.arg(arg=name, annotation=attribute("constexpr").node)
for name in non_meta_names
] + [
ast.arg(arg=name, annotation=attribute("constexpr").node)
Expand Down Expand Up @@ -287,15 +287,30 @@ def _generate_autotune(self, params, meta):
)

def _generate_launch(self, params, meta):
constexpr_params = [param for param in params if Symbol.is_constexpr(param)]
constexpr_params_without_prefixes = [
Symbol.remove_prefix(param) for param in constexpr_params
]

launch = ast.FunctionDef(
name=f"launch_{self._func_def.name}",
args=ast.arguments(
posonlyargs=[],
args=[ast.arg(arg.original.name) for arg in self._args],
args=[ast.arg(arg=arg.original.name) for arg in self._args]
+ [ast.arg(arg=param) for param in constexpr_params_without_prefixes],
kwonlyargs=[],
defaults=[],
),
body=[
ast.Assign(
targets=[ast.Name(id=param, ctx=ast.Store())],
value=ast.Name(id=param_without_prefix, ctx=ast.Load()),
)
for param, param_without_prefix in zip(
constexpr_params, constexpr_params_without_prefixes
)
]
+ [
ast.Expr(
ast.Call(
func=ast.Subscript(
Expand Down
28 changes: 24 additions & 4 deletions src/ninetoothed/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,39 @@ def visit_Call(self, node):

@staticmethod
def is_constexpr(name):
return name.startswith("_ninetoothed_constexpr_") or Symbol.is_meta(name)
return name.startswith(Symbol._constexpr_prefix()) or Symbol.is_meta(name)

@staticmethod
def is_meta(name):
return name.startswith("_ninetoothed_meta_")
return name.startswith(Symbol._meta_prefix())

@staticmethod
def remove_prefix(name):
if name.startswith(Symbol._constexpr_prefix()):
return name.removeprefix(Symbol._constexpr_prefix())

if name.startswith(Symbol._meta_prefix()):
return name.removeprefix(Symbol._meta_prefix())

@staticmethod
def _create_constexpr(name):
return f"_ninetoothed_constexpr_{name}"
return f"{Symbol._constexpr_prefix()}{name}"

@staticmethod
def _create_meta(name):
return f"_ninetoothed_meta_{name}"
return f"{Symbol._meta_prefix()}{name}"

@staticmethod
def _constexpr_prefix():
return f"{Symbol._ninetoothed_prefix()}constexpr_"

@staticmethod
def _meta_prefix():
return f"{Symbol._ninetoothed_prefix()}meta_"

@staticmethod
def _ninetoothed_prefix():
return "_ninetoothed_"


class _FindAndReplacer(ast.NodeTransformer):
Expand Down
14 changes: 7 additions & 7 deletions tests/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,26 @@

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor
from ninetoothed import Symbol, Tensor
from tests.skippers import skip_if_cuda_not_available


def softmax(input):
output = torch.empty_like(input)

block_size = triton.next_power_of_2(input.shape[-1])
BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True)

@ninetoothed.jit
def softmax_kernel(
input_row: Tensor(2, other=float("-inf")).tile((1, block_size)),
output_row: Tensor(2).tile((1, block_size)),
input_row: Tensor(2, other=float("-inf")).tile((1, BLOCK_SIZE)),
output_row: Tensor(2).tile((1, BLOCK_SIZE)),
):
row_minus_max = input_row - ntl.max(input_row)
numerator = ntl.exp(row_minus_max)
denominator = ntl.sum(numerator)
output_row = numerator / denominator # noqa: F841

softmax_kernel(input, output)
output = torch.empty_like(input)

softmax_kernel(input, output, BLOCK_SIZE=triton.next_power_of_2(input.shape[-1]))

return output

Expand Down

0 comments on commit 895beec

Please sign in to comment.