Skip to content

Commit

Permalink
[TKW] Refactor gather indices generation (#222)
Browse files Browse the repository at this point in the history
Previously, we could end generating dynamic indices both for the gather
`start_indices` and `offsets_vec` which is suboptimal.

Now, detect if we need dynamic `offsets_vec`, and generate
`start_indices` as 0 in this case, encoding entire index calculation in
`offsets_vec`.

---------

Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
  • Loading branch information
Hardcode84 authored Oct 17, 2024
1 parent a92f3db commit 1aa05e0
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 13 deletions.
46 changes: 33 additions & 13 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,12 @@ def _get_const(val):
_enforce_non_rational(lhs, term)
res = arith_d.andi(*_broadcast(lhs, rhs))
stack.append(res)
case sympy.logic.boolalg.BooleanFalse():
res = arith_d.constant(IntegerType.get_signless(1), 0)
stack.append(res)
case sympy.logic.boolalg.BooleanTrue():
res = arith_d.constant(IntegerType.get_signless(1), 1)
stack.append(res)
case sympy.UnevaluatedExpr():
continue
case _:
Expand Down Expand Up @@ -599,8 +605,8 @@ def _construct_gather_scatter_indices(

start_indices = _get_start_indices(result_index)
start_indices_orig = _get_start_indices(index)
dynamic_offsets = []

need_dynamic_offsets = False
start_indices_offset = _compute_offset(start_indices, strides)
for i in range(elements_per_thread):
# Update most-minor dim, i.e. in case of identity mapping it will
Expand All @@ -626,22 +632,36 @@ def _construct_gather_scatter_indices(
# arith ops and then `vector.insertelement` them into offsets vec.
offset = int(offset)
else:
dyn_offset = gen_sympy_index(add_emitter_subs(emitter), offset)
dynamic_offsets.append((i, dyn_offset))
offset = 0
need_dynamic_offsets = True
break

offsets.append(IntegerAttr.get(IndexType.get(), offset))

start_indices = _build_start_indices(emitter, result_index)
offsets_vec_type = VectorType.get([elements_per_thread], IndexType.get())

offsets_vec = arith_d.ConstantOp(
offsets_vec_type, DenseElementsAttr.get(offsets, offsets_vec_type)
)

for i, off in dynamic_offsets:
pos = arith_d.ConstantOp(IndexType.get(), i)
offsets_vec = vector_d.insertelement(off, offsets_vec, position=pos)
if need_dynamic_offsets:
# In case we need dynamic `offsets_vec`, set all `start_indices` to 0
# and encode entire index info in `offsets_vec`.
result_index = {key: 0 for key in symbolc_shape}
start_indices = _build_start_indices(emitter, result_index)
subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)]
# Last item in `subs` corresponds to last item in `start_indices_orig`
# which is fastest changing dim.
# Replacing last element with `idxc.iota(elements_per_thread)` will
# generate vectorized index code, each element in it corresponding to
# individual vector element index.
subs[-1] = (
subs[-1][0],
start_indices_orig[-1] + idxc.iota(elements_per_thread),
)
indices = [i.subs(subs) for i in index_mapping]
offsets_vec = gen_sympy_index(
add_emitter_subs(emitter), _compute_offset(indices, strides)
)
else:
start_indices = _build_start_indices(emitter, result_index)
offsets_vec = arith_d.ConstantOp(
offsets_vec_type, DenseElementsAttr.get(offsets, offsets_vec_type)
)

mask = _build_mask(emitter, index, elements_per_thread)
if mask is None:
Expand Down
140 changes: 140 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,145 @@ def repeat(
# CHECK: scf.yield %[[MAX]], %[[MMA]] : vector<1xf16>, vector<4xf32>


@run_test
def test_igemm():
n, c, h, w = 2, 640, 64, 64
cf, hf, wf, nf = c, 3, 3, 640
stride = 1
padding = 0

x = torch.randn(n, c, h, w, dtype=torch.float16)
we = torch.randn(nf, cf, hf, wf, dtype=torch.float16)

h_out = (h + 2 * padding - hf) // stride + 1
w_out = (w + 2 * padding - wf) // stride + 1
res_shape = (n, nf, h_out, w_out)
out = torch.zeros(res_shape, dtype=torch.float32)

sym = tkl.sym
N, C, H, W = sym.N, sym.C, sym.H, sym.W
NF, HF, WF = sym.NF, sym.HF, sym.WF

H_OUT = (H + 2 * padding - HF) // stride + 1
W_OUT = (W + 2 * padding - WF) // stride + 1
SZ_OUT = H_OUT * W_OUT

K = HF * WF * C
M = SZ_OUT * N

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)

x_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={
N: i // SZ_OUT,
C: j // (HF * WF),
H: (i % SZ_OUT) % W_OUT * stride + (j % (HF * WF)) % WF,
W: (i % SZ_OUT) // W_OUT * stride + (j % (HF * WF)) // WF,
},
outputs={M: i, K: j},
)
w_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={NF: i % NF, C: j // (HF * WF), HF: j % WF, WF: (j % (HF * WF)) // WF},
outputs={NF: i, K: j},
)
out_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={M: i, NF: j},
outputs={
N: i // SZ_OUT,
NF: j,
H_OUT: (i % SZ_OUT) % W_OUT,
W_OUT: (i % SZ_OUT) // W_OUT,
},
)

# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = 16
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
# Other hyperparameters
ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD

# layout == "nhwc_hwcf"
x_type = tkl.Memory[N, H, W, C, ADDRESS_SPACE, tkl.f16]
we_type = tkl.Memory[HF, WF, C, NF, ADDRESS_SPACE, tkl.f16]
out_type = tkl.Memory[N, H_OUT, W_OUT, NF, GLOBAL_ADDRESS_SPACE, tkl.f32]
x = torch.permute(x, (0, 2, 3, 1)).contiguous()
we = torch.permute(we, (2, 3, 1, 0)).contiguous()
out = torch.permute(out, (0, 2, 3, 1)).contiguous()

# Expose user-constraints
constraints: list[tkw.Constraint] = []
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(NF, BLOCK_N)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
)
]

@tkw.wave(constraints)
def conv(
x: x_type,
we: we_type,
out: out_type,
):
c_reg = tkl.Register[M, NF, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
a_reg = tkw.read(
x,
mapping=x_mapping,
elements_per_thread=ELEMS_PER_THREAD,
)
b_reg = tkw.read(
we,
mapping=w_mapping,
elements_per_thread=ELEMS_PER_THREAD,
)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(
repeat, out, mapping=out_mapping, elements_per_thread=ELEMS_PER_THREAD
)

with tk.gen.TestLaunchContext(
{
N: n,
C: c,
W: w,
H: h,
NF: nf,
WF: wf,
HF: hf,
BLOCK_M: 16,
BLOCK_N: 16,
ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
},
canonicalize=True,
):
print(conv(x, we, out).module_op)
# CHECK: func @conv
# CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index

# Check we are setting gather start indices to 0
# CHECK: %{{.*}} = vector.gather %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] [%{{.*}}], %{{.*}}, %{{.*}} : memref<2x64x64x640xf16
# CHECK: %{{.*}} = vector.gather %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] [%{{.*}}], %{{.*}}, %{{.*}} : memref<3x3x640x640xf16


@run_test
def test_add_float():
constraints: list[tkw.Constraint] = [
Expand All @@ -1020,6 +1159,7 @@ def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
with codegen_test_context():
a = torch.randn(16, 16, dtype=torch.float16)
print(test(a).module_op)
# CHECK: func @test
# CHECK: %[[SLICE:.+]] = vector.load
# CHECK: arith.addf %[[SLICE]], %[[SLICE]] : vector<16xf16>

Expand Down

0 comments on commit 1aa05e0

Please sign in to comment.