-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
54 lines (46 loc) · 1.63 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
import torchvision
class CNN(nn.Module): # random basic model for testing things
def __init__(self):
super(CNN, self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(3,16,5),
nn.ReLU(),
nn.Conv2d(16,32,5),
nn.ReLU(),
nn.MaxPool2d(2,2),
nn.Conv2d(32,64,5),
nn.ReLU(),
nn.MaxPool2d(2,2)
)
self.fc_layer = nn.Sequential(
nn.Linear(64*4*4,100),
nn.ReLU(),
nn.Linear(100,10)
)
def forward(self,x):
print(x.size())
out = self.layer(x)
out = out.view(-1,64*4*4)
out = self.fc_layer(out)
return out
class ResNet(nn.Module):
def __init__(self, version, dset='cifar'):
super(ResNet, self).__init__()
self.resnet = torch.hub.load('pytorch/vision:v0.6.0', 'resnet' + version, pretrained=False)
if dset == 'cifar':
self.resnet.fc = nn.Linear(in_features=self.resnet.fc.in_features, out_features=10, bias=True)
else:
self.resnet.fc = nn.Linear(in_features=self.resnet.fc.in_features, out_features=10, bias=True)
def forward(self,x):
out = self.resnet(x)
return out
class MobileNet(nn.Module):
def __init__(self):
super(MobileNet, self).__init__()
self.mobile = torch.hub.load('pytorch/vision:v0.6.0', 'mobilenet_v2' , pretrained=False)
self.mobile.classifier = nn.Sequential(nn.Dropout(p=0.2, inplace=False), nn.Linear(1280, 10, bias=True))
def forward(self,x):
out = self.mobile(x)
return out