Skip to content

Commit

Permalink
Add missing parameters to jax-triton kernel call cache key.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 508402360
  • Loading branch information
chr1sj0nes authored and The jax_triton Authors committed Feb 9, 2023
1 parent 7088013 commit b8bc647
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,14 @@ def avals_to_layouts(avals):


def get_or_create_triton_kernel(
ctx,
fn,
scalar_arg_dtypes,
arg_dtypes,
*,
num_warps,
num_stages,
metaparams,
dump_binary_path,
) -> triton_kernel_call_lib.TritonKernel:
arg_dtypes = list(map(get_triton_type, ctx.avals_in))
for idx, dtype in scalar_arg_dtypes:
arg_dtypes.insert(idx, dtype)
arg_dtypes.extend(map(get_triton_type, ctx.avals_out))
signature = dict(enumerate(arg_dtypes))

constants = {fn.arg_names.index(k): v for k, v in metaparams.items()}
Expand Down Expand Up @@ -223,15 +218,17 @@ def triton_kernel_call_lowering(
)
i32_type = ir.IntegerType.get_signless(32)

scalar_arg_dtypes = []
arg_dtypes = list(map(get_triton_type, ctx.avals_in))
# Buffer args are filled in at runtime.
encoded_args = [None] * (len(array_args) + len(scalar_args) + len(out_shapes))
for idx, dtype, v in scalar_args:
scalar_arg_dtypes.append((idx, dtype))
arg_dtypes.insert(idx, dtype)
encoded_args[idx] = triton_kernel_call_lib.encode_kernel_parameter(v, dtype)

arg_dtypes.extend(map(get_triton_type, ctx.avals_out))

if isinstance(fn, triton.runtime.autotuner.Autotuner):
if any(idx not in fn.key_idx for idx, _ in scalar_arg_dtypes):
if any(idx not in fn.key_idx for idx, _, _ in scalar_args):
logging.warning(
"Auto-tuning key does not include all scalar arguments. "
"We may perform redundant auto-tuning."
Expand Down Expand Up @@ -264,7 +261,13 @@ def prune_configs(configs, named_args):

# Cache auto-tuned calls with the same parameters, so the auto-tuning need
# only be performed once.
cache_key = (fn, tuple(configs), tuple(encoded_args))
cache_key = (
fn,
tuple(configs),
tuple(arg_dtypes),
tuple(encoded_args),
tuple(metaparams.items()),
)
kernel_call = _KERNEL_CALL_CACHE.get(cache_key)

if kernel_call is None:
Expand All @@ -273,9 +276,8 @@ def prune_configs(configs, named_args):
config_metaparams = metaparams.copy()
config_metaparams.update(config.kwargs)
kernel = get_or_create_triton_kernel(
ctx,
fn,
scalar_arg_dtypes,
arg_dtypes,
num_warps=config.num_warps,
num_stages=config.num_stages,
metaparams=config_metaparams,
Expand Down

0 comments on commit b8bc647

Please sign in to comment.