Skip to content

Commit

Permalink
Merge pull request #77 from jax-ml:batching-fix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 508421789
  • Loading branch information
The jax_triton Authors committed Feb 9, 2023
2 parents b8bc647 + 2cf597d commit e3a1931
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
23 changes: 21 additions & 2 deletions jax_triton/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
return value
assert is_indexing is not None
output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape)
squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing)])
squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing,
dtype=np.bool_)])
return lax.squeeze(output, squeeze_dims)

def _maybe_dynamic_update_slice(start_idx, block_shape, value, update,
Expand Down Expand Up @@ -215,7 +216,10 @@ def _block_map_function(new_idx, *args):
block_mapping_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_block_map_function), idx_avals)
shape = aval.shape if block_mapping is None else block_mapping.block_shape
new_block_shape = tuple_insert(shape, dim, pallas_core.mapped)
if dim is batching.not_mapped:
new_block_shape = shape
else:
new_block_shape = tuple_insert(shape, dim, pallas_core.mapped)
return BlockMapping(new_block_shape,
jax_core.ClosedJaxpr(block_mapping_jaxpr, consts))

Expand All @@ -240,7 +244,22 @@ def _pallas_call_batching_rule(args, dims, *,
# dimensions. For now, we just use 0.
# TODO(sharadmv): explore inferring better output dimensions via a heuristic
# TODO(sharadmv): explore a long term solution to output dim inference

# When we have input/output aliasing, since the output will be mapped, we need
# to make sure to broadcast the input across that dimension if it is not
# mapped.
dims_ = list(dims)
args_ = list(args)
for input_index, _ in input_output_aliases:
dim = dims_[input_index]
if dim is batching.not_mapped:
dims_[input_index] = 0
args_[input_index] = batching.broadcast(args_[input_index], axis_size, 0)
args = tuple(args_)
dims = tuple(dims_)

all_dims = list(dims) + [0] * len(out_shapes)

batched_block_mappings = map(partial(_batch_block_mapping, grid_spec.grid),
avals, all_dims, block_mappings)
batched_in_shapes = tuple(
Expand Down
22 changes: 22 additions & 0 deletions tests/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,16 @@ def add_one(x_ref, o_ref):
out_ref = jnp.arange(1, 9)
np.testing.assert_allclose(out, out_ref)

def test_vmap_of_simple_kernel_with_in_axes_None(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32),
debug=False)
def add(x_ref, y_ref, o_ref):
o_ref[()] = x_ref[()] + y_ref[()]
out = jax.vmap(add, in_axes=(0, None))(jnp.arange(8), 1)
out_ref = jnp.arange(1, 9)
np.testing.assert_allclose(out, out_ref)

def test_double_vmap_of_simple_kernel(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32),
Expand Down Expand Up @@ -669,6 +679,18 @@ def add_one(x_ref, o_ref):
out_ref = jnp.arange(1, 9).reshape((4, 2))
np.testing.assert_allclose(out, out_ref)

def test_vmap_of_kernel_with_input_output_aliases(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32),
debug=False,
input_output_aliases={1:0},
grid=())
def add(x_ref, _, o_ref):
o_ref[()] = x_ref[()] + o_ref[()] + 1
out = jax.vmap(add, in_axes=(0, None))(jnp.arange(8), 1)
out_ref = jnp.arange(2, 10)
np.testing.assert_allclose(out, out_ref)

def test_vmap_of_slicing_kernel_different_axes(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),
Expand Down

0 comments on commit e3a1931

Please sign in to comment.