Skip to content

Commit

Permalink
Fix displayed ent coef value
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Aug 25, 2024
1 parent 94ce7c6 commit 5d206af
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
5 changes: 4 additions & 1 deletion sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.policy.actor_state,
self.ent_coef_state,
self.key,
(actor_loss_value, qf_loss_value, ent_coef_value),
(actor_loss_value, qf_loss_value, ent_coef_loss),
) = self._train(
self.gamma,
self.target_entropy,
Expand All @@ -224,11 +224,14 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.ent_coef_state,
self.key,
)
ent_coef_value = self.ent_coef_state.apply_fn({"params": self.ent_coef_state.params})
self._n_updates += gradient_steps
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/actor_loss", actor_loss_value.item())
self.logger.record("train/critic_loss", qf_loss_value.item())
self.logger.record("train/ent_coef", ent_coef_value.item())
if isinstance(self.ent_coef, EntropyCoef):
self.logger.record("train/ent_coef_loss", ent_coef_loss.item())

@staticmethod
@jax.jit
Expand Down
6 changes: 5 additions & 1 deletion sbx/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.policy.actor_state,
self.ent_coef_state,
self.key,
(actor_loss_value, qf_loss_value, ent_coef_value),
(actor_loss_value, qf_loss_value, ent_coef_loss),
) = self._train(
self.gamma,
self.tau,
Expand All @@ -227,11 +227,15 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.ent_coef_state,
self.key,
)
ent_coef_value = self.ent_coef_state.apply_fn({"params": self.ent_coef_state.params})

self._n_updates += gradient_steps
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/actor_loss", actor_loss_value.item())
self.logger.record("train/critic_loss", qf_loss_value.item())
self.logger.record("train/ent_coef", ent_coef_value.item())
if isinstance(self.ent_coef, EntropyCoef):
self.logger.record("train/ent_coef_loss", ent_coef_loss.item())

@staticmethod
@jax.jit
Expand Down
6 changes: 5 additions & 1 deletion sbx/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.policy.actor_state,
self.ent_coef_state,
self.key,
(qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_value),
(qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_loss),
) = self._train(
self.gamma,
self.tau,
Expand All @@ -232,11 +232,15 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.ent_coef_state,
self.key,
)
ent_coef_value = self.ent_coef_state.apply_fn({"params": self.ent_coef_state.params})

self._n_updates += gradient_steps
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/actor_loss", actor_loss_value.item())
self.logger.record("train/critic_loss", qf1_loss_value.item())
self.logger.record("train/ent_coef", ent_coef_value.item())
if isinstance(self.ent_coef, EntropyCoef):
self.logger.record("train/ent_coef_loss", ent_coef_loss.item())

@staticmethod
@partial(jax.jit, static_argnames=["n_target_quantiles"])
Expand Down

0 comments on commit 5d206af

Please sign in to comment.