-
Notifications
You must be signed in to change notification settings - Fork 12
/
onnx_model_gpt2.py
80 lines (64 loc) · 3.53 KB
/
onnx_model_gpt2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#--------------------------------------------------------------------------
import logging
import onnx
from onnx_model_bert import BertOnnxModel
from fusion_gpt_attention_no_past import FusionGptAttentionNoPast
from fusion_gpt_attention import FusionGptAttention
from fusion_gpt_attention_megatron import FusionGptAttentionMegatron
logger = logging.getLogger(__name__)
class Gpt2OnnxModel(BertOnnxModel):
def __init__(self, model, num_heads, hidden_size):
super().__init__(model, num_heads, hidden_size)
def fuse_attention(self):
if len(self.model.graph.input) == 1 or len(self.model.graph.output) == 1:
fusion = FusionGptAttentionNoPast(self, self.num_heads)
fusion.apply()
else:
fusion = FusionGptAttention(self, self.num_heads)
fusion.apply()
fusion = FusionGptAttentionMegatron(self, self.num_heads)
fusion.apply()
def postprocess(self):
"""
Remove extra reshape nodes.
"""
logger.debug(f"start postprocessing...")
input_name_to_nodes = self.input_name_to_nodes()
output_name_to_node = self.output_name_to_node()
reshape_count = 0
for gemm_node in self.get_nodes_by_op_type("Gemm"):
reshape_after_gemm = self.find_first_child_by_type(gemm_node,
'Reshape',
input_name_to_nodes,
recursive=False)
return_indice = []
nodes = self.match_parent_path(gemm_node, ['Reshape', 'FastGelu'], [0, 0], output_name_to_node)
if nodes is None:
nodes = self.match_parent_path(gemm_node, ['Reshape', 'LayerNormalization'], [0, 0],
output_name_to_node)
if nodes is None:
continue
(reshape_before_gemm, root_node) = nodes
matmul_node_name = self.create_node_name('MatMul', 'FullyConnect_MatMul')
matmul_node = onnx.helper.make_node('MatMul',
inputs=[matmul_node_name + "_input", gemm_node.input[1]],
outputs=[matmul_node_name + "_output"],
name=matmul_node_name)
add_node_name = self.create_node_name('Add', 'FullyConnect_Add')
add_node = onnx.helper.make_node('Add',
inputs=[matmul_node_name + "_output", gemm_node.input[2]],
outputs=[add_node_name + "_output"],
name=add_node_name)
self.replace_input_of_all_nodes(reshape_after_gemm.output[0], add_node_name + "_output")
# Link root node output with MatMul
self.replace_input_of_all_nodes(root_node.output[0], matmul_node_name + "_input")
root_node.output[0] = matmul_node_name + "_input"
self.replace_input_of_all_nodes(reshape_after_gemm.output[0], add_node_name + "_output")
self.add_node(matmul_node)
self.add_node(add_node)
reshape_count += 2
self.prune_graph()
logger.info(f"postprocess: remove Reshape count:{reshape_count}")