-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrunGan.py
56 lines (37 loc) · 1.62 KB
/
runGan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Lets try an experiment
from hw3.answers import part2_vae_hyperparams
from hw3.experiments import run_experiment_GAN
# from hw3.autoencoder import AutoEncoderError
import numpy as np
Init_Name = 'Gan_Tain_1outof2'
OUTDIR = './results/Res7/'
AllResults = {}
bs = 16
h_dim = 256
z_dim = 128
Epochs = 100
gen_lr = 0.0008
dis_lr = 0.0005
data_label = 1
label_noise = 0.3
def oneExp(gen_lr,des_lr,generator_optim,dsc_optim):
try:
name = Init_Name + 'DSC_' + dsc_optim + 'GEN_' + generator_optim + 'genlr_' + str(gen_lr) + 'deslr' + str(des_lr)
res = run_experiment_GAN(name, out_dir=OUTDIR, seed=42,
# Training params
bs_train=8, bs_test=None, batches=100, epochs=100,
early_stopping=10, checkpoints=None,print_every=10,
# Model params
h_dim=256, z_dim=128, gen_lr=gen_lr, des_lr=des_lr,
generator_optim=generator_optim,
dsc_optim=dsc_optim, data_label=data_label, label_noise=label_noise)
AllResults[name] = res
except OSError as e:
AllResults[name] = 'Failed... ' + str(e)
# for gen_lr in [0.0008,0.0005,0.0001]:
# for des_lr in [0.0008,0.0005,0.0001]:
# for generator_optim in ['SGD','Adam']:
# for dsc_optim in ['SGD', 'Adam']:
# oneExp(gen_lr,des_lr,generator_optim,dsc_optim)
oneExp(gen_lr,dis_lr,'SGD','SGD')
np.save(OUTDIR + 'RunFinal', AllResults)