This repository has been archived by the owner on Sep 12, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
68 lines (53 loc) · 1.82 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, BatchNormalization
from tensorflow.keras.layers import ReLU, Input
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
import tensorflow as tf
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
def depth_to_space(inp, scale):
return tf.nn.depth_to_space(inp, scale)
def linear(x):
return tf.keras.activations.linear(x)
def relu6(x):
return ReLU(max_value=6.0)(x)
def _upscale(inp, inp_filter, scale):
x = depth_to_space(inp, scale)
x = Conv2D(inp_filter, (1, 1), strides=(1, 1), padding='same')(x)
return x
def block(inp, out_filters, exp_ratio):
channel = K.image_data_format()
if channel == 'channel_last':
channel_axis = -1
else:
channel_axis = 1
inp_channel = K.int_shape(inp)[channel_axis]
exp_filter = inp_channel * exp_ratio
x = Conv2D(exp_filter, (1, 1), padding='same')(inp)
x = BatchNormalization()(x)
x = relu6(x)
x = DepthwiseConv2D((3, 3), padding='same', strides=(2, 2))(x)
x = relu6(x)
x = Conv2D(out_filters, (1, 1), padding='same')(x)
x = linear(x)
return x
def bottleneck(inp, out_filt, t, n):
x = block(inp, out_filt, t)
for i in range(1, n):
x = block(x, out_filt, t)
return x
def nn(inp_shape):
inputs = Input(shape=inp_shape)
x = Conv2D(8, (3, 3), strides=(2, 2), padding='same', name='input_layer')(inputs)
x = bottleneck(x, 16, 1, 1)
x = bottleneck(x, 32, 6, 1)
x = bottleneck(x, 64, 6, 1)
x = _upscale(x, 32, 2)
x = _upscale(x, 48, 2)
x = depth_to_space(x, 4)
sr_model = Model(inputs, x)
sr_model.summary()
return sr_model