From 98c52e3aca8fd41f3fe5a29e68abfb148764be10 Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Tue, 29 Oct 2024 11:15:59 -0700 Subject: [PATCH] Address comments Signed-off-by: Harsh Menon --- iree/turbine/kernel/wave/codegen.py | 34 ++++++++-------- .../kernel/wave/index_sequence_analysis.py | 13 ------ iree/turbine/kernel/wave/utils.py | 40 ++++++++++++------- 3 files changed, 43 insertions(+), 44 deletions(-) diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 7a4e5091..5c02872d 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -1324,23 +1324,23 @@ def handle_reshape(emitter: WaveEmitter, node: fx.Node): offset = custom.expanded_dims[innermost_dim] # Determine whether to extract or combine. - if len(args) == 1: - # Extract the appropriate slice. - size = ( - target_vector_shapes[innermost_dim] // custom.vector_shapes[innermost_dim] - ) - vector = cast_vector(emitter, args[0]) - result_type = VectorType.get([size], vector.type.element_type) - slice = vector_d.extract_strided_slice( - result_type, - vector, - [offset * size], - [size], - [1], + if len(args) > 1: + raise NotImplementedError( + "reshape: Currently only handles cases where target_vector_shapes > custom.vector_shapes" ) - emitter.bind_node_proxy(node, IRProxyValue(slice)) - return - raise NotImplementedError( - "reshape: Currently only handles cases where target_vector_shapes > custom.vector_shapes" + # Extract the appropriate slice. The offset is obtained from the expanded_dim + # and so corresponds to the dim_query during expansion. To obtain the + # actual offset, we need to multiple by the size which is determined by comparing + # the source and target vector shapes along the innermost dimension. + size = target_vector_shapes[innermost_dim] // custom.vector_shapes[innermost_dim] + vector = cast_vector(emitter, args[0]) + result_type = VectorType.get([size], vector.type.element_type) + slice = vector_d.extract_strided_slice( + result_type, + vector, + [offset * size], + [size], + [1], ) + emitter.bind_node_proxy(node, IRProxyValue(slice)) diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 9923f873..f79e10f2 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -185,19 +185,6 @@ def is_contiguous_dim( return is_innermost_dim or all_unit_dims -def add_reshape(custom: CustomOp): - for arg in custom.node_args.values(): - if not isinstance(arg, Sequence): - arg = [arg] - for subarg in arg: - # These are ops in the parent graph that have not had their vector shapes set yet. - if subarg.vector_shapes is None: - continue - if subarg.vector_shapes != custom.vector_shapes: - with custom.graph.inserting_before(custom.fx_node): - Reshape(subarg, custom.vector_shapes).add_to_graph(custom.graph) - - def set_vector_shapes( constraints: Sequence[Constraint], mma_index: dict[MMA, dict[IndexSymbol, int]], diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index 3858c7d3..4ecd66a5 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -228,7 +228,6 @@ def is_mma(node): mapping: dict[MMA, dict[IndexSymbol, int]] = {} mma_nodes = trace.walk(is_mma) - last_vector_shapes = {} for node in mma_nodes: custom: MMA = get_custom(node) m, n = custom.acc_type.symbolic_shape[-2:] @@ -250,19 +249,6 @@ def is_mma(node): custom.vector_shapes.update(hardware_constraint.vector_shapes) custom.anchor = custom custom.reduction_dim = k - if last_vector_shapes and last_vector_shapes != custom.vector_shapes: - with custom.graph.inserting_before(custom.fx_node): - for i, arg in custom.node_args.items(): - if is_reshape_needed(arg, custom.vector_shapes, last_vector_shapes): - reshape = Reshape(arg.fx_node, last_vector_shapes).add_to_graph( - custom.graph - ) - custom_reshape = get_custom(reshape) - custom_reshape.vector_shapes = custom.vector_shapes - custom_reshape.anchor = custom - custom.update_arg(i, reshape) - - last_vector_shapes = custom.vector_shapes # Since expansion proceeds bottom-up, we set the vector shapes # of the parent reduction to the vector shapes of the last MMA node. @@ -272,6 +258,32 @@ def is_mma(node): reduction.anchor = custom mma_slices = {get_custom(x): capture_mma_slices(get_custom(x)) for x in mma_nodes} + + # Determine if any reshapes are required. Reshapes are added for + # chained matmuls when the vector shapes of the operands in one matmul + # differ from those in another matmul. + for src in mma_nodes: + custom_src = get_custom(src) + for dst in mma_nodes: + if src == dst: + continue + custom_dst = get_custom(dst) + lhs_slice = capture_backward_slice(custom_dst.lhs) + rhs_slice = capture_backward_slice(custom_dst.rhs) + if src in lhs_slice or src in rhs_slice: + with custom_dst.graph.inserting_before(dst): + for i, arg in custom_dst.node_args.items(): + if is_reshape_needed( + arg, custom_dst.vector_shapes, custom_src.vector_shapes + ): + reshape = Reshape( + arg.fx_node, custom_src.vector_shapes + ).add_to_graph(custom.graph) + custom_reshape = get_custom(reshape) + custom_reshape.vector_shapes = custom.vector_shapes + custom_reshape.anchor = custom + custom.update_arg(i, reshape) + return mapping, mma_slices