Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] Update scripts with compile #2449

Open
wants to merge 1 commit into
base: gh/vmoens/28/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def main(cfg: "DictConfig"): # noqa: F821
weight_decay=cfg.optim.weight_decay,
eps=cfg.optim.eps,
)
if cfg.loss.compile:
loss_module = torch.compile(loss_module)

# Create logger
logger = None
Expand Down
2 changes: 2 additions & 0 deletions sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def main(cfg: "DictConfig"): # noqa: F821
entropy_coef=cfg.loss.entropy_coef,
critic_coef=cfg.loss.critic_coef,
)
if cfg.loss.compile:
loss_module = torch.compile(loss_module)

# Create optimizers
actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr)
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/a2c/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ loss:
critic_coef: 0.25
entropy_coef: 0.01
loss_critic_type: l2
compile: True
1 change: 1 addition & 0 deletions sota-implementations/a2c/config_mujoco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ loss:
critic_coef: 0.25
entropy_coef: 0.0
loss_critic_type: l2
compile: True
82 changes: 43 additions & 39 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import numpy as np
import torch
import tqdm
from tensordict import TensorDict

from torchrl._utils import logger as torchrl_logger
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -81,66 +83,68 @@ def main(cfg: "DictConfig"): # noqa: F821
alpha_prime_optim,
) = make_continuous_cql_optimizer(cfg, loss_module)

pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)

gradient_steps = cfg.optim.gradient_steps
policy_eval_start = cfg.optim.policy_eval_start
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps

# Training loop
start_time = time.time()
for i in range(gradient_steps):
pbar.update(1)
# sample data
data = replay_buffer.sample()
# compute loss
loss_vals = loss_module(data.clone().to(device))
def update(data, i):
critic_optim.zero_grad()
q_loss, metadata = loss_module.q_loss(data)
cql_loss, cql_metadata = loss_module.cql_loss(data)
q_loss = q_loss + cql_loss
q_loss.backward()
critic_optim.step()
metadata.update(cql_metadata)

# official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks
policy_optim.zero_grad()
if i >= policy_eval_start:
actor_loss = loss_vals["loss_actor"]
actor_loss, actor_metadata = loss_module.actor_loss(data)
else:
actor_loss = loss_vals["loss_actor_bc"]
q_loss = loss_vals["loss_qvalue"]
cql_loss = loss_vals["loss_cql"]

q_loss = q_loss + cql_loss

# update model
alpha_loss = loss_vals["loss_alpha"]
alpha_prime_loss = loss_vals["loss_alpha_prime"]
actor_loss, actor_metadata = loss_module.actor_bc_loss(data)
actor_loss.backward()
policy_optim.step()
metadata.update(actor_metadata)

alpha_optim.zero_grad()
alpha_loss, alpha_metadata = loss_module.alpha_loss(actor_metadata)
alpha_loss.backward()
alpha_optim.step()

policy_optim.zero_grad()
actor_loss.backward()
policy_optim.step()
metadata.update(alpha_metadata)

if alpha_prime_optim is not None:
alpha_prime_optim.zero_grad()
alpha_prime_loss.backward(retain_graph=True)
alpha_prime_loss, alpha_prime_metadata = loss_module.alpha_prime_loss(data)
alpha_prime_loss.backward()
alpha_prime_optim.step()
metadata.update(alpha_prime_metadata)

critic_optim.zero_grad()
# TODO: we have the option to compute losses independently retain is not needed?
q_loss.backward(retain_graph=False)
critic_optim.step()
loss_vals = TensorDict(metadata)
loss_vals["loss_qvalue"] = q_loss
loss_vals["loss_cql"] = cql_loss
loss_vals["loss_alpha"] = alpha_loss
loss = actor_loss + q_loss + alpha_loss
if alpha_prime_optim is not None:
loss_vals["loss_alpha_prime"] = alpha_prime_loss
loss = loss + alpha_prime_loss
loss_vals["loss"] = loss

return loss_vals.detach()

if cfg.loss.compile:
update = torch.compile(update, mode=cfg.loss.compile_mode)

loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss
# Training loop
start_time = time.time()
pbar = tqdm.tqdm(range(gradient_steps))
for i in pbar:
# sample data
data = replay_buffer.sample().to(device)
loss_vals = update(data, i)

# log metrics
to_log = {
"loss": loss.item(),
"loss_actor_bc": loss_vals["loss_actor_bc"].item(),
"loss_actor": loss_vals["loss_actor"].item(),
"loss_qvalue": q_loss.item(),
"loss_cql": cql_loss.item(),
"loss_alpha": alpha_loss.item(),
"loss_alpha_prime": alpha_prime_loss.item(),
}
to_log = loss_vals.mean().to_dict()

# update qnet_target params
target_net_updater.step()
Expand Down
123 changes: 82 additions & 41 deletions sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import torch
import tqdm
from tensordict import TensorDict
from torchrl._utils import logger as torchrl_logger
from tensordict.nn import CudaGraphModule

from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.record.loggers import generate_exp_name, get_logger

Expand Down Expand Up @@ -111,17 +113,77 @@ def main(cfg: "DictConfig"): # noqa: F821
evaluation_interval = cfg.logger.log_interval
eval_rollout_steps = cfg.logger.eval_steps

def update(sampled_tensordict):

critic_optim.zero_grad()
q_loss, metadata = loss_module.q_loss(sampled_tensordict)
cql_loss, metadata_cql = loss_module.cql_loss(sampled_tensordict)
metadata.update(metadata)
q_loss = q_loss + cql_loss
q_loss.backward()
critic_optim.step()

if loss_module.with_lagrange:
alpha_prime_optim.zero_grad()
alpha_prime_loss, metadata_aprime = loss_module.alpha_prime_loss(
sampled_tensordict
)
metadata.update(metadata_aprime)
alpha_prime_loss.backward()
alpha_prime_optim.step()

policy_optim.zero_grad()
# loss_actor_bc, _ = loss_module.actor_bc_loss(sampled_tensordict)
actor_loss, actor_metadata = loss_module.actor_loss(sampled_tensordict)
metadata.update(actor_metadata)
actor_loss.backward()
policy_optim.step()

alpha_optim.zero_grad()
alpha_loss, metadata_actor = loss_module.alpha_loss(actor_metadata)
metadata.update(metadata_actor)
alpha_loss.backward()
alpha_optim.step()
loss_td = TensorDict(metadata)

loss_td["loss_actor"] = actor_loss
loss_td["loss_qvalue"] = q_loss
loss_td["loss_cql"] = cql_loss
loss_td["loss_alpha"] = alpha_loss
if alpha_prime_optim:
alpha_prime_loss = loss_td["loss_alpha_prime"]

loss = actor_loss + alpha_loss + q_loss
if alpha_prime_optim is not None:
loss = loss + alpha_prime_loss

loss_td["loss"] = loss
return loss_td.detach()

if cfg.loss.compile:
update = torch.compile(update, mode=cfg.loss.compile_mode)

if cfg.loss.cudagraphs:
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)

sampling_start = time.time()
for i, tensordict in enumerate(collector):
collector_iter = iter(collector)
for i in range(cfg.collector.total_frames):
timeit.print()
timeit.erase()
with timeit("collection"):
tensordict = next(collector_iter)
sampling_time = time.time() - sampling_start
pbar.update(tensordict.numel())
# update weights of the inference policy
collector.update_policy_weights_()
with timeit("update policies"):
# update weights of the inference policy
collector.update_policy_weights_()

tensordict = tensordict.view(-1)
tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# add to replay buffer
replay_buffer.extend(tensordict.cpu())
with timeit("extend"):
replay_buffer.extend(tensordict.cpu())
collected_frames += current_frames

# optimization steps
Expand All @@ -130,44 +192,22 @@ def main(cfg: "DictConfig"): # noqa: F821
log_loss_td = TensorDict({}, [num_updates])
for j in range(num_updates):
# sample from replay buffer
sampled_tensordict = replay_buffer.sample()
with timeit("sample"):
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict.clone()

loss_td = loss_module(sampled_tensordict)

actor_loss = loss_td["loss_actor"]
q_loss = loss_td["loss_qvalue"]
cql_loss = loss_td["loss_cql"]
q_loss = q_loss + cql_loss
alpha_loss = loss_td["loss_alpha"]
alpha_prime_loss = loss_td["loss_alpha_prime"]

alpha_optim.zero_grad()
alpha_loss.backward()
alpha_optim.step()

policy_optim.zero_grad()
actor_loss.backward()
policy_optim.step()

if alpha_prime_optim is not None:
alpha_prime_optim.zero_grad()
alpha_prime_loss.backward(retain_graph=True)
alpha_prime_optim.step()

critic_optim.zero_grad()
q_loss.backward(retain_graph=False)
critic_optim.step()

log_loss_td[j] = loss_td.detach()
with timeit("update"):
loss_td = update(sampled_tensordict)
log_loss_td[j] = loss_td

# update qnet_target params
target_net_updater.step()
with timeit("target net"):
# update qnet_target params
target_net_updater.step()

# update priority
if prb:
Expand All @@ -191,10 +231,11 @@ def main(cfg: "DictConfig"): # noqa: F821
metrics_to_log["train/loss_actor"] = log_loss_td.get("loss_actor").mean()
metrics_to_log["train/loss_qvalue"] = log_loss_td.get("loss_qvalue").mean()
metrics_to_log["train/loss_alpha"] = log_loss_td.get("loss_alpha").mean()
metrics_to_log["train/loss_alpha_prime"] = log_loss_td.get(
"loss_alpha_prime"
).mean()
metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean()
if alpha_prime_optim is not None:
metrics_to_log["train/loss_alpha_prime"] = log_loss_td.get(
"loss_alpha_prime"
).mean()
# metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean()
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time

Expand All @@ -204,7 +245,7 @@ def main(cfg: "DictConfig"): # noqa: F821
cur_test_frame = (i * frames_per_batch) // evaluation_interval
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(), timeit("eval"):
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/cql/discrete_cql_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ loss:
loss_function: l2
gamma: 0.99
tau: 0.005
compile: True
4 changes: 3 additions & 1 deletion sota-implementations/cql/offline_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,7 @@ loss:
max_q_backup: False
deterministic_backup: False
num_random: 10
with_lagrange: True
with_lagrange: False
lagrange_thresh: 5.0 # tau
compile: False
compile_mode: reduce-overhead
5 changes: 4 additions & 1 deletion sota-implementations/cql/online_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,8 @@ loss:
max_q_backup: False
deterministic_backup: False
num_random: 10
with_lagrange: True
with_lagrange: False
lagrange_thresh: 10.0
compile: False
compile_mode: reduce-overhead
cudagraphs: False
14 changes: 10 additions & 4 deletions sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,17 +202,21 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"):
# We use a ProbabilisticActor to make sure that we map the
# network output to the right space using a TanhDelta
# distribution.
high = action_spec.space.high
low = action_spec.space.low
if train_env.batch_size:
high = high[(0,) * len(train_env.batch_size)]
low = low[(0,) * len(train_env.batch_size)]
actor = ProbabilisticActor(
module=actor_module,
in_keys=["loc", "scale"],
spec=action_spec,
distribution_class=TanhNormal,
distribution_kwargs={
"low": action_spec.space.low[len(train_env.batch_size) :],
"high": action_spec.space.high[
len(train_env.batch_size) :
], # remove batch-size
"low": low.to(device),
"high": high.to(device),
"tanh_loc": False,
"safe_tanh": not cfg.loss.compile,
},
default_interaction_type=ExplorationType.RANDOM,
)
Expand Down Expand Up @@ -334,6 +338,8 @@ def make_discrete_loss(loss_cfg, model):
)
loss_module.make_value_estimator(gamma=loss_cfg.gamma)
target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)
if loss_cfg.compile:
loss_module = torch.compile(loss_module)

return loss_module, target_net_updater

Expand Down
Loading
Loading