From 595f89de5309b260d97db33858350a5e10ddd3ee Mon Sep 17 00:00:00 2001 From: naumix Date: Thu, 31 Oct 2024 14:42:24 +0100 Subject: [PATCH 01/21] add BRO --- sbx/bro/__init__.py | 3 + sbx/bro/bro.py | 528 ++++++++++++++++++++++++++++++++++++++++++++ sbx/bro/policies.py | 259 ++++++++++++++++++++++ 3 files changed, 790 insertions(+) create mode 100644 sbx/bro/__init__.py create mode 100644 sbx/bro/bro.py create mode 100644 sbx/bro/policies.py diff --git a/sbx/bro/__init__.py b/sbx/bro/__init__.py new file mode 100644 index 0000000..f1be1c2 --- /dev/null +++ b/sbx/bro/__init__.py @@ -0,0 +1,3 @@ +from sbx.bro.bro import BRO + +__all__ = ["BRO"] diff --git a/sbx/bro/bro.py b/sbx/bro/bro.py new file mode 100644 index 0000000..37b4e3b --- /dev/null +++ b/sbx/bro/bro.py @@ -0,0 +1,528 @@ +from functools import partial +from typing import Any, ClassVar, Dict, Literal, Optional, Tuple, Type, Union + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import optax +from flax.training.train_state import TrainState +from gymnasium import spaces +from jax.typing import ArrayLike +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.noise import ActionNoise +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule + +from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax +from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState +from sbx.bro.policies import BROPolicy + + +class EntropyCoef(nn.Module): + ent_coef_init: float = 1.0 + + @nn.compact + def __call__(self) -> jnp.ndarray: + log_ent_coef = self.param("log_ent_coef", init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init))) + return jnp.exp(log_ent_coef) + + +class ConstantEntropyCoef(nn.Module): + ent_coef_init: float = 1.0 + + @nn.compact + def __call__(self) -> float: + # Hack to not optimize the entropy coefficient while not having to use if/else for the jit + # TODO: add parameter in train to remove that hack + self.param("dummy_param", init_fn=lambda key: jnp.full((), self.ent_coef_init)) + return self.ent_coef_init + + +class BRO(OffPolicyAlgorithmJax): + policy_aliases: ClassVar[Dict[str, Type[BROPolicy]]] = { # type: ignore[assignment] + "MlpPolicy": BROPolicy, + # Minimal dict support using flatten() + "MultiInputPolicy": BROPolicy, + } + + policy: BROPolicy + action_space: spaces.Box # type: ignore[assignment] + + def __init__( + self, + policy, + env: Union[GymEnv, str], + #BRO + n_quantiles: int = 100, + learning_rate: Union[float, Schedule] = 3e-4, + qf_learning_rate: Optional[float] = None, + buffer_size: int = 1_000_000, # 1e6 + learning_starts: int = 100, + batch_size: int = 256, + tau: float = 0.005, + gamma: float = 0.99, + train_freq: Union[int, Tuple[int, str]] = 1, + gradient_steps: int = 2, + policy_delay: int = 1, + action_noise: Optional[ActionNoise] = None, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + ent_coef: Union[str, float] = "auto", + target_entropy: Union[Literal["auto"], float] = "auto", + use_sde: bool = False, + sde_sample_freq: int = -1, + use_sde_at_warmup: bool = False, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: str = "auto", + _init_setup_model: bool = True, + ) -> None: + super().__init__( + policy=policy, + env=env, + learning_rate=learning_rate, + qf_learning_rate=qf_learning_rate, + buffer_size=buffer_size, + learning_starts=learning_starts, + batch_size=batch_size, + tau=tau, + gamma=gamma, + train_freq=train_freq, + gradient_steps=gradient_steps, + action_noise=action_noise, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + use_sde_at_warmup=use_sde_at_warmup, + policy_kwargs=policy_kwargs, + tensorboard_log=tensorboard_log, + verbose=verbose, + seed=seed, + supported_action_spaces=(spaces.Box,), + support_multi_env=True, + ) + + self.policy_delay = policy_delay + self.ent_coef_init = ent_coef + self.target_entropy = target_entropy + + self.n_quantiles = n_quantiles + taus_ = jnp.arange(0, n_quantiles+1) / n_quantiles + self.quantile_taus = ((taus_[1:] + taus_[:-1]) / 2.0)[None, ..., None] + + self.distributional = True if self.n_quantiles > 1 else False + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + super()._setup_model() + + if not hasattr(self, "policy") or self.policy is None: + self.policy = self.policy_class( # type: ignore[assignment] + self.observation_space, + self.action_space, + self.lr_schedule, + self.n_quantiles, + **self.policy_kwargs, + ) + + assert isinstance(self.qf_learning_rate, float) + + self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) + + self.key, ent_key = jax.random.split(self.key, 2) + + self.actor = self.policy.actor # type: ignore[assignment] + self.qf = self.policy.qf # type: ignore[assignment] + + # The entropy coefficient or entropy can be learned automatically + # see Automating Entropy Adjustment for Maximum Entropy RL section + # of https://arxiv.org/abs/1812.05905 + if isinstance(self.ent_coef_init, str) and self.ent_coef_init.startswith("auto"): + # Default initial value of ent_coef when learned + ent_coef_init = 1.0 + if "_" in self.ent_coef_init: + ent_coef_init = float(self.ent_coef_init.split("_")[1]) + assert ent_coef_init > 0.0, "The initial value of ent_coef must be greater than 0" + + # Note: we optimize the log of the entropy coeff which is slightly different from the paper + # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 + self.ent_coef = EntropyCoef(ent_coef_init) + else: + # This will throw an error if a malformed string (different from 'auto') is passed + assert isinstance( + self.ent_coef_init, float + ), f"Entropy coef must be float when not equal to 'auto', actual: {self.ent_coef_init}" + self.ent_coef = ConstantEntropyCoef(self.ent_coef_init) # type: ignore[assignment] + + self.ent_coef_state = TrainState.create( + apply_fn=self.ent_coef.apply, + params=self.ent_coef.init(ent_key)["params"], + tx=optax.adam( + learning_rate=self.learning_rate, + ), + ) + + # Target entropy is used when learning the entropy coefficient + if self.target_entropy == "auto": + # automatically set target entropy if needed + self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) / 2 # type: ignore + else: + # Force conversion + # this will also throw an error for unexpected string + self.target_entropy = float(self.target_entropy) + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + tb_log_name: str = "BRO", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ): + return super().learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + tb_log_name=tb_log_name, + reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, + ) + + def train(self, gradient_steps: int, batch_size: int) -> None: + assert self.replay_buffer is not None + # Sample all at once for efficiency (so we can jit the for loop) + data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) + + if isinstance(data.observations, dict): + keys = list(self.observation_space.keys()) # type: ignore[attr-defined] + obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1) + next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1) + else: + obs = data.observations.numpy() + next_obs = data.next_observations.numpy() + + # Convert to numpy + data = ReplayBufferSamplesNp( # type: ignore[assignment] + obs, + data.actions.numpy(), + next_obs, + data.dones.numpy().flatten(), + data.rewards.numpy().flatten(), + ) + + ( + self.policy.qf_state, + self.policy.actor_state, + self.ent_coef_state, + self.key, + (actor_loss_value, qf_loss_value, ent_coef_value), + ) = self._train( + self.gamma, + self.tau, + self.target_entropy, + gradient_steps, + data, + self.policy_delay, + (self._n_updates + 1) % self.policy_delay, + self.policy.qf_state, + self.policy.actor_state, + self.ent_coef_state, + self.quantile_taus, + self.distributional, + self.key, + ) + self._n_updates += gradient_steps + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/actor_loss", actor_loss_value.item()) + self.logger.record("train/critic_loss", qf_loss_value.item()) + self.logger.record("train/ent_coef", ent_coef_value.item()) + + @staticmethod + @jax.jit + def update_critic( + gamma: float, + actor_state: TrainState, + qf_state: RLTrainState, + ent_coef_state: TrainState, + observations: jax.Array, + actions: jax.Array, + next_observations: jax.Array, + rewards: jax.Array, + dones: jax.Array, + quantile_taus: jax.Array, + key: jax.Array, + ): + key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) + # sample action from the actor + dist = actor_state.apply_fn(actor_state.params, next_observations) + next_state_actions = dist.sample(seed=noise_key) + next_log_prob = dist.log_prob(next_state_actions) + + ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) + + qf_next_values = qf_state.apply_fn( + qf_state.target_params, + next_observations, + next_state_actions, + rngs={"dropout": dropout_key_target}, + ) + + next_q_values = jnp.mean(qf_next_values, axis=0) + # td error + entropy term + next_q_values = next_q_values - ent_coef_value * next_log_prob.reshape(-1, 1) + # shape is (batch_size, 1) + target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values + + def mse_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: + # shape is (n_critics, batch_size, 1) + current_q_values = qf_state.apply_fn(params, observations, actions, rngs={"dropout": dropout_key}) + return 0.5 * ((target_q_values - current_q_values) ** 2).mean(axis=1).sum() + + qf_loss_value, grads = jax.value_and_grad(mse_loss, has_aux=False)(qf_state.params, dropout_key_current) + qf_state = qf_state.apply_gradients(grads=grads) + + return ( + qf_state, + (qf_loss_value, ent_coef_value), + key, + ) + + @staticmethod + @jax.jit + def update_critic_quantile( + gamma: float, + actor_state: TrainState, + qf_state: RLTrainState, + ent_coef_state: TrainState, + observations: jax.Array, + actions: jax.Array, + next_observations: jax.Array, + rewards: jax.Array, + dones: jax.Array, + quantile_taus: jax.Array, + key: jax.Array, + ): + + key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) + # sample action from the actor + dist = actor_state.apply_fn(actor_state.params, next_observations) + next_state_actions = dist.sample(seed=noise_key) + next_log_prob = dist.log_prob(next_state_actions) + + ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) + + qf_next_values = qf_state.apply_fn( + qf_state.target_params, + next_observations, + next_state_actions, + rngs={"dropout": dropout_key_target}, + ) + + next_q_values = jnp.mean(qf_next_values, axis=0) + # entropy term + next_q_values = next_q_values - ent_coef_value * next_log_prob.reshape(-1, 1) + # shape is (batch_size, n_quantiles, n_quantiles) + target_q_values = rewards[..., None, None] + (1 - dones[..., None, None]) * gamma * next_q_values[:, None, :] + + def quantile_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: + # shape is (n_critics, batch_size, 1) + current_q_values = qf_state.apply_fn(params, observations, actions, rngs={"dropout": dropout_key}) + quantile_td_error = current_q_values[..., None] - target_q_values[None, ...] + def calculate_quantile_huber_loss(quantile_td_error: jnp.ndarray, quantile_taus: jnp.ndarray, kappa: float = 1.0) -> jnp.ndarray: + element_wise_huber_loss = jnp.where(jnp.absolute(quantile_td_error) <= kappa, 0.5 * quantile_td_error ** 2, kappa * (jnp.absolute(quantile_td_error) - 0.5 * kappa)) + mask = jax.lax.stop_gradient(jnp.where(quantile_td_error < 0, 1, 0)) # detach this + element_wise_quantile_huber_loss = jnp.absolute(quantile_taus - mask) * element_wise_huber_loss / kappa + quantile_huber_loss = element_wise_quantile_huber_loss.sum(axis=0).sum(axis=1).mean() + return quantile_huber_loss + quantile_huber_loss = calculate_quantile_huber_loss(quantile_td_error, quantile_taus) + return quantile_huber_loss + + qf_loss_value, grads = jax.value_and_grad(quantile_loss, has_aux=False)(qf_state.params, dropout_key_current) + qf_state = qf_state.apply_gradients(grads=grads) + + return ( + qf_state, + (qf_loss_value, ent_coef_value), + key, + ) + + @staticmethod + @jax.jit + def update_actor( + actor_state: RLTrainState, + qf_state: RLTrainState, + ent_coef_state: TrainState, + observations: jax.Array, + key: jax.Array, + ): + key, dropout_key, noise_key = jax.random.split(key, 3) + + def actor_loss(params: flax.core.FrozenDict) -> Tuple[jax.Array, jax.Array]: + dist = actor_state.apply_fn(params, observations) + actor_actions = dist.sample(seed=noise_key) + log_prob = dist.log_prob(actor_actions).reshape(-1, 1) + + qf_pi = qf_state.apply_fn( + qf_state.params, + observations, + actor_actions, + rngs={"dropout": dropout_key}, + ) + + # Take mean among all critics + qf_pi_lb = jnp.mean(qf_pi, axis=0) + qf_pi_lb = jnp.mean(qf_pi_lb, axis=-1, keepdims=True) + ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) + actor_loss = (ent_coef_value * log_prob - qf_pi_lb).mean() + return actor_loss, -log_prob.mean() + + (actor_loss_value, entropy), grads = jax.value_and_grad(actor_loss, has_aux=True)(actor_state.params) + actor_state = actor_state.apply_gradients(grads=grads) + + return actor_state, qf_state, actor_loss_value, key, entropy + + @staticmethod + @jax.jit + def soft_update(tau: float, qf_state: RLTrainState) -> RLTrainState: + qf_state = qf_state.replace(target_params=optax.incremental_update(qf_state.params, qf_state.target_params, tau)) + return qf_state + + @staticmethod + @jax.jit + def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float): + def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array: + ent_coef_value = ent_coef_state.apply_fn({"params": temp_params}) + ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr] + return ent_coef_loss + + ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params) + ent_coef_state = ent_coef_state.apply_gradients(grads=grads) + + return ent_coef_state, ent_coef_loss + + @classmethod + def update_actor_and_temperature( + cls, + actor_state: RLTrainState, + qf_state: RLTrainState, + ent_coef_state: TrainState, + observations: jax.Array, + target_entropy: ArrayLike, + key: jax.Array, + ): + (actor_state, qf_state, actor_loss_value, key, entropy) = cls.update_actor( + actor_state, + qf_state, + ent_coef_state, + observations, + key, + ) + ent_coef_state, ent_coef_loss_value = cls.update_temperature(target_entropy, ent_coef_state, entropy) + return actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key + + @classmethod + @partial(jax.jit, static_argnames=["cls", "gradient_steps", "policy_delay", "policy_delay_offset", "distributional"]) + def _train( + cls, + gamma: float, + tau: float, + target_entropy: ArrayLike, + gradient_steps: int, + data: ReplayBufferSamplesNp, + policy_delay: int, + policy_delay_offset: int, + qf_state: RLTrainState, + actor_state: TrainState, + ent_coef_state: TrainState, + quantile_taus: jax.Array, + distributional: bool, + key: jax.Array, + ): + assert data.observations.shape[0] % gradient_steps == 0 + batch_size = data.observations.shape[0] // gradient_steps + + carry = { + "actor_state": actor_state, + "qf_state": qf_state, + "ent_coef_state": ent_coef_state, + "key": key, + "info": { + "actor_loss": jnp.array(0.0), + "qf_loss": jnp.array(0.0), + "ent_coef_loss": jnp.array(0.0), + }, + } + + def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: + # Note: this method must be defined inline because + # `fori_loop` expect a signature fn(index, carry) -> carry + actor_state = carry["actor_state"] + qf_state = carry["qf_state"] + ent_coef_state = carry["ent_coef_state"] + key = carry["key"] + info = carry["info"] + batch_obs = jax.lax.dynamic_slice_in_dim(data.observations, i * batch_size, batch_size) + batch_act = jax.lax.dynamic_slice_in_dim(data.actions, i * batch_size, batch_size) + batch_next_obs = jax.lax.dynamic_slice_in_dim(data.next_observations, i * batch_size, batch_size) + batch_rew = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size) + batch_done = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size) + + (qf_state, (qf_loss_value, ent_coef_value), key) = jax.lax.cond( + distributional == True, + # If True: + cls.update_critic_quantile, + # If False: + cls.update_critic, + gamma, + actor_state, + qf_state, + ent_coef_state, + batch_obs, + batch_act, + batch_next_obs, + batch_rew, + batch_done, + quantile_taus, + key, + ) + + qf_state = cls.soft_update(tau, qf_state) + + (actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key) = jax.lax.cond( + (policy_delay_offset + i) % policy_delay == 0, + # If True: + cls.update_actor_and_temperature, + # If False: + lambda *_: (actor_state, qf_state, ent_coef_state, info["actor_loss"], info["ent_coef_loss"], key), + actor_state, + qf_state, + ent_coef_state, + batch_obs, + target_entropy, + key, + ) + info = {"actor_loss": actor_loss_value, "qf_loss": qf_loss_value, "ent_coef_loss": ent_coef_loss_value} + + return { + "actor_state": actor_state, + "qf_state": qf_state, + "ent_coef_state": ent_coef_state, + "key": key, + "info": info, + } + + update_carry = jax.lax.fori_loop(0, gradient_steps, one_update, carry) + + return ( + update_carry["qf_state"], + update_carry["actor_state"], + update_carry["ent_coef_state"], + update_carry["key"], + (update_carry["info"]["actor_loss"], update_carry["info"]["qf_loss"], update_carry["info"]["ent_coef_loss"]), + ) diff --git a/sbx/bro/policies.py b/sbx/bro/policies.py new file mode 100644 index 0000000..9b9c96e --- /dev/null +++ b/sbx/bro/policies.py @@ -0,0 +1,259 @@ +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import optax +import tensorflow_probability.substrates.jax as tfp +from flax.training.train_state import TrainState +from gymnasium import spaces +from stable_baselines3.common.type_aliases import Schedule + +from sbx.common.distributions import TanhTransformedDistribution +from sbx.common.policies import BaseJaxPolicy, Flatten +from sbx.common.type_aliases import RLTrainState + +tfd = tfp.distributions + +class BroNetBlock(nn.Module): + n_units: int + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + out = nn.Dense(self.n_units)(x) + out = nn.LayerNorm()(out) + out = self.activation_fn(out) + out = nn.Dense(self.n_units)(out) + out = nn.LayerNorm()(out) + return x + out + +class BroNet(nn.Module): + net_arch: Sequence[int] + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = nn.Dense(self.net_arch[0], self.activation_fn)(x) + x = nn.LayerNorm()(x) + x = self.activation_fn(x) + for n_units in self.net_arch: + x = BroNetBlock(n_units)(x) + return x + +class Actor(nn.Module): + net_arch: Sequence[int] + action_dim: int + log_std_min: float = -10 + log_std_max: float = 2 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + def get_std(self): + # Make it work with gSDE + return jnp.array(0.0) + + @nn.compact + def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] + x = Flatten()(x) + x = BroNet(net_arch=self.net_arch)(x) + mean = nn.Dense(self.action_dim)(x) + log_std = nn.Dense(self.action_dim)(x) + log_std = self.log_std_min + (self.log_std_max - self.log_std_min) * 0.5 * (1 + nn.tanh(log_std)) + dist = TanhTransformedDistribution( + tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), + ) + return dist + +class Critic(nn.Module): + net_arch: Sequence[int] + n_quantiles: int = 100 + dropout_rate: Optional[float] = None + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + @nn.compact + def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: + x = Flatten()(x) + x = jnp.concatenate([x, action], -1) + x = BroNet(self.net_arch, self.activation_fn)(x) + x = nn.Dense(self.n_quantiles)(x) + return x + +class VectorCritic(nn.Module): + net_arch: Sequence[int] + n_quantiles: int = 100 + n_critics: int = 2 + dropout_rate: Optional[float] = None + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + @nn.compact + def __call__(self, obs: jnp.ndarray, action: jnp.ndarray): + # Idea taken from https://github.com/perrin-isir/xpag + # Similar to https://github.com/tinkoff-ai/CORL for PyTorch + vmap_critic = nn.vmap( + Critic, + variable_axes={"params": 0}, # parameters not shared between the critics + split_rngs={"params": True}, # different initializations + in_axes=None, + out_axes=0, + axis_size=self.n_critics, + ) + q_values = vmap_critic( + net_arch=self.net_arch, + n_quantiles=self.n_quantiles, + activation_fn=self.activation_fn, + )(obs, action) + return q_values + +class BROPolicy(BaseJaxPolicy): + action_space: spaces.Box # type: ignore[assignment] + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Box, + lr_schedule: Schedule, + # BRO + n_quantiles: int = 100, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + dropout_rate: float = 0.0, + layer_norm: bool = False, + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, + use_sde: bool = False, + # Note: most gSDE parameters are not used + # this is to keep API consistent with SB3 + log_std_init: float = -3, + use_expln: bool = False, + clip_mean: float = 2.0, + features_extractor_class=None, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2, + share_features_extractor: bool = False, + ): + if optimizer_kwargs is None: + # Note: the default value for b1 is 0.9 in Adam. + # b1=0.5 is used in the original CrossQ implementation + # but shows only little overall improvement. + optimizer_kwargs = {} + if optimizer_class in [optax.adam, optax.adamw]: + optimizer_kwargs["b1"] = 0.5 + + super().__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=True, + ) + + self.dropout_rate = dropout_rate + self.layer_norm = layer_norm + self.n_quantiles = n_quantiles + if net_arch is not None: + if isinstance(net_arch, list): + self.net_arch_pi = self.net_arch_qf = net_arch + else: + self.net_arch_pi = net_arch["pi"] + self.net_arch_qf = net_arch["qf"] + else: + self.net_arch_pi = [256] + self.net_arch_qf = [512, 512] + print(self.net_arch_qf) + self.n_critics = n_critics + self.use_sde = use_sde + self.activation_fn = activation_fn + + self.key = self.noise_key = jax.random.PRNGKey(0) + + def build(self, key: jax.Array, + lr_schedule: Schedule, + qf_learning_rate: float) -> jax.Array: + key, actor_key, qf_key, dropout_key = jax.random.split(key, 4) + # Keep a key for the actor + key, self.key = jax.random.split(key, 2) + # Initialize noise + self.reset_noise() + + if isinstance(self.observation_space, spaces.Dict): + obs = jnp.array([spaces.flatten(self.observation_space, self.observation_space.sample())]) + else: + obs = jnp.array([self.observation_space.sample()]) + action = jnp.array([self.action_space.sample()]) + + self.actor = Actor( + action_dim=int(np.prod(self.action_space.shape)), + net_arch=self.net_arch_pi, + activation_fn=self.activation_fn, + ) + + # Hack to make gSDE work without modifying internal SB3 code + self.actor.reset_noise = self.reset_noise + + self.actor_state = TrainState.create( + apply_fn=self.actor.apply, + params=self.actor.init(actor_key, obs), + tx=self.optimizer_class( + learning_rate=lr_schedule(1), # type: ignore[call-arg] + #learning_rate=qf_learning_rate, # type: ignore[call-arg] + **self.optimizer_kwargs, + ), + ) + + self.qf = VectorCritic( + net_arch=self.net_arch_qf, + n_quantiles=self.n_quantiles, + n_critics=self.n_critics, + dropout_rate=self.dropout_rate, + activation_fn=self.activation_fn, + ) + + self.qf_state = RLTrainState.create( + apply_fn=self.qf.apply, + params=self.qf.init( + {"params": qf_key, "dropout": dropout_key}, + obs, + action, + ), + target_params=self.qf.init( + {"params": qf_key, "dropout": dropout_key}, + obs, + action, + ), + tx=self.optimizer_class( + learning_rate=qf_learning_rate, # type: ignore[call-arg] + **self.optimizer_kwargs, + ), + ) + + self.actor.apply = jax.jit(self.actor.apply) # type: ignore[method-assign] + self.qf.apply = jax.jit( # type: ignore[method-assign] + self.qf.apply, + static_argnames=("dropout_rate", "use_layer_norm"), + ) + + return key + + def reset_noise(self, batch_size: int = 1) -> None: + """ + Sample new weights for the exploration matrix, when using gSDE. + """ + self.key, self.noise_key = jax.random.split(self.key, 2) + + def forward(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray: + return self._predict(obs, deterministic=deterministic) + + def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.ndarray: # type: ignore[override] + if deterministic: + return BaseJaxPolicy.select_action(self.actor_state, observation) + # Trick to use gSDE: repeat sampled noise by using the same noise key + if not self.use_sde: + self.reset_noise() + return BaseJaxPolicy.sample_action(self.actor_state, observation, self.noise_key) + + + From 1580b130bbe14d550222f91d6374fbd9df5ab270 Mon Sep 17 00:00:00 2001 From: naumix Date: Thu, 31 Oct 2024 14:43:45 +0100 Subject: [PATCH 02/21] Update policies.py --- sbx/bro/policies.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sbx/bro/policies.py b/sbx/bro/policies.py index 9b9c96e..5297d6e 100644 --- a/sbx/bro/policies.py +++ b/sbx/bro/policies.py @@ -162,7 +162,8 @@ def __init__( self.net_arch_qf = net_arch["qf"] else: self.net_arch_pi = [256] - self.net_arch_qf = [512, 512] + # In the paper we use [512, 512] although we also use higher RR, here we use bigger network size to compensate for the smaller RR + self.net_arch_qf = [1024, 1024] print(self.net_arch_qf) self.n_critics = n_critics self.use_sde = use_sde From 2c0322be63805f523ada5582183aeaceb197d6cc Mon Sep 17 00:00:00 2001 From: naumix Date: Fri, 1 Nov 2024 19:00:19 +0100 Subject: [PATCH 03/21] add --- sbx/__init__.py | 3 +++ sbx/bro/bro.py | 9 +++++---- sbx/common/off_policy_algorithm.py | 1 + 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/sbx/__init__.py b/sbx/__init__.py index a7c13bc..8ec1325 100644 --- a/sbx/__init__.py +++ b/sbx/__init__.py @@ -7,6 +7,8 @@ from sbx.sac import SAC from sbx.td3 import TD3 from sbx.tqc import TQC +from sbx.bro import BRO + # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") @@ -30,4 +32,5 @@ def DroQ(*args, **kwargs): "SAC", "TD3", "TQC", + "BRO", ] diff --git a/sbx/bro/bro.py b/sbx/bro/bro.py index 37b4e3b..fe1c1e9 100644 --- a/sbx/bro/bro.py +++ b/sbx/bro/bro.py @@ -19,6 +19,7 @@ from sbx.bro.policies import BROPolicy + class EntropyCoef(nn.Module): ent_coef_init: float = 1.0 @@ -238,10 +239,10 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.key, ) self._n_updates += gradient_steps - self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") - self.logger.record("train/actor_loss", actor_loss_value.item()) - self.logger.record("train/critic_loss", qf_loss_value.item()) - self.logger.record("train/ent_coef", ent_coef_value.item()) + #self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + #self.logger.record("train/actor_loss", actor_loss_value.item()) + #self.logger.record("train/critic_loss", qf_loss_value.item()) + #self.logger.record("train/ent_coef", ent_coef_value.item()) @staticmethod @jax.jit diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index ba0b9ed..fddf57b 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -115,6 +115,7 @@ def _setup_model(self) -> None: device="cpu", # force cpu device to easy torch -> numpy conversion n_envs=self.n_envs, optimize_memory_usage=self.optimize_memory_usage, + handle_timeout_termination=False, **replay_buffer_kwargs, ) # Convert train freq parameter to TrainFreq object From f23b80f17172e0250a9dfef0976dd4a5422c1d48 Mon Sep 17 00:00:00 2001 From: naumix Date: Fri, 1 Nov 2024 22:44:59 +0100 Subject: [PATCH 04/21] add scripts --- make_dmc.py | 157 ++++++++++++++++++++++++++++++++++ sbx/bro/bro.py | 5 ++ scripts/dmc/acrobot.sh | 13 +++ scripts/dmc/cheetah.sh | 13 +++ scripts/dmc/dog_run.sh | 13 +++ scripts/dmc/dog_stand.sh | 13 +++ scripts/dmc/dog_trot.sh | 13 +++ scripts/dmc/dog_walk.sh | 13 +++ scripts/dmc/finger.sh | 13 +++ scripts/dmc/fish.sh | 13 +++ scripts/dmc/hopper.sh | 13 +++ scripts/dmc/humanoid_run.sh | 13 +++ scripts/dmc/humanoid_stand.sh | 13 +++ scripts/dmc/humanoid_walk.sh | 13 +++ scripts/dmc/pendulum.sh | 13 +++ scripts/dmc/quadruped.sh | 13 +++ scripts/dmc/walker.sh | 13 +++ scripts/run_tests.sh | 2 - train.py | 110 ++++++++++++++++++++++++ 19 files changed, 467 insertions(+), 2 deletions(-) create mode 100644 make_dmc.py create mode 100644 scripts/dmc/acrobot.sh create mode 100644 scripts/dmc/cheetah.sh create mode 100644 scripts/dmc/dog_run.sh create mode 100644 scripts/dmc/dog_stand.sh create mode 100644 scripts/dmc/dog_trot.sh create mode 100644 scripts/dmc/dog_walk.sh create mode 100644 scripts/dmc/finger.sh create mode 100644 scripts/dmc/fish.sh create mode 100644 scripts/dmc/hopper.sh create mode 100644 scripts/dmc/humanoid_run.sh create mode 100644 scripts/dmc/humanoid_stand.sh create mode 100644 scripts/dmc/humanoid_walk.sh create mode 100644 scripts/dmc/pendulum.sh create mode 100644 scripts/dmc/quadruped.sh create mode 100644 scripts/dmc/walker.sh delete mode 100755 scripts/run_tests.sh create mode 100644 train.py diff --git a/make_dmc.py b/make_dmc.py new file mode 100644 index 0000000..57374dd --- /dev/null +++ b/make_dmc.py @@ -0,0 +1,157 @@ +#adapted from https://github.com/imgeorgiev/dmc2gymnasium + +import logging +import numpy as np + +from dm_control import suite +from dm_env import specs +from gymnasium.core import Env +from gymnasium.spaces import Box +from gymnasium import spaces +from gymnasium.wrappers import FlattenObservation, RescaleAction + +def _spec_to_box(spec, dtype=np.float32): + def extract_min_max(s): + assert s.dtype == np.float64 or s.dtype == np.float32 + dim = int(np.prod(s.shape)) + if type(s) == specs.Array: + bound = np.inf * np.ones(dim, dtype=np.float32) + return -bound, bound + elif type(s) == specs.BoundedArray: + zeros = np.zeros(dim, dtype=np.float32) + return s.minimum + zeros, s.maximum + zeros + else: + logging.error("Unrecognized type") + mins, maxs = [], [] + for s in spec: + mn, mx = extract_min_max(s) + mins.append(mn) + maxs.append(mx) + low = np.concatenate(mins, axis=0).astype(dtype) + high = np.concatenate(maxs, axis=0).astype(dtype) + assert low.shape == high.shape + return Box(low, high, dtype=dtype) + + +def _flatten_obs(obs, dtype=np.float32): + obs_pieces = [] + for v in obs.values(): + flat = np.array([v]) if np.isscalar(v) else v.ravel() + obs_pieces.append(flat) + return np.concatenate(obs_pieces, axis=0).astype(dtype) + + +class DMCGym(Env): + def __init__( + self, + env_name, + task_kwargs={}, + environment_kwargs={}, + #rendering="egl", + render_height=64, + render_width=64, + render_camera_id=0, + action_repeat=1 + ): + domain = env_name.split('-')[0] + task = env_name.split('-')[1] + self._env = suite.load( + domain, + task, + task_kwargs, + environment_kwargs, + ) + + # placeholder to allow built in gymnasium rendering + self.render_mode = "rgb_array" + self.render_height = render_height + self.render_width = render_width + self.render_camera_id = render_camera_id + + self._true_action_space = _spec_to_box([self._env.action_spec()], np.float32) + self._norm_action_space = spaces.Box( + low=-1.0, + high=1.0, + shape=self._true_action_space.shape, + dtype=np.float32 + ) + + self._observation_space = _spec_to_box(self._env.observation_spec().values()) + self._action_space = _spec_to_box([self._env.action_spec()]) + self.action_repeat = action_repeat + + # set seed if provided with task_kwargs + if "random" in task_kwargs: + seed = task_kwargs["random"] + self._observation_space.seed(seed) + self._action_space.seed(seed) + + def __getattr__(self, name): + """Add this here so that we can easily access attributes of the underlying env""" + return getattr(self._env, name) + + @property + def observation_space(self): + return self._observation_space + + @property + def action_space(self): + return self._action_space + + @property + def reward_range(self): + """DMC always has a per-step reward range of (0, 1)""" + return 0, 1 + + def _convert_action(self, action): + action = action.astype(np.float64) + true_delta = self._true_action_space.high - self._true_action_space.low + norm_delta = self._norm_action_space.high - self._norm_action_space.low + action = (action - self._norm_action_space.low) / norm_delta + action = action * true_delta + self._true_action_space.low + action = action.astype(np.float32) + return action + + def step(self, action): + assert self._norm_action_space.contains(action) + action = self._convert_action(action) + assert self._true_action_space.contains(action) + action = np.clip(action, -1.0, 1.0) + reward = 0 + info = {} + for i in range(self.action_repeat): + timestep = self._env.step(action) + observation = _flatten_obs(timestep.observation) + reward += timestep.reward + termination = False # we never reach a goal + truncation = timestep.last() + if truncation: + return observation, reward, termination, truncation, info + return observation, reward, termination, truncation, info + + def reset(self, seed=None, options=None): + if seed is not None: + if not isinstance(seed, np.random.RandomState): + seed = np.random.RandomState(seed) + self._env.task._random = seed + if options: + logging.warn("Currently doing nothing with options={:}".format(options)) + timestep = self._env.reset() + observation = _flatten_obs(timestep.observation) + info = {} + return observation, info + + def render(self, height=None, width=None, camera_id=None): + height = height or self.render_height + width = width or self.render_width + camera_id = camera_id or self.render_camera_id + return self._env.physics.render(height=height, width=width, camera_id=camera_id) + + +def make_env_dmc(env_name: str, action_repeat: int = 1) -> Env: + env = DMCGym(env_name=env_name, action_repeat=action_repeat) + env = RescaleAction(env, -1.0, 1.0) + env = FlattenObservation(env) + return env + + \ No newline at end of file diff --git a/sbx/bro/bro.py b/sbx/bro/bro.py index fe1c1e9..a3b29c3 100644 --- a/sbx/bro/bro.py +++ b/sbx/bro/bro.py @@ -239,6 +239,11 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.key, ) self._n_updates += gradient_steps + return { + 'actor_loss': actor_loss_value.item(), + 'critic_loss': qf_loss_value.item(), + 'ent_coef': ent_coef_value.item(), + } #self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") #self.logger.record("train/actor_loss", actor_loss_value.item()) #self.logger.record("train/critic_loss", qf_loss_value.item()) diff --git a/scripts/dmc/acrobot.sh b/scripts/dmc/acrobot.sh new file mode 100644 index 0000000..50070c2 --- /dev/null +++ b/scripts/dmc/acrobot.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=acrobot-swingup & python3 train_torch.py --env_name=acrobot-swingup & python3 train_torch.py --env_name=acrobot-swingup + +wait diff --git a/scripts/dmc/cheetah.sh b/scripts/dmc/cheetah.sh new file mode 100644 index 0000000..c6e8b3f --- /dev/null +++ b/scripts/dmc/cheetah.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=cheetah-run & python3 train_torch.py --env_name=cheetah-run & python3 train_torch.py --env_name=cheetah-run + +wait diff --git a/scripts/dmc/dog_run.sh b/scripts/dmc/dog_run.sh new file mode 100644 index 0000000..88fbfa7 --- /dev/null +++ b/scripts/dmc/dog_run.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=dog-run & python3 train_torch.py --env_name=dog-run & python3 train_torch.py --env_name=dog-run + +wait diff --git a/scripts/dmc/dog_stand.sh b/scripts/dmc/dog_stand.sh new file mode 100644 index 0000000..7184f93 --- /dev/null +++ b/scripts/dmc/dog_stand.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=dog-stand & python3 train_torch.py --env_name=dog-stand & python3 train_torch.py --env_name=dog-stand + +wait diff --git a/scripts/dmc/dog_trot.sh b/scripts/dmc/dog_trot.sh new file mode 100644 index 0000000..701d141 --- /dev/null +++ b/scripts/dmc/dog_trot.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=dog-trot & python3 train_torch.py --env_name=dog-trot & python3 train_torch.py --env_name=dog-trot + +wait diff --git a/scripts/dmc/dog_walk.sh b/scripts/dmc/dog_walk.sh new file mode 100644 index 0000000..24f9635 --- /dev/null +++ b/scripts/dmc/dog_walk.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=dog-walk & python3 train_torch.py --env_name=dog-walk & python3 train_torch.py --env_name=dog-walk + +wait diff --git a/scripts/dmc/finger.sh b/scripts/dmc/finger.sh new file mode 100644 index 0000000..8cf0d21 --- /dev/null +++ b/scripts/dmc/finger.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=finger-turn_hard & python3 train_torch.py --env_name=finger-turn_hard & python3 train_torch.py --env_name=finger-turn_hard + +wait diff --git a/scripts/dmc/fish.sh b/scripts/dmc/fish.sh new file mode 100644 index 0000000..c18e8e3 --- /dev/null +++ b/scripts/dmc/fish.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=fish-swim & python3 train_torch.py --env_name=fish-swim & python3 train_torch.py --env_name=fish-swim + +wait diff --git a/scripts/dmc/hopper.sh b/scripts/dmc/hopper.sh new file mode 100644 index 0000000..7ecd2cc --- /dev/null +++ b/scripts/dmc/hopper.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=hopper-hop & python3 train_torch.py --env_name=hopper-hop & python3 train_torch.py --env_name=hopper-hop + +wait diff --git a/scripts/dmc/humanoid_run.sh b/scripts/dmc/humanoid_run.sh new file mode 100644 index 0000000..7f33786 --- /dev/null +++ b/scripts/dmc/humanoid_run.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=humanoid-run & python3 train_torch.py --env_name=humanoid-run & python3 train_torch.py --env_name=humanoid-run + +wait diff --git a/scripts/dmc/humanoid_stand.sh b/scripts/dmc/humanoid_stand.sh new file mode 100644 index 0000000..16fa76f --- /dev/null +++ b/scripts/dmc/humanoid_stand.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=humanoid-stand & python3 train_torch.py --env_name=humanoid-stand & python3 train_torch.py --env_name=humanoid-stand + +wait diff --git a/scripts/dmc/humanoid_walk.sh b/scripts/dmc/humanoid_walk.sh new file mode 100644 index 0000000..651c2a3 --- /dev/null +++ b/scripts/dmc/humanoid_walk.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=humanoid-walk & python3 train_torch.py --env_name=humanoid-walk & python3 train_torch.py --env_name=humanoid-walk + +wait diff --git a/scripts/dmc/pendulum.sh b/scripts/dmc/pendulum.sh new file mode 100644 index 0000000..5ca5fd4 --- /dev/null +++ b/scripts/dmc/pendulum.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=pendulum-swingup & python3 train_torch.py --env_name=pendulum-swingup & python3 train_torch.py --env_name=pendulum-swingup + +wait diff --git a/scripts/dmc/quadruped.sh b/scripts/dmc/quadruped.sh new file mode 100644 index 0000000..5de95f1 --- /dev/null +++ b/scripts/dmc/quadruped.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=quadruped-run & python3 train_torch.py --env_name=quadruped-run & python3 train_torch.py --env_name=quadruped-run + +wait diff --git a/scripts/dmc/walker.sh b/scripts/dmc/walker.sh new file mode 100644 index 0000000..3e97f7f --- /dev/null +++ b/scripts/dmc/walker.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --env_name=walker-run & python3 train_torch.py --env_name=walker-run & python3 train_torch.py --env_name=walker-run + +wait diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh deleted file mode 100755 index 9ecb207..0000000 --- a/scripts/run_tests.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes -m "not expensive" diff --git a/train.py b/train.py new file mode 100644 index 0000000..cc38cfc --- /dev/null +++ b/train.py @@ -0,0 +1,110 @@ +import gymnasium as gym +from sbx.bro.bro import BRO +import numpy as np +from make_dmc import make_env_dmc +import wandb + +from absl import app, flags + +flags.DEFINE_string('env_name', 'cheetah-run', 'Environment name.') +flags.DEFINE_string('benchmark', 'dmc', 'Environment name.') +flags.DEFINE_integer('learning_starts', 2500, 'Number of training steps to start training.') +flags.DEFINE_integer('training_steps', 1000000, 'Number of training steps.') +flags.DEFINE_integer('batch_size', 128, 'Mini batch size.') +flags.DEFINE_integer('gradient_steps', 2, 'Number of updates per step.') +flags.DEFINE_integer('n_quantiles', 100, 'Number of training steps.') +flags.DEFINE_integer('eval_freq', 25000, 'Eval interval.') +flags.DEFINE_integer('num_episodes', 5, 'Number of episodes used for evaluation.') +FLAGS = flags.FLAGS + +''' +class flags: + env_name: str = "cheetah-run" + learning_starts: int = 5000 + training_steps: int = 10_000 + seed: int = 0 + batch_size: int = 128 + gradient_steps: int = 2 + use_wandb: bool = False + n_quantiles: int = 100 + eval_freq: int = 5000 + num_episodes: int = 5 +FLAGS = flags() +''' + +def evaluate(env, model, num_episodes): + returns = np.zeros(num_episodes) + for episode in range(num_episodes): + not_done = True + obs, _ = env.reset(seed=np.random.randint(1e7)) + obs = np.expand_dims(obs, axis=0) + ret = 0 + while not_done: + action = model.policy.forward(obs, deterministic=True)[0] + next_obs, reward, term, trun, info = env.step(action) + next_obs = np.expand_dims(next_obs, axis=0) + obs = next_obs + ret += reward + if term or trun : + not_done = False + returns[episode] = ret + return {'return_eval': returns.mean()} + +def log_to_wandb(step, infos): + dict_to_log = {'timestep': step} + for info_key in infos: + dict_to_log[f'{info_key}'] = infos[info_key] + wandb.log(dict_to_log, step=step) + +def get_env(benchmark, env_name): + if benchmark == 'gym': + return gym.make(FLAGS.env_name) + else: + return make_env_dmc(env_name=FLAGS.env_name, action_repeat=1) + +def main(_): + SEED = np.random.randint(1e7) + wandb.init( + config=FLAGS, + entity='naumix', + project='BRO_SBX', + group=f'{FLAGS.env_name}_{SEED}', + name=f'BRO_Quantile:{FLAGS.n_quantiles}_BS:{FLAGS.batch_size}' + ) + + env = get_env(FLAGS.benchmark, FLAGS.env_name) + eval_env = get_env(FLAGS.benchmark, FLAGS.env_name) + model = BRO("MlpPolicy", env, learning_starts=FLAGS.learning_starts, verbose=0, n_quantiles=FLAGS.n_quantiles, seed=SEED) + np.random.seed(SEED) + + obs, _ = env.reset(seed=np.random.randint(1e7)) + obs = np.expand_dims(obs, axis=0) + + for i in range(1, FLAGS.training_steps+1): + if i <= FLAGS.learning_starts: + action = env.action_space.sample() + else: + action = model.policy.forward(obs, deterministic=False)[0] + next_obs, reward, term, trun, info = env.step(action) + next_obs = np.expand_dims(next_obs, axis=0) + + done = 1.0 if (term and not trun) else 0.0 + + model.replay_buffer.add(obs, next_obs, action, reward, done, info) + if term or trun: + obs, _ = env.reset(seed=np.random.randint(1e7)) + obs = np.expand_dims(obs, axis=0) + else: + obs = next_obs + + if i >= FLAGS.learning_starts: + train_info = model.train(FLAGS.gradient_steps, FLAGS.batch_size) + + if i % FLAGS.eval_freq == 0: + eval_info = evaluate(eval_env, model, FLAGS.num_episodes) + info = {**eval_info, **train_info} + #print(eval_info) + log_to_wandb(i, info) + +if __name__ == '__main__': + app.run(main) From e291f699c02f938ecfbccb5d1e160ba681c3ff68 Mon Sep 17 00:00:00 2001 From: naumix Date: Fri, 1 Nov 2024 23:54:02 +0100 Subject: [PATCH 05/21] scripts --- scripts/dmc/acrobot.sh | 2 +- scripts/dmc/cheetah.sh | 2 +- scripts/dmc/dog_run.sh | 2 +- scripts/dmc/dog_stand.sh | 2 +- scripts/dmc/dog_trot.sh | 2 +- scripts/dmc/dog_walk.sh | 2 +- scripts/dmc/finger.sh | 2 +- scripts/dmc/fish.sh | 2 +- scripts/dmc/hopper.sh | 2 +- scripts/dmc/humanoid_run.sh | 2 +- scripts/dmc/humanoid_stand.sh | 2 +- scripts/dmc/humanoid_walk.sh | 2 +- scripts/dmc/pendulum.sh | 2 +- scripts/dmc/quadruped.sh | 2 +- scripts/dmc/walker.sh | 2 +- train.py => train_torch.py | 4 ++-- 16 files changed, 17 insertions(+), 17 deletions(-) rename train.py => train_torch.py (98%) diff --git a/scripts/dmc/acrobot.sh b/scripts/dmc/acrobot.sh index 50070c2..f5e59bf 100644 --- a/scripts/dmc/acrobot.sh +++ b/scripts/dmc/acrobot.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=acrobot-swingup & python3 train_torch.py --env_name=acrobot-swingup & python3 train_torch.py --env_name=acrobot-swingup +python3 train_torch.py --env_name=acrobot-swingup wait diff --git a/scripts/dmc/cheetah.sh b/scripts/dmc/cheetah.sh index c6e8b3f..9c025df 100644 --- a/scripts/dmc/cheetah.sh +++ b/scripts/dmc/cheetah.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=cheetah-run & python3 train_torch.py --env_name=cheetah-run & python3 train_torch.py --env_name=cheetah-run +python3 train_torch.py --env_name=cheetah-run wait diff --git a/scripts/dmc/dog_run.sh b/scripts/dmc/dog_run.sh index 88fbfa7..45b27ea 100644 --- a/scripts/dmc/dog_run.sh +++ b/scripts/dmc/dog_run.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=dog-run & python3 train_torch.py --env_name=dog-run & python3 train_torch.py --env_name=dog-run +python3 train_torch.py --env_name=dog-run wait diff --git a/scripts/dmc/dog_stand.sh b/scripts/dmc/dog_stand.sh index 7184f93..07a4584 100644 --- a/scripts/dmc/dog_stand.sh +++ b/scripts/dmc/dog_stand.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=dog-stand & python3 train_torch.py --env_name=dog-stand & python3 train_torch.py --env_name=dog-stand +python3 train_torch.py --env_name=dog-stand wait diff --git a/scripts/dmc/dog_trot.sh b/scripts/dmc/dog_trot.sh index 701d141..d7b1411 100644 --- a/scripts/dmc/dog_trot.sh +++ b/scripts/dmc/dog_trot.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=dog-trot & python3 train_torch.py --env_name=dog-trot & python3 train_torch.py --env_name=dog-trot +python3 train_torch.py --env_name=dog-trot wait diff --git a/scripts/dmc/dog_walk.sh b/scripts/dmc/dog_walk.sh index 24f9635..4221da2 100644 --- a/scripts/dmc/dog_walk.sh +++ b/scripts/dmc/dog_walk.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=dog-walk & python3 train_torch.py --env_name=dog-walk & python3 train_torch.py --env_name=dog-walk +python3 train_torch.py --env_name=dog-walk wait diff --git a/scripts/dmc/finger.sh b/scripts/dmc/finger.sh index 8cf0d21..8a10360 100644 --- a/scripts/dmc/finger.sh +++ b/scripts/dmc/finger.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=finger-turn_hard & python3 train_torch.py --env_name=finger-turn_hard & python3 train_torch.py --env_name=finger-turn_hard +python3 train_torch.py --env_name=finger-turn_hard wait diff --git a/scripts/dmc/fish.sh b/scripts/dmc/fish.sh index c18e8e3..7dca0fa 100644 --- a/scripts/dmc/fish.sh +++ b/scripts/dmc/fish.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=fish-swim & python3 train_torch.py --env_name=fish-swim & python3 train_torch.py --env_name=fish-swim +python3 train_torch.py --env_name=fish-swim wait diff --git a/scripts/dmc/hopper.sh b/scripts/dmc/hopper.sh index 7ecd2cc..fea40d9 100644 --- a/scripts/dmc/hopper.sh +++ b/scripts/dmc/hopper.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=hopper-hop & python3 train_torch.py --env_name=hopper-hop & python3 train_torch.py --env_name=hopper-hop +python3 train_torch.py --env_name=hopper-hop wait diff --git a/scripts/dmc/humanoid_run.sh b/scripts/dmc/humanoid_run.sh index 7f33786..f434d0b 100644 --- a/scripts/dmc/humanoid_run.sh +++ b/scripts/dmc/humanoid_run.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=humanoid-run & python3 train_torch.py --env_name=humanoid-run & python3 train_torch.py --env_name=humanoid-run +python3 train_torch.py --env_name=humanoid-run wait diff --git a/scripts/dmc/humanoid_stand.sh b/scripts/dmc/humanoid_stand.sh index 16fa76f..e22bff0 100644 --- a/scripts/dmc/humanoid_stand.sh +++ b/scripts/dmc/humanoid_stand.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=humanoid-stand & python3 train_torch.py --env_name=humanoid-stand & python3 train_torch.py --env_name=humanoid-stand +python3 train_torch.py --env_name=humanoid-stand wait diff --git a/scripts/dmc/humanoid_walk.sh b/scripts/dmc/humanoid_walk.sh index 651c2a3..64a2a37 100644 --- a/scripts/dmc/humanoid_walk.sh +++ b/scripts/dmc/humanoid_walk.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=humanoid-walk & python3 train_torch.py --env_name=humanoid-walk & python3 train_torch.py --env_name=humanoid-walk +python3 train_torch.py --env_name=humanoid-walk wait diff --git a/scripts/dmc/pendulum.sh b/scripts/dmc/pendulum.sh index 5ca5fd4..0e54c35 100644 --- a/scripts/dmc/pendulum.sh +++ b/scripts/dmc/pendulum.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=pendulum-swingup & python3 train_torch.py --env_name=pendulum-swingup & python3 train_torch.py --env_name=pendulum-swingup +python3 train_torch.py --env_name=pendulum-swingup wait diff --git a/scripts/dmc/quadruped.sh b/scripts/dmc/quadruped.sh index 5de95f1..e91ba25 100644 --- a/scripts/dmc/quadruped.sh +++ b/scripts/dmc/quadruped.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=quadruped-run & python3 train_torch.py --env_name=quadruped-run & python3 train_torch.py --env_name=quadruped-run +python3 train_torch.py --env_name=quadruped-run wait diff --git a/scripts/dmc/walker.sh b/scripts/dmc/walker.sh index 3e97f7f..3da6aa7 100644 --- a/scripts/dmc/walker.sh +++ b/scripts/dmc/walker.sh @@ -8,6 +8,6 @@ conda activate sbx module load cuDNN/8.9.2.26-CUDA-12.2.0 -python3 train_torch.py --env_name=walker-run & python3 train_torch.py --env_name=walker-run & python3 train_torch.py --env_name=walker-run +python3 train_torch.py --env_name=walker-run wait diff --git a/train.py b/train_torch.py similarity index 98% rename from train.py rename to train_torch.py index cc38cfc..ab106f5 100644 --- a/train.py +++ b/train_torch.py @@ -68,8 +68,8 @@ def main(_): config=FLAGS, entity='naumix', project='BRO_SBX', - group=f'{FLAGS.env_name}_{SEED}', - name=f'BRO_Quantile:{FLAGS.n_quantiles}_BS:{FLAGS.batch_size}' + group=f'{FLAGS.env_name}', + name=f'BRO_Quantile:{FLAGS.n_quantiles}_BS:{FLAGS.batch_size}_{SEED}' ) env = get_env(FLAGS.benchmark, FLAGS.env_name) From c2ca2f9a5bb8259322078d9e465092f2a50caa5a Mon Sep 17 00:00:00 2001 From: naumix Date: Sat, 2 Nov 2024 18:13:22 +0100 Subject: [PATCH 06/21] add --- sbx/bro/bro.py | 121 +++++++++++++++++++++++++++-- sbx/bro/policies.py | 2 +- sbx/common/off_policy_algorithm.py | 2 +- scripts/gym/ant.sh | 13 ++++ scripts/gym/cheetah.sh | 13 ++++ scripts/gym/hopper.sh | 13 ++++ scripts/gym/walker.sh | 13 ++++ train_torch.py | 15 ++-- 8 files changed, 179 insertions(+), 13 deletions(-) create mode 100644 scripts/gym/ant.sh create mode 100644 scripts/gym/cheetah.sh create mode 100644 scripts/gym/hopper.sh create mode 100644 scripts/gym/walker.sh diff --git a/sbx/bro/bro.py b/sbx/bro/bro.py index a3b29c3..70b05c6 100644 --- a/sbx/bro/bro.py +++ b/sbx/bro/bro.py @@ -27,8 +27,7 @@ class EntropyCoef(nn.Module): def __call__(self) -> jnp.ndarray: log_ent_coef = self.param("log_ent_coef", init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init))) return jnp.exp(log_ent_coef) - - + class ConstantEntropyCoef(nn.Module): ent_coef_init: float = 1.0 @@ -39,7 +38,27 @@ def __call__(self) -> float: self.param("dummy_param", init_fn=lambda key: jnp.full((), self.ent_coef_init)) return self.ent_coef_init - +@jax.jit +def _get_stats( + actor_state: RLTrainState, + qf_state: RLTrainState, + ent_coef_state: TrainState, + observations: jax.Array, + key: jax.Array, +): + key, dropout_key, noise_key = jax.random.split(key, 3) + dist = actor_state.apply_fn(actor_state.params, observations) + actor_actions = dist.sample(seed=noise_key) + log_prob = dist.log_prob(actor_actions).reshape(-1, 1) + qf_pi = qf_state.apply_fn( + qf_state.params, + observations, + actor_actions, + rngs={"dropout": dropout_key}, + ) + ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) + return qf_pi.mean(), jnp.absolute(actor_actions).mean(), ent_coef_value.mean(), -log_prob.mean() + class BRO(OffPolicyAlgorithmJax): policy_aliases: ClassVar[Dict[str, Type[BROPolicy]]] = { # type: ignore[assignment] "MlpPolicy": BROPolicy, @@ -110,11 +129,13 @@ def __init__( self.policy_delay = policy_delay self.ent_coef_init = ent_coef self.target_entropy = target_entropy + self.init_key = jax.random.PRNGKey(seed) self.n_quantiles = n_quantiles taus_ = jnp.arange(0, n_quantiles+1) / n_quantiles self.quantile_taus = ((taus_[1:] + taus_[:-1]) / 2.0)[None, ..., None] + self.distributional = True if self.n_quantiles > 1 else False if _init_setup_model: self._setup_model() @@ -122,6 +143,61 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() + if not hasattr(self, "policy") or self.policy is None: + self.policy = self.policy_class( # type: ignore[assignment] + self.observation_space, + self.action_space, + self.lr_schedule, + self.n_quantiles, + **self.policy_kwargs, + ) + + assert isinstance(self.qf_learning_rate, float) + self.key = self.policy.build(self.init_key, self.lr_schedule, self.qf_learning_rate) + + self.key, ent_key = jax.random.split(self.key, 2) + + self.actor = self.policy.actor # type: ignore[assignment] + self.qf = self.policy.qf # type: ignore[assignment] + + # The entropy coefficient or entropy can be learned automatically + # see Automating Entropy Adjustment for Maximum Entropy RL section + # of https://arxiv.org/abs/1812.05905 + if isinstance(self.ent_coef_init, str) and self.ent_coef_init.startswith("auto"): + # Default initial value of ent_coef when learned + ent_coef_init = 1.0 + if "_" in self.ent_coef_init: + ent_coef_init = float(self.ent_coef_init.split("_")[1]) + assert ent_coef_init > 0.0, "The initial value of ent_coef must be greater than 0" + + # Note: we optimize the log of the entropy coeff which is slightly different from the paper + # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 + self.ent_coef = EntropyCoef(ent_coef_init) + else: + # This will throw an error if a malformed string (different from 'auto') is passed + assert isinstance( + self.ent_coef_init, float + ), f"Entropy coef must be float when not equal to 'auto', actual: {self.ent_coef_init}" + self.ent_coef = ConstantEntropyCoef(self.ent_coef_init) # type: ignore[assignment] + + self.ent_coef_state = TrainState.create( + apply_fn=self.ent_coef.apply, + params=self.ent_coef.init(ent_key)["params"], + tx=optax.adam( + learning_rate=self.learning_rate, b1=0.5 + ), + ) + + # Target entropy is used when learning the entropy coefficient + if self.target_entropy == "auto": + # automatically set target entropy if needed + self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) / 2 # type: ignore + else: + # Force conversion + # this will also throw an error for unexpected string + self.target_entropy = float(self.target_entropy) + + def reset(self): if not hasattr(self, "policy") or self.policy is None: self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, @@ -133,7 +209,7 @@ def _setup_model(self) -> None: assert isinstance(self.qf_learning_rate, float) - self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) + self.key = self.policy.build(self.init_key, self.lr_schedule, self.qf_learning_rate) self.key, ent_key = jax.random.split(self.key, 2) @@ -164,7 +240,7 @@ def _setup_model(self) -> None: apply_fn=self.ent_coef.apply, params=self.ent_coef.init(ent_key)["params"], tx=optax.adam( - learning_rate=self.learning_rate, + learning_rate=self.learning_rate, b1=0.5 ), ) @@ -242,7 +318,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: return { 'actor_loss': actor_loss_value.item(), 'critic_loss': qf_loss_value.item(), - 'ent_coef': ent_coef_value.item(), + 'ent_loss': ent_coef_value.item(), } #self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") #self.logger.record("train/actor_loss", actor_loss_value.item()) @@ -431,7 +507,40 @@ def update_actor_and_temperature( ) ent_coef_state, ent_coef_loss_value = cls.update_temperature(target_entropy, ent_coef_state, entropy) return actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key + + def get_stats(self, batch_size): + data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + if isinstance(data.observations, dict): + keys = list(self.observation_space.keys()) # type: ignore[attr-defined] + obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1) + next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1) + else: + obs = data.observations.numpy() + next_obs = data.next_observations.numpy() + + # Convert to numpy + data = ReplayBufferSamplesNp( # type: ignore[assignment] + obs, + data.actions.numpy(), + next_obs, + data.dones.numpy().flatten(), + data.rewards.numpy().flatten(), + ) + q, a, temp, ent = _get_stats( + self.policy.actor_state, + self.policy.qf_state, + self.ent_coef_state, + obs, + self.key, + ) + + return { + 'q': q.mean().item(), + 'a': a.item(), + 'temp': temp.item(), + 'entropy': ent.item()} + @classmethod @partial(jax.jit, static_argnames=["cls", "gradient_steps", "policy_delay", "policy_delay_offset", "distributional"]) def _train( diff --git a/sbx/bro/policies.py b/sbx/bro/policies.py index 5297d6e..1bee8bd 100644 --- a/sbx/bro/policies.py +++ b/sbx/bro/policies.py @@ -128,7 +128,7 @@ def __init__( features_extractor_class=None, features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, + optimizer_class: Callable[..., optax.GradientTransformation] = optax.adamw, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index fddf57b..11e6c8c 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -71,7 +71,7 @@ def __init__( support_multi_env=support_multi_env, ) # Will be updated later - self.key = jax.random.PRNGKey(0) + self.key = jax.random.PRNGKey(seed) # Note: we do not allow schedule for it self.qf_learning_rate = qf_learning_rate diff --git a/scripts/gym/ant.sh b/scripts/gym/ant.sh new file mode 100644 index 0000000..1d63aee --- /dev/null +++ b/scripts/gym/ant.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --benchmark=gym --env_name=Ant-v4 + +wait diff --git a/scripts/gym/cheetah.sh b/scripts/gym/cheetah.sh new file mode 100644 index 0000000..63108f6 --- /dev/null +++ b/scripts/gym/cheetah.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --benchmark=gym --env_name=HalfCheetah-v4 + +wait diff --git a/scripts/gym/hopper.sh b/scripts/gym/hopper.sh new file mode 100644 index 0000000..3eeca18 --- /dev/null +++ b/scripts/gym/hopper.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --benchmark=gym --env_name=Hopper-v4 + +wait diff --git a/scripts/gym/walker.sh b/scripts/gym/walker.sh new file mode 100644 index 0000000..8a3e68c --- /dev/null +++ b/scripts/gym/walker.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --benchmark=gym --env_name=Walker2d-v4 + +wait diff --git a/train_torch.py b/train_torch.py index ab106f5..0aa57d1 100644 --- a/train_torch.py +++ b/train_torch.py @@ -8,9 +8,9 @@ flags.DEFINE_string('env_name', 'cheetah-run', 'Environment name.') flags.DEFINE_string('benchmark', 'dmc', 'Environment name.') -flags.DEFINE_integer('learning_starts', 2500, 'Number of training steps to start training.') +flags.DEFINE_integer('learning_starts', 2000, 'Number of training steps to start training.') flags.DEFINE_integer('training_steps', 1000000, 'Number of training steps.') -flags.DEFINE_integer('batch_size', 128, 'Mini batch size.') +flags.DEFINE_integer('batch_size', 256, 'Mini batch size.') flags.DEFINE_integer('gradient_steps', 2, 'Number of updates per step.') flags.DEFINE_integer('n_quantiles', 100, 'Number of training steps.') flags.DEFINE_integer('eval_freq', 25000, 'Eval interval.') @@ -74,9 +74,10 @@ def main(_): env = get_env(FLAGS.benchmark, FLAGS.env_name) eval_env = get_env(FLAGS.benchmark, FLAGS.env_name) - model = BRO("MlpPolicy", env, learning_starts=FLAGS.learning_starts, verbose=0, n_quantiles=FLAGS.n_quantiles, seed=SEED) + model = BRO("MlpPolicy", env, learning_starts=FLAGS.learning_starts, verbose=0, n_quantiles=FLAGS.n_quantiles, seed=SEED, batch_size=FLAGS.batch_size, learning_starts=FLAGS.learning_starts, gradient_steps=FLAGS.gradient_steps) np.random.seed(SEED) + reset_list = [20000] obs, _ = env.reset(seed=np.random.randint(1e7)) obs = np.expand_dims(obs, axis=0) @@ -97,12 +98,16 @@ def main(_): else: obs = next_obs + if i in reset_list: + model.reset() + if i >= FLAGS.learning_starts: train_info = model.train(FLAGS.gradient_steps, FLAGS.batch_size) - + if i % FLAGS.eval_freq == 0: eval_info = evaluate(eval_env, model, FLAGS.num_episodes) - info = {**eval_info, **train_info} + stat_info = model.get_stats(FLAGS.batch_size) + info = {**eval_info, **train_info, **stat_info} #print(eval_info) log_to_wandb(i, info) From 6599938b6cef194d00825b781752b1bbcb30e4e2 Mon Sep 17 00:00:00 2001 From: naumix Date: Sat, 2 Nov 2024 22:15:27 +0100 Subject: [PATCH 07/21] Update policies.py --- sbx/bro/policies.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sbx/bro/policies.py b/sbx/bro/policies.py index 1bee8bd..576ef80 100644 --- a/sbx/bro/policies.py +++ b/sbx/bro/policies.py @@ -21,7 +21,7 @@ class BroNetBlock(nn.Module): activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: out = nn.Dense(self.n_units)(x) out = nn.LayerNorm()(out) out = self.activation_fn(out) @@ -74,10 +74,10 @@ class Critic(nn.Module): @nn.compact def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: x = Flatten()(x) - x = jnp.concatenate([x, action], -1) - x = BroNet(self.net_arch, self.activation_fn)(x) - x = nn.Dense(self.n_quantiles)(x) - return x + out = jnp.concatenate([x, action], -1) + out = BroNet(self.net_arch, self.activation_fn)(out) + out = nn.Dense(self.n_quantiles)(out) + return out class VectorCritic(nn.Module): net_arch: Sequence[int] @@ -139,7 +139,7 @@ def __init__( # but shows only little overall improvement. optimizer_kwargs = {} if optimizer_class in [optax.adam, optax.adamw]: - optimizer_kwargs["b1"] = 0.5 + pass super().__init__( observation_space, @@ -163,7 +163,7 @@ def __init__( else: self.net_arch_pi = [256] # In the paper we use [512, 512] although we also use higher RR, here we use bigger network size to compensate for the smaller RR - self.net_arch_qf = [1024, 1024] + self.net_arch_qf = [512, 512] print(self.net_arch_qf) self.n_critics = n_critics self.use_sde = use_sde From 9e2c6dd31f6037e3631cd0395dbd2e3f741eab24 Mon Sep 17 00:00:00 2001 From: naumix Date: Sun, 3 Nov 2024 13:47:04 +0100 Subject: [PATCH 08/21] add --- sbx/bro/policies.py | 2 +- scripts/gym/humanoid.sh | 13 +++++++++++++ train_torch.py | 15 ++++++++------- 3 files changed, 22 insertions(+), 8 deletions(-) create mode 100644 scripts/gym/humanoid.sh diff --git a/sbx/bro/policies.py b/sbx/bro/policies.py index 576ef80..dc2a56e 100644 --- a/sbx/bro/policies.py +++ b/sbx/bro/policies.py @@ -117,7 +117,7 @@ def __init__( n_quantiles: int = 100, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, dropout_rate: float = 0.0, - layer_norm: bool = False, + layer_norm: bool = True, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, # Note: most gSDE parameters are not used diff --git a/scripts/gym/humanoid.sh b/scripts/gym/humanoid.sh new file mode 100644 index 0000000..7d2a7d2 --- /dev/null +++ b/scripts/gym/humanoid.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# + +source ~/miniconda3/bin/activate +conda init bash +source ~/.bashrc +conda activate sbx + +module load cuDNN/8.9.2.26-CUDA-12.2.0 + +python3 train_torch.py --benchmark=gym --env_name=Humanoid-v4 + +wait diff --git a/train_torch.py b/train_torch.py index 0aa57d1..f747f7d 100644 --- a/train_torch.py +++ b/train_torch.py @@ -8,19 +8,20 @@ flags.DEFINE_string('env_name', 'cheetah-run', 'Environment name.') flags.DEFINE_string('benchmark', 'dmc', 'Environment name.') -flags.DEFINE_integer('learning_starts', 2000, 'Number of training steps to start training.') +flags.DEFINE_integer('learning_starts', 2500, 'Number of training steps to start training.') flags.DEFINE_integer('training_steps', 1000000, 'Number of training steps.') -flags.DEFINE_integer('batch_size', 256, 'Mini batch size.') +flags.DEFINE_integer('batch_size', 128, 'Mini batch size.') flags.DEFINE_integer('gradient_steps', 2, 'Number of updates per step.') -flags.DEFINE_integer('n_quantiles', 100, 'Number of training steps.') +flags.DEFINE_integer('n_quantiles', 1, 'Number of training steps.') flags.DEFINE_integer('eval_freq', 25000, 'Eval interval.') flags.DEFINE_integer('num_episodes', 5, 'Number of episodes used for evaluation.') FLAGS = flags.FLAGS ''' class flags: + benchmark: str = 'dmc' env_name: str = "cheetah-run" - learning_starts: int = 5000 + learning_starts: int = 9999 training_steps: int = 10_000 seed: int = 0 batch_size: int = 128 @@ -69,15 +70,15 @@ def main(_): entity='naumix', project='BRO_SBX', group=f'{FLAGS.env_name}', - name=f'BRO_Quantile:{FLAGS.n_quantiles}_BS:{FLAGS.batch_size}_{SEED}' + name=f'BRO_Quantile:{FLAGS.n_quantiles}_BS:{FLAGS.batch_size}_{SEED}_repro' ) env = get_env(FLAGS.benchmark, FLAGS.env_name) eval_env = get_env(FLAGS.benchmark, FLAGS.env_name) - model = BRO("MlpPolicy", env, learning_starts=FLAGS.learning_starts, verbose=0, n_quantiles=FLAGS.n_quantiles, seed=SEED, batch_size=FLAGS.batch_size, learning_starts=FLAGS.learning_starts, gradient_steps=FLAGS.gradient_steps) + model = BRO("MlpPolicy", env, learning_starts=FLAGS.learning_starts, verbose=0, n_quantiles=FLAGS.n_quantiles, seed=SEED, batch_size=FLAGS.batch_size, gradient_steps=FLAGS.gradient_steps) np.random.seed(SEED) - reset_list = [20000] + reset_list = [15000] obs, _ = env.reset(seed=np.random.randint(1e7)) obs = np.expand_dims(obs, axis=0) From c5aef9394c97a013ac5262221af9cfa9b8c19a37 Mon Sep 17 00:00:00 2001 From: naumix Date: Mon, 4 Nov 2024 00:18:36 +0100 Subject: [PATCH 09/21] add --- sbx/bro/bro.py | 41 +++++++++++++++++++++++++++++++++-------- sbx/bro/policies.py | 2 +- train_torch.py | 4 ++-- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/sbx/bro/bro.py b/sbx/bro/bro.py index 70b05c6..3f3cac6 100644 --- a/sbx/bro/bro.py +++ b/sbx/bro/bro.py @@ -75,6 +75,7 @@ def __init__( env: Union[GymEnv, str], #BRO n_quantiles: int = 100, + pessimism: float = 0.0, learning_rate: Union[float, Schedule] = 3e-4, qf_learning_rate: Optional[float] = None, buffer_size: int = 1_000_000, # 1e6 @@ -134,7 +135,7 @@ def __init__( self.n_quantiles = n_quantiles taus_ = jnp.arange(0, n_quantiles+1) / n_quantiles self.quantile_taus = ((taus_[1:] + taus_[:-1]) / 2.0)[None, ..., None] - + self.pessimism = pessimism self.distributional = True if self.n_quantiles > 1 else False if _init_setup_model: @@ -312,6 +313,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.ent_coef_state, self.quantile_taus, self.distributional, + self.pessimism, self.key, ) self._n_updates += gradient_steps @@ -338,6 +340,7 @@ def update_critic( rewards: jax.Array, dones: jax.Array, quantile_taus: jax.Array, + pessimism: float, key: jax.Array, ): key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) @@ -354,8 +357,13 @@ def update_critic( next_state_actions, rngs={"dropout": dropout_key_target}, ) - - next_q_values = jnp.mean(qf_next_values, axis=0) + + ensemble_size = qf_next_values.shape[0] + diff = jnp.abs(qf_next_values[:, None, :, :] - qf_next_values[None, :, :, :]) / 2 + i, j = jnp.triu_indices(ensemble_size, k=1) # Get indices for upper triangle without the diagonal + critic_disagreement = jnp.mean(diff[i, j], axis=0) + + next_q_values = jnp.mean(qf_next_values, axis=0) - pessimism * critic_disagreement # td error + entropy term next_q_values = next_q_values - ent_coef_value * next_log_prob.reshape(-1, 1) # shape is (batch_size, 1) @@ -388,6 +396,7 @@ def update_critic_quantile( rewards: jax.Array, dones: jax.Array, quantile_taus: jax.Array, + pessimism: float, key: jax.Array, ): @@ -405,17 +414,22 @@ def update_critic_quantile( next_state_actions, rngs={"dropout": dropout_key_target}, ) - - next_q_values = jnp.mean(qf_next_values, axis=0) + + # calculate disagreement + ensemble_size = qf_next_values.shape[0] + diff = jnp.abs(qf_next_values[:, None, :, :] - qf_next_values[None, :, :, :]) / 2 + i, j = jnp.triu_indices(ensemble_size, k=1) # Get indices for upper triangle without the diagonal + critic_disagreement = jnp.mean(diff[i, j], axis=0) + next_q_values = jnp.mean(qf_next_values, axis=0) - pessimism * critic_disagreement # entropy term next_q_values = next_q_values - ent_coef_value * next_log_prob.reshape(-1, 1) - # shape is (batch_size, n_quantiles, n_quantiles) + # shape is (batch_size, 1, n_quantiles) target_q_values = rewards[..., None, None] + (1 - dones[..., None, None]) * gamma * next_q_values[:, None, :] def quantile_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: # shape is (n_critics, batch_size, 1) current_q_values = qf_state.apply_fn(params, observations, actions, rngs={"dropout": dropout_key}) - quantile_td_error = current_q_values[..., None] - target_q_values[None, ...] + quantile_td_error = target_q_values[None, ...] - current_q_values[..., None] def calculate_quantile_huber_loss(quantile_td_error: jnp.ndarray, quantile_taus: jnp.ndarray, kappa: float = 1.0) -> jnp.ndarray: element_wise_huber_loss = jnp.where(jnp.absolute(quantile_td_error) <= kappa, 0.5 * quantile_td_error ** 2, kappa * (jnp.absolute(quantile_td_error) - 0.5 * kappa)) mask = jax.lax.stop_gradient(jnp.where(quantile_td_error < 0, 1, 0)) # detach this @@ -441,6 +455,7 @@ def update_actor( qf_state: RLTrainState, ent_coef_state: TrainState, observations: jax.Array, + pessimism: float, key: jax.Array, ): key, dropout_key, noise_key = jax.random.split(key, 3) @@ -458,7 +473,12 @@ def actor_loss(params: flax.core.FrozenDict) -> Tuple[jax.Array, jax.Array]: ) # Take mean among all critics - qf_pi_lb = jnp.mean(qf_pi, axis=0) + ensemble_size = qf_pi.shape[0] + diff = jnp.abs(qf_pi[:, None, :, :] - qf_pi[None, :, :, :]) / 2 + i, j = jnp.triu_indices(ensemble_size, k=1) # Get indices for upper triangle without the diagonal + critic_disagreement = jnp.mean(diff[i, j], axis=0) + + qf_pi_lb = jnp.mean(qf_pi, axis=0) - pessimism * critic_disagreement qf_pi_lb = jnp.mean(qf_pi_lb, axis=-1, keepdims=True) ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) actor_loss = (ent_coef_value * log_prob - qf_pi_lb).mean() @@ -496,6 +516,7 @@ def update_actor_and_temperature( ent_coef_state: TrainState, observations: jax.Array, target_entropy: ArrayLike, + pessimism: float, key: jax.Array, ): (actor_state, qf_state, actor_loss_value, key, entropy) = cls.update_actor( @@ -503,6 +524,7 @@ def update_actor_and_temperature( qf_state, ent_coef_state, observations, + pessimism, key, ) ent_coef_state, ent_coef_loss_value = cls.update_temperature(target_entropy, ent_coef_state, entropy) @@ -557,6 +579,7 @@ def _train( ent_coef_state: TrainState, quantile_taus: jax.Array, distributional: bool, + pessimism: float, key: jax.Array, ): assert data.observations.shape[0] % gradient_steps == 0 @@ -604,6 +627,7 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: batch_rew, batch_done, quantile_taus, + pessimism, key, ) @@ -620,6 +644,7 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: ent_coef_state, batch_obs, target_entropy, + pessimism, key, ) info = {"actor_loss": actor_loss_value, "qf_loss": qf_loss_value, "ent_coef_loss": ent_coef_loss_value} diff --git a/sbx/bro/policies.py b/sbx/bro/policies.py index dc2a56e..992fa30 100644 --- a/sbx/bro/policies.py +++ b/sbx/bro/policies.py @@ -115,6 +115,7 @@ def __init__( lr_schedule: Schedule, # BRO n_quantiles: int = 100, + n_critics: int = 2, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, dropout_rate: float = 0.0, layer_norm: bool = True, @@ -130,7 +131,6 @@ def __init__( normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adamw, optimizer_kwargs: Optional[Dict[str, Any]] = None, - n_critics: int = 2, share_features_extractor: bool = False, ): if optimizer_kwargs is None: diff --git a/train_torch.py b/train_torch.py index f747f7d..ea72cd8 100644 --- a/train_torch.py +++ b/train_torch.py @@ -20,7 +20,7 @@ ''' class flags: benchmark: str = 'dmc' - env_name: str = "cheetah-run" + env_name: str = "walker-walk" learning_starts: int = 9999 training_steps: int = 10_000 seed: int = 0 @@ -109,7 +109,7 @@ def main(_): eval_info = evaluate(eval_env, model, FLAGS.num_episodes) stat_info = model.get_stats(FLAGS.batch_size) info = {**eval_info, **train_info, **stat_info} - #print(eval_info) + print(eval_info) log_to_wandb(i, info) if __name__ == '__main__': From c40d2d9b3dfc988f0e0473c3798a7f89c5dffecb Mon Sep 17 00:00:00 2001 From: naumix Date: Mon, 4 Nov 2024 00:22:09 +0100 Subject: [PATCH 10/21] Update train_torch.py --- train_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_torch.py b/train_torch.py index ea72cd8..0fd5825 100644 --- a/train_torch.py +++ b/train_torch.py @@ -12,7 +12,7 @@ flags.DEFINE_integer('training_steps', 1000000, 'Number of training steps.') flags.DEFINE_integer('batch_size', 128, 'Mini batch size.') flags.DEFINE_integer('gradient_steps', 2, 'Number of updates per step.') -flags.DEFINE_integer('n_quantiles', 1, 'Number of training steps.') +flags.DEFINE_integer('n_quantiles', 100, 'Number of training steps.') flags.DEFINE_integer('eval_freq', 25000, 'Eval interval.') flags.DEFINE_integer('num_episodes', 5, 'Number of episodes used for evaluation.') FLAGS = flags.FLAGS From a826d1c8c2a0dd73dc1eb54a171d274a7e6aa280 Mon Sep 17 00:00:00 2001 From: naumix Date: Sat, 9 Nov 2024 18:20:11 -0800 Subject: [PATCH 11/21] Update bro.py --- sbx/bro/bro.py | 55 +------------------------------------------------- 1 file changed, 1 insertion(+), 54 deletions(-) diff --git a/sbx/bro/bro.py b/sbx/bro/bro.py index 3f3cac6..fa27148 100644 --- a/sbx/bro/bro.py +++ b/sbx/bro/bro.py @@ -143,60 +143,7 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() - - if not hasattr(self, "policy") or self.policy is None: - self.policy = self.policy_class( # type: ignore[assignment] - self.observation_space, - self.action_space, - self.lr_schedule, - self.n_quantiles, - **self.policy_kwargs, - ) - - assert isinstance(self.qf_learning_rate, float) - self.key = self.policy.build(self.init_key, self.lr_schedule, self.qf_learning_rate) - - self.key, ent_key = jax.random.split(self.key, 2) - - self.actor = self.policy.actor # type: ignore[assignment] - self.qf = self.policy.qf # type: ignore[assignment] - - # The entropy coefficient or entropy can be learned automatically - # see Automating Entropy Adjustment for Maximum Entropy RL section - # of https://arxiv.org/abs/1812.05905 - if isinstance(self.ent_coef_init, str) and self.ent_coef_init.startswith("auto"): - # Default initial value of ent_coef when learned - ent_coef_init = 1.0 - if "_" in self.ent_coef_init: - ent_coef_init = float(self.ent_coef_init.split("_")[1]) - assert ent_coef_init > 0.0, "The initial value of ent_coef must be greater than 0" - - # Note: we optimize the log of the entropy coeff which is slightly different from the paper - # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 - self.ent_coef = EntropyCoef(ent_coef_init) - else: - # This will throw an error if a malformed string (different from 'auto') is passed - assert isinstance( - self.ent_coef_init, float - ), f"Entropy coef must be float when not equal to 'auto', actual: {self.ent_coef_init}" - self.ent_coef = ConstantEntropyCoef(self.ent_coef_init) # type: ignore[assignment] - - self.ent_coef_state = TrainState.create( - apply_fn=self.ent_coef.apply, - params=self.ent_coef.init(ent_key)["params"], - tx=optax.adam( - learning_rate=self.learning_rate, b1=0.5 - ), - ) - - # Target entropy is used when learning the entropy coefficient - if self.target_entropy == "auto": - # automatically set target entropy if needed - self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) / 2 # type: ignore - else: - # Force conversion - # this will also throw an error for unexpected string - self.target_entropy = float(self.target_entropy) + self.reset() def reset(self): if not hasattr(self, "policy") or self.policy is None: From 826bec93c43f48b392ca33b9a2af2cd0b44bab80 Mon Sep 17 00:00:00 2001 From: naumix Date: Sat, 9 Nov 2024 18:23:57 -0800 Subject: [PATCH 12/21] compliance with sbx --- sbx/bro/bro.py | 46 ++------------------ train_torch.py | 116 ------------------------------------------------- 2 files changed, 4 insertions(+), 158 deletions(-) delete mode 100644 train_torch.py diff --git a/sbx/bro/bro.py b/sbx/bro/bro.py index fa27148..139f3b1 100644 --- a/sbx/bro/bro.py +++ b/sbx/bro/bro.py @@ -264,15 +264,10 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.key, ) self._n_updates += gradient_steps - return { - 'actor_loss': actor_loss_value.item(), - 'critic_loss': qf_loss_value.item(), - 'ent_loss': ent_coef_value.item(), - } - #self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") - #self.logger.record("train/actor_loss", actor_loss_value.item()) - #self.logger.record("train/critic_loss", qf_loss_value.item()) - #self.logger.record("train/ent_coef", ent_coef_value.item()) + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/actor_loss", actor_loss_value.item()) + self.logger.record("train/critic_loss", qf_loss_value.item()) + self.logger.record("train/ent_coef", ent_coef_value.item()) @staticmethod @jax.jit @@ -476,39 +471,6 @@ def update_actor_and_temperature( ) ent_coef_state, ent_coef_loss_value = cls.update_temperature(target_entropy, ent_coef_state, entropy) return actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key - - def get_stats(self, batch_size): - data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) - - if isinstance(data.observations, dict): - keys = list(self.observation_space.keys()) # type: ignore[attr-defined] - obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1) - next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1) - else: - obs = data.observations.numpy() - next_obs = data.next_observations.numpy() - - # Convert to numpy - data = ReplayBufferSamplesNp( # type: ignore[assignment] - obs, - data.actions.numpy(), - next_obs, - data.dones.numpy().flatten(), - data.rewards.numpy().flatten(), - ) - q, a, temp, ent = _get_stats( - self.policy.actor_state, - self.policy.qf_state, - self.ent_coef_state, - obs, - self.key, - ) - - return { - 'q': q.mean().item(), - 'a': a.item(), - 'temp': temp.item(), - 'entropy': ent.item()} @classmethod @partial(jax.jit, static_argnames=["cls", "gradient_steps", "policy_delay", "policy_delay_offset", "distributional"]) diff --git a/train_torch.py b/train_torch.py deleted file mode 100644 index 0fd5825..0000000 --- a/train_torch.py +++ /dev/null @@ -1,116 +0,0 @@ -import gymnasium as gym -from sbx.bro.bro import BRO -import numpy as np -from make_dmc import make_env_dmc -import wandb - -from absl import app, flags - -flags.DEFINE_string('env_name', 'cheetah-run', 'Environment name.') -flags.DEFINE_string('benchmark', 'dmc', 'Environment name.') -flags.DEFINE_integer('learning_starts', 2500, 'Number of training steps to start training.') -flags.DEFINE_integer('training_steps', 1000000, 'Number of training steps.') -flags.DEFINE_integer('batch_size', 128, 'Mini batch size.') -flags.DEFINE_integer('gradient_steps', 2, 'Number of updates per step.') -flags.DEFINE_integer('n_quantiles', 100, 'Number of training steps.') -flags.DEFINE_integer('eval_freq', 25000, 'Eval interval.') -flags.DEFINE_integer('num_episodes', 5, 'Number of episodes used for evaluation.') -FLAGS = flags.FLAGS - -''' -class flags: - benchmark: str = 'dmc' - env_name: str = "walker-walk" - learning_starts: int = 9999 - training_steps: int = 10_000 - seed: int = 0 - batch_size: int = 128 - gradient_steps: int = 2 - use_wandb: bool = False - n_quantiles: int = 100 - eval_freq: int = 5000 - num_episodes: int = 5 -FLAGS = flags() -''' - -def evaluate(env, model, num_episodes): - returns = np.zeros(num_episodes) - for episode in range(num_episodes): - not_done = True - obs, _ = env.reset(seed=np.random.randint(1e7)) - obs = np.expand_dims(obs, axis=0) - ret = 0 - while not_done: - action = model.policy.forward(obs, deterministic=True)[0] - next_obs, reward, term, trun, info = env.step(action) - next_obs = np.expand_dims(next_obs, axis=0) - obs = next_obs - ret += reward - if term or trun : - not_done = False - returns[episode] = ret - return {'return_eval': returns.mean()} - -def log_to_wandb(step, infos): - dict_to_log = {'timestep': step} - for info_key in infos: - dict_to_log[f'{info_key}'] = infos[info_key] - wandb.log(dict_to_log, step=step) - -def get_env(benchmark, env_name): - if benchmark == 'gym': - return gym.make(FLAGS.env_name) - else: - return make_env_dmc(env_name=FLAGS.env_name, action_repeat=1) - -def main(_): - SEED = np.random.randint(1e7) - wandb.init( - config=FLAGS, - entity='naumix', - project='BRO_SBX', - group=f'{FLAGS.env_name}', - name=f'BRO_Quantile:{FLAGS.n_quantiles}_BS:{FLAGS.batch_size}_{SEED}_repro' - ) - - env = get_env(FLAGS.benchmark, FLAGS.env_name) - eval_env = get_env(FLAGS.benchmark, FLAGS.env_name) - model = BRO("MlpPolicy", env, learning_starts=FLAGS.learning_starts, verbose=0, n_quantiles=FLAGS.n_quantiles, seed=SEED, batch_size=FLAGS.batch_size, gradient_steps=FLAGS.gradient_steps) - np.random.seed(SEED) - - reset_list = [15000] - obs, _ = env.reset(seed=np.random.randint(1e7)) - obs = np.expand_dims(obs, axis=0) - - for i in range(1, FLAGS.training_steps+1): - if i <= FLAGS.learning_starts: - action = env.action_space.sample() - else: - action = model.policy.forward(obs, deterministic=False)[0] - next_obs, reward, term, trun, info = env.step(action) - next_obs = np.expand_dims(next_obs, axis=0) - - done = 1.0 if (term and not trun) else 0.0 - - model.replay_buffer.add(obs, next_obs, action, reward, done, info) - if term or trun: - obs, _ = env.reset(seed=np.random.randint(1e7)) - obs = np.expand_dims(obs, axis=0) - else: - obs = next_obs - - if i in reset_list: - model.reset() - - if i >= FLAGS.learning_starts: - train_info = model.train(FLAGS.gradient_steps, FLAGS.batch_size) - - if i % FLAGS.eval_freq == 0: - eval_info = evaluate(eval_env, model, FLAGS.num_episodes) - stat_info = model.get_stats(FLAGS.batch_size) - info = {**eval_info, **train_info, **stat_info} - print(eval_info) - log_to_wandb(i, info) - -if __name__ == '__main__': - app.run(main) From a9ac2501a0520b1fd07f69b9c2342b986509393e Mon Sep 17 00:00:00 2001 From: naumix Date: Sat, 9 Nov 2024 18:58:22 -0800 Subject: [PATCH 13/21] blackify --- make_dmc.py | 157 ----------------------------- sbx/bro/bro.py | 71 +++++++------ sbx/bro/policies.py | 29 +++--- sbx/common/off_policy_algorithm.py | 5 +- scripts/dmc/acrobot.sh | 13 --- scripts/dmc/cheetah.sh | 13 --- scripts/dmc/dog_run.sh | 13 --- scripts/dmc/dog_stand.sh | 13 --- scripts/dmc/dog_trot.sh | 13 --- scripts/dmc/dog_walk.sh | 13 --- scripts/dmc/finger.sh | 13 --- scripts/dmc/fish.sh | 13 --- scripts/dmc/hopper.sh | 13 --- scripts/dmc/humanoid_run.sh | 13 --- scripts/dmc/humanoid_stand.sh | 13 --- scripts/dmc/humanoid_walk.sh | 13 --- scripts/dmc/pendulum.sh | 13 --- scripts/dmc/quadruped.sh | 13 --- scripts/dmc/walker.sh | 13 --- scripts/gym/ant.sh | 13 --- scripts/gym/cheetah.sh | 13 --- scripts/gym/hopper.sh | 13 --- scripts/gym/humanoid.sh | 13 --- scripts/gym/walker.sh | 13 --- scripts/run_tests.sh | 2 + scripts/test_bro.py | 7 ++ 26 files changed, 65 insertions(+), 466 deletions(-) delete mode 100644 make_dmc.py delete mode 100644 scripts/dmc/acrobot.sh delete mode 100644 scripts/dmc/cheetah.sh delete mode 100644 scripts/dmc/dog_run.sh delete mode 100644 scripts/dmc/dog_stand.sh delete mode 100644 scripts/dmc/dog_trot.sh delete mode 100644 scripts/dmc/dog_walk.sh delete mode 100644 scripts/dmc/finger.sh delete mode 100644 scripts/dmc/fish.sh delete mode 100644 scripts/dmc/hopper.sh delete mode 100644 scripts/dmc/humanoid_run.sh delete mode 100644 scripts/dmc/humanoid_stand.sh delete mode 100644 scripts/dmc/humanoid_walk.sh delete mode 100644 scripts/dmc/pendulum.sh delete mode 100644 scripts/dmc/quadruped.sh delete mode 100644 scripts/dmc/walker.sh delete mode 100644 scripts/gym/ant.sh delete mode 100644 scripts/gym/cheetah.sh delete mode 100644 scripts/gym/hopper.sh delete mode 100644 scripts/gym/humanoid.sh delete mode 100644 scripts/gym/walker.sh create mode 100644 scripts/run_tests.sh create mode 100644 scripts/test_bro.py diff --git a/make_dmc.py b/make_dmc.py deleted file mode 100644 index 57374dd..0000000 --- a/make_dmc.py +++ /dev/null @@ -1,157 +0,0 @@ -#adapted from https://github.com/imgeorgiev/dmc2gymnasium - -import logging -import numpy as np - -from dm_control import suite -from dm_env import specs -from gymnasium.core import Env -from gymnasium.spaces import Box -from gymnasium import spaces -from gymnasium.wrappers import FlattenObservation, RescaleAction - -def _spec_to_box(spec, dtype=np.float32): - def extract_min_max(s): - assert s.dtype == np.float64 or s.dtype == np.float32 - dim = int(np.prod(s.shape)) - if type(s) == specs.Array: - bound = np.inf * np.ones(dim, dtype=np.float32) - return -bound, bound - elif type(s) == specs.BoundedArray: - zeros = np.zeros(dim, dtype=np.float32) - return s.minimum + zeros, s.maximum + zeros - else: - logging.error("Unrecognized type") - mins, maxs = [], [] - for s in spec: - mn, mx = extract_min_max(s) - mins.append(mn) - maxs.append(mx) - low = np.concatenate(mins, axis=0).astype(dtype) - high = np.concatenate(maxs, axis=0).astype(dtype) - assert low.shape == high.shape - return Box(low, high, dtype=dtype) - - -def _flatten_obs(obs, dtype=np.float32): - obs_pieces = [] - for v in obs.values(): - flat = np.array([v]) if np.isscalar(v) else v.ravel() - obs_pieces.append(flat) - return np.concatenate(obs_pieces, axis=0).astype(dtype) - - -class DMCGym(Env): - def __init__( - self, - env_name, - task_kwargs={}, - environment_kwargs={}, - #rendering="egl", - render_height=64, - render_width=64, - render_camera_id=0, - action_repeat=1 - ): - domain = env_name.split('-')[0] - task = env_name.split('-')[1] - self._env = suite.load( - domain, - task, - task_kwargs, - environment_kwargs, - ) - - # placeholder to allow built in gymnasium rendering - self.render_mode = "rgb_array" - self.render_height = render_height - self.render_width = render_width - self.render_camera_id = render_camera_id - - self._true_action_space = _spec_to_box([self._env.action_spec()], np.float32) - self._norm_action_space = spaces.Box( - low=-1.0, - high=1.0, - shape=self._true_action_space.shape, - dtype=np.float32 - ) - - self._observation_space = _spec_to_box(self._env.observation_spec().values()) - self._action_space = _spec_to_box([self._env.action_spec()]) - self.action_repeat = action_repeat - - # set seed if provided with task_kwargs - if "random" in task_kwargs: - seed = task_kwargs["random"] - self._observation_space.seed(seed) - self._action_space.seed(seed) - - def __getattr__(self, name): - """Add this here so that we can easily access attributes of the underlying env""" - return getattr(self._env, name) - - @property - def observation_space(self): - return self._observation_space - - @property - def action_space(self): - return self._action_space - - @property - def reward_range(self): - """DMC always has a per-step reward range of (0, 1)""" - return 0, 1 - - def _convert_action(self, action): - action = action.astype(np.float64) - true_delta = self._true_action_space.high - self._true_action_space.low - norm_delta = self._norm_action_space.high - self._norm_action_space.low - action = (action - self._norm_action_space.low) / norm_delta - action = action * true_delta + self._true_action_space.low - action = action.astype(np.float32) - return action - - def step(self, action): - assert self._norm_action_space.contains(action) - action = self._convert_action(action) - assert self._true_action_space.contains(action) - action = np.clip(action, -1.0, 1.0) - reward = 0 - info = {} - for i in range(self.action_repeat): - timestep = self._env.step(action) - observation = _flatten_obs(timestep.observation) - reward += timestep.reward - termination = False # we never reach a goal - truncation = timestep.last() - if truncation: - return observation, reward, termination, truncation, info - return observation, reward, termination, truncation, info - - def reset(self, seed=None, options=None): - if seed is not None: - if not isinstance(seed, np.random.RandomState): - seed = np.random.RandomState(seed) - self._env.task._random = seed - if options: - logging.warn("Currently doing nothing with options={:}".format(options)) - timestep = self._env.reset() - observation = _flatten_obs(timestep.observation) - info = {} - return observation, info - - def render(self, height=None, width=None, camera_id=None): - height = height or self.render_height - width = width or self.render_width - camera_id = camera_id or self.render_camera_id - return self._env.physics.render(height=height, width=width, camera_id=camera_id) - - -def make_env_dmc(env_name: str, action_repeat: int = 1) -> Env: - env = DMCGym(env_name=env_name, action_repeat=action_repeat) - env = RescaleAction(env, -1.0, 1.0) - env = FlattenObservation(env) - return env - - \ No newline at end of file diff --git a/sbx/bro/bro.py b/sbx/bro/bro.py index 139f3b1..faca6da 100644 --- a/sbx/bro/bro.py +++ b/sbx/bro/bro.py @@ -19,7 +19,6 @@ from sbx.bro.policies import BROPolicy - class EntropyCoef(nn.Module): ent_coef_init: float = 1.0 @@ -27,7 +26,8 @@ class EntropyCoef(nn.Module): def __call__(self) -> jnp.ndarray: log_ent_coef = self.param("log_ent_coef", init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init))) return jnp.exp(log_ent_coef) - + + class ConstantEntropyCoef(nn.Module): ent_coef_init: float = 1.0 @@ -38,6 +38,7 @@ def __call__(self) -> float: self.param("dummy_param", init_fn=lambda key: jnp.full((), self.ent_coef_init)) return self.ent_coef_init + @jax.jit def _get_stats( actor_state: RLTrainState, @@ -58,7 +59,8 @@ def _get_stats( ) ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) return qf_pi.mean(), jnp.absolute(actor_actions).mean(), ent_coef_value.mean(), -log_prob.mean() - + + class BRO(OffPolicyAlgorithmJax): policy_aliases: ClassVar[Dict[str, Type[BROPolicy]]] = { # type: ignore[assignment] "MlpPolicy": BROPolicy, @@ -73,13 +75,13 @@ def __init__( self, policy, env: Union[GymEnv, str], - #BRO + # BRO n_quantiles: int = 100, - pessimism: float = 0.0, + pessimism: float = 0.1, learning_rate: Union[float, Schedule] = 3e-4, qf_learning_rate: Optional[float] = None, buffer_size: int = 1_000_000, # 1e6 - learning_starts: int = 100, + learning_starts: int = 2_500, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, @@ -130,13 +132,12 @@ def __init__( self.policy_delay = policy_delay self.ent_coef_init = ent_coef self.target_entropy = target_entropy - self.init_key = jax.random.PRNGKey(seed) - + self.n_quantiles = n_quantiles - taus_ = jnp.arange(0, n_quantiles+1) / n_quantiles + taus_ = jnp.arange(0, n_quantiles + 1) / n_quantiles self.quantile_taus = ((taus_[1:] + taus_[:-1]) / 2.0)[None, ..., None] self.pessimism = pessimism - + self.distributional = True if self.n_quantiles > 1 else False if _init_setup_model: self._setup_model() @@ -144,7 +145,7 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() self.reset() - + def reset(self): if not hasattr(self, "policy") or self.policy is None: self.policy = self.policy_class( # type: ignore[assignment] @@ -157,7 +158,7 @@ def reset(self): assert isinstance(self.qf_learning_rate, float) - self.key = self.policy.build(self.init_key, self.lr_schedule, self.qf_learning_rate) + self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) self.key, ent_key = jax.random.split(self.key, 2) @@ -187,9 +188,7 @@ def reset(self): self.ent_coef_state = TrainState.create( apply_fn=self.ent_coef.apply, params=self.ent_coef.init(ent_key)["params"], - tx=optax.adam( - learning_rate=self.learning_rate, b1=0.5 - ), + tx=optax.adam(learning_rate=self.learning_rate, b1=0.5), ) # Target entropy is used when learning the entropy coefficient @@ -299,12 +298,12 @@ def update_critic( next_state_actions, rngs={"dropout": dropout_key_target}, ) - + ensemble_size = qf_next_values.shape[0] diff = jnp.abs(qf_next_values[:, None, :, :] - qf_next_values[None, :, :, :]) / 2 - i, j = jnp.triu_indices(ensemble_size, k=1) # Get indices for upper triangle without the diagonal + i, j = jnp.triu_indices(ensemble_size, k=1) # Get indices for upper triangle without the diagonal critic_disagreement = jnp.mean(diff[i, j], axis=0) - + next_q_values = jnp.mean(qf_next_values, axis=0) - pessimism * critic_disagreement # td error + entropy term next_q_values = next_q_values - ent_coef_value * next_log_prob.reshape(-1, 1) @@ -324,7 +323,7 @@ def mse_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: (qf_loss_value, ent_coef_value), key, ) - + @staticmethod @jax.jit def update_critic_quantile( @@ -356,11 +355,11 @@ def update_critic_quantile( next_state_actions, rngs={"dropout": dropout_key_target}, ) - + # calculate disagreement ensemble_size = qf_next_values.shape[0] diff = jnp.abs(qf_next_values[:, None, :, :] - qf_next_values[None, :, :, :]) / 2 - i, j = jnp.triu_indices(ensemble_size, k=1) # Get indices for upper triangle without the diagonal + i, j = jnp.triu_indices(ensemble_size, k=1) # Get indices for upper triangle without the diagonal critic_disagreement = jnp.mean(diff[i, j], axis=0) next_q_values = jnp.mean(qf_next_values, axis=0) - pessimism * critic_disagreement # entropy term @@ -372,12 +371,20 @@ def quantile_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.A # shape is (n_critics, batch_size, 1) current_q_values = qf_state.apply_fn(params, observations, actions, rngs={"dropout": dropout_key}) quantile_td_error = target_q_values[None, ...] - current_q_values[..., None] - def calculate_quantile_huber_loss(quantile_td_error: jnp.ndarray, quantile_taus: jnp.ndarray, kappa: float = 1.0) -> jnp.ndarray: - element_wise_huber_loss = jnp.where(jnp.absolute(quantile_td_error) <= kappa, 0.5 * quantile_td_error ** 2, kappa * (jnp.absolute(quantile_td_error) - 0.5 * kappa)) - mask = jax.lax.stop_gradient(jnp.where(quantile_td_error < 0, 1, 0)) # detach this + + def calculate_quantile_huber_loss( + quantile_td_error: jnp.ndarray, quantile_taus: jnp.ndarray, kappa: float = 1.0 + ) -> jnp.ndarray: + element_wise_huber_loss = jnp.where( + jnp.absolute(quantile_td_error) <= kappa, + 0.5 * quantile_td_error**2, + kappa * (jnp.absolute(quantile_td_error) - 0.5 * kappa), + ) + mask = jax.lax.stop_gradient(jnp.where(quantile_td_error < 0, 1, 0)) # detach this element_wise_quantile_huber_loss = jnp.absolute(quantile_taus - mask) * element_wise_huber_loss / kappa quantile_huber_loss = element_wise_quantile_huber_loss.sum(axis=0).sum(axis=1).mean() return quantile_huber_loss + quantile_huber_loss = calculate_quantile_huber_loss(quantile_td_error, quantile_taus) return quantile_huber_loss @@ -397,7 +404,7 @@ def update_actor( qf_state: RLTrainState, ent_coef_state: TrainState, observations: jax.Array, - pessimism: float, + pessimism: float, key: jax.Array, ): key, dropout_key, noise_key = jax.random.split(key, 3) @@ -413,13 +420,13 @@ def actor_loss(params: flax.core.FrozenDict) -> Tuple[jax.Array, jax.Array]: actor_actions, rngs={"dropout": dropout_key}, ) - + # Take mean among all critics ensemble_size = qf_pi.shape[0] diff = jnp.abs(qf_pi[:, None, :, :] - qf_pi[None, :, :, :]) / 2 - i, j = jnp.triu_indices(ensemble_size, k=1) # Get indices for upper triangle without the diagonal + i, j = jnp.triu_indices(ensemble_size, k=1) # Get indices for upper triangle without the diagonal critic_disagreement = jnp.mean(diff[i, j], axis=0) - + qf_pi_lb = jnp.mean(qf_pi, axis=0) - pessimism * critic_disagreement qf_pi_lb = jnp.mean(qf_pi_lb, axis=-1, keepdims=True) ent_coef_value = ent_coef_state.apply_fn({"params": ent_coef_state.params}) @@ -471,7 +478,7 @@ def update_actor_and_temperature( ) ent_coef_state, ent_coef_loss_value = cls.update_temperature(target_entropy, ent_coef_state, entropy) return actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key - + @classmethod @partial(jax.jit, static_argnames=["cls", "gradient_steps", "policy_delay", "policy_delay_offset", "distributional"]) def _train( @@ -519,7 +526,7 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: batch_next_obs = jax.lax.dynamic_slice_in_dim(data.next_observations, i * batch_size, batch_size) batch_rew = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size) batch_done = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size) - + (qf_state, (qf_loss_value, ent_coef_value), key) = jax.lax.cond( distributional == True, # If True: @@ -538,8 +545,8 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: quantile_taus, pessimism, key, - ) - + ) + qf_state = cls.soft_update(tau, qf_state) (actor_state, qf_state, ent_coef_state, actor_loss_value, ent_coef_loss_value, key) = jax.lax.cond( diff --git a/sbx/bro/policies.py b/sbx/bro/policies.py index 992fa30..fed9bb8 100644 --- a/sbx/bro/policies.py +++ b/sbx/bro/policies.py @@ -16,6 +16,7 @@ tfd = tfp.distributions + class BroNetBlock(nn.Module): n_units: int activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @@ -29,19 +30,21 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: out = nn.LayerNorm()(out) return x + out + class BroNet(nn.Module): net_arch: Sequence[int] activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu - + @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: x = nn.Dense(self.net_arch[0], self.activation_fn)(x) x = nn.LayerNorm()(x) x = self.activation_fn(x) for n_units in self.net_arch: x = BroNetBlock(n_units)(x) return x - + + class Actor(nn.Module): net_arch: Sequence[int] action_dim: int @@ -64,7 +67,8 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), ) return dist - + + class Critic(nn.Module): net_arch: Sequence[int] n_quantiles: int = 100 @@ -78,7 +82,8 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: out = BroNet(self.net_arch, self.activation_fn)(out) out = nn.Dense(self.n_quantiles)(out) return out - + + class VectorCritic(nn.Module): net_arch: Sequence[int] n_quantiles: int = 100 @@ -105,6 +110,7 @@ def __call__(self, obs: jnp.ndarray, action: jnp.ndarray): )(obs, action) return q_values + class BROPolicy(BaseJaxPolicy): action_space: spaces.Box # type: ignore[assignment] @@ -120,7 +126,7 @@ def __init__( dropout_rate: float = 0.0, layer_norm: bool = True, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, - use_sde: bool = False, + use_sde: bool = False, # Note: most gSDE parameters are not used # this is to keep API consistent with SB3 log_std_init: float = -3, @@ -150,7 +156,7 @@ def __init__( optimizer_kwargs=optimizer_kwargs, squash_output=True, ) - + self.dropout_rate = dropout_rate self.layer_norm = layer_norm self.n_quantiles = n_quantiles @@ -171,9 +177,7 @@ def __init__( self.key = self.noise_key = jax.random.PRNGKey(0) - def build(self, key: jax.Array, - lr_schedule: Schedule, - qf_learning_rate: float) -> jax.Array: + def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) -> jax.Array: key, actor_key, qf_key, dropout_key = jax.random.split(key, 4) # Keep a key for the actor key, self.key = jax.random.split(key, 2) @@ -200,7 +204,7 @@ def build(self, key: jax.Array, params=self.actor.init(actor_key, obs), tx=self.optimizer_class( learning_rate=lr_schedule(1), # type: ignore[call-arg] - #learning_rate=qf_learning_rate, # type: ignore[call-arg] + # learning_rate=qf_learning_rate, # type: ignore[call-arg] **self.optimizer_kwargs, ), ) @@ -255,6 +259,3 @@ def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.n if not self.use_sde: self.reset_noise() return BaseJaxPolicy.sample_action(self.actor_state, observation, self.noise_key) - - - diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index 11e6c8c..7878215 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -22,7 +22,7 @@ def __init__( learning_rate: Union[float, Schedule], qf_learning_rate: Optional[float] = None, buffer_size: int = 1_000_000, # 1e6 - learning_starts: int = 100, + learning_starts: int = 2_500, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, @@ -71,7 +71,7 @@ def __init__( support_multi_env=support_multi_env, ) # Will be updated later - self.key = jax.random.PRNGKey(seed) + self.key = jax.random.PRNGKey(0) # Note: we do not allow schedule for it self.qf_learning_rate = qf_learning_rate @@ -115,7 +115,6 @@ def _setup_model(self) -> None: device="cpu", # force cpu device to easy torch -> numpy conversion n_envs=self.n_envs, optimize_memory_usage=self.optimize_memory_usage, - handle_timeout_termination=False, **replay_buffer_kwargs, ) # Convert train freq parameter to TrainFreq object diff --git a/scripts/dmc/acrobot.sh b/scripts/dmc/acrobot.sh deleted file mode 100644 index f5e59bf..0000000 --- a/scripts/dmc/acrobot.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --env_name=acrobot-swingup - -wait diff --git a/scripts/dmc/cheetah.sh b/scripts/dmc/cheetah.sh deleted file mode 100644 index 9c025df..0000000 --- a/scripts/dmc/cheetah.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --env_name=cheetah-run - -wait diff --git a/scripts/dmc/dog_run.sh b/scripts/dmc/dog_run.sh deleted file mode 100644 index 45b27ea..0000000 --- a/scripts/dmc/dog_run.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --env_name=dog-run - -wait diff --git a/scripts/dmc/dog_stand.sh b/scripts/dmc/dog_stand.sh deleted file mode 100644 index 07a4584..0000000 --- a/scripts/dmc/dog_stand.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --env_name=dog-stand - -wait diff --git a/scripts/dmc/dog_trot.sh b/scripts/dmc/dog_trot.sh deleted file mode 100644 index d7b1411..0000000 --- a/scripts/dmc/dog_trot.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --env_name=dog-trot - -wait diff --git a/scripts/dmc/dog_walk.sh b/scripts/dmc/dog_walk.sh deleted file mode 100644 index 4221da2..0000000 --- a/scripts/dmc/dog_walk.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --env_name=dog-walk - -wait diff --git a/scripts/dmc/finger.sh b/scripts/dmc/finger.sh deleted file mode 100644 index 8a10360..0000000 --- a/scripts/dmc/finger.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --env_name=finger-turn_hard - -wait diff --git a/scripts/dmc/fish.sh b/scripts/dmc/fish.sh deleted file mode 100644 index 7dca0fa..0000000 --- a/scripts/dmc/fish.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --env_name=fish-swim - -wait diff --git a/scripts/dmc/hopper.sh b/scripts/dmc/hopper.sh deleted file mode 100644 index fea40d9..0000000 --- a/scripts/dmc/hopper.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --env_name=hopper-hop - -wait diff --git a/scripts/dmc/humanoid_run.sh b/scripts/dmc/humanoid_run.sh deleted file mode 100644 index f434d0b..0000000 --- a/scripts/dmc/humanoid_run.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --env_name=humanoid-run - -wait diff --git a/scripts/dmc/humanoid_stand.sh b/scripts/dmc/humanoid_stand.sh deleted file mode 100644 index e22bff0..0000000 --- a/scripts/dmc/humanoid_stand.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --env_name=humanoid-stand - -wait diff --git a/scripts/dmc/humanoid_walk.sh b/scripts/dmc/humanoid_walk.sh deleted file mode 100644 index 64a2a37..0000000 --- a/scripts/dmc/humanoid_walk.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --env_name=humanoid-walk - -wait diff --git a/scripts/dmc/pendulum.sh b/scripts/dmc/pendulum.sh deleted file mode 100644 index 0e54c35..0000000 --- a/scripts/dmc/pendulum.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --env_name=pendulum-swingup - -wait diff --git a/scripts/dmc/quadruped.sh b/scripts/dmc/quadruped.sh deleted file mode 100644 index e91ba25..0000000 --- a/scripts/dmc/quadruped.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --env_name=quadruped-run - -wait diff --git a/scripts/dmc/walker.sh b/scripts/dmc/walker.sh deleted file mode 100644 index 3da6aa7..0000000 --- a/scripts/dmc/walker.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --env_name=walker-run - -wait diff --git a/scripts/gym/ant.sh b/scripts/gym/ant.sh deleted file mode 100644 index 1d63aee..0000000 --- a/scripts/gym/ant.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --benchmark=gym --env_name=Ant-v4 - -wait diff --git a/scripts/gym/cheetah.sh b/scripts/gym/cheetah.sh deleted file mode 100644 index 63108f6..0000000 --- a/scripts/gym/cheetah.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --benchmark=gym --env_name=HalfCheetah-v4 - -wait diff --git a/scripts/gym/hopper.sh b/scripts/gym/hopper.sh deleted file mode 100644 index 3eeca18..0000000 --- a/scripts/gym/hopper.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --benchmark=gym --env_name=Hopper-v4 - -wait diff --git a/scripts/gym/humanoid.sh b/scripts/gym/humanoid.sh deleted file mode 100644 index 7d2a7d2..0000000 --- a/scripts/gym/humanoid.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --benchmark=gym --env_name=Humanoid-v4 - -wait diff --git a/scripts/gym/walker.sh b/scripts/gym/walker.sh deleted file mode 100644 index 8a3e68c..0000000 --- a/scripts/gym/walker.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -# - -source ~/miniconda3/bin/activate -conda init bash -source ~/.bashrc -conda activate sbx - -module load cuDNN/8.9.2.26-CUDA-12.2.0 - -python3 train_torch.py --benchmark=gym --env_name=Walker2d-v4 - -wait diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh new file mode 100644 index 0000000..b356a08 --- /dev/null +++ b/scripts/run_tests.sh @@ -0,0 +1,2 @@ +#!/bin/bash +python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes -m "not expensive" \ No newline at end of file diff --git a/scripts/test_bro.py b/scripts/test_bro.py new file mode 100644 index 0000000..98304a8 --- /dev/null +++ b/scripts/test_bro.py @@ -0,0 +1,7 @@ +from sbx.bro.bro import BRO +import gymnasium as gym + +env = gym.make("Pendulum-v1") + +model = BRO("MlpPolicy", env, verbose=1) +model.learn(total_timesteps=10_000, progress_bar=True) From a8968fdbb57a6f111d463e69752773ba48de9378 Mon Sep 17 00:00:00 2001 From: naumix <43146457+naumix@users.noreply.github.com> Date: Sat, 9 Nov 2024 19:16:55 -0800 Subject: [PATCH 14/21] Delete scripts/test_bro.py --- scripts/test_bro.py | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 scripts/test_bro.py diff --git a/scripts/test_bro.py b/scripts/test_bro.py deleted file mode 100644 index 98304a8..0000000 --- a/scripts/test_bro.py +++ /dev/null @@ -1,7 +0,0 @@ -from sbx.bro.bro import BRO -import gymnasium as gym - -env = gym.make("Pendulum-v1") - -model = BRO("MlpPolicy", env, verbose=1) -model.learn(total_timesteps=10_000, progress_bar=True) From 6d0e2bdfa509a49ea49c325b132d375462b7c54d Mon Sep 17 00:00:00 2001 From: naumix <43146457+naumix@users.noreply.github.com> Date: Sat, 9 Nov 2024 19:17:24 -0800 Subject: [PATCH 15/21] Update off_policy_algorithm.py --- sbx/common/off_policy_algorithm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index 7878215..ba0b9ed 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -22,7 +22,7 @@ def __init__( learning_rate: Union[float, Schedule], qf_learning_rate: Optional[float] = None, buffer_size: int = 1_000_000, # 1e6 - learning_starts: int = 2_500, + learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, From 2fc7fba6ddd085e7a38c98e7d22621bd91cefe6e Mon Sep 17 00:00:00 2001 From: naumix <43146457+naumix@users.noreply.github.com> Date: Sat, 9 Nov 2024 19:22:27 -0800 Subject: [PATCH 16/21] Update policies.py --- sbx/bro/policies.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sbx/bro/policies.py b/sbx/bro/policies.py index fed9bb8..f89f559 100644 --- a/sbx/bro/policies.py +++ b/sbx/bro/policies.py @@ -169,8 +169,7 @@ def __init__( else: self.net_arch_pi = [256] # In the paper we use [512, 512] although we also use higher RR, here we use bigger network size to compensate for the smaller RR - self.net_arch_qf = [512, 512] - print(self.net_arch_qf) + self.net_arch_qf = [1024, 1024] self.n_critics = n_critics self.use_sde = use_sde self.activation_fn = activation_fn From b48836016152ca85abbd22e20d8599484d32f011 Mon Sep 17 00:00:00 2001 From: naumix Date: Sat, 9 Nov 2024 19:30:08 -0800 Subject: [PATCH 17/21] Update policies.py --- sbx/bro/policies.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sbx/bro/policies.py b/sbx/bro/policies.py index fed9bb8..f89f559 100644 --- a/sbx/bro/policies.py +++ b/sbx/bro/policies.py @@ -169,8 +169,7 @@ def __init__( else: self.net_arch_pi = [256] # In the paper we use [512, 512] although we also use higher RR, here we use bigger network size to compensate for the smaller RR - self.net_arch_qf = [512, 512] - print(self.net_arch_qf) + self.net_arch_qf = [1024, 1024] self.n_critics = n_critics self.use_sde = use_sde self.activation_fn = activation_fn From 2bb4ccae5362ef1976f158a3cc82f7bcb2d36af3 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 10 Nov 2024 14:41:47 +0100 Subject: [PATCH 18/21] Fix import, argument and add run test --- sbx/__init__.py | 3 +-- sbx/bro/bro.py | 4 ++-- sbx/bro/policies.py | 8 +++++--- tests/test_run.py | 4 ++-- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/sbx/__init__.py b/sbx/__init__.py index 8ec1325..8b47be1 100644 --- a/sbx/__init__.py +++ b/sbx/__init__.py @@ -1,5 +1,6 @@ import os +from sbx.bro import BRO from sbx.crossq import CrossQ from sbx.ddpg import DDPG from sbx.dqn import DQN @@ -7,8 +8,6 @@ from sbx.sac import SAC from sbx.td3 import TD3 from sbx.tqc import TQC -from sbx.bro import BRO - # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") diff --git a/sbx/bro/bro.py b/sbx/bro/bro.py index faca6da..c77afc9 100644 --- a/sbx/bro/bro.py +++ b/sbx/bro/bro.py @@ -14,9 +14,9 @@ from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from sbx.bro.policies import BROPolicy from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState -from sbx.bro.policies import BROPolicy class EntropyCoef(nn.Module): @@ -528,7 +528,7 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: batch_done = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size) (qf_state, (qf_loss_value, ent_coef_value), key) = jax.lax.cond( - distributional == True, + distributional, # If True: cls.update_critic_quantile, # If False: diff --git a/sbx/bro/policies.py b/sbx/bro/policies.py index f89f559..8b44b09 100644 --- a/sbx/bro/policies.py +++ b/sbx/bro/policies.py @@ -37,11 +37,11 @@ class BroNet(nn.Module): @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - x = nn.Dense(self.net_arch[0], self.activation_fn)(x) + x = nn.Dense(self.net_arch[0])(x) x = nn.LayerNorm()(x) x = self.activation_fn(x) for n_units in self.net_arch: - x = BroNetBlock(n_units)(x) + x = BroNetBlock(n_units, self.activation_fn)(x) return x @@ -168,7 +168,9 @@ def __init__( self.net_arch_qf = net_arch["qf"] else: self.net_arch_pi = [256] - # In the paper we use [512, 512] although we also use higher RR, here we use bigger network size to compensate for the smaller RR + # In the original implementation, the authors use [512, 512] + # but with a higher replay ratio (RR), + # here we use bigger network size to compensate for the smaller RR self.net_arch_qf = [1024, 1024] self.n_critics = n_critics self.use_sde = use_sde diff --git a/tests/test_run.py b/tests/test_run.py index 18d6dec..54b676a 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -8,7 +8,7 @@ from stable_baselines3.common.envs import BitFlippingEnv from stable_baselines3.common.evaluation import evaluate_policy -from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ, DroQ +from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ, DroQ, BRO def check_save_load(model, model_class, tmp_path): @@ -116,7 +116,7 @@ def test_dropout(model_class): model.learn(110) -@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN, CrossQ]) +@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN, CrossQ, BRO]) def test_policy_kwargs(model_class) -> None: env_id = "CartPole-v1" if model_class == DQN else "Pendulum-v1" From 7d526ce18b24435012a44e3cdd6f86647a65165d Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 10 Nov 2024 14:46:40 +0100 Subject: [PATCH 19/21] Update learning starts to match SAC and others --- sbx/bro/bro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbx/bro/bro.py b/sbx/bro/bro.py index c77afc9..3da47a7 100644 --- a/sbx/bro/bro.py +++ b/sbx/bro/bro.py @@ -81,7 +81,7 @@ def __init__( learning_rate: Union[float, Schedule] = 3e-4, qf_learning_rate: Optional[float] = None, buffer_size: int = 1_000_000, # 1e6 - learning_starts: int = 2_500, + learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, From 8a4a8c483e1b80f9b64eed62038592dd33bbfd2e Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 12 Nov 2024 09:28:13 +0100 Subject: [PATCH 20/21] Sort imports --- tests/test_run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_run.py b/tests/test_run.py index 54b676a..f6ed009 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -8,7 +8,7 @@ from stable_baselines3.common.envs import BitFlippingEnv from stable_baselines3.common.evaluation import evaluate_policy -from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ, DroQ, BRO +from sbx import BRO, DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ, DroQ def check_save_load(model, model_class, tmp_path): From e28060ead3bc6405aacaebda1dea11d1ec19c7d9 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 12 Nov 2024 10:31:56 +0100 Subject: [PATCH 21/21] Fix run_test mode --- scripts/run_tests.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 scripts/run_tests.sh diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh old mode 100644 new mode 100755