-
Notifications
You must be signed in to change notification settings - Fork 1
/
wandb_optuna.py
153 lines (121 loc) · 4.47 KB
/
wandb_optuna.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
The main code is based on https://github.com/optuna/optuna-examples/blob/63fe36db4701d5b230ade04eb2283371fb2265bf/pytorch/pytorch_simple.py
"""
import wandb
import os
import optuna
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision import datasets
from torchvision import transforms
# wandb might cause an error without this.
os.environ["WANDB_START_METHOD"] = "thread"
DEVICE = torch.device("cpu")
BATCHSIZE = 128
CLASSES = 10
DIR = os.getcwd()
EPOCHS = 100
LOG_INTERVAL = 10
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10
STUDY_NAME = "pytorch-optimization"
def train(optimizer, model, train_loader):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
# Limiting training data for faster epochs.
if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES:
break
data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
def validate(model, valid_loader):
# Validation of the model.
model.eval()
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_loader):
# Limiting validation data.
if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES:
break
data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
output = model(data)
# Get the index of the max log-probability.
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES)
return accuracy
def define_model(trial):
# We optimize the number of layers, hidden units and dropout ratio in each layer.
n_layers = trial.suggest_int("n_layers", 1, 3)
layers = []
in_features = 28 * 28
for i in range(n_layers):
out_features = trial.suggest_int("n_units_l{}".format(i), 4, 128)
layers.append(nn.Linear(in_features, out_features))
layers.append(nn.ReLU())
p = trial.suggest_float("dropout_l{}".format(i), 0.2, 0.5)
layers.append(nn.Dropout(p))
in_features = out_features
layers.append(nn.Linear(in_features, CLASSES))
layers.append(nn.LogSoftmax(dim=1))
return nn.Sequential(*layers)
# Get the data loaders of FashionMNIST dataset.
train_loader = torch.utils.data.DataLoader(
datasets.FashionMNIST(
DIR, train=True, download=True, transform=transforms.ToTensor()
),
batch_size=BATCHSIZE,
shuffle=True,
)
valid_loader = torch.utils.data.DataLoader(
datasets.FashionMNIST(DIR, train=False, transform=transforms.ToTensor()),
batch_size=BATCHSIZE,
shuffle=True,
)
def objective(trial):
# Generate the model.
model = define_model(trial).to(DEVICE)
# Generate the optimizers.
optimizer_name = trial.suggest_categorical("optimizer", ["AdamW", "RMSprop", "SGD"])
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
# init tracking experiment.
# hyper-parameters, trial id are stored.
config = dict(trial.params)
config["trial.number"] = trial.number
wandb.init(
project="optuna",
entity="nzw0301", # NOTE: this entity depends on your wandb account.
config=config,
group=STUDY_NAME,
reinit=True,
)
# Training of the model.
for epoch in range(EPOCHS):
train(optimizer, model, train_loader)
val_accuracy = validate(model, valid_loader)
trial.report(val_accuracy, epoch)
# report validation accuracy to wandb
wandb.log(data={"validation accuracy": val_accuracy}, step=epoch)
# Handle pruning based on the intermediate value.
if trial.should_prune():
wandb.run.summary["state"] = "pruned"
wandb.finish(quiet=True)
raise optuna.exceptions.TrialPruned()
# report the final validation accuracy to wandb
wandb.run.summary["final accuracy"] = val_accuracy
wandb.run.summary["state"] = "complated"
wandb.finish(quiet=True)
return val_accuracy
study = optuna.create_study(
direction="maximize",
study_name=STUDY_NAME,
pruner=optuna.pruners.MedianPruner(),
)
study.optimize(objective, n_trials=100, timeout=600)