diff --git a/jax_triton/pallas/pallas_call.py b/jax_triton/pallas/pallas_call.py index db584713..5b69058d 100644 --- a/jax_triton/pallas/pallas_call.py +++ b/jax_triton/pallas/pallas_call.py @@ -67,7 +67,8 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing): return value assert is_indexing is not None output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape) - squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing)]) + squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing, + dtype=np.bool_)]) return lax.squeeze(output, squeeze_dims) def _maybe_dynamic_update_slice(start_idx, block_shape, value, update, @@ -215,7 +216,10 @@ def _block_map_function(new_idx, *args): block_mapping_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(_block_map_function), idx_avals) shape = aval.shape if block_mapping is None else block_mapping.block_shape - new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) + if dim is batching.not_mapped: + new_block_shape = shape + else: + new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) return BlockMapping(new_block_shape, jax_core.ClosedJaxpr(block_mapping_jaxpr, consts)) @@ -240,7 +244,22 @@ def _pallas_call_batching_rule(args, dims, *, # dimensions. For now, we just use 0. # TODO(sharadmv): explore inferring better output dimensions via a heuristic # TODO(sharadmv): explore a long term solution to output dim inference + + # When we have input/output aliasing, since the output will be mapped, we need + # to make sure to broadcast the input across that dimension if it is not + # mapped. + dims_ = list(dims) + args_ = list(args) + for input_index, _ in input_output_aliases: + dim = dims_[input_index] + if dim is batching.not_mapped: + dims_[input_index] = 0 + args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0) + args = tuple(args_) + dims = tuple(dims_) + all_dims = list(dims) + [0] * len(out_shapes) + batched_block_mappings = map(partial(_batch_block_mapping, grid_spec.grid), avals, all_dims, block_mappings) batched_in_shapes = tuple( diff --git a/tests/pallas_test.py b/tests/pallas_test.py index 0c44d27b..77c9fb79 100644 --- a/tests/pallas_test.py +++ b/tests/pallas_test.py @@ -623,6 +623,16 @@ def add_one(x_ref, o_ref): out_ref = jnp.arange(1, 9) np.testing.assert_allclose(out, out_ref) + def test_vmap_of_simple_kernel_with_in_axes_None(self): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + debug=False) + def add(x_ref, y_ref, o_ref): + o_ref[()] = x_ref[()] + y_ref[()] + out = jax.vmap(add, in_axes=(0, None))(jnp.arange(8), 1) + out_ref = jnp.arange(1, 9) + np.testing.assert_allclose(out, out_ref) + def test_double_vmap_of_simple_kernel(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), @@ -669,6 +679,18 @@ def add_one(x_ref, o_ref): out_ref = jnp.arange(1, 9).reshape((4, 2)) np.testing.assert_allclose(out, out_ref) + def test_vmap_of_kernel_with_input_output_aliases(self): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), + debug=False, + input_output_aliases={1:0}, + grid=()) + def add(x_ref, _, o_ref): + o_ref[()] = x_ref[()] + o_ref[()] + 1 + out = jax.vmap(add, in_axes=(0, None))(jnp.arange(8), 1) + out_ref = jnp.arange(2, 10) + np.testing.assert_allclose(out, out_ref) + def test_vmap_of_slicing_kernel_different_axes(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),