Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add op to cast between dtypes #223

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def shuffle(src: "Register", offset: int, width: int) -> "Register":
...


def cast(src: "Register", dtype: DataType) -> "Register":
...


def define_op(op_name: str) -> Callable[[T], T]:
def decorator(cls: T) -> T:
cls.tkw_op_name = op_name
Expand Down Expand Up @@ -1159,3 +1163,23 @@ def indexing_dims(self) -> list[IndexSymbol]:
def type(self) -> Memory:
src_type = get_custom(self.arg).type
return src_type


@define_op("cast")
@dataclass
class CastOp(CustomOp, ABC):
"""
Represents a cast operation.
"""

arg: fx.Node
dtype: DataType

@property
def indexing_dims(self) -> list[IndexSymbol]:
return get_custom(self.arg).indexing_dims

@property
def type(self) -> Memory:
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
src_shape = get_custom(self.arg).type.symbolic_shape
return Register[*src_shape, self.dtype]
63 changes: 63 additions & 0 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Attribute,
DenseElementsAttr,
FloatAttr,
F16Type,
F32Type,
IndexType,
InsertionPoint,
Expand Down Expand Up @@ -63,6 +64,7 @@
CustomOp,
scheduling_barrier,
scheduling_group_barrier,
cast,
)
from ..lang.wave_types import IndexMapping, IndexSymbol
from ..compiler.base import CodegenError, ValidationError, NDEBUG
Expand Down Expand Up @@ -1198,3 +1200,64 @@ def handle_get_result(emitter: WaveEmitter, node: fx.Node):
@handle_op(operator.getitem)
def handle_getitem(emitter: WaveEmitter, node: fx.Node):
raise NotImplementedError("getitem: Currently only stub implementation")


def get_float_type(bitwidth: int):
match bitwidth:
case 16:
return F16Type.get()
case 32:
return F32Type.get()
case _:
raise NotImplementedError(f"Unsupported float bitwidth: {bitwidth}")


@handle_op(cast)
def handle_cast(emitter: WaveEmitter, node: fx.Node):
try:
register, dtype = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
vector_src = cast_vector(emitter, register)
src_vector_type = vector_src.type
dst_elem_type = IrType.parse(dtype.ir_type_asm())
dst_vector_type = VectorType.get(src_vector_type.shape, dst_elem_type)

if src_vector_type == dst_vector_type:
emitter.bind_node_proxy(node, vector_src)
return

is_src_float = _is_float_type(src_vector_type.element_type)
is_dst_float = _is_float_type(dst_elem_type)

conversion_ops = {
(True, True): lambda _, x: x,
(False, False): lambda _, x: x,
(True, False): arith_d.fptosi,
(False, True): arith_d.sitofp,
}

cast_ops = {
(True, True): arith_d.extf,
(True, False): arith_d.extsi,
(False, True): arith_d.truncf,
(False, False): arith_d.trunci,
}

dtype = (
get_float_type(dst_elem_type.width)
raikonenfnu marked this conversation as resolved.
Show resolved Hide resolved
if is_dst_float
else IntegerType.get_signless(dst_elem_type.width)
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
)
converted_vector = conversion_ops[(is_src_float, is_dst_float)](
VectorType.get(src_vector_type.shape, dtype), vector_src
raikonenfnu marked this conversation as resolved.
Show resolved Hide resolved
)

casted_vector = cast_ops[
(
src_vector_type.element_type.width < dst_elem_type.width,
is_dst_float and is_src_float,
)
](dst_vector_type, converted_vector)

emitter.bind_node_proxy(node, IRProxyValue(casted_vector))
83 changes: 83 additions & 0 deletions lit_tests/kernel/wave/casting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# RUN: python %s | FileCheck %s

import pytest
from typing import Callable
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel.wave.utils import run_test
import torch

M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K
B = tkl.sym.B
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
BLOCK_B = tkl.sym.BLOCK_B
LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEM_PER_THREAD
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEM_PER_THREAD
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0


def codegen_test_context(canonicalize: bool = False, dynamic_symbols=[]):
bindings = {
M: 16,
N: 16,
K: 16,
BLOCK_M: 16,
BLOCK_N: 16,
BLOCK_K: 16,
ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value,
}

# Remove dynamic symbols from the bindings.
for sym in dynamic_symbols:
if sym in bindings:
del bindings[sym]

return tk.gen.TestLaunchContext(
bindings, canonicalize=canonicalize, dynamic_symbols=dynamic_symbols
)


@run_test
def test_cast():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16}
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
a_reg = tkw.read(a, elements_per_thread=16)
a_reg = tkw.cast(a_reg, tkl.f32)
a_reg = tkw.cast(a_reg, tkl.i32)
a_reg = tkw.cast(a_reg, tkl.i16)
a_reg = tkw.cast(a_reg, tkl.i32)
a_reg = tkw.cast(a_reg, tkl.f32)
a_reg = tkw.cast(a_reg, tkl.f16)
tkw.write(a_reg, b, elements_per_thread=16)

with codegen_test_context(canonicalize=True):
a = torch.randn(16, 16, dtype=torch.float16)
b = torch.zeros(16, 16, dtype=torch.float16)
print(test(a, b).module_op)

# CHECK: %[[D0:.*]] = arith.extf {{.*}} : vector<16xf16> to vector<16xf32>
# CHECK: %[[D1:.*]] = arith.fptosi %[[D0]] : vector<16xf32> to vector<16xi32>
# CHECK: %[[D2:.*]] = arith.trunci %[[D1]] : vector<16xi32> to vector<16xi16>
# CHECK: %[[D3:.*]] = arith.extsi %[[D2]] : vector<16xi16> to vector<16xi32>
# CHECK: %[[D4:.*]] = arith.sitofp %[[D3]] : vector<16xi32> to vector<16xf32>
# CHECK: %[[D5:.*]] = arith.truncf %[[D4]] : vector<16xf32> to vector<16xf16>
58 changes: 58 additions & 0 deletions tests/kernel/wave/wave_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,3 +898,61 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
stride=stride,
run_bench=True,
)


@require_e2e
@pytest.mark.parametrize("shape", [256, 64])
def test_cast(shape, request):
run_bench = request.config.getoption("--runperf")
M = tkl.sym.M
N = tkl.sym.N
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE

# Each workgroup works on single row of input data, and rows are further
# split into blocks of size up to 256. We have single wave per WG,
# and with default wave size of 64, each thread is operating on up to 4
# elements.
wave_size = 64
BLOCK_M = 1
# Tile size cannot be dynamic, so we use a fixed value here.
BLOCK_N = sympy.Max(sympy.Min(shape[1], 256), wave_size)
ELEMS_PER_THREAD = BLOCK_N / wave_size

constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=wave_size,
waves_per_block=(1, 1, 1),
vector_shapes={M: BLOCK_M, N: BLOCK_N},
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
res = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD)
res = tkw.cast(res, tkl.f16)
tkw.write(res, b, elements_per_thread=ELEMS_PER_THREAD)

config = {"backend": "rocm", "device": "hip", "target": "gfx942"}

a = torch.randn(shape, dtype=torch.float32)
b = torch.zeros(shape, dtype=torch.float16)
with tk.gen.TestLaunchContext(
{
M: shape[0],
N: shape[1],
ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value,
},
canonicalize=True,
run=True,
run_bench=run_bench,
run_config=config,
):
test(a, b)
assert_allclose(a.to(dtype=torch.float16), b)
Loading