Skip to content

Commit

Permalink
Update comments
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Oct 24, 2024
1 parent 6cd924e commit 125a8ca
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion sb3_contrib/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)

# Joint forward pass of obs/next_obs and actions/next_state_actions to have only
# one forward pass with shape (n_critics, 2 * batch_size, 1).
# one forward pass.
#
# This has two reasons:
# 1. According to the paper obs/actions and next_obs/next_state_actions are differently
Expand All @@ -241,6 +241,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
self.critic.set_bn_training_mode(True)
all_q_values = th.cat(self.critic(all_obs, all_actions), dim=1)
self.critic.set_bn_training_mode(False)
# (2 * batch_size, n_critics) -> (batch_size, n_critics), (batch_size, n_critics)
current_q_values, next_q_values = th.split(all_q_values, batch_size, dim=0)
# (batch_size, n_critics) -> (n_critics, batch_size, 1)
current_q_values = current_q_values.T[..., None]
Expand Down

0 comments on commit 125a8ca

Please sign in to comment.