-
Notifications
You must be signed in to change notification settings - Fork 0
/
args.py
executable file
·53 lines (42 loc) · 2.86 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import argparse
import torch
def parse_args():
parser = argparse.ArgumentParser(description='jwseo')
# dataloader related
parser.add_argument("--data_dir", type=str, default="../../HDD/dataset/")
parser.add_argument("--save_dir", type=str, default="../../HDD2/raqvae/")
parser.add_argument("--dataset", type=str, default="CelebA",
choices=['cifar10', 'CelebA', 'CelebA_128', 'ImageNet'])
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--batch_size_test", type=int, default=64)
parser.add_argument("--num_workers", type=int, default=8)
# model size
parser.add_argument("--raq_type", type=str, default="mb", choices=['mb', 'dd'], help="mb: model-based, dd: data-driven")
parser.add_argument("--model_type", type=str, default="vqvae", choices=['vqvae', 'vqvae2', 'vqgan'])
parser.add_argument("--num_embeddings", type=int, default=256, help="base vocabulary size; number of possible discrete states")
parser.add_argument("--embedding_dim", type=int, default=64, help="size of the vector of the embedding of each discrete token")
parser.add_argument("--n_hid", type=int, default=64, help="number of channels controlling the size of the model")
# Training options
parser.add_argument('--n_epochs', type=int, default=300, help='number of training epochs')
parser.add_argument('--lr', type=int, default=5e-4, help='learning rate')
parser.add_argument('--seed', type=int, default=0, help='training seed: 10, 42, 170, 682')
parser.add_argument('--cuda_ind', type=int, default=0, help='index for cuda device')
# Model-based options
parser.add_argument('--cluster_target', type=int, default=512, help='Codebook clustering taget')
parser.add_argument('--max_iter', type=int, default=200, help='number of dkm iterations')
parser.add_argument('--epsilon', type=int, default=1e-8, help='epsilon for softmax function')
parser.add_argument('--temp', type=int, default=1e-2, help='Softmax temperature of DKM')
# Data-driven options
parser.add_argument("--num_embeddings_min", type=int, default=32,
help="minimum vocabulary size; number of possible discrete states")
parser.add_argument("--num_embeddings_max", type=int, default=2048,
help="maximum vocabulary size; number of possible discrete states")
parser.add_argument("--num_embeddings_test", type=int, default=512,
help="Test vocabulary size; number of possible discrete states")
# directory for FID
parser.add_argument('--img_dir', type=str, default='imgs/')
args = parser.parse_args()
args.device = torch.device("cuda:" + str(args.cuda_ind) if torch.cuda.is_available() else "cpu")
device = args.device
print("Runnung on CUDA: ", device, "...")
return args