Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for calling Triton code from NineToothed and 0-dimensional tensors #9

Merged
merged 7 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "ninetoothed"
version = "0.5.0"
version = "0.6.0"
authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }]
description = "A domain-specific language based on Triton but providing higher-level abstraction."
readme = "README.md"
Expand Down
48 changes: 34 additions & 14 deletions src/ninetoothed/jit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import ast
import collections
import functools
import importlib.util
import inspect
import itertools
import math
import sys
import tempfile

import triton
Expand Down Expand Up @@ -41,24 +43,20 @@ def __call__(self):
ast.fix_missing_locations(tree)

unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":")
dependencies = self._find_dependencies()
source = "\n\n".join((unparsed, dependencies)).strip()

with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
temp_file.write(unparsed.encode("utf-8"))
temp_file.write(source.encode("utf-8"))
temp_file_name = temp_file.name

with open(temp_file_name, "r") as temp_file:
code = compile(
source=temp_file.read(),
filename=temp_file_name,
mode="exec",
)

namespace = {}
exec(code, namespace)
module = type(self)._import_from_path(temp_file_name, temp_file_name)
module_vars = vars(module)

handle = _Handle(
namespace[self.func.__name__],
namespace[f"launch_{self.func.__name__}"],
module_vars[self.func.__name__],
module_vars[f"launch_{self.func.__name__}"],
source,
)

type(self).handles[source_file][source_line] = handle
Expand All @@ -74,6 +72,24 @@ def _get_tree(self):

return ast.Module(body=[finder.result], type_ignores=[])

def _find_dependencies(self):
dependencies = set()

for obj in self.func.__globals__.values():
if isinstance(obj, triton.runtime.JITFunction):
dependencies.add(obj.src)

return "\n".join(f"@triton.jit\n{dependency}" for dependency in dependencies)

@staticmethod
def _import_from_path(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)

return module


class CodeGenerator(ast.NodeTransformer):
def __init__(self, context):
Expand Down Expand Up @@ -102,7 +118,7 @@ def visit_FunctionDef(self, node):
self.generic_visit(node)

for arg in self._args:
if not isinstance(arg, Tensor):
if not isinstance(arg, Tensor) or arg.ndim == 0:
continue

offsets = arg.offsets()
Expand Down Expand Up @@ -355,6 +371,9 @@ def _generate_grid(self):

@staticmethod
def _generate_load(tensor, intermediate_indices=()):
if tensor.ndim == 0:
return Symbol(tensor.original.name).node

pointers, mask = CodeGenerator._generate_pointers_and_mask(
tensor, intermediate_indices
)
Expand Down Expand Up @@ -459,9 +478,10 @@ def visit_Call(self, node):


class _Handle:
def __init__(self, kernel, launch):
def __init__(self, kernel, launch, source):
self._kernel = kernel
self._launch = launch
self._source = source

def __call__(self, *args, **kwargs):
return self._launch(*args, **kwargs)
Expand Down
5 changes: 4 additions & 1 deletion src/ninetoothed/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(

self.dtype = dtype

self.name = f"tensor_{type(self).num_instances}"
self.name = f"_ninetoothed_tensor_{type(self).num_instances}"

if ndim is not None:
self.shape = (Symbol(self.size_string(i)) for i in range(ndim))
Expand Down Expand Up @@ -103,6 +103,9 @@ def squeeze(self, dim):
)

def names(self):
if self.ndim == 0:
return {self.original.name}

return (
{self.original.pointer_string()}
| {
Expand Down
104 changes: 104 additions & 0 deletions tests/test_addmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import random

import torch

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Symbol, Tensor
from tests.skippers import skip_if_cuda_not_available, skip_if_float8_e5m2_not_supported


def addmm(input, mat1, mat2, beta=1, alpha=1):
BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True)

input_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))

output_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))

mat1_tiled = (
Tensor(2)
.tile((BLOCK_SIZE_M, BLOCK_SIZE_K))
.tile((1, -1))
.expand((-1, output_tiled.shape[1]))
)
mat1_tiled.dtype = mat1_tiled.dtype.squeeze(0)

mat2_tiled = (
Tensor(2)
.tile((BLOCK_SIZE_K, BLOCK_SIZE_N))
.tile((-1, 1))
.expand((output_tiled.shape[0], -1))
)
mat2_tiled.dtype = mat2_tiled.dtype.squeeze(1)

@ninetoothed.jit
def addmm_kernel(
input: input_tiled,
mat1: mat1_tiled,
mat2: mat2_tiled,
beta: Tensor(0),
alpha: Tensor(0),
output: output_tiled,
):
accumulator = ntl.zeros(output.shape, dtype=ntl.float32)
for k in range(mat1.shape[0]):
accumulator += ntl.dot(mat1[k], mat2[k])
output = beta * input + alpha * accumulator.to(ntl.float16)

output = torch.empty(
(mat1.shape[0], mat2.shape[1]), device=mat1.device, dtype=torch.float16
)

addmm_kernel(input, mat1, mat2, beta, alpha, output)

return output


@skip_if_cuda_not_available
class TestCUDA:
@classmethod
def setup_class(cls):
torch.manual_seed(0)

shape = (512, 512)

cls.input = torch.randn(shape, device="cuda")
cls.mat1 = torch.randn(shape, device="cuda")
cls.mat2 = torch.randn(shape, device="cuda")
cls.beta = random.uniform(0, 1)
cls.alpha = random.uniform(0, 1)

def test_fp16(self):
input = type(self).input.to(torch.float16)
mat1 = type(self).mat1.to(torch.float16)
mat2 = type(self).mat2.to(torch.float16)
beta = type(self).beta
alpha = type(self).alpha

assert torch.allclose(
addmm(input, mat1, mat2, beta=beta, alpha=alpha),
torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha),
atol=0.075,
)

@skip_if_float8_e5m2_not_supported
def test_fp8(self):
input = type(self).input.to(torch.float8_e5m2)
mat1 = type(self).mat1.to(torch.float8_e5m2)
mat2 = type(self).mat2.T.to(torch.float8_e5m2)
beta = type(self).beta
alpha = type(self).alpha

assert torch.allclose(
addmm(input, mat1, mat2, beta=beta, alpha=alpha),
torch.addmm(
input.to(torch.float16),
mat1.to(torch.float16),
mat2.to(torch.float16),
beta=beta,
alpha=alpha,
),
atol=0.125,
)