From a6709da9447aeef524b8db25d440f413755f43cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Martin=20Kubov=C4=8D=C3=ADk?= Date: Fri, 22 Dec 2023 08:43:11 +0100 Subject: [PATCH] update --- config/dqn.yaml | 10 +++++----- rl_toolkit/networks/models/dueling.py | 15 +++++++++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/config/dqn.yaml b/config/dqn.yaml index 66911f8..2fd2025 100644 --- a/config/dqn.yaml +++ b/config/dqn.yaml @@ -8,17 +8,17 @@ Server: # Agent process Agent: temp_init: 0.5 - temp_min: 0.05 - temp_decay: 0.99999 + temp_min: 0.01 + temp_decay: 0.999999 warmup_steps: 1000 # Learner process Learner: train_steps: 1000000 batch_size: 256 - warmup_steps: 500 + warmup_steps: 1000 # for learning rate scheduler gamma: 0.99 - tau: 0.01 + tau: 0.005 # Model definition Model: @@ -31,7 +31,7 @@ Model: learning_rate: !!float 3e-4 global_clipnorm: 1.0 weight_decay: !!float 1e-4 - frame_stack: 16 + frame_stack: 16 # 12 # Paths save_path: "./save/model" diff --git a/rl_toolkit/networks/models/dueling.py b/rl_toolkit/networks/models/dueling.py index 989dcd3..c69abfc 100644 --- a/rl_toolkit/networks/models/dueling.py +++ b/rl_toolkit/networks/models/dueling.py @@ -8,6 +8,9 @@ Layer, LayerNormalization, MultiHeadAttention, + GlobalAveragePooling1D, + GlobalMaxPooling1D, + Lambda, ) @@ -139,8 +142,12 @@ def __init__( for _ in range(num_layers) ] + # Reduce + # self.flatten = Lambda(lambda x: x[:, -1]) + # self.flatten = GlobalMaxPooling1D() + self.flatten = GlobalAveragePooling1D() + # Output - self.norm = LayerNormalization(epsilon=1e-6) self.V = Dense( 1, activation=None, @@ -158,10 +165,10 @@ def call(self, inputs, training=None): for layer in self.e_layers: x = layer(x, training=training) - x = self.norm(x, training=training) - # select last timestep for prediction a_t - x = x[:, -1] + # Reduce block + x = self.flatten(x, training=training) + # x = self.drop_out(x, training=training) # compute value & advantage V = self.V(x, training=training)