-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_vgg19.py
127 lines (116 loc) · 9.13 KB
/
model_vgg19.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Modified by Dacrol
import find_mxnet
import mxnet as mx
import os, sys
from collections import namedtuple
ConvExecutor = namedtuple('ConvExecutor', ['executor', 'data', 'data_grad', 'style', 'content', 'arg_dict'])
def get_symbol():
# declare symbol
data = mx.sym.Variable("data")
# conv1_1 = mx.symbol.Convolution(name='conv1_1', data=data, num_filter=64, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
# relu1_1 = mx.symbol.Activation(name='relu1_1', data=conv1_1 , act_type='relu')
# conv1_2 = mx.symbol.Convolution(name='conv1_2', data=relu1_1 , num_filter=64, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
# relu1_2 = mx.symbol.Activation(name='relu1_2', data=conv1_2 , act_type='relu')
# pool1 = mx.symbol.Pooling(name='pool1', data=relu1_2 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='avg')
# conv2_1 = mx.symbol.Convolution(name='conv2_1', data=pool1 , num_filter=128, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
# relu2_1 = mx.symbol.Activation(name='relu2_1', data=conv2_1 , act_type='relu')
# conv2_2 = mx.symbol.Convolution(name='conv2_2', data=relu2_1 , num_filter=128, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
# relu2_2 = mx.symbol.Activation(name='relu2_2', data=conv2_2 , act_type='relu')
# pool2 = mx.symbol.Pooling(name='pool2', data=relu2_2 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='avg')
# conv3_1 = mx.symbol.Convolution(name='conv3_1', data=pool2 , num_filter=256, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
# relu3_1 = mx.symbol.Activation(name='relu3_1', data=conv3_1 , act_type='relu')
# conv3_2 = mx.symbol.Convolution(name='conv3_2', data=relu3_1 , num_filter=256, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
# relu3_2 = mx.symbol.Activation(name='relu3_2', data=conv3_2 , act_type='relu')
# conv3_3 = mx.symbol.Convolution(name='conv3_3', data=relu3_2 , num_filter=256, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
# relu3_3 = mx.symbol.Activation(name='relu3_3', data=conv3_3 , act_type='relu')
# conv3_4 = mx.symbol.Convolution(name='conv3_4', data=relu3_3 , num_filter=256, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
# relu3_4 = mx.symbol.Activation(name='relu3_4', data=conv3_4 , act_type='relu')
# pool3 = mx.symbol.Pooling(name='pool3', data=relu3_4 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='avg')
# conv4_1 = mx.symbol.Convolution(name='conv4_1', data=pool3 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
# relu4_1 = mx.symbol.Activation(name='relu4_1', data=conv4_1 , act_type='relu')
# conv4_2 = mx.symbol.Convolution(name='conv4_2', data=relu4_1 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
# relu4_2 = mx.symbol.Activation(name='relu4_2', data=conv4_2 , act_type='relu')
# conv4_3 = mx.symbol.Convolution(name='conv4_3', data=relu4_2 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
# relu4_3 = mx.symbol.Activation(name='relu4_3', data=conv4_3 , act_type='relu')
# conv4_4 = mx.symbol.Convolution(name='conv4_4', data=relu4_3 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
# relu4_4 = mx.symbol.Activation(name='relu4_4', data=conv4_4 , act_type='relu')
# pool4 = mx.symbol.Pooling(name='pool4', data=relu4_4 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='avg')
# conv5_1 = mx.symbol.Convolution(name='conv5_1', data=pool4 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
# relu5_1 = mx.symbol.Activation(name='relu5_1', data=conv5_1 , act_type='relu')
conv1_1 = mx.symbol.Convolution(name='conv1_1', data=data, num_filter=64, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=64)
relu1_1 = mx.symbol.Activation(name='relu1_1', data=conv1_1 , act_type='relu')
conv1_2 = mx.symbol.Convolution(name='conv1_2', data=relu1_1 , num_filter=64, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=64)
relu1_2 = mx.symbol.Activation(name='relu1_2', data=conv1_2 , act_type='relu')
pool1 = mx.symbol.Pooling(name='pool1', data=relu1_2 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='avg')
conv2_1 = mx.symbol.Convolution(name='conv2_1', data=pool1 , num_filter=128, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=64)
relu2_1 = mx.symbol.Activation(name='relu2_1', data=conv2_1 , act_type='relu')
conv2_2 = mx.symbol.Convolution(name='conv2_2', data=relu2_1 , num_filter=128, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=64)
relu2_2 = mx.symbol.Activation(name='relu2_2', data=conv2_2 , act_type='relu')
pool2 = mx.symbol.Pooling(name='pool2', data=relu2_2 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='avg')
conv3_1 = mx.symbol.Convolution(name='conv3_1', data=pool2 , num_filter=256, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=64)
relu3_1 = mx.symbol.Activation(name='relu3_1', data=conv3_1 , act_type='relu')
conv3_2 = mx.symbol.Convolution(name='conv3_2', data=relu3_1 , num_filter=256, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=64)
relu3_2 = mx.symbol.Activation(name='relu3_2', data=conv3_2 , act_type='relu')
conv3_3 = mx.symbol.Convolution(name='conv3_3', data=relu3_2 , num_filter=256, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=64)
relu3_3 = mx.symbol.Activation(name='relu3_3', data=conv3_3 , act_type='relu')
conv3_4 = mx.symbol.Convolution(name='conv3_4', data=relu3_3 , num_filter=256, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=64)
relu3_4 = mx.symbol.Activation(name='relu3_4', data=conv3_4 , act_type='relu')
pool3 = mx.symbol.Pooling(name='pool3', data=relu3_4 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='avg')
conv4_1 = mx.symbol.Convolution(name='conv4_1', data=pool3 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=64)
relu4_1 = mx.symbol.Activation(name='relu4_1', data=conv4_1 , act_type='relu')
conv4_2 = mx.symbol.Convolution(name='conv4_2', data=relu4_1 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=64)
relu4_2 = mx.symbol.Activation(name='relu4_2', data=conv4_2 , act_type='relu')
conv4_3 = mx.symbol.Convolution(name='conv4_3', data=relu4_2 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=64)
relu4_3 = mx.symbol.Activation(name='relu4_3', data=conv4_3 , act_type='relu')
conv4_4 = mx.symbol.Convolution(name='conv4_4', data=relu4_3 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=64)
relu4_4 = mx.symbol.Activation(name='relu4_4', data=conv4_4 , act_type='relu')
pool4 = mx.symbol.Pooling(name='pool4', data=relu4_4 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='avg')
conv5_1 = mx.symbol.Convolution(name='conv5_1', data=pool4 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=64)
relu5_1 = mx.symbol.Activation(name='relu5_1', data=conv5_1 , act_type='relu')
# style and content layers
style = mx.sym.Group([relu1_1, relu2_1, relu3_1, relu4_1, relu5_1])
content = mx.sym.Group([relu4_2])
return style, content
def get_executor(style, content, input_size, ctx):
out = mx.sym.Group([style, content])
# make executor
arg_shapes, output_shapes, aux_shapes = out.infer_shape(data=(1, 3, input_size[0], input_size[1]))
arg_names = out.list_arguments()
arg_dict = dict(zip(arg_names, [mx.nd.zeros(shape, ctx=ctx) for shape in arg_shapes]))
grad_dict = {"data": arg_dict["data"].copyto(ctx)}
# init with pretrained weight
pretrained = mx.nd.load("./model/vgg19.params")
for name in arg_names:
if name == "data":
continue
key = "arg:" + name
if key in pretrained:
pretrained[key].copyto(arg_dict[name])
else:
print("Skip argument %s" % name)
executor = out.bind(ctx=ctx, args=arg_dict, args_grad=grad_dict, grad_req="write")
return ConvExecutor(executor=executor,
data=arg_dict["data"],
data_grad=grad_dict["data"],
style=executor.outputs[:-1],
content=executor.outputs[-1],
arg_dict=arg_dict)
def get_model(input_size, ctx):
style, content = get_symbol()
return get_executor(style, content, input_size, ctx)