Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 17, 2024
1 parent 7c044b3 commit 7444057
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
4 changes: 3 additions & 1 deletion sota-implementations/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def main(cfg: "DictConfig"): # noqa: F821
loss_module = make_dt_loss(cfg.loss, actor, device=model_device)

# Create optimizer
transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module)
transformer_optim, scheduler = make_dt_optimizer(
cfg.optim, loss_module, model_device
)

# Create inference policy
inference_policy = DecisionTransformerInferenceWrapper(
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,10 +511,10 @@ def make_odt_optimizer(optim_cfg, loss_module):
return dt_optimizer, log_temp_optimizer, scheduler


def make_dt_optimizer(optim_cfg, loss_module):
def make_dt_optimizer(optim_cfg, loss_module, device):
dt_optimizer = torch.optim.Adam(
loss_module.actor_network_params.flatten_keys().values(),
lr=torch.as_tensor(optim_cfg.lr),
lr=torch.tensor(optim_cfg.lr, device=device),
weight_decay=optim_cfg.weight_decay,
eps=1.0e-8,
)
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/dqn/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ loss:

compile:
compile: False
compile_mode:
compile_mode: default
cudagraphs: False

0 comments on commit 7444057

Please sign in to comment.