-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining.py
367 lines (311 loc) · 14.9 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
import argparse
import os
import socket
import sys
from pprint import pprint
from loguru import logger
from optuna.pruners import MedianPruner
from optuna.visualization import plot_optimization_history, plot_param_importances
from vimms_gym.wrappers import HistoryWrapper, flatten_dict_observations
sys.path.append('.')
import optuna
# the import order is important to use all cpu cores
# import numpy as np
import torch
from stable_baselines3 import PPO, DQN
from sb3_contrib import RecurrentPPO, MaskablePPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import CheckpointCallback
from tune import sample_dqn_params, sample_ppo_params, TrialEvalCallback
from vimms_gym.experiments import preset_qcb_small, ENV_QCB_SMALL_GAUSSIAN, \
ENV_QCB_MEDIUM_GAUSSIAN, \
ENV_QCB_LARGE_GAUSSIAN, ENV_QCB_SMALL_EXTRACTED, ENV_QCB_MEDIUM_EXTRACTED, \
ENV_QCB_LARGE_EXTRACTED, preset_qcb_medium, preset_qcb_large
from vimms.Common import create_if_not_exist
from vimms_gym.env import DDAEnv
from vimms_gym.common import HISTORY_HORIZON, MAX_EVAL_TIME_PER_EPISODE, METHOD_PPO, \
METHOD_PPO_RECURRENT, METHOD_DQN, ALPHA, BETA, EVAL_METRIC_REWARD, \
EVAL_METRIC_F1, EVAL_METRIC_COVERAGE_PROP, EVAL_METRIC_INTENSITY_PROP, \
EVAL_METRIC_MS1_MS2_RATIO, EVAL_METRIC_EFFICIENCY, GYM_ENV_NAME, GYM_NUM_ENV, USE_SUBPROC
TRAINING_CHECKPOINT_FREQ = 100E6
TRAINING_CHECKPOINT_FREQ = max(TRAINING_CHECKPOINT_FREQ // GYM_NUM_ENV, 1)
EVAL_FREQ = 1E5
N_TRIALS = 100
N_EVAL_EPISODES = 30
def train(model_name, timesteps, horizon, params, max_peaks, out_dir, out_file, verbose=0):
set_torch_threads()
model_params = params['model']
env = make_environment(max_peaks, params, horizon)
model = init_model(model_name, model_params, env, out_dir=out_dir, verbose=verbose)
checkpoint_callback = CheckpointCallback(
save_freq=TRAINING_CHECKPOINT_FREQ, save_path=out_dir,
name_prefix='%s_checkpoint' % model_name)
log_interval = 1 if verbose == 2 else 4
model.learn(total_timesteps=timesteps, callback=checkpoint_callback, log_interval=log_interval)
if out_file is None:
out_file = '%s_%s.zip' % (GYM_ENV_NAME, model_name)
fname = os.path.join(out_dir, out_file)
model.save(fname)
def tune(model_name, timesteps, horizon, params, max_peaks, out_dir,
n_trials, n_eval_episodes, eval_freq, eval_metric, max_eval_time_per_episode,
tune_model, tune_reward, n_startup_trials=0, verbose=0):
set_torch_threads()
# Do not prune before 1/3 of the max budget is used
n_evaluations = max(1, timesteps // int(eval_freq))
n_warmup_steps = int(timesteps // 3)
pruner = MedianPruner(n_startup_trials=n_startup_trials, n_warmup_steps=n_warmup_steps)
logger.info(
f"Doing {int(n_evaluations)} intermediate evaluations for pruning based on the number of timesteps."
f" (1 evaluation every {int(eval_freq)} timesteps)"
f" after warmup of {n_warmup_steps} steps"
)
# Add stream handler of stdout to show the messages
study_name = f'{model_name}'
db_name = os.path.abspath(os.path.join(out_dir, 'study.db'))
storage_name = f'sqlite:///{db_name}'
study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True,
pruner=pruner, direction='maximize')
try:
objective = Objective(model_name, timesteps, horizon, params, max_peaks, out_dir,
n_evaluations, n_eval_episodes, eval_metric,
max_eval_time_per_episode,
tune_model, tune_reward, verbose=verbose)
study.optimize(objective, n_trials=n_trials, catch=(ValueError,))
except KeyboardInterrupt:
pass
trial = study.best_trial
logger.info('Number of finished trials: ', len(study.trials))
logger.info('Best trial:')
logger.info('Value: ', trial.value)
logger.info('Params: ')
for key, value in trial.params.items():
logger.info(f' {key}: {value}')
# Write report csv and pickle
i = 0
while os.path.exists(os.path.join(out_dir, f'study_{i}.csv')):
i += 1
study.trials_dataframe().to_csv(os.path.join(out_dir, f'study_{i}.csv'))
# Plot optimization result
try:
fig1 = plot_optimization_history(study)
fig1.write_image(os.path.join(out_dir, f'fig1_{i}.png'))
fig2 = plot_param_importances(study)
fig2.write_image(os.path.join(out_dir, f'fig2_{i}.png'))
except (ValueError, ImportError, RuntimeError):
pass
class Objective(object):
def __init__(self, model_name, timesteps, horizon, params, max_peaks, out_dir,
n_evaluations, n_eval_episodes, eval_metric, max_eval_time_per_episode,
tune_model, tune_reward, verbose=0):
self.model_name = model_name
self.timesteps = timesteps
self.horizon = horizon
self.params = params
self.max_peaks = max_peaks
self.out_dir = out_dir
self.n_evaluations = n_evaluations
self.n_eval_episodes = n_eval_episodes
self.eval_metric = eval_metric
self.max_eval_time_per_episode = max_eval_time_per_episode
self.tune_model = tune_model
self.tune_reward = tune_reward
self.verbose = verbose
def __call__(self, trial):
# Sample parameters
if self.model_name == METHOD_PPO:
sampled_params = sample_ppo_params(trial, self.tune_model, self.tune_reward)
elif self.model_name == METHOD_PPO_RECURRENT:
sampled_params = sample_ppo_params(trial, self.tune_model, self.tune_reward)
elif self.model_name == METHOD_DQN:
sampled_params = sample_dqn_params(trial, self.tune_model, self.tune_reward)
# Generate model and reward parameters
if self.tune_model: # if tuning, use the sampled model parameters
model_params = dict(sampled_params)
try:
del model_params['alpha']
del model_params['beta']
except KeyError:
pass
else: # otherwise use pre-defined model parameters
model_params = self.params['model']
if self.tune_reward: # if tuning, use the sampled reward parameters
self.params['env']['alpha'] = sampled_params['alpha']
self.params['env']['beta'] = sampled_params['beta']
else: # otherwise leave them as they are
pass
# Create the RL model
env = make_environment(self.max_peaks, self.params, self.horizon)
model = init_model(self.model_name, model_params, env, out_dir=self.out_dir,
verbose=self.verbose)
# Create env used for evaluation
# eval_env = make_environment(self.max_peaks, self.params)
# print('Creating evaluation environment with params', self.params)
eval_env = DDAEnv(self.max_peaks, self.params)
eval_env = flatten_dict_observations(eval_env)
eval_env = HistoryWrapper(eval_env, horizon=self.horizon)
eval_env = Monitor(eval_env)
# Create the callback that will periodically evaluate
# and report the performance
optuna_eval_freq = int(self.timesteps / self.n_evaluations)
optuna_eval_freq = max(optuna_eval_freq // GYM_NUM_ENV,
1) # adjust for multiple environments
eval_callback = TrialEvalCallback(
eval_env, self.model_name, trial, self.eval_metric, self.n_eval_episodes,
optuna_eval_freq, self.max_eval_time_per_episode,
deterministic=True, verbose=self.verbose,
best_model_save_path=self.out_dir,
log_path=self.out_dir
)
try:
log_interval = 1 if self.verbose == 2 else 4
model.learn(self.timesteps, callback=eval_callback, log_interval=log_interval)
# Free memory
model.env.close()
eval_env.close()
except (AssertionError, ValueError) as e:
# Sometimes, random hyperparams can generate NaN
# Free memory
model.env.close()
eval_env.close()
# Prune hyperparams that generate NaNs
print(e)
print('============')
print('Sampled parameters:')
pprint(sampled_params)
raise optuna.exceptions.TrialPruned()
is_pruned = eval_callback.is_pruned
reward = eval_callback.last_mean_reward
del model.env, eval_env
del model
if is_pruned:
raise optuna.exceptions.TrialPruned()
return reward
def init_model(model_name, model_params, env, out_dir=None, verbose=0):
if out_dir is not None:
tensorboard_log = os.path.join(out_dir, '%s_%s_tensorboard' % (GYM_ENV_NAME, model_name))
model = None
if model_name == METHOD_PPO:
model = MaskablePPO('MlpPolicy', env, tensorboard_log=tensorboard_log, verbose=verbose,
**model_params)
elif model_name == METHOD_PPO_RECURRENT:
model = RecurrentPPO('MlpLstmPolicy', env, tensorboard_log=tensorboard_log,
verbose=verbose,
**model_params)
elif model_name == METHOD_DQN:
model = DQN('MlpPolicy', env, tensorboard_log=tensorboard_log, verbose=verbose,
**model_params)
assert model is not None
return model
def set_torch_threads():
torch_threads = 1 # Set pytorch num threads to 1 for faster training
if socket.gethostname() == 'cauchy': # except on cauchy where we have no gpu, only cpu
torch_threads = 40
torch.set_num_threads(torch_threads)
def mask_fn(env):
return env.valid_action_mask()
def make_environment(max_peaks, params, horizon):
def make_env(rank, seed=0):
def _init():
env = DDAEnv(max_peaks, params)
check_env(env)
env.seed(rank)
env = flatten_dict_observations(env)
env = HistoryWrapper(env, horizon=horizon)
env = Monitor(env)
return env
set_random_seed(seed)
return _init
if not USE_SUBPROC:
env = DummyVecEnv([make_env(i) for i in range(GYM_NUM_ENV)])
else:
env = SubprocVecEnv([make_env(i) for i in range(GYM_NUM_ENV)])
return env
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Training and parameter optimisation script for ViMMS-Gym')
parser.add_argument('--results', default=os.path.abspath('notebooks'), type=str,
help='Base location to store results')
parser.add_argument('--out_file', default=None, type=str,
help='Output model filename. If None, a default name will be used.')
parser.add_argument('--verbose', default=0, type=int,
help='Verbosity level')
# model parameters
parser.add_argument('--model', choices=[
METHOD_PPO,
METHOD_PPO_RECURRENT,
METHOD_DQN
], required=True, type=str, help='Specify model name')
parser.add_argument('--timesteps', required=True, type=float, help='Training timesteps')
parser.add_argument('--horizon', default=HISTORY_HORIZON, type=int,
help='How many actions and observations to consider for history wrapping')
parser.add_argument('--tune_model', action='store_true',
help='Optimise model parameters instead of training')
parser.add_argument('--tune_reward', action='store_true',
help='Optimise reward parameters instead of training')
parser.add_argument('--n_trials', default=N_TRIALS, type=int,
help='How many trials in optuna tuning')
parser.add_argument('--n_eval_episodes', default=N_EVAL_EPISODES, type=int,
help='How many evaluation episodes in optuna tuning')
parser.add_argument('--eval_freq', default=EVAL_FREQ, type=float,
help='Frequency of intermediate evaluation steps before pruning an '
'episode in optuna tuning')
parser.add_argument('--max_eval_time_per_episode', default=MAX_EVAL_TIME_PER_EPISODE,
type=float,
help='Maximum time allowed to run one evaluation episode in a trial '
'during optuna tuning')
parser.add_argument('--eval_metric', choices=[
EVAL_METRIC_REWARD,
EVAL_METRIC_F1,
EVAL_METRIC_COVERAGE_PROP,
EVAL_METRIC_INTENSITY_PROP,
EVAL_METRIC_MS1_MS2_RATIO,
EVAL_METRIC_EFFICIENCY
], type=str, help='Specify evaluation metric in optuna tuning')
# environment parameters
parser.add_argument('--preset', choices=[
ENV_QCB_SMALL_GAUSSIAN,
ENV_QCB_MEDIUM_GAUSSIAN,
ENV_QCB_LARGE_GAUSSIAN,
ENV_QCB_SMALL_EXTRACTED,
ENV_QCB_MEDIUM_EXTRACTED,
ENV_QCB_LARGE_EXTRACTED
], required=True, type=str, help='Specify environmental preset')
parser.add_argument('--alpha', default=ALPHA, type=float,
help='First weight parameter in the reward function')
parser.add_argument('--beta', default=BETA, type=float,
help='Second weight parameter in the reward function')
args = parser.parse_args()
model_name = args.model
if args.tune_reward:
alpha = None
beta = None
else:
alpha = args.alpha
beta = args.beta
out_dir = os.path.abspath(os.path.join(args.results, args.model))
create_if_not_exist(out_dir)
# choose one preset and generate parameters for it
presets = {
ENV_QCB_SMALL_GAUSSIAN: {'f': preset_qcb_small, 'extract': False},
ENV_QCB_MEDIUM_GAUSSIAN: {'f': preset_qcb_medium, 'extract': False},
ENV_QCB_LARGE_GAUSSIAN: {'f': preset_qcb_large, 'extract': False},
ENV_QCB_SMALL_EXTRACTED: {'f': preset_qcb_small, 'extract': True},
ENV_QCB_MEDIUM_EXTRACTED: {'f': preset_qcb_medium, 'extract': True},
ENV_QCB_LARGE_EXTRACTED: {'f': preset_qcb_large, 'extract': True},
}
preset_func = presets[args.preset]['f']
extract = presets[args.preset]['extract']
params, max_peaks = preset_func(model_name, alpha=alpha, beta=beta,
extract_chromatograms=extract)
# actually train the model here
if args.tune_model or args.tune_reward:
tune(model_name, args.timesteps, args.horizon, params, max_peaks, out_dir, args.n_trials,
args.n_eval_episodes, int(args.eval_freq), args.eval_metric,
args.max_eval_time_per_episode,
args.tune_model, args.tune_reward, verbose=args.verbose)
else:
train(model_name, args.timesteps, args.horizon, params, max_peaks, out_dir, args.out_file,
verbose=args.verbose)