-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
115 lines (96 loc) · 4.34 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from torch.utils.data import DataLoader
from learner import Learner
from loss import *
from dataset import *
import os
from sklearn import metrics
import argparse
from FFC import *
parser = argparse.ArgumentParser(description='PyTorch MIL Training')
parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
parser.add_argument('--w', default=0.0010000000474974513, type=float, help='weight_decay')
parser.add_argument('--modality', default='TWO', type=str, help='modality')
parser.add_argument('--input_dim', default=2048, type=int, help='input_dim')
parser.add_argument('--drop', default=0.6, type=float, help='dropout_rate')
parser.add_argument('--FFC', '-r', action='store_true',help='FFC')
args = parser.parse_args()
best_auc = 0
normal_train_dataset = Normal_Loader(is_train=1, modality=args.modality)
normal_test_dataset = Normal_Loader(is_train=0, modality=args.modality)
anomaly_train_dataset = Anomaly_Loader(is_train=1, modality=args.modality)
anomaly_test_dataset = Anomaly_Loader(is_train=0, modality=args.modality)
normal_train_loader = DataLoader(normal_train_dataset, batch_size=30, shuffle=True)
normal_test_loader = DataLoader(normal_test_dataset, batch_size=1, shuffle=True)
anomaly_train_loader = DataLoader(anomaly_train_dataset, batch_size=30, shuffle=True)
anomaly_test_loader = DataLoader(anomaly_test_dataset, batch_size=1, shuffle=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if args.FFC:
model = Learner2(input_dim=args.input_dim, drop_p=args.drop).to(device)
else:
model = Learner(input_dim=args.input_dim, drop_p=args.drop).to(device)
optimizer = torch.optim.Adagrad(model.parameters(), lr= args.lr, weight_decay=args.w)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 50])
criterion = MIL
def train(epoch):
print('\nEpoch: %d' % epoch)
model.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (normal_inputs, anomaly_inputs) in enumerate(zip(normal_train_loader, anomaly_train_loader)):
inputs = torch.cat([anomaly_inputs, normal_inputs], dim=1)
batch_size = inputs.shape[0]
inputs = inputs.view(-1, inputs.size(-1)).to(device)
outputs = model(inputs)
loss = criterion(outputs, batch_size)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
print('loss = {}', train_loss/len(normal_train_loader))
scheduler.step()
def test_abnormal(epoch):
model.eval()
global best_auc
auc = 0
with torch.no_grad():
for i, (data, data2) in enumerate(zip(anomaly_test_loader, normal_test_loader)):
inputs, gts, frames = data
inputs = inputs.view(-1, inputs.size(-1)).to(torch.device('cuda'))
score = model(inputs)
score = score.cpu().detach().numpy()
score_list = np.zeros(frames[0])
step = np.round(np.linspace(0, frames[0]//16, 33))
for j in range(32):
score_list[int(step[j])*16:(int(step[j+1]))*16] = score[j]
gt_list = np.zeros(frames[0])
for k in range(len(gts)//2):
s = gts[k*2]
e = min(gts[k*2+1], frames)
gt_list[s-1:e] = 1
inputs2, gts2, frames2 = data2
inputs2 = inputs2.view(-1, inputs2.size(-1)).to(torch.device('cuda'))
score2 = model(inputs2)
score2 = score2.cpu().detach().numpy()
score_list2 = np.zeros(frames2[0])
step2 = np.round(np.linspace(0, frames2[0]//16, 33))
for kk in range(32):
score_list2[int(step2[kk])*16:(int(step2[kk+1]))*16] = score2[kk]
gt_list2 = np.zeros(frames2[0])
score_list3 = np.concatenate((score_list, score_list2), axis=0)
gt_list3 = np.concatenate((gt_list, gt_list2), axis=0)
fpr, tpr, thresholds = metrics.roc_curve(gt_list3, score_list3, pos_label=1)
auc += metrics.auc(fpr, tpr)
print('auc = {}',auc/140)
if best_auc < auc/140:
print('Saving..')
state = {
'net': model.state_dict(),
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt.pth')
best_auc = auc/140
for epoch in range(0, 75):
train(epoch)
test_abnormal(epoch)