-
Notifications
You must be signed in to change notification settings - Fork 12
/
fusion_reshape.py
142 lines (119 loc) · 6.06 KB
/
fusion_reshape.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#--------------------------------------------------------------------------
from fusion_base import Fusion
from logging import getLogger
import numpy as np
from onnx import helper, numpy_helper, TensorProto
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionReshape(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "Reshape", "Reshape")
def replace_reshape_node(self, shape, reshape_node, concat_node):
shape_value = np.asarray(shape, dtype=np.int64)
constant_shape_name = self.model.create_node_name('Constant', 'constant_shape')
new_node = helper.make_node('Constant',
inputs=[],
outputs=[constant_shape_name],
value=helper.make_tensor(name='const_tensor',
data_type=TensorProto.INT64,
dims=shape_value.shape,
vals=bytes(shape_value),
raw=True))
reshape_node.input[1] = constant_shape_name
reshape_node.name = self.model.create_node_name('Reshape', 'Reshape_Fuse')
self.nodes_to_remove.extend([concat_node])
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node):
if reshape_node.input[1] not in output_name_to_node:
return
concat_node = output_name_to_node[reshape_node.input[1]]
if concat_node.op_type != 'Concat' or len(concat_node.input) < 3 or len(concat_node.input) > 4:
return
path0 = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Gather', 'Shape'], [0, 0, 0],
output_name_to_node)
if path0 is None:
return
(unsqueeze_0, gather_0, shape_0) = path0
path1 = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Gather', 'Shape'], [1, 0, 0],
output_name_to_node)
if path1 is None:
return
(unsqueeze_1, gather_1, shape_1) = path1
shape = []
gather_value = self.model.get_constant_value(gather_0.input[1])
if gather_value == 0:
shape.append(0)
gather_value = self.model.get_constant_value(gather_1.input[1])
if gather_value == 1:
shape.append(0)
if len(shape) != 2:
return
path2 = []
path3 = []
shape_nodes = [shape_0, shape_1]
if len(concat_node.input) == 3 and self.model.get_initializer(concat_node.input[2]) is None:
path2 = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Mul', 'Gather', 'Shape'], [2, 0, 0, 0],
output_name_to_node)
if path2 is None:
path2 = self.model.match_parent_path(
concat_node, ['Unsqueeze', 'Mul', 'Squeeze', 'Slice', 'Shape'], [2, 0, 0, 0, 0],
output_name_to_node) # GPT2 exported by PyTorch 1.4 with opset_version=11
if path2 is None:
return
path3 = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Mul', 'Gather', 'Shape'], [2, 0, 1, 0],
output_name_to_node)
if path3 is None:
path3 = self.model.match_parent_path(
concat_node, ['Unsqueeze', 'Mul', 'Squeeze', 'Slice', 'Shape'], [2, 0, 1, 0, 0],
output_name_to_node) # GPT2 exported by PyTorch 1.4 with opset_version=11
if path3 is None:
return
shape_nodes.extend([path2[-1], path3[-1]])
shape.append(-1)
elif (len(concat_node.input) > 2):
concat_2 = self.model.get_initializer(concat_node.input[2])
if concat_2 is None:
return
concat_value = numpy_helper.to_array(concat_2)
if isinstance(concat_value, list):
shape.extend(concat_value)
else:
shape.append(concat_value)
if len(concat_node.input) == 4 and self.model.get_initializer(concat_node.input[3]) is None:
if -1 in shape:
return
path2 = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Div', 'Gather', 'Shape'], [3, 0, 0, 0],
output_name_to_node)
if path2 is None:
path2 = self.model.match_parent_path(
concat_node, ['Unsqueeze', 'Div', 'Squeeze', 'Slice', 'Shape'], [3, 0, 0, 0, 0],
output_name_to_node) # GPT2 exported by PyTorch 1.4 with opset_version=11
if path2 is None:
return
shape_nodes.extend([path2[-1]])
shape.append(-1)
elif (len(concat_node.input) > 3):
concat_3 = self.model.get_initializer(concat_node.input[3])
if concat_3 is None:
return
concat_value = numpy_helper.to_array(concat_3)
if isinstance(concat_value, list):
shape.extend(concat_value)
else:
shape.append(concat_value)
root_input = reshape_node.input[0]
same_shape_input = True
for shape_node in shape_nodes:
if shape_node.input[0] != root_input:
same_shape_input = False
if not same_shape_input:
return
self.replace_reshape_node(shape, reshape_node, concat_node)
self.nodes_to_remove.extend(path0)
self.nodes_to_remove.extend(path1)
self.nodes_to_remove.extend(path2)
self.nodes_to_remove.extend(path3)