-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathrunner.py
executable file
·225 lines (195 loc) · 9.3 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
#!/usr/bin/env python
"""Driver for DCASE 2019 Task 2 Baseline.
See README.md in this directory for a more detailed description.
Usage:
- Download Kaggle data: train_{curated,noisy}.{csv,zip}, test.zip. Unzip
zip files into directories train_curated, train_noisy, test.
- We start by demonstrating how to train on the curated dataset and
run inference on the test set. The procedure is similar if you wanted
to train on the noisy dataset, or some combination of curated and noisy.
We also support training on noisy data and then warmstarting curated
training with a noisily trained checkpoint using the --warmstart* flags.
See the README.md in this repo for more details.
- Shuffle and randomly split the training CSV into a train set and held-out
validation set: train.csv, validation.csv.
- Prepare class map:
$ make_class_map.py < /path/to/train_curated.csv > /path/to/class_map.csv
This should match the class_map.csv provided in this repository.
- Train a model with checkpoints produced in a new train_dir:
$ runner.py \
--mode train \
--model mobilenet-v1 \
--class_map_path /path/to/class_map.csv \
--train_clip_dir /path/to/train_curated \
--train_csv_path /path/to/train.csv \
--train_dir /path/to/train_dir
To override default hyperparameters, also pass in the --hparams flag:
--hparams name=value,name=value,..
See model.parse_hparams() for default values of all hyperparameters.
- Evaluate the trained model on the validation set on all checkpoints
in train_dir:
$ runner.py \
--mode eval \
--model mobilenet-v1 \
--class_map_path /path/to/class_map.csv \
--eval_clip_dir /path/to/train_curated \
--eval_csv_path /path/to/validation.csv \
--eval_dir /path/to/eval_dir \
--train_dir /path/to/train_dir
(make sure to use the same hparams overrides as used in training)
Evaluation iterates over all available checkpoints and writes marker
files (containing per-class and overall lwlrap) in eval_dir for each
checkpoint, so that it can be stopped and resumed safely without having
to repeat any work.
- Training and evaluation will produce TensorFlow summaries in event log
files in train_dir and eval_dir which you can view by running a TensorBoard
server pointed at these directories. Typically, you would have several
train/eval jobs running in parallel (one for each combination of
hyperparameters in a grid search), and a single TensorBoard visualizer job
that lets you look at the results from all the runs in real time.
- Run inference on a trained model to produce predictions in the Kaggle
submission format in file submission.csv. You will do this inside a kernel
to make your submission.
$ runner.py \
--mode inference \
--model mobilenet-v1 \
--class_map_path /path/to/class_map.csv \
--inference_clip_dir /path/to/test \
--inference_checkpoint /path/to/train_dir/model.ckpt-<N> \
--predictions_csv_path /path/to/submission.csv
(make sure to use the same hparams overrides as used in training)
"""
from __future__ import print_function
import argparse
import sys
import tensorflow as tf
import evaluation
import inference
import model
import train
def parse_flags(argv):
parser = argparse.ArgumentParser(description='DCASE 2019 Task 2 Baseline')
# Flags common to all modes.
all_modes_group = parser.add_argument_group('Flags common to all modes')
all_modes_group.add_argument(
'--mode', type=str, choices=['train', 'eval', 'inference'], required=True,
help='Run one of training, evaluation, or inference.')
all_modes_group.add_argument(
'--model', type=str, choices=['mobilenet-v1'],
default='mobilenet-v1', required=True,
help='Name of a model architecture. Current options: mobilenet-v1.')
all_modes_group.add_argument(
'--hparams', type=str, default='',
help='Model hyperparameters in comma-separated name=value format.')
all_modes_group.add_argument(
'--class_map_path', type=str, default='', required=True,
help='Path to CSV file containing map between class index and name.')
# Flags for training only.
training_group = parser.add_argument_group('Flags for training only')
training_group.add_argument(
'--train_clip_dir', type=str, default='',
help='Path to directory containing training clips.')
training_group.add_argument(
'--train_csv_path', type=str, default='',
help='Path to CSV file containing training clip filenames and labels.')
training_group.add_argument(
'--epoch_num_batches', type=int, default=0,
help='Number of batches in an epoch.')
training_group.add_argument(
'--warmstart_checkpoint', type=str, default='',
help='Path to a model checkpoint to use for warm-started training.')
training_group.add_argument(
'--warmstart_include_scopes', type=str, default='',
help='Comma-separated list of variable scopes to include when loading '
'the warm-start checkpoint.')
training_group.add_argument(
'--warmstart_exclude_scopes', type=str, default='',
help='Comma-separated list of variable scopes to exclude when loading '
'the warm-start checkpoint.')
# Flags for training and evaluation.
train_eval_group = parser.add_argument_group('Flags for training and eval')
train_eval_group.add_argument(
'--train_dir', type=str, default='',
help='Path to a directory which will hold model checkpoints and other outputs.')
# Flags for evaluation only.
eval_group = parser.add_argument_group('Flags for evaluation only')
eval_group.add_argument(
'--eval_clip_dir', type=str, default='',
help='Path to directory containing evaluation clips.')
eval_group.add_argument(
'--eval_csv_path', type=str, default='',
help='Path to CSV file containing evaluation clip filenames and labels.')
eval_group.add_argument(
'--eval_dir', type=str, default='',
help='Path to a directory holding eval results.')
# Flags for inference only.
inference_group = parser.add_argument_group('Flags for inference only')
inference_group.add_argument(
'--inference_checkpoint', type=str, default='',
help='Path to a model checkpoint to use for inference.')
inference_group.add_argument(
'--inference_clip_dir', type=str, default='',
help='Path to directory containing test clips.')
inference_group.add_argument(
'--predictions_csv_path', type=str, default='',
help='Path to a CSV file in which to store predictions.')
flags, rest_argv = parser.parse_known_args(argv)
# Additional per-mode validation.
try:
if flags.mode == 'train':
assert flags.train_clip_dir, 'Must specify --train_clip_dir'
assert flags.train_csv_path, 'Must specify --train_csv_path'
assert flags.train_dir, 'Must specify --train_dir'
if 'lrdecay' in flags.hparams:
assert flags.epoch_num_batches > 0, (
'When using hparams.lrdecay, must specify --epoch_num_batches')
if 'warmstart' in flags.hparams:
assert flags.warmstart_checkpoint, (
'When using hparams.warmstart, must specify --warmstart_checkpoint')
elif flags.mode == 'eval':
assert flags.eval_clip_dir, 'Must specify --eval_clip_dir'
assert flags.eval_csv_path, 'Must specify --eval_csv_path'
assert flags.eval_dir, 'Must specify --eval_dir'
assert flags.train_dir, 'Must specify --train_dir'
else:
assert flags.mode == 'inference'
assert flags.inference_checkpoint, 'Must specify --inference_checkpoint'
assert flags.inference_clip_dir, 'Must specify --inference_clip_dir'
assert flags.predictions_csv_path, 'Must specify --predictions_csv_path'
except AssertionError as e:
print('\nError: ', e, '\n', file=sys.stderr)
parser.print_help(file=sys.stderr)
sys.exit(1)
return flags, rest_argv
flags = None
def main(argv):
hparams = model.parse_hparams(flags.hparams)
if flags.mode == 'train':
def split_csv(scopes):
return scopes.split(',') if scopes else None
train.train(model_name=flags.model, hparams=hparams,
class_map_path=flags.class_map_path,
train_csv_path=flags.train_csv_path,
train_clip_dir=flags.train_clip_dir,
train_dir=flags.train_dir,
epoch_batches=flags.epoch_num_batches,
warmstart_checkpoint=flags.warmstart_checkpoint,
warmstart_include_scopes=split_csv(flags.warmstart_include_scopes),
warmstart_exclude_scopes=split_csv(flags.warmstart_exclude_scopes))
elif flags.mode == 'eval':
evaluation.evaluate(model_name=flags.model, hparams=hparams,
class_map_path=flags.class_map_path,
eval_csv_path=flags.eval_csv_path,
eval_clip_dir=flags.eval_clip_dir,
eval_dir=flags.eval_dir,
train_dir=flags.train_dir)
else:
assert flags.mode == 'inference'
inference.predict(model_name=flags.model, hparams=hparams,
class_map_path=flags.class_map_path,
inference_clip_dir=flags.inference_clip_dir,
inference_checkpoint=flags.inference_checkpoint,
predictions_csv_path=flags.predictions_csv_path)
if __name__ == '__main__':
flags, sys.argv = parse_flags(sys.argv)
tf.app.run(main)