From b8bc64748497c8835f4f5d731463c45596168876 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Thu, 9 Feb 2023 09:56:20 -0800 Subject: [PATCH] Add missing parameters to jax-triton kernel call cache key. PiperOrigin-RevId: 508402360 --- jax_triton/triton_lib.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index d92d0e79..029fc427 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -132,19 +132,14 @@ def avals_to_layouts(avals): def get_or_create_triton_kernel( - ctx, fn, - scalar_arg_dtypes, + arg_dtypes, *, num_warps, num_stages, metaparams, dump_binary_path, ) -> triton_kernel_call_lib.TritonKernel: - arg_dtypes = list(map(get_triton_type, ctx.avals_in)) - for idx, dtype in scalar_arg_dtypes: - arg_dtypes.insert(idx, dtype) - arg_dtypes.extend(map(get_triton_type, ctx.avals_out)) signature = dict(enumerate(arg_dtypes)) constants = {fn.arg_names.index(k): v for k, v in metaparams.items()} @@ -223,15 +218,17 @@ def triton_kernel_call_lowering( ) i32_type = ir.IntegerType.get_signless(32) - scalar_arg_dtypes = [] + arg_dtypes = list(map(get_triton_type, ctx.avals_in)) # Buffer args are filled in at runtime. encoded_args = [None] * (len(array_args) + len(scalar_args) + len(out_shapes)) for idx, dtype, v in scalar_args: - scalar_arg_dtypes.append((idx, dtype)) + arg_dtypes.insert(idx, dtype) encoded_args[idx] = triton_kernel_call_lib.encode_kernel_parameter(v, dtype) + arg_dtypes.extend(map(get_triton_type, ctx.avals_out)) + if isinstance(fn, triton.runtime.autotuner.Autotuner): - if any(idx not in fn.key_idx for idx, _ in scalar_arg_dtypes): + if any(idx not in fn.key_idx for idx, _, _ in scalar_args): logging.warning( "Auto-tuning key does not include all scalar arguments. " "We may perform redundant auto-tuning." @@ -264,7 +261,13 @@ def prune_configs(configs, named_args): # Cache auto-tuned calls with the same parameters, so the auto-tuning need # only be performed once. - cache_key = (fn, tuple(configs), tuple(encoded_args)) + cache_key = ( + fn, + tuple(configs), + tuple(arg_dtypes), + tuple(encoded_args), + tuple(metaparams.items()), + ) kernel_call = _KERNEL_CALL_CACHE.get(cache_key) if kernel_call is None: @@ -273,9 +276,8 @@ def prune_configs(configs, named_args): config_metaparams = metaparams.copy() config_metaparams.update(config.kwargs) kernel = get_or_create_triton_kernel( - ctx, fn, - scalar_arg_dtypes, + arg_dtypes, num_warps=config.num_warps, num_stages=config.num_stages, metaparams=config_metaparams,