Skip to content

Commit

Permalink
Support for setting the target entropy in CrossQ and SAC
Browse files Browse the repository at this point in the history
  • Loading branch information
jan1854 committed Apr 5, 2024
1 parent e64dca3 commit 9953063
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
13 changes: 10 additions & 3 deletions sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
ent_coef: Union[str, float] = "auto",
target_entropy: Union[str, float] = "auto",
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(

self.policy_delay = policy_delay
self.ent_coef_init = ent_coef
self.target_entropy = target_entropy

if _init_setup_model:
self._setup_model()
Expand Down Expand Up @@ -155,8 +157,14 @@ def _setup_model(self) -> None:
),
)

# automatically set target entropy if needed
self.target_entropy = -np.prod(self.action_space.shape).astype(np.float32)
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
# automatically set target entropy if needed
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) # type: ignore
else:
# Force conversion
# this will also throw an error for unexpected string
self.target_entropy = float(self.target_entropy)

def learn(
self,
Expand Down Expand Up @@ -251,7 +259,6 @@ def update_critic(
def mse_loss(
params: flax.core.FrozenDict, batch_stats: flax.core.FrozenDict, dropout_key: flax.core.FrozenDict
) -> Tuple[jax.Array, jax.Array]:

# Joint forward pass of obs/next_obs and actions/next_state_actions to have only
# one forward pass with shape (n_critics, 2 * batch_size, 1).
#
Expand Down
12 changes: 10 additions & 2 deletions sbx/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
ent_coef: Union[str, float] = "auto",
target_entropy: Union[str, float] = "auto",
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(

self.policy_delay = policy_delay
self.ent_coef_init = ent_coef
self.target_entropy = target_entropy

if _init_setup_model:
self._setup_model()
Expand Down Expand Up @@ -157,8 +159,14 @@ def _setup_model(self) -> None:
),
)

# automatically set target entropy if needed
self.target_entropy = -np.prod(self.action_space.shape).astype(np.float32)
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
# automatically set target entropy if needed
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) # type: ignore
else:
# Force conversion
# this will also throw an error for unexpected string
self.target_entropy = float(self.target_entropy)

def learn(
self,
Expand Down

0 comments on commit 9953063

Please sign in to comment.