Skip to content

Commit

Permalink
Merge pull request #9 from InfiniTensor/dev
Browse files Browse the repository at this point in the history
Add support for calling Triton code from NineToothed and 0-dimensional tensors
  • Loading branch information
voltjia authored Oct 22, 2024
2 parents 34d3e04 + 8fe5586 commit 2c64fd8
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "ninetoothed"
version = "0.5.0"
version = "0.6.0"
authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }]
description = "A domain-specific language based on Triton but providing higher-level abstraction."
readme = "README.md"
Expand Down
48 changes: 34 additions & 14 deletions src/ninetoothed/jit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import ast
import collections
import functools
import importlib.util
import inspect
import itertools
import math
import sys
import tempfile

import triton
Expand Down Expand Up @@ -41,24 +43,20 @@ def __call__(self):
ast.fix_missing_locations(tree)

unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":")
dependencies = self._find_dependencies()
source = "\n\n".join((unparsed, dependencies)).strip()

with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
temp_file.write(unparsed.encode("utf-8"))
temp_file.write(source.encode("utf-8"))
temp_file_name = temp_file.name

with open(temp_file_name, "r") as temp_file:
code = compile(
source=temp_file.read(),
filename=temp_file_name,
mode="exec",
)

namespace = {}
exec(code, namespace)
module = type(self)._import_from_path(temp_file_name, temp_file_name)
module_vars = vars(module)

handle = _Handle(
namespace[self.func.__name__],
namespace[f"launch_{self.func.__name__}"],
module_vars[self.func.__name__],
module_vars[f"launch_{self.func.__name__}"],
source,
)

type(self).handles[source_file][source_line] = handle
Expand All @@ -74,6 +72,24 @@ def _get_tree(self):

return ast.Module(body=[finder.result], type_ignores=[])

def _find_dependencies(self):
dependencies = set()

for obj in self.func.__globals__.values():
if isinstance(obj, triton.runtime.JITFunction):
dependencies.add(obj.src)

return "\n".join(f"@triton.jit\n{dependency}" for dependency in dependencies)

@staticmethod
def _import_from_path(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)

return module


class CodeGenerator(ast.NodeTransformer):
def __init__(self, context):
Expand Down Expand Up @@ -102,7 +118,7 @@ def visit_FunctionDef(self, node):
self.generic_visit(node)

for arg in self._args:
if not isinstance(arg, Tensor):
if not isinstance(arg, Tensor) or arg.ndim == 0:
continue

offsets = arg.offsets()
Expand Down Expand Up @@ -355,6 +371,9 @@ def _generate_grid(self):

@staticmethod
def _generate_load(tensor, intermediate_indices=()):
if tensor.ndim == 0:
return Symbol(tensor.original.name).node

pointers, mask = CodeGenerator._generate_pointers_and_mask(
tensor, intermediate_indices
)
Expand Down Expand Up @@ -459,9 +478,10 @@ def visit_Call(self, node):


class _Handle:
def __init__(self, kernel, launch):
def __init__(self, kernel, launch, source):
self._kernel = kernel
self._launch = launch
self._source = source

def __call__(self, *args, **kwargs):
return self._launch(*args, **kwargs)
Expand Down
5 changes: 4 additions & 1 deletion src/ninetoothed/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(

self.dtype = dtype

self.name = f"tensor_{type(self).num_instances}"
self.name = f"_ninetoothed_tensor_{type(self).num_instances}"

if ndim is not None:
self.shape = (Symbol(self.size_string(i)) for i in range(ndim))
Expand Down Expand Up @@ -103,6 +103,9 @@ def squeeze(self, dim):
)

def names(self):
if self.ndim == 0:
return {self.original.name}

return (
{self.original.pointer_string()}
| {
Expand Down
104 changes: 104 additions & 0 deletions tests/test_addmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import random

import torch

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Symbol, Tensor
from tests.skippers import skip_if_cuda_not_available, skip_if_float8_e5m2_not_supported


def addmm(input, mat1, mat2, beta=1, alpha=1):
BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True)

input_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))

output_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))

mat1_tiled = (
Tensor(2)
.tile((BLOCK_SIZE_M, BLOCK_SIZE_K))
.tile((1, -1))
.expand((-1, output_tiled.shape[1]))
)
mat1_tiled.dtype = mat1_tiled.dtype.squeeze(0)

mat2_tiled = (
Tensor(2)
.tile((BLOCK_SIZE_K, BLOCK_SIZE_N))
.tile((-1, 1))
.expand((output_tiled.shape[0], -1))
)
mat2_tiled.dtype = mat2_tiled.dtype.squeeze(1)

@ninetoothed.jit
def addmm_kernel(
input: input_tiled,
mat1: mat1_tiled,
mat2: mat2_tiled,
beta: Tensor(0),
alpha: Tensor(0),
output: output_tiled,
):
accumulator = ntl.zeros(output.shape, dtype=ntl.float32)
for k in range(mat1.shape[0]):
accumulator += ntl.dot(mat1[k], mat2[k])
output = beta * input + alpha * accumulator.to(ntl.float16)

output = torch.empty(
(mat1.shape[0], mat2.shape[1]), device=mat1.device, dtype=torch.float16
)

addmm_kernel(input, mat1, mat2, beta, alpha, output)

return output


@skip_if_cuda_not_available
class TestCUDA:
@classmethod
def setup_class(cls):
torch.manual_seed(0)

shape = (512, 512)

cls.input = torch.randn(shape, device="cuda")
cls.mat1 = torch.randn(shape, device="cuda")
cls.mat2 = torch.randn(shape, device="cuda")
cls.beta = random.uniform(0, 1)
cls.alpha = random.uniform(0, 1)

def test_fp16(self):
input = type(self).input.to(torch.float16)
mat1 = type(self).mat1.to(torch.float16)
mat2 = type(self).mat2.to(torch.float16)
beta = type(self).beta
alpha = type(self).alpha

assert torch.allclose(
addmm(input, mat1, mat2, beta=beta, alpha=alpha),
torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha),
atol=0.075,
)

@skip_if_float8_e5m2_not_supported
def test_fp8(self):
input = type(self).input.to(torch.float8_e5m2)
mat1 = type(self).mat1.to(torch.float8_e5m2)
mat2 = type(self).mat2.T.to(torch.float8_e5m2)
beta = type(self).beta
alpha = type(self).alpha

assert torch.allclose(
addmm(input, mat1, mat2, beta=beta, alpha=alpha),
torch.addmm(
input.to(torch.float16),
mat1.to(torch.float16),
mat2.to(torch.float16),
beta=beta,
alpha=alpha,
),
atol=0.125,
)

0 comments on commit 2c64fd8

Please sign in to comment.