Skip to content

Commit

Permalink
Add manual heuristics for num_warps and num_stages
Browse files Browse the repository at this point in the history
  • Loading branch information
voltjia committed Aug 21, 2024
1 parent a0d3cf3 commit b9bce30
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/ninetoothed/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import math
import tempfile

import triton

from ninetoothed.language import attribute, call
from ninetoothed.symbol import Symbol
from ninetoothed.tensor import Tensor
Expand Down Expand Up @@ -226,6 +228,13 @@ def visit_Assign(self, node):
return node

def _generate_autotune(self, params, meta):
device = triton.runtime.driver.active.get_current_device()
properties = triton.runtime.driver.active.utils.get_device_properties(device)
max_shared_mem = properties["max_shared_mem"]

num_warps = 8
num_stages = max_shared_mem // 2**15

configs = [
ast.Call(
func=ast.Attribute(
Expand All @@ -239,7 +248,10 @@ def _generate_autotune(self, params, meta):
values=[ast.Constant(value=value) for value in values],
)
],
keywords=[],
keywords=[
ast.keyword(arg="num_warps", value=ast.Constant(value=num_warps)),
ast.keyword(arg="num_stages", value=ast.Constant(value=num_stages)),
],
)
for values in itertools.product(self._POWER_OF_TWOS, repeat=len(meta))
if self._MIN_PRODUCT <= math.prod(values) <= self._MAX_PRODUCT
Expand Down

0 comments on commit b9bce30

Please sign in to comment.