-
Notifications
You must be signed in to change notification settings - Fork 4
/
train_maml.py
123 lines (107 loc) · 5.84 KB
/
train_maml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from utils import setup_logger
import argparse
import os
import maml
def parse_args():
"""
Parse the arguments for the MAML model
Returns:
args: parsed arguments
"""
parser = argparse.ArgumentParser(description='Implementation of \
Model-Agnostic Meta Learning on \
Fault Diagnosis Datasets')
# Training parameters
parser.add_argument('--ways', type=int, default=10,
help='Number of classes per task, default=10')
parser.add_argument('--shots', type=int, default=5,
help='Number of support examples per class, default=1')
# Meta-learning parameters
parser.add_argument('--meta_lr', type=float, default=0.001,
help='Outer loop learning rate, default=0.001')
parser.add_argument('--fast_lr', type=float, default=0.1,
help='Inner loop learning rate, default=0.1')
parser.add_argument('--adapt_steps', type=int, default=5,
help='Number of inner loop steps for adaptation, default=5')
parser.add_argument('--meta_batch_size', type=int, default=32,
help='Number of outer loop iterations, \
i.e. no. of meta-tasks for each batch, \
default=32')
parser.add_argument('--iters', type=int, default=300,
help='Number of outer-loop iterations, default=300')
parser.add_argument('--first_order', type=bool, default=True,
help='Use the first-order approximation, default=True')
# Cuda and Random Seed
parser.add_argument('--cuda', type=bool, default=True,
help='Use CUDA if available, default=True')
parser.add_argument('--seed', type=int, default=42,
help='Random seed, default=42')
# Dataset parameters
parser.add_argument('--data_dir_path', type=str, default='./data',
help='Path to the data directory, default=./data')
parser.add_argument('--dataset', type=str, default='CWRU',
help='Which dataset to use, \
default=CWRU, \
options=[CWRU, HST]')
parser.add_argument('--preprocess', type=str, default='STFT',
help='Which preprocessing technique to use, \
default=STFT, \
options=[WT, STFT, FFT]')
parser.add_argument('--train_domains', type=str, default='0,1,2',
help='Training domain, integer(s) separated by commas, default=0,1,2')
parser.add_argument('--test_domain', type=int, default=3,
help='Test domain, single integer, default=3')
parser.add_argument('--train_task_num', type=int, default=200,
help='Number of samples per domain for training, default=200')
parser.add_argument('--test_task_num', type=int, default=100,
help='Number of samples per domain for testing, default=100')
# Curve plotting parameters
parser.add_argument('--plot', type=bool, default=True,
help='Plot the learning curve, default=True')
parser.add_argument('--plot_path', type=str, default='./images',
help='Directory to save the learning curve, default=./images')
parser.add_argument('--plot_step', type=int, default=50,
help='Step for plotting the learning curve, default=50')
# Logging parameters
parser.add_argument('--log', type=bool, default=True,
help='Log the training process, default=True')
parser.add_argument('--log_path', type=str, default='./logs',
help='Directory to save the logs, default=./logs')
# Model checkpoint parameters
parser.add_argument('--checkpoint', type=bool, default=True,
help='Save the model checkpoints, default=True')
parser.add_argument('--checkpoint_path', type=str, default='./checkpoints',
help='Directory to save the model checkpoints, default=./checkpoints')
parser.add_argument('--checkpoint_step', type=int, default=50,
help='Step for saving the model checkpoints, default=50')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
if args.dataset not in ['CWRU', 'HST']:
raise ValueError('Dataset must be either CWRU or HST.')
if args.preprocess not in ['WT', 'STFT', 'FFT']:
raise ValueError('Preprocessing technique must be either WT, STFT, or FFT.')
args.train_domains = args.train_domains.split(',')
train_domains_str = ''
for i in range(len(args.train_domains)):
train_domains_str += str(args.train_domains[i])
args.train_domains = [int(i) for i in args.train_domains]
# Experiment title in the format:
# MAML_"dataset name"_"number of ways" + "number of shots"_"source domains"_"target domain".log
experiment_title = 'MAML_{}_{}_{}w{}s_source{}_target{}'.format(args.dataset,
args.preprocess,
args.ways,
args.shots,
train_domains_str,
args.test_domain)
if args.plot:
if not os.path.exists(args.plot_path):
os.makedirs(args.plot_path)
if args.checkpoint:
if not os.path.exists(args.checkpoint_path):
os.makedirs(args.checkpoint_path)
if args.log:
if not os.path.exists(args.log_path):
os.makedirs(args.log_path)
setup_logger(args.log_path, experiment_title)
maml.train(args, experiment_title)