Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh-nod committed Oct 17, 2024
1 parent 5e5f78f commit 65a459f
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 25 deletions.
2 changes: 1 addition & 1 deletion iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,7 @@ def indexing_dims(self) -> list[IndexSymbol]:
return get_custom(self.arg).indexing_dims

@property
def type(self) -> Memory:
def type(self) -> Register:
src_type = get_custom(self.arg).type
return src_type

Expand Down
31 changes: 14 additions & 17 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,19 +1220,20 @@ def handle_cast(emitter: WaveEmitter, node: fx.Node):
raise ValidationError("Malformed arguments") from e
vector_src = cast_vector(emitter, register)
src_vector_type = vector_src.type
src_elem_type = src_vector_type.element_type
dst_elem_type = IrType.parse(dtype.ir_type_asm())
dst_vector_type = VectorType.get(src_vector_type.shape, dst_elem_type)

if src_vector_type == dst_vector_type:
emitter.bind_node_proxy(node, vector_src)
return

is_src_float = _is_float_type(src_vector_type.element_type)
is_src_float = _is_float_type(src_elem_type)
is_dst_float = _is_float_type(dst_elem_type)
is_src_int = _is_integer_like_type(src_elem_type)
is_dst_int = _is_integer_like_type(dst_elem_type)

conversion_ops = {
(True, True): lambda _, x: x,
(False, False): lambda _, x: x,
(True, False): arith_d.fptosi,
(False, True): arith_d.sitofp,
}
Expand All @@ -1244,20 +1245,16 @@ def handle_cast(emitter: WaveEmitter, node: fx.Node):
(False, False): arith_d.trunci,
}

dtype = (
get_float_type(dst_elem_type.width)
if is_dst_float
else IntegerType.get_signless(dst_elem_type.width)
)
converted_vector = conversion_ops[(is_src_float, is_dst_float)](
VectorType.get(src_vector_type.shape, dtype), vector_src
)

casted_vector = cast_ops[
(
src_vector_type.element_type.width < dst_elem_type.width,
is_dst_float and is_src_float,
if (is_src_float and is_dst_float) or (is_src_int and is_dst_int):
casted_vector = cast_ops[
(
src_vector_type.element_type.width < dst_elem_type.width,
is_dst_float and is_src_float,
)
](dst_vector_type, vector_src)
else:
casted_vector = conversion_ops[(is_src_float, is_dst_float)](
dst_vector_type, vector_src
)
](dst_vector_type, converted_vector)

emitter.bind_node_proxy(node, IRProxyValue(casted_vector))
14 changes: 8 additions & 6 deletions lit_tests/kernel/wave/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def test(
):
a_reg = tkw.read(a, elements_per_thread=16)
a_reg = tkw.cast(a_reg, tkl.f32)
a_reg = tkw.cast(a_reg, tkl.i32)
a_reg = tkw.cast(a_reg, tkl.i8)
a_reg = tkw.cast(a_reg, tkl.f16)
a_reg = tkw.cast(a_reg, tkl.i16)
a_reg = tkw.cast(a_reg, tkl.i32)
a_reg = tkw.cast(a_reg, tkl.f32)
Expand All @@ -76,8 +77,9 @@ def test(
print(test(a, b).module_op)

# CHECK: %[[D0:.*]] = arith.extf {{.*}} : vector<16xf16> to vector<16xf32>
# CHECK: %[[D1:.*]] = arith.fptosi %[[D0]] : vector<16xf32> to vector<16xi32>
# CHECK: %[[D2:.*]] = arith.trunci %[[D1]] : vector<16xi32> to vector<16xi16>
# CHECK: %[[D3:.*]] = arith.extsi %[[D2]] : vector<16xi16> to vector<16xi32>
# CHECK: %[[D4:.*]] = arith.sitofp %[[D3]] : vector<16xi32> to vector<16xf32>
# CHECK: %[[D5:.*]] = arith.truncf %[[D4]] : vector<16xf32> to vector<16xf16>
# CHECK: %[[D1:.*]] = arith.fptosi %[[D0]] : vector<16xf32> to vector<16xi8>
# CHECK: %[[D2:.*]] = arith.sitofp %[[D1]] : vector<16xi8> to vector<16xf16>
# CHECK: %[[D3:.*]] = arith.fptosi %[[D2]] : vector<16xf16> to vector<16xi16>
# CHECK: %[[D4:.*]] = arith.extsi %[[D3]] : vector<16xi16> to vector<16xi32>
# CHECK: %[[D5:.*]] = arith.sitofp %[[D4]] : vector<16xi32> to vector<16xf32>
# CHECK: %[[D6:.*]] = arith.truncf %[[D5]] : vector<16xf32> to vector<16xf16>
2 changes: 1 addition & 1 deletion tests/kernel/wave/wave_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:


@require_e2e
@pytest.mark.parametrize("shape", [256, 64])
@pytest.mark.parametrize("shape", [(256, 64)])
def test_cast(shape, request):
run_bench = request.config.getoption("--runperf")
M = tkl.sym.M
Expand Down

0 comments on commit 65a459f

Please sign in to comment.