From 5c81398ef854dde4eeaed51e0715c5de18a9d344 Mon Sep 17 00:00:00 2001 From: Corentin <111868204+corentinlger@users.noreply.github.com> Date: Sat, 6 Jul 2024 17:50:25 +0200 Subject: [PATCH] Update QR-DQN optimizer to only use q_net parameters (#252) * Updated QR-DQN optimizer input to only include quantile_net parameters * Fix QR-DQN paper link in docs and update changelog --- docs/misc/changelog.rst | 4 +++- docs/modules/qrdqn.rst | 2 +- sb3_contrib/qrdqn/policies.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 60e36123..63776e6e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -16,6 +16,8 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger) +- Updated QR-DQN paper link in docs (@corentinlger) Deprecations: ^^^^^^^^^^^^^ @@ -580,4 +582,4 @@ Contributors: ------------- @ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec -@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl +@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl @corentinlger diff --git a/docs/modules/qrdqn.rst b/docs/modules/qrdqn.rst index 2a913e14..efafcf8b 100644 --- a/docs/modules/qrdqn.rst +++ b/docs/modules/qrdqn.rst @@ -24,7 +24,7 @@ instead of predicting the mean return (DQN). Notes ----- -- Original paper: https://arxiv.org/abs/1710.100442 +- Original paper: https://arxiv.org/abs/1710.10044 - Distributional RL (C51): https://arxiv.org/abs/1707.06887 - Further reference: https://github.com/amy12xx/ml_notes_and_reports/blob/master/distributional_rl/QRDQN.pdf diff --git a/sb3_contrib/qrdqn/policies.py b/sb3_contrib/qrdqn/policies.py index 07752cdc..317396ac 100644 --- a/sb3_contrib/qrdqn/policies.py +++ b/sb3_contrib/qrdqn/policies.py @@ -171,7 +171,7 @@ def _build(self, lr_schedule: Schedule) -> None: # Setup optimizer with initial learning rate self.optimizer = self.optimizer_class( # type: ignore[call-arg] - self.parameters(), + self.quantile_net.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs, )