From bef788ab04ea989f2a19fc61c1f96f53c13a0c7f Mon Sep 17 00:00:00 2001 From: George Petterson Date: Tue, 13 Jun 2023 20:59:08 -0400 Subject: [PATCH] Add attention op insertion code --- shark/shark_importer.py | 48 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/shark/shark_importer.py b/shark/shark_importer.py index e12f7c0922..547dca4bfe 100644 --- a/shark/shark_importer.py +++ b/shark/shark_importer.py @@ -312,6 +312,51 @@ def get_f16_inputs(inputs, is_f16, f16_input_mask): return tuple(f16_masked_inputs) +def insert_attention_block(fx_g): + import torch + + unary_ops = [ + torch.ops.aten._unsafe_view, + torch.ops.aten.view, + torch.ops.aten.expand, + torch.ops.aten.clone, + ] + + def traverse(node): + while node.target in unary_ops: + node = node.args[0] + return node + + for node in fx_g.graph.nodes: + if node.target in [torch.ops.aten.bmm]: + outer_bmm = node + node = traverse(outer_bmm.args[0]) + if node.target in [torch.ops.aten._softmax]: + softmax_node = node + node = traverse(softmax_node.args[0]) + if node.target in [torch.ops.aten.bmm]: + inner_bmm = node + value = outer_bmm.args[1] + key = inner_bmm.args[1] + with fx_g.graph.inserting_before(outer_bmm): + key = fx_g.graph.call_function( + torch.ops.aten.transpose, + args=(key, -2, -1), + kwargs={}, + ) + query = inner_bmm.args[0] + new_node = fx_g.graph.call_function( + torch.ops.aten.scaled_dot_product_attention, + args=(query, key, value), + kwargs={}, + ) + outer_bmm.append(new_node) + outer_bmm.append(key) + outer_bmm.replace_all_uses_with(new_node) + + fx_g.graph.lint() + + # Upcasts the block/list of ops. def add_upcast(fx_g): import torch @@ -548,6 +593,8 @@ def strip_overloads(gm): strip_overloads(fx_g) + insert_attention_block(fx_g) + if is_f16: fx_g = fx_g.half() transform_fx(fx_g) @@ -567,6 +614,7 @@ def strip_overloads(gm): return ts_graph inputs = get_f16_inputs(inputs, is_f16, f16_input_mask) + mlir_importer = SharkImporter( ts_graph, inputs,