-
Notifications
You must be signed in to change notification settings - Fork 0
/
NetTIME_train.py
295 lines (279 loc) · 7.26 KB
/
NetTIME_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import argparse
from NetTIME import TrainWorkflow
######## User Input ########
parser = argparse.ArgumentParser("Training a NetTIME model.")
# Training parameters
parser.add_argument(
"--batch_size",
type=int,
default=1800,
help="Training batch size. Default: 1800",
)
parser.add_argument(
"--num_epochs",
type=int,
default=50,
help="Number of training epoch. Default: 50",
)
parser.add_argument(
"--num_workers",
type=int,
default=20,
help="Number of workers used to perform multi-process data loading. "
"Default: 20",
)
parser.add_argument(
"--start_from_checkpoint",
type=str,
default=None,
help="Path to a pretrained model checkpoint from which to start training. "
"Default: None, training starts from scratch.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Learning rate, Default: 1e-4",
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.0,
help="Oprimizer weight decay. Default: 0.0",
)
parser.add_argument(
"--seed", type=int, default=None, help="Random seed. Default: None"
)
parser.add_argument(
"--loss_avg_ratio",
type=float,
default=0.9,
help="Weight of loss history when calculating cumulative loss. Default: 0.9",
)
parser.add_argument(
"--clip_threshold",
type=float,
default=None,
help="The max_norm of the gradients to clip when using graduent clipping. "
"Default: None, no gradient clipping.",
)
# Data
parser.add_argument(
"--ct_feature",
action="store_true",
help="Include cell type-specific feature during training. Default: False.",
)
parser.add_argument(
"--tf_feature",
action="store_true",
help="Include TF-specific feature during training. Default: False.",
)
parser.add_argument(
"--output_key",
type=str,
default=["output_conserved", "output_relaxed"],
nargs="+",
help="A list of keys specifying the types of target labels to use for "
"training. Default: output_conserved output_relaxed",
)
parser.add_argument(
"--dataset",
type=str,
default="data/datasets/training_example/training_minOverlap200_maxUnion600.h5",
help="Path to training data. "
"Default: data/datasets/training_example/training_minOverlap200_maxUnion600.h5",
)
parser.add_argument(
"--dtype",
type=str,
default="TRAINING",
help="Dataset type. Default: TRAINING.",
)
parser.add_argument(
"--index_file",
type=str,
default="data/embeddings/example.pkl",
help="Path to a pickle file containing indices for TF and cell type labels."
"Default: data/embeddings/example.pkl",
)
parser.add_argument(
"--exclude_groups",
type=str,
default=None,
nargs="+",
help="List of group names to be excluded from training. "
"Default: None, all conditions in DATASET are included during training.",
)
parser.add_argument(
"--include_groups",
type=str,
default=None,
nargs="+",
help="List of group names to be included for training."
"Default: None, all conditions in DATASET are included during training.",
)
# Display and save
parser.add_argument(
"--print_every",
type=int,
default=10,
help="Display training loss every PRINT_EVERY steps. Default: 10",
)
parser.add_argument(
"--evaluate_every",
type=int,
default=50,
help="Save a model checkpoint every (PRINT_EVERY * EVALUATE_EVERY) steps "
"for evaluation. Default 50",
)
parser.add_argument(
"--output_dir",
type=str,
default="experiments/",
help="Root directory for saving experiment results."
"Default: experiments/",
)
parser.add_argument(
"--experiment_name",
type=str,
default="training_example",
help="experiment name.",
)
parser.add_argument(
"--result_dir",
type=str,
default=None,
help="Specify an alternative location to save model training loss files.",
)
parser.add_argument(
"--ckpt_dir",
type=str,
default=None,
help="Specify an alternative location to save model checkpoint files.",
)
# Model architecture
parser.add_argument(
"--tf_vocab_size",
type=int,
default=256,
help="TF vocabulary size. Default: 256",
)
parser.add_argument(
"--ct_vocab_size",
type=int,
default=64,
help="Cell type vocabulary size. Default: 64",
)
parser.add_argument(
"--input_size",
type=int,
default=0,
help="Sum of the number of cell type features and the number of TF features. "
"Default: 0, no cell type and TF features included during training.",
)
parser.add_argument(
"--output_size",
type=int,
default=2,
help="Number of classes in target labels. Default 2",
)
parser.add_argument(
"--seq_length",
type=int,
default=1000,
help="Input sequence length, Default 1000",
)
parser.add_argument(
"--embedding_size",
type=int,
default=50,
help="Dimension of the embedding vectors for TFs and cell types. Default: 50",
)
parser.add_argument(
"--disable_tf_embed",
action="store_true",
help="Do not use TF embedding vectors during training. Default False.",
)
parser.add_argument(
"--disable_ct_embed",
action="store_true",
help="Do not use CT embedding vectors during training. Default False.",
)
parser.add_argument(
"--fc_act_fn",
type=str,
default="ReLU",
choices=["ReLU", "Tanh"],
help="Name of the activation function for FC layers. Default: ReLU",
)
parser.add_argument(
"--num_basic_blocks",
type=int,
default=2,
help="Number of Basic Block layers. Default: 2",
)
parser.add_argument(
"--cnn_act_fn",
type=str,
default="ReLU",
choices=["ReLU", "Tanh"],
help="Name of the activation function for CNN layers. Default: ReLU",
)
parser.add_argument(
"--rnn_act_fn",
type=str,
default="Tanh",
choices=["ReLU", "Tanh"],
help="Activation function for RNN layers. Default: Tanh",
)
parser.add_argument(
"--kernel_size",
type=int,
default=7,
help="CNN kernal size. Default: 7",
)
parser.add_argument(
"--stride",
type=int,
default=1,
help="CNN stride size. Default 1.",
)
parser.add_argument(
"--dropout",
type=float,
default=0.0,
help="Dropout rate. Default 0.0",
)
args = parser.parse_args()
######## Configure workflow ########
workflow = TrainWorkflow()
# Training parameters
workflow.batch_size = args.batch_size
workflow.num_epochs = args.num_epochs
workflow.num_workers = args.num_workers
workflow.start_from_checkpoint = args.start_from_checkpoint
workflow.learning_rate = args.learning_rate
workflow.weight_decay = args.weight_decay
workflow.seed = args.seed
workflow.loss_avg_ratio = args.loss_avg_ratio
workflow.clip_threshold = args.clip_threshold
# Data
workflow.ct_feature = args.ct_feature
workflow.tf_feature = args.tf_feature
workflow.output_key = args.output_key
workflow.dataset = args.dataset
workflow.dtype = args.dtype
workflow.index_file = args.index_file
workflow.exclude_groups = args.exclude_groups
workflow.include_groups = args.include_groups
# Display and save
workflow.print_every = args.print_every
workflow.evaluate_every = args.evaluate_every
workflow.output_dir = args.output_dir
workflow.experiment_name = args.experiment_name
workflow.result_dir = args.result_dir
workflow.ckpt_dir = args.ckpt_dir
# Args
workflow.args = args
######## Model Run ########
workflow.run()