diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index 94e2bbc..4a1ffeb 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -211,7 +211,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.policy.actor_state, self.ent_coef_state, self.key, - (actor_loss_value, qf_loss_value, ent_coef_value), + (actor_loss_value, qf_loss_value, ent_coef_loss), ) = self._train( self.gamma, self.target_entropy, @@ -224,11 +224,14 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.ent_coef_state, self.key, ) + ent_coef_value = self.ent_coef_state.apply_fn({"params": self.ent_coef_state.params}) self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/actor_loss", actor_loss_value.item()) self.logger.record("train/critic_loss", qf_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) + if isinstance(self.ent_coef, EntropyCoef): + self.logger.record("train/ent_coef_loss", ent_coef_loss.item()) @staticmethod @jax.jit diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index 11f8ff5..f566aea 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -213,7 +213,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.policy.actor_state, self.ent_coef_state, self.key, - (actor_loss_value, qf_loss_value, ent_coef_value), + (actor_loss_value, qf_loss_value, ent_coef_loss), ) = self._train( self.gamma, self.tau, @@ -227,11 +227,15 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.ent_coef_state, self.key, ) + ent_coef_value = self.ent_coef_state.apply_fn({"params": self.ent_coef_state.params}) + self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/actor_loss", actor_loss_value.item()) self.logger.record("train/critic_loss", qf_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) + if isinstance(self.ent_coef, EntropyCoef): + self.logger.record("train/ent_coef_loss", ent_coef_loss.item()) @staticmethod @jax.jit diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index 5161f4d..b65c0c5 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -216,7 +216,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.policy.actor_state, self.ent_coef_state, self.key, - (qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_value), + (qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_loss), ) = self._train( self.gamma, self.tau, @@ -232,11 +232,15 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.ent_coef_state, self.key, ) + ent_coef_value = self.ent_coef_state.apply_fn({"params": self.ent_coef_state.params}) + self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/actor_loss", actor_loss_value.item()) self.logger.record("train/critic_loss", qf1_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) + if isinstance(self.ent_coef, EntropyCoef): + self.logger.record("train/ent_coef_loss", ent_coef_loss.item()) @staticmethod @partial(jax.jit, static_argnames=["n_target_quantiles"])