-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathWarpST.py
121 lines (101 loc) · 4.57 KB
/
WarpST.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import tensorflow as tf
from bicubic_interp import bicubic_interp_2d
def WarpST(U, V, out_size, name='DeformableTransformer', **kwargs):
def _repeat(x, n_repeats):
with tf.variable_scope('_repeat'):
rep = tf.transpose(
tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])), 1), [1, 0])
rep = tf.cast(rep, 'int32')
x = tf.matmul(tf.reshape(x, (-1, 1)), rep)
return tf.reshape(x, [-1])
def _interpolate(im, x, y, out_size):
with tf.variable_scope('_interpolate'):
num_batch = tf.shape(im)[0]
height = tf.shape(im)[1]
width = tf.shape(im)[2]
channels = tf.shape(im)[3]
x = tf.cast(x, 'float32')
y = tf.cast(y, 'float32')
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = out_size[0]
out_width = out_size[1]
zero = tf.zeros([], dtype='int32')
max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
x = (x + 1.0)*(width_f) / 2.0
y = (y + 1.0)*(height_f) / 2.0
x0 = tf.cast(tf.floor(x), 'int32')
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), 'int32')
y1 = y0 + 1
x0 = tf.clip_by_value(x0, zero, max_x)
x1 = tf.clip_by_value(x1, zero, max_x)
y0 = tf.clip_by_value(y0, zero, max_y)
y1 = tf.clip_by_value(y1, zero, max_y)
dim2 = width
dim1 = width*height
base = _repeat(tf.range(num_batch)*dim1, out_height*out_width)
base_y0 = base + y0*dim2
base_y1 = base + y1*dim2
idx_a = base_y0 + x0
idx_b = base_y1 + x0
idx_c = base_y0 + x1
idx_d = base_y1 + x1
im_flat = tf.reshape(im, tf.stack([-1, channels]))
im_flat = tf.cast(im_flat, 'float32')
Ia = tf.gather(im_flat, idx_a)
Ib = tf.gather(im_flat, idx_b)
Ic = tf.gather(im_flat, idx_c)
Id = tf.gather(im_flat, idx_d)
x0_f = tf.cast(x0, 'float32')
x1_f = tf.cast(x1, 'float32')
y0_f = tf.cast(y0, 'float32')
y1_f = tf.cast(y1, 'float32')
wa = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1)
wb = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1)
wc = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1)
wd = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1)
output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
return output
def _meshgrid(height, width):
with tf.variable_scope('_meshgrid'):
x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])),
tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0]))
y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1),
tf.ones(shape=tf.stack([1, width])))
x_t_flat = tf.reshape(x_t, (1, -1))
y_t_flat = tf.reshape(y_t, (1, -1))
grid = tf.concat([x_t_flat, y_t_flat], 0)
return grid
def _transform(V, U, out_size):
with tf.variable_scope('_transform'):
num_batch = tf.shape(U)[0]
height = tf.shape(U)[1]
width = tf.shape(U)[2]
num_channels = tf.shape(U)[3]
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = out_size[0]
out_width = out_size[1]
grid = _meshgrid(out_height, out_width)
grid = tf.reshape(grid, [-1])
grid = tf.tile(grid, tf.stack([num_batch]))
grid = tf.reshape(grid, tf.stack([num_batch, 2, -1]))
V = bicubic_interp_2d(V, out_size)
V = tf.transpose(V, [0, 3, 1, 2])
V = tf.reshape(V, [num_batch, 2, -1])
T_g = tf.add(V, grid)
x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])
y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
x_s_flat = tf.reshape(x_s, [-1])
y_s_flat = tf.reshape(y_s, [-1])
input_transformed = _interpolate(
U, x_s_flat, y_s_flat, out_size)
output = tf.reshape(
input_transformed,
tf.stack([num_batch, out_height, out_width, num_channels]))
return output
with tf.variable_scope(name):
output = _transform(V, U, out_size)
return output