forked from chengdazhi/Deformable-Convolution-V2-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
34 lines (28 loc) · 839 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from modules import DeformConv
num_deformable_groups = 2
N, inC, inH, inW = 2, 6, 512, 512
outC, outH, outW = 4, 512, 512
kH, kW = 3, 3
conv = nn.Conv2d(
inC,
num_deformable_groups * 2 * kH * kW,
kernel_size=(kH, kW),
stride=(1, 1),
padding=(1, 1),
bias=False).cuda()
conv_offset2d = DeformConv(
inC,
outC, (kH, kW),
stride=1,
padding=1,
num_deformable_groups=num_deformable_groups).cuda()
inputs = Variable(torch.randn(N, inC, inH, inW).cuda(), requires_grad=True)
offset = conv(inputs)
#offset = Variable(torch.randn(N, num_deformable_groups * 2 * kH * kW, inH, inW).cuda(), requires_grad=True)
output = conv_offset2d(inputs, offset)
output.backward(output.data)
print(output.size())