-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathutils.py
136 lines (102 loc) · 4.88 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import re
import sys
from .functions import *
import torch.fx
grayscale = torchvision.transforms.Grayscale(num_output_channels=1)
def convert_data_for_quaternion(batch):
"""
converts batches of RGB images in 4 channels for QNNs
"""
assert all(batch[i][0].size(0) == 3 for i in range(len(batch)))
inputs, labels = [], []
for i in range(len(batch)):
inputs.append(torch.cat([batch[i][0], grayscale(batch[i][0])], 0))
labels.append(batch[i][1])
return torch.stack(inputs), torch.LongTensor(labels)
# does not find an application yet
def apply_quaternion_gradient(model, layers):
"""
hooks real-valued gradients and transforms them into one for
quaternion gradient descent
@type model: nn.Module
"""
for n, ((_, layer), parameter) in enumerate(zip(model.named_children(), model.parameters())):
layer_name = re.match("^\w+", str(layer)).group()
if layer_name in layers and len(parameter.shape) > 1 and n != 1:
parameter.register_hook(to_conj)
return model
@torch.fx.wrap
def check_shapes(x):
if x.dim() in [3, 5]:
x = torch.cat([*x.chunk()], 2).squeeze()
return x
def convert_to_quaternion(Net, verbose=False, spinor=False):
"""
converts a real_valued initialized Network to a quaternion one
@type Net: nn.Module
@type verbose: bool
@type spinor: bool
"""
last_module = len([mod for mod in Net.children()])
layers = ["Linear", "Conv1d", "Conv2d", "Conv3d",
"ConvTranspose1d", "ConvTranspose2d", "ConvTranspose3d"]
for n, (name, layer) in enumerate(Net.named_children()):
layer_name = re.match("^\w+", str(layer)).group()
if n != last_module - 1:
if layer_name in layers[1:]:
params = re.findall("(?<!\w)\d+(?<=\w)", str(layer))
in_features, out_features, kernel_size, stride = \
int(params[0]), int(params[1]), (int(params[2]), int(params[3])), (int(params[4]), int(params[5]))
assert in_features % 4 == 0, "number of in_channels must be divisible by 4"
assert out_features % 4 == 0, "number of out_channels must be divisible by 4"
init_func = initialize_conv
args = (in_features // 4, out_features // 4, kernel_size)
elif layer_name == layers[0]:
params = re.findall("(?<==)\w+", str(layer))
in_features, out_features, bias = int(params[0]), int(params[1]), bool(params[2])
assert in_features % 4 == 0, "number of in_channels must be divisible by 4"
assert out_features % 4 == 0, "number of out_channels must be divisible by 4"
init_func = initialize_linear
args = (in_features // 4, out_features // 4)
else:
continue
quaternion_weight = init_func(*args)
if spinor:
weight = quaternion_weight._real_rot_repr
else:
weight = quaternion_weight._real_repr
getattr(Net, name).weight = nn.Parameter(weight)
if getattr(Net, name).bias != None:
getattr(Net, name).bias = nn.Parameter(torch.zeros(out_features))
traced = torch.fx.symbolic_trace(layer)
for node in traced.graph.nodes:
if node.op == 'placeholder':
with traced.graph.inserting_after(node):
new_node = traced.graph.call_function(
check_shapes, args=(node,))
if any(lay in node.name for lay in ["conv", "lin"]):
with traced.graph.inserting_before(node):
all_nodes = [node for node in traced.graph.nodes]
new_node = traced.graph.call_function(node.target,
(all_nodes[1], *node.args[1:]), node.kwargs)
node.replace_all_uses_with(new_node)
traced.graph.erase_node(node)
if node.op == 'output':
all_nodes = [node for node in traced.graph.nodes]
with traced.graph.inserting_before(node):
new_node = traced.graph.call_function(
Q, args=(node.prev,))
node.replace_all_uses_with(new_node)
traced.graph.erase_node(node)
with traced.graph.inserting_after(node):
new_node = traced.graph.output(node.prev, )
if verbose:
print("-" * 20, layer_name, "-" * 20, sep="\n")
print(torch.fx.GraphModule(layer, traced.graph))
traced.graph.lint()
setattr(Net, name, torch.fx.GraphModule(layer, traced.graph))
return Net