-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_cav_ppo.py
77 lines (63 loc) · 2.43 KB
/
main_cav_ppo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import click
import torch
from dowel import CsvOutput, logger
from garage import wrap_experiment
from garage.envs import normalize
from garage.experiment.deterministic import set_seed
from garage.sampler import RaySampler
from garage.torch.algos import PPO
from garage.torch.policies import GaussianMLPPolicy
from garage.torch.value_functions import GaussianMLPValueFunction
from garage.trainer import Trainer
from cav_environment import CAVVelEnv
@click.command()
@click.option('--seed', default=1)
@click.option('--epochs', default=2000)
@click.option('--episodes_per_task', default=10)
@click.option('--max_episode_length', default=75)
@click.option('--saved_dir', default=os.getcwd()+"/logs")
@click.option('--log_file', default=os.getcwd()+"/logs/ppo.csv")
@wrap_experiment
def main(ctxt, seed, epochs, episodes_per_task, max_episode_length, saved_dir, log_file):
"""Train PPO with CAV environment.
Set up environment and algorithm and run the task.
Args:
seed (int): Used to seed the random number generator to produce
determinism.
epochs (int): Number of training epochs.
episodes_per_task (int): Number of episodes per epoch per task
for training.
max_episode_length (int): The maximum steps allowed for an
episode.
saved_dir (str): Path where snapshots are saved.
log_file (str): Path where csvs are saved.
"""
set_seed(seed)
logger.add_output(CsvOutput(log_file))
env = normalize(CAVVelEnv(max_episode_length=max_episode_length),
expected_action_scale=10.)
policy = GaussianMLPPolicy(
env_spec=env.spec,
hidden_sizes=(16, 16),
hidden_nonlinearity=torch.tanh,
output_nonlinearity=None,
)
sampler = RaySampler(agents=policy,
envs=env,
max_episode_length=env.spec.max_episode_length,
is_tf_worker=True)
value_function = GaussianMLPValueFunction(env_spec=env.spec)
algo = PPO(env_spec=env.spec,
policy=policy,
value_function=value_function,
sampler=sampler,
discount=0.99,
gae_lambda=0.97,
lr_clip_range=2e-1)
ctxt.snapshot_dir = saved_dir
trainer = Trainer(ctxt)
trainer.setup(algo, env)
trainer.train(n_epochs=epochs,
batch_size=episodes_per_task * max_episode_length)
main()