-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_utils.py
153 lines (118 loc) · 4.43 KB
/
model_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
from torch import nn
from torchvision.models import resnet50, alexnet, vgg13, vgg16, densenet161
from torchvision.models import ResNet50_Weights, AlexNet_Weights, VGG13_Weights, VGG16_Weights, DenseNet161_Weights
def freeze_model_params_(model):
"""
Freeze all the parameters of a model.
Args:
model (nn.Module): a convolutional neural network
Returns:
None
"""
for param in model.parameters():
param.requires_grad = False
def get_model_trainable_params(model):
"""
Get the unfreezed params in a model.
Args:
model (nn.Module): a convolutional neural network
Returns:
trainable_params (List): list of parameters to learn during the
training.
"""
return filter(lambda param: param.requires_grad, model.parameters())
def build_classifier(num_features, num_hidden_units, num_classes, dropout):
"""
Build a classifier.
Args:
num_features (Int): the number of units in the input layer
num_hidden_units (Int): the number of units in the hidden layer
num_classes (Int): the number of units in the output layer
dropout (Float): the nodes drop probability
Returns:
classifier (nn.Module): a model capable of classification
"""
classifier = nn.Sequential(
nn.Linear(num_features, num_hidden_units),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(num_hidden_units, num_classes),
nn.LogSoftmax(dim=1),
)
return classifier
# Inspired by pytorch documentation.
# SEE:
# https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
def build_model(arch, num_hidden_units, num_classes, dropout):
"""
Download the appropriate pretrained model according to the specified
architecture and adjust the classifier.
Args:
arch (String): the architecture of the pretrained models.
num_hidden_units (Unt): the number of units in the hidden layer of the
classifier.
num_classes (Int): the number of classes in the dataset.
dropout (Float): the nodes drop probability
Returns:
model (nn.Module): a convolutional neural network with a classifier
layer ready to be trained.
"""
if arch == "resnet50":
model = resnet50(weights=ResNet50_Weights.DEFAULT)
freeze_model_params_(model)
num_features = model.fc.in_features
model.fc = build_classifier(
num_features, num_hidden_units, num_classes, dropout
)
elif arch == "alexnet":
model = alexnet(weights=AlexNet_Weights.DEFAULT)
freeze_model_params_(model)
num_features = model.classifier[0].in_features
model.classifier = build_classifier(
num_features, num_hidden_units, num_classes, dropout
)
elif arch == "vgg13":
model = vgg13(weights=VGG13_Weights.DEFAULT)
freeze_model_params_(model)
num_features = model.classifier[0].in_features
model.classifier = build_classifier(
num_features, num_hidden_units, num_classes, dropout
)
elif arch == "vgg16":
model = vgg16(weights=VGG16_Weights.DEFAULT)
freeze_model_params_(model)
num_features = model.classifier[0].in_features
model.classifier = build_classifier(
num_features, num_hidden_units, num_classes, dropout
)
elif arch == "densenet161":
model = densenet161(weights=DenseNet161_Weights.DEFAULT)
freeze_model_params_(model)
num_features = model.classifier.in_features
model.classifier = build_classifier(
num_features, num_hidden_units, num_classes, dropout
)
else:
print("Invalid model name, exiting...")
exit()
return model
def rebuild_model_from_checkpoint(checkpoint_path):
"""
Rebuild a model from its checkpoint.
Args:
checkpoint_path (Path): a path to the model's checkpoint.
Returns:
model (nn.Module): a convolutional neural network with a classifier
layer ready to be retrained of inferred.
"""
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model = build_model(
checkpoint["arch"],
checkpoint["num_hidden_units"],
checkpoint["num_classes"],
checkpoint["dropout"],
)
model.load_state_dict(checkpoint["model_state_dict"])
model.class_to_idx = checkpoint["class_to_idx"]
return model