-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_mlp.py
124 lines (104 loc) · 4.29 KB
/
train_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import pytorch_lightning as pl
import torch
import torch.nn as nn
from models.mlp_classification import MultiLayerPerceptron
from utils import MnistDataModule
from tqdm import tqdm
import pickle
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def train_FO_mlp():
mnist_dm = MnistDataModule(bs=64)
mnistclassifier = MultiLayerPerceptron()
if torch.cuda.is_available(): # if you have GPUs
trainer = pl.Trainer(max_epochs=epochs, devices=1, accumulate_grad_batches=1, val_check_interval=0.1)
else:
trainer = pl.Trainer(max_epochs=epochs)
trainer.fit(model=mnistclassifier, datamodule=mnist_dm)
dict_results = {}
dict_results['Tr_Loss'] = mnistclassifier.tr_loss
dict_results['Time'] = mnistclassifier.time
dict_results['Query'] = mnistclassifier.query
with open('MNIST_FO_lr1e3_bs64.pickle', 'wb') as f:
pickle.dump(dict_results, f)
def train_ZO_mlp():
mnist_dm = MnistDataModule(bs=128)
mnist_dm.setup(stage='fit')
train_dataloader = mnist_dm.train_dataloader()
val_dataloader = mnist_dm.val_dataloader()
mnistclassifier = MultiLayerPerceptron(zero_order_eps=1e-3, learning_rate=1e-3)
model = mnistclassifier.model
#model.to(device)
mnistclassifier.model.eval()
for epoch in range(epochs):
# validation loop
for i, (x, y) in enumerate(tqdm(val_dataloader)):
#x.to(device)
#y.to(device)
loss = mnistclassifier.validation_step_ZO(model, x, y)
# training loop
for i, (x, y) in enumerate(tqdm(train_dataloader)):
#x.to(device)
#y.to(device)
loss = mnistclassifier.training_step_ZO(model, (x, y))
dict_results = {}
dict_results['Tr_Loss'] = mnistclassifier.tr_loss
dict_results['Time'] = mnistclassifier.time
dict_results['Query'] = mnistclassifier.query
with open('MNIST_ZO_lr1e3_bs128.pickle', 'wb') as f:
pickle.dump(dict_results, f)
def train_ZO_SVRG_Coord_Rand_mlp():
mnist_dm = MnistDataModule()
mnist_dm.setup(stage='fit')
train_dataloader = mnist_dm.train_dataloader()
val_dataloader = mnist_dm.val_dataloader()
mnistclassifier = MultiLayerPerceptron(zero_order_eps=1e-3, learning_rate=1e-3)
model = mnistclassifier.model
#model.to(device)
mnistclassifier.model.eval()
for epoch in range(epochs):
print('epoch:', epoch)
# validation loop
for i, (x, y) in enumerate(tqdm(val_dataloader)):
#x.to(device)
#y.to(device)
loss = mnistclassifier.validation_step_ZO(model, x, y)
# training loop
for i, (x, y) in enumerate(tqdm(train_dataloader)):
#x.to(device)
#y.to(device)
loss = mnistclassifier.training_step_ZO_SVRG_Rand_Coord(model, (x, y), epoch, i)
dict_results = {}
dict_results['Tr_Loss'] = mnistclassifier.tr_loss
with open('MLP_ZO_SVRG_Coord_Rand_FD150_bs64_lr1e3.pickle', 'wb') as f:
pickle.dump(dict_results, f)
def train_ZO_SVRG_mlp():
mnist_dm = MnistDataModule(bs=9096)
mnist_dm.setup(stage='fit')
train_dataloader = mnist_dm.train_dataloader()
val_dataloader = mnist_dm.val_dataloader()
n_batches = len(train_dataloader)
mnistclassifier = MultiLayerPerceptron(zero_order_eps=1e-3, learning_rate=1e-5, learning_rate_aux=2e-3, q=2)
model = mnistclassifier.model
#model.to(device)
mnistclassifier.model.eval()
for epoch in range(epochs):
print('epoch:', epoch)
# validation loop
for i, (x, y) in enumerate(tqdm(val_dataloader)):
#x.to(device)
#y.to(device)
loss = mnistclassifier.validation_step_ZO(model, x, y)
# training loop
for i, (x, y) in enumerate(tqdm(train_dataloader)):
#x.to(device)
#y.to(device)
loss = mnistclassifier.training_step_ZO_SVRG(model, (x, y), epoch, i)
dict_results = {}
dict_results['Tr_Loss'] = mnistclassifier.tr_loss
dict_results['Time'] = mnistclassifier.time
dict_results['Query'] = mnistclassifier.query
with open('MNIST_ZO_SVRG_q2_bs32_lr1e5.pickle', 'wb') as f:
pickle.dump(dict_results, f)
if __name__ == "__main__":
epochs = 100
train_ZO_mlp()