-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
132 lines (96 loc) · 3.63 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#
import glob
import os
import argparse
from PIL import Image
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.metrics import roc_curve, roc_auc_score, auc
from sklearn.model_selection import GroupKFold
from models import *
from bilinear_layers import *
import utils
class oaDataset(Dataset):
def __init__(self, root_path, transform):
# store filenames
self.roi_root = root_path
self.flist = glob.glob(os.path.join(self.roi_root, "**"))
self.transform = transform
def __len__(self):
# return size of dataset
return len(self.flist)
def __getitem__(self, idx):
fname = self.flist[idx]
image, min_jsw, jsw_des, fjsw, kl, jsize = np.load(fname, allow_pickle=True) # PIL image
if kl > 1:
grade =1
else:
grade = 0
image = Image.fromarray(image.astype('uint8'), 'L')
image = self.transform(image)
jsw = np.array(jsw_des)
# noise = np.random.normal(0, 1, 221)
# jsw = jsw + noise
jsw = torch.from_numpy(jsw).float()
return image, jsw, grade
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--oai_assesment', default='/../DATA/OAI/kXR_SQ_BU00_SAS/kxr_sq_bu00.sas7bdat')
parser.add_argument('--model', default=combined)
args = parser.parse_args()
SEED = 42
MAX_EPOCH = 100
np.random.seed(SEED)
torch.manual_seed(SEED);
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
cudnn.benchmark = True
cv = GroupKFold(n_splits=5)
train_transforms = transforms.Compose([
transforms.Resize(size=(56,56)),
transforms.RandomCrop(size= (48,48)),
transforms.ToTensor(),
])
eval_transforms = transforms.Compose([
transforms.Resize(size=(48,48)),
transforms.ToTensor(),
])
plt.figure(figsize=(10, 10))
train_root = '../DATA/CROPS/OAI_00m_tm_fjsw_standardized/'
test_root = '../DATA/CROPS/MOST_00m_tm_fjsw_standardized/'
train_dataloader = DataLoader(oaDataset(train_root, train_transforms),
batch_size=64,
shuffle=True,
num_workers=8,
drop_last=True)
test_dataloader = DataLoader(oaDataset(test_root, eval_transforms),
batch_size=64,
num_workers=8,
drop_last=True)
net = args.model().to(device)
# Loss and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, nesterov=True)
best_acc = 0.0
best_epoch = None
fpr = []
tpr = []
for epoch in range(MAX_EPOCH):
optimizer,lr = utils.adjust_learning_rate(optimizer,epoch,0.00001, 0.01, 8)
# Training
train_loss = utils.train_epoch(epoch, net, optimizer, train_dataloader, criterion)
# Validating
val_loss, preds, truth = utils.validate_epoch(net, test_dataloader, criterion)
auc_val = roc_auc_score(truth, preds)
print(epoch + 1, train_loss, val_loss, auc_val)
if auc_val > best_acc:
best_acc = auc_val
best_epoch = epoch
fpr, tpr, thresholds = roc_curve(truth, preds)
plt.plot(fpr, tpr, lw=2, alpha=0.8,
label='ROC (AUC = %0.2f)' % (best_acc))
plt.show()
print("Best AUC:", best_acc)