-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_td3_per.py
169 lines (151 loc) · 6.73 KB
/
train_td3_per.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import torch
from agents.ddpg_agent import DDPGAgent
from agents.policies.td3_policy import TD3Policy
from agents.models.components.noise import GaussianProcess
from agents.memory.prioritized_memory import PrioritizedMemory
from agents.models.components.mlp import MLP
from agents.models.td3 import TD3Critic
from tasks.reacher.solutions.utils import get_simulator, STATE_SIZE, ACTION_SIZE, BRAIN_NAME
from tasks.reacher.solutions.ddpg import SOLUTIONS_CHECKPOINT_DIR
from agents.models.components.critics import Critic
from tools.lr_schedulers import DummyLRScheduler
from tools.parameter_scheduler import ParameterScheduler
import pickle
from tools.rl_constants import BrainSet, Brain, Experience, Action
from tools.rl_constants import RandomBrainAction
from tools.parameter_scheduler import LinearDecaySchedule
from tools.layer_initializations import init_layer_within_range, init_layer_inverse_root_fan_in
from torch import nn
import numpy as np
import torch.nn.functional as F
NUM_AGENTS = 20
NUM_EPISODES = 200
SEED = 0
BATCH_SIZE = 512
REPLAY_BUFFER_SIZE = int(1e6)
GAMMA = 0.99 # discount factor
TAU = 5e-3 # for soft update of target parameters
N_LEARNING_ITERATIONS = 10 # number of learning updates
UPDATE_FREQUENCY = 20 # every n time step do update
MAX_T = 1000
CRITIC_WEIGHT_DECAY = 0.0 # 1e-2
ACTOR_WEIGHT_DECAY = 0.0
LR_ACTOR = 1e-4 # learning rate of the actor
LR_CRITIC = 1e-4 # learning rate of the critic
POLICY_UPDATE_FREQUENCY = 2
WARMUP_STEPS = int(5e3)
MIN_PRIORITY = 1e-3
DROPOUT = None
BATCHNORM = False
SOLVE_SCORE = 30
SAVE_TAG = 'per_td3'
ACTOR_CHECKPOINT_PATH = os.path.join(SOLUTIONS_CHECKPOINT_DIR, f'{SAVE_TAG}_actor_checkpoint.pth')
CRITIC_CHECKPOINT_PATH = os.path.join(SOLUTIONS_CHECKPOINT_DIR, f'{SAVE_TAG}_critic_checkpoint.pth')
TRAINING_SCORES_PLOT_SAVE_PATH = os.path.join(SOLUTIONS_CHECKPOINT_DIR, f'{SAVE_TAG}_training_scores.png')
TRAINING_SCORES_SAVE_PATH = os.path.join(SOLUTIONS_CHECKPOINT_DIR, f'{SAVE_TAG}_training_scores.pkl')
def get_agent(memory_):
return DDPGAgent(
state_shape=STATE_SIZE,
action_size=ACTION_SIZE,
random_seed=SEED,
memory_factory=lambda: memory_,
actor_model_factory=lambda: MLP(
layer_sizes=(STATE_SIZE, 256, 128, ACTION_SIZE),
seed=SEED, with_batchnorm=BATCHNORM, dropout=DROPOUT,
output_function=torch.nn.Tanh(),
output_layer_initialization_fn=init_layer_within_range,
activation_function=torch.nn.LeakyReLU()
),
critic_model_factory=lambda: TD3Critic(
critic_model_factory=lambda: Critic(
state_featurizer=MLP(
layer_sizes=(STATE_SIZE, 256),
dropout=DROPOUT,
with_batchnorm=BATCHNORM,
output_function=torch.nn.LeakyReLU(),
),
output_module=MLP(
layer_sizes=(256 + ACTION_SIZE, 128, 1),
dropout=DROPOUT,
with_batchnorm=BATCHNORM,
activation_function=torch.nn.LeakyReLU(),
output_layer_initialization_fn=init_layer_within_range,
),
seed=SEED,
),
seed=SEED
),
actor_optimizer_factory=lambda params: torch.optim.Adam(params, lr=LR_ACTOR, weight_decay=ACTOR_WEIGHT_DECAY),
critic_optimizer_factory=lambda params: torch.optim.Adam(params, lr=LR_CRITIC, weight_decay=CRITIC_WEIGHT_DECAY),
critic_optimizer_scheduler=lambda x: DummyLRScheduler(x),
actor_optimizer_scheduler=lambda x: DummyLRScheduler(x),
policy_factory=lambda: TD3Policy(
action_dim=ACTION_SIZE,
noise=GaussianProcess(std_fn=LinearDecaySchedule(start=0.3, end=0, steps=NUM_EPISODES)),
seed=SEED,
random_brain_action_factory=lambda: RandomBrainAction(
ACTION_SIZE,
NUM_AGENTS,
continuous_actions=True,
continuous_action_range=(-1, 1),
)
),
update_frequency=UPDATE_FREQUENCY,
n_learning_iterations=N_LEARNING_ITERATIONS,
batch_size=BATCH_SIZE,
gamma=GAMMA,
tau=TAU,
policy_update_frequency=POLICY_UPDATE_FREQUENCY,
shared_agent_brain=True
)
def get_solution_brain_set():
memory = PrioritizedMemory(
capacity=REPLAY_BUFFER_SIZE,
state_shape=(1, STATE_SIZE),
# Anneal alpha linearly
alpha_scheduler=ParameterScheduler(initial=0.6, lambda_fn=lambda i: 0.6 - 0.6 * i / NUM_EPISODES, final=0.),
# Anneal beta linearly
beta_scheduler=ParameterScheduler(initial=0.4, final=1,
lambda_fn=lambda i: 0.4 + 0.6 * i / NUM_EPISODES), # Anneal beta linearly
seed=SEED,
continuous_actions=True,
min_priority=MIN_PRIORITY
)
reacher_brain = Brain(
brain_name=BRAIN_NAME,
action_size=ACTION_SIZE,
state_shape=STATE_SIZE,
observation_type='vector',
agents=[get_agent(memory)],
)
brain_set = BrainSet(brains=[reacher_brain])
return brain_set
def step_agents_fn(brain_set: BrainSet, next_brain_environment: dict, t: int):
for brain_name, brain_environment in next_brain_environment.items():
agent = brain_set[brain_name].agents[0]
for i in range(NUM_AGENTS):
action = brain_environment['actions'][0].value[i]
action = action[np.newaxis, ...]
brain_agent_experience = Experience(
state=brain_environment['states'][i].unsqueeze(0),
action=Action(value=action),
reward=brain_environment['rewards'][i],
next_state=brain_environment['next_states'][i].unsqueeze(0),
done=brain_environment['dones'][i],
t_step=t,
)
agent.step(brain_agent_experience)
if __name__ == "__main__":
simulator = get_simulator()
brain_set = get_solution_brain_set()
simulator.warmup(brain_set, int(WARMUP_STEPS / MAX_T), max_t=MAX_T, step_agents_fn=step_agents_fn)
agents, training_scores, i_episode, training_time = simulator.train(brain_set, n_episodes=NUM_EPISODES, max_t=MAX_T, solved_score=SOLVE_SCORE, step_agents_fn=step_agents_fn)
if training_scores.get_mean_sliding_scores() > SOLVE_SCORE:
brain = brain_set[BRAIN_NAME]
trained_agent = brain.agents[0]
torch.save(trained_agent.online_actor.state_dict(), ACTOR_CHECKPOINT_PATH)
torch.save(trained_agent.online_critic.state_dict(), CRITIC_CHECKPOINT_PATH)
training_scores.save_scores_plot(TRAINING_SCORES_PLOT_SAVE_PATH)
with open(TRAINING_SCORES_SAVE_PATH, 'wb') as f:
pickle.dump(training_scores, f)