From 7444057268ca59bee5395ebec1aeeb85cc42467c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 17 Dec 2024 16:42:09 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- sota-implementations/decision_transformer/dt.py | 4 +++- sota-implementations/decision_transformer/utils.py | 4 ++-- sota-implementations/dqn/config_atari.yaml | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index 8a9eb0c0985..9e8446ed82f 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -76,7 +76,9 @@ def main(cfg: "DictConfig"): # noqa: F821 loss_module = make_dt_loss(cfg.loss, actor, device=model_device) # Create optimizer - transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module) + transformer_optim, scheduler = make_dt_optimizer( + cfg.optim, loss_module, model_device + ) # Create inference policy inference_policy = DecisionTransformerInferenceWrapper( diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 5f14734addd..d4a67e7d3a9 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -511,10 +511,10 @@ def make_odt_optimizer(optim_cfg, loss_module): return dt_optimizer, log_temp_optimizer, scheduler -def make_dt_optimizer(optim_cfg, loss_module): +def make_dt_optimizer(optim_cfg, loss_module, device): dt_optimizer = torch.optim.Adam( loss_module.actor_network_params.flatten_keys().values(), - lr=torch.as_tensor(optim_cfg.lr), + lr=torch.tensor(optim_cfg.lr, device=device), weight_decay=optim_cfg.weight_decay, eps=1.0e-8, ) diff --git a/sota-implementations/dqn/config_atari.yaml b/sota-implementations/dqn/config_atari.yaml index bcbada5dc36..85d513fbb2c 100644 --- a/sota-implementations/dqn/config_atari.yaml +++ b/sota-implementations/dqn/config_atari.yaml @@ -42,5 +42,5 @@ loss: compile: compile: False - compile_mode: + compile_mode: default cudagraphs: False