-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
167 lines (135 loc) · 5.05 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import sys
import os
import shutil
import sys
import argparse
from functools import partial
import tensorflow as tf
from rl.agents.a2c.runner import A2CRunner
from rl.agents.a2c.agent import A2CAgent
from rl.agents.feudal.runner import FeudalRunner
from rl.agents.feudal.agent import FeudalAgent
from rl.agents.ppo.runner import PPORunner
from rl.agents.ppo.agent import PPOAgent
from rl.networks.feudal import Feudal
from rl.networks.fully_conv import FullyConv
from rl.networks.conv_lstm import ConvLSTM
from rl.environment import SubprocVecEnv, make_sc2env, SingleEnv
from rl.common.cmd_util import SC2ArgumentParser
# Just disables warnings for mussing AVX/FMA instructions
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# Workaround for pysc2 flags
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['run.py'])
agents = {
'a2c' : {
'agent' : A2CAgent,
'runner': A2CRunner,
'policies' : {
'default' : FullyConv,
'fully_conv' : FullyConv,
'conv_lstm' : ConvLSTM
}
},
'feudal' : {
'agent' : FeudalAgent,
'runner' : FeudalRunner,
'policies' : {
'default' : Feudal,
'feudal' : Feudal
}
},
'ppo' : {
'agent' : PPOAgent,
'runner': PPORunner,
'policies' : {
'default' : FullyConv,
'fully_conv' : FullyConv,
# 'conv_lstm' : ConvLSTM
}
},
}
args_parser = SC2ArgumentParser()
args = args_parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
args.ckpt_path = os.path.join(args.save_dir, args.experiment_id)
summary_type = 'train' if args.train else 'eval'
summary_path = os.path.join(args.summary_dir, args.experiment_id, summary_type)
if args.resume:
args = args_parser.restore(os.path.join(args.summary_dir, args.experiment_id))
args.ow = False
else:
args_parser.save(args,os.path.join(args.summary_dir, args.experiment_id))
def main():
if not args.agent in agents:
print("Error '{}' agent does not exist!".format(args.agent))
sys.exit(1)
if not args.policy in agents[args.agent]['policies']:
print("Error: '{}' policy does not exist for '{}' agent!".format(args.policy, args.agent))
sys.exit(1)
if args.train and args.ow and (os.path.isdir(summary_path) or os.path.isdir(args.ckpt_path)):
yes,no = {'yes','y'},{'no','n', ''}
choice = input(
"\nWARNING! An experiment with the name '{}' already exists.\nAre you sure you want to overwrite it? [y/N]: "
.format(args.experiment_id)
).lower()
if choice in yes:
shutil.rmtree(args.ckpt_path, ignore_errors=True)
shutil.rmtree(summary_path, ignore_errors=True)
else:
print('Quitting program.')
sys.exit(0)
size_px = (args.res, args.res)
env_args = dict(
map_name=args.map,
step_mul=args.step_mul,
game_steps_per_episode=0,
screen_size_px=size_px,
minimap_size_px=size_px
)
vis_env_args = env_args.copy()
vis_env_args['visualize'] = args.vis
num_vis = min(args.envs, args.max_windows)
env_fns = [partial(make_sc2env, **vis_env_args)] * num_vis
num_no_vis = args.envs - num_vis
if num_no_vis > 0:
env_fns.extend([partial(make_sc2env, **env_args)] * num_no_vis)
envs = SubprocVecEnv(env_fns)
summary_writer = tf.summary.FileWriter(summary_path)
args.summary_writer = summary_writer
network_data_format = 'NHWC' if args.nhwc else 'NCHW'
# TODO: We should actually do individual setup and argument parser methods
# for each agent since they require different parameters etc.
print('\n################################\n#')
print('# Running {} Agent with {} policy'.format(args.agent, args.policy))
print('#\n################################\n')
# TODO: pass args directly so each agent and runner can pick theirs
agent = agents[args.agent]['agent'](agents[args.agent]['policies'][args.policy], args)
runner = agents[args.agent]['runner'](agent, envs, summary_writer, args)
i = agent.get_global_step()
try:
while args.iters==-1 or i<args.iters:
write_summary = args.train and i % args.summary_iters == 0
if i > 0 and i % args.save_iters == 0:
_save_if_training(agent, summary_writer)
result = runner.run_batch(train_summary=write_summary)
if write_summary:
agent_step, loss, summary = result
summary_writer.add_summary(summary, global_step=agent_step)
print('iter %d: loss = %f' % (agent_step, loss))
i+=1
except KeyboardInterrupt:
pass
_save_if_training(agent, summary_writer)
envs.close()
summary_writer.close()
print(f'mean score: {runner.get_mean_score()}')
print(f'max score: {runner.get_max_score()}')
def _save_if_training(agent, summary_writer):
if args.train:
agent.save(args.ckpt_path)
summary_writer.flush()
sys.stdout.flush()
if __name__ == "__main__":
main()