-
Notifications
You must be signed in to change notification settings - Fork 11
/
hbp_model.py
executable file
·71 lines (51 loc) · 2.54 KB
/
hbp_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import resnet_model
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
## Modify dimensions based on model
# resnet 18,34
# self.proj0 = nn.Conv2d(512, 8192, kernel_size=1, stride=1)
# self.proj1 = nn.Conv2d(512, 8192, kernel_size=1, stride=1)
# self.proj2 = nn.Conv2d(512, 8192, kernel_size=1, stride=1)
# resnet 50, ...
self.proj0 = nn.Conv2d(2048, 8192, kernel_size=1, stride=1)
self.proj1 = nn.Conv2d(2048, 8192, kernel_size=1, stride=1)
self.proj2 = nn.Conv2d(2048, 8192, kernel_size=1, stride=1)
# fc layer
self.fc_concat = torch.nn.Linear(8192 * 3, 200)
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
# elif isinstance(m, nn.Linear):
# m.weight.data.normal_(0, 0.01)
# m.bias.data.zero_()
self.softmax = nn.LogSoftmax(dim=1)
self.avgpool = nn.AvgPool2d(kernel_size=14)
## select base-model
# self.features = resnet_model.resnet34(pretrained=True,
# model_root='/data/guijun/HBP_finegrained/pth/resnet34.pth')
self.features = resnet_model.resnet50(pretrained=True,
model_root='/data/guijun/HBP_finegrained/pth/resnet50.pth')
def forward(self, x):
batch_size = x.size(0)
feature4_0, feature4_1, feature4_2 = self.features(x)
feature4_0 = self.proj0(feature4_0)
feature4_1 = self.proj1(feature4_1)
feature4_2 = self.proj2(feature4_2)
inter1 = feature4_0 * feature4_1
inter2 = feature4_0 * feature4_2
inter3 = feature4_1 * feature4_2
inter1 = self.avgpool(inter1).view(batch_size, -1)
inter2 = self.avgpool(inter2).view(batch_size, -1)
inter3 = self.avgpool(inter3).view(batch_size, -1)
result1 = torch.nn.functional.normalize(torch.sign(inter1) * torch.sqrt(torch.abs(inter1) + 1e-10))
result2 = torch.nn.functional.normalize(torch.sign(inter2) * torch.sqrt(torch.abs(inter2) + 1e-10))
result3 = torch.nn.functional.normalize(torch.sign(inter3) * torch.sqrt(torch.abs(inter3) + 1e-10))
result = torch.cat((result1, result2, result3), 1)
result = self.fc_concat(result)
return self.softmax(result)