diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 74c253db..ca3d4aa5 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -1160,7 +1160,7 @@ def indexing_dims(self) -> list[IndexSymbol]: return get_custom(self.arg).indexing_dims @property - def type(self) -> Memory: + def type(self) -> Register: src_type = get_custom(self.arg).type return src_type diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index acab88fc..f7d83244 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -1220,6 +1220,7 @@ def handle_cast(emitter: WaveEmitter, node: fx.Node): raise ValidationError("Malformed arguments") from e vector_src = cast_vector(emitter, register) src_vector_type = vector_src.type + src_elem_type = src_vector_type.element_type dst_elem_type = IrType.parse(dtype.ir_type_asm()) dst_vector_type = VectorType.get(src_vector_type.shape, dst_elem_type) @@ -1227,12 +1228,12 @@ def handle_cast(emitter: WaveEmitter, node: fx.Node): emitter.bind_node_proxy(node, vector_src) return - is_src_float = _is_float_type(src_vector_type.element_type) + is_src_float = _is_float_type(src_elem_type) is_dst_float = _is_float_type(dst_elem_type) + is_src_int = _is_integer_like_type(src_elem_type) + is_dst_int = _is_integer_like_type(dst_elem_type) conversion_ops = { - (True, True): lambda _, x: x, - (False, False): lambda _, x: x, (True, False): arith_d.fptosi, (False, True): arith_d.sitofp, } @@ -1244,20 +1245,16 @@ def handle_cast(emitter: WaveEmitter, node: fx.Node): (False, False): arith_d.trunci, } - dtype = ( - get_float_type(dst_elem_type.width) - if is_dst_float - else IntegerType.get_signless(dst_elem_type.width) - ) - converted_vector = conversion_ops[(is_src_float, is_dst_float)]( - VectorType.get(src_vector_type.shape, dtype), vector_src - ) - - casted_vector = cast_ops[ - ( - src_vector_type.element_type.width < dst_elem_type.width, - is_dst_float and is_src_float, + if (is_src_float and is_dst_float) or (is_src_int and is_dst_int): + casted_vector = cast_ops[ + ( + src_vector_type.element_type.width < dst_elem_type.width, + is_dst_float and is_src_float, + ) + ](dst_vector_type, vector_src) + else: + casted_vector = conversion_ops[(is_src_float, is_dst_float)]( + dst_vector_type, vector_src ) - ](dst_vector_type, converted_vector) emitter.bind_node_proxy(node, IRProxyValue(casted_vector)) diff --git a/lit_tests/kernel/wave/casting.py b/lit_tests/kernel/wave/casting.py index ec2c5ea4..6b88d83a 100644 --- a/lit_tests/kernel/wave/casting.py +++ b/lit_tests/kernel/wave/casting.py @@ -63,7 +63,8 @@ def test( ): a_reg = tkw.read(a, elements_per_thread=16) a_reg = tkw.cast(a_reg, tkl.f32) - a_reg = tkw.cast(a_reg, tkl.i32) + a_reg = tkw.cast(a_reg, tkl.i8) + a_reg = tkw.cast(a_reg, tkl.f16) a_reg = tkw.cast(a_reg, tkl.i16) a_reg = tkw.cast(a_reg, tkl.i32) a_reg = tkw.cast(a_reg, tkl.f32) @@ -76,8 +77,9 @@ def test( print(test(a, b).module_op) # CHECK: %[[D0:.*]] = arith.extf {{.*}} : vector<16xf16> to vector<16xf32> - # CHECK: %[[D1:.*]] = arith.fptosi %[[D0]] : vector<16xf32> to vector<16xi32> - # CHECK: %[[D2:.*]] = arith.trunci %[[D1]] : vector<16xi32> to vector<16xi16> - # CHECK: %[[D3:.*]] = arith.extsi %[[D2]] : vector<16xi16> to vector<16xi32> - # CHECK: %[[D4:.*]] = arith.sitofp %[[D3]] : vector<16xi32> to vector<16xf32> - # CHECK: %[[D5:.*]] = arith.truncf %[[D4]] : vector<16xf32> to vector<16xf16> + # CHECK: %[[D1:.*]] = arith.fptosi %[[D0]] : vector<16xf32> to vector<16xi8> + # CHECK: %[[D2:.*]] = arith.sitofp %[[D1]] : vector<16xi8> to vector<16xf16> + # CHECK: %[[D3:.*]] = arith.fptosi %[[D2]] : vector<16xf16> to vector<16xi16> + # CHECK: %[[D4:.*]] = arith.extsi %[[D3]] : vector<16xi16> to vector<16xi32> + # CHECK: %[[D5:.*]] = arith.sitofp %[[D4]] : vector<16xi32> to vector<16xf32> + # CHECK: %[[D6:.*]] = arith.truncf %[[D5]] : vector<16xf32> to vector<16xf16> diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 36f518a3..40c0a6f3 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -901,7 +901,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: @require_e2e -@pytest.mark.parametrize("shape", [256, 64]) +@pytest.mark.parametrize("shape", [(256, 64)]) def test_cast(shape, request): run_bench = request.config.getoption("--runperf") M = tkl.sym.M