-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_ppo_change_friction_10.py
53 lines (46 loc) · 1.38 KB
/
train_ppo_change_friction_10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from pgdrive import PGDriveEnv
from ray import tune
from utils import train, get_train_parser
if __name__ == '__main__':
args = get_train_parser().parse_args()
exp_name = "change_friction_10"
stop = int(10000000)
config = dict(
env=PGDriveEnv,
env_config=dict(
environment_num=tune.grid_search([100]),
start_seed=tune.grid_search([5000, 6000, 7000, 8000, 9000]),
vehicle_config=dict(
wheel_friction=tune.grid_search([1.0])
)
),
# ===== Evaluation =====
evaluation_interval=5,
evaluation_num_episodes=20,
evaluation_config=dict(env_config=dict(environment_num=200, start_seed=0)),
evaluation_num_workers=2,
metrics_smoothing_episodes=20,
# ===== Training =====
horizon=1000,
num_sgd_iter=20,
lr=5e-5,
rollout_fragment_length=200,
sgd_minibatch_size=100,
train_batch_size=30000,
num_gpus=0.5 if args.num_gpus != 0 else 0,
num_cpus_per_worker=0.25,
num_cpus_for_driver=1,
num_workers=10,
)
train(
"PPO",
exp_name=exp_name,
keep_checkpoints_num=5,
stop=stop,
config=config,
num_gpus=args.num_gpus,
# num_seeds=args.num_seeds,
num_seeds=1,
test_mode=args.test,
# local_mode=True
)