-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
35 lines (34 loc) · 904 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from model import maddpg
import os
import argparse
from gym.spaces.discrete import Discrete
from gym.spaces.multi_discrete import MultiDiscrete
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"-s",
"--scenario",
help="set the scenario.",
type=str
)
args = parser.parse_args()
env_id = args.scenario
os.makedirs('./models/{}'.format(env_id), exist_ok=True)
# * the size of replay buffer must be appropriate
test = maddpg(
env_id=env_id,
episode=10000,
learning_rate=1e-3,
gamma=0.97,
capacity=10000,
batch_size=128,
value_iter=1,
policy_iter=1,
rho=0.99,
render=False,
episode_len=45,
train_freq=5,
entropy_weight=0.0001,
model_path=False
)
test.run()