From e7448391cad548b253a6e381bb803da252187f87 Mon Sep 17 00:00:00 2001 From: Tobias Birchler Date: Tue, 5 Dec 2023 23:47:10 +0100 Subject: [PATCH] Add C51 algorithm (DLR-RM/stable-baselines3#622) --- sb3_contrib/__init__.py | 2 + sb3_contrib/c51/__init__.py | 4 + sb3_contrib/c51/c51.py | 316 +++++++++++++++++++++++++++++++++ sb3_contrib/c51/policies.py | 336 ++++++++++++++++++++++++++++++++++++ 4 files changed, 658 insertions(+) create mode 100644 sb3_contrib/c51/__init__.py create mode 100644 sb3_contrib/c51/c51.py create mode 100644 sb3_contrib/c51/policies.py diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 3fbd28d8..919ae873 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -1,6 +1,7 @@ import os from sb3_contrib.ars import ARS +from sb3_contrib.c51 import C51 from sb3_contrib.ppo_mask import MaskablePPO from sb3_contrib.ppo_recurrent import RecurrentPPO from sb3_contrib.qrdqn import QRDQN @@ -14,6 +15,7 @@ __all__ = [ "ARS", + "C51", "MaskablePPO", "RecurrentPPO", "QRDQN", diff --git a/sb3_contrib/c51/__init__.py b/sb3_contrib/c51/__init__.py new file mode 100644 index 00000000..f8e2947b --- /dev/null +++ b/sb3_contrib/c51/__init__.py @@ -0,0 +1,4 @@ +from sb3_contrib.c51.c51 import C51 +from sb3_contrib.c51.policies import CnnPolicy, MlpPolicy, MultiInputPolicy + +__all__ = ["C51", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"] diff --git a/sb3_contrib/c51/c51.py b/sb3_contrib/c51/c51.py new file mode 100644 index 00000000..0e651a60 --- /dev/null +++ b/sb3_contrib/c51/c51.py @@ -0,0 +1,316 @@ +import warnings +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union + +import numpy as np +import torch as th +from gymnasium import spaces +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update + +from sb3_contrib.c51.policies import C51Policy, CategoricalNetwork, CnnPolicy, MlpPolicy, MultiInputPolicy + +SelfC51 = TypeVar("SelfC51", bound="C51") + + +def project(supports, weights, target_support): + """Projects a batch of (support, weights) onto target_support. + + Based on equation (7) in (Bellemare et al., 2017): https://arxiv.org/abs/1707.06887 + In the rest of the comments we will refer to this equation simply as Eq7. + + Args: + supports: Batch of supports. + weights: Batch of weights. + target_support: Target support. + + Returns: + Batch of weights after projection. + """ + v_min, v_max = target_support[0], target_support[-1] + # `N` in Eq7. + n_atoms = target_support.shape[0] + # delta_z = `\Delta z` in Eq7. + delta_z = (v_max - v_min) / (n_atoms - 1) + # clipped_support = `[\hat{T}_{z_j}]^{V_max}_{V_min}` in Eq7. + clipped_support = th.clip(supports, v_min, v_max) + # numerator = `|clipped_support - z_i|` in Eq7. + numerator = th.abs(clipped_support[:, None] - target_support[:, None]) + quotient = 1 - (numerator / delta_z) + # clipped_quotient = `[1 - numerator / (\Delta z)]_0^1` in Eq7. + clipped_quotient = th.clip(quotient, 0, 1) + # inner_prod = `\sum_{j=0}^{N-1} clipped_quotient * p_j(x', \pi(x'))` in Eq7. + inner_prod = clipped_quotient * weights[:, None] + return th.sum(inner_prod, dim=-1) + + +class C51(OffPolicyAlgorithm): + """ + Categorical Deep Q-Network (C51) + Paper: https://arxiv.org/abs/1707.06887 + Default hyperparameters are taken from the paper and are tuned for Atari games. + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param buffer_size: size of the replay buffer + :param learning_starts: how many steps of the model to collect transitions for before learning starts + :param batch_size: Minibatch size for each gradient update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update + :param gamma: the discount factor + :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit + like ``(5, "step")`` or ``(2, "episode")``. + :param gradient_steps: How many gradient steps to do after each rollout + (see ``train_freq`` and ``n_episodes_rollout``) + Set to ``-1`` means to do as many gradient steps as steps done in the environment + during the rollout. + :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). + If ``None``, it will be automatically selected. + :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer + at a cost of more complexity. + See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 + :param target_update_interval: update the target network every ``target_update_interval`` + environment steps. + :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced + :param exploration_initial_eps: initial value of random action probability + :param exploration_final_eps: final value of random action probability + :param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping) + :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average + the reported success rate, mean episode length, and mean reward over + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + # Linear schedule will be defined in `_setup_model()` + exploration_schedule: Schedule + categorical_net: CategoricalNetwork + categorical_net_target: CategoricalNetwork + policy: C51Policy + + def __init__( + self, + policy: Union[str, Type[C51Policy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 2.5e-4, + buffer_size: int = 1000000, # 1e6 + learning_starts: int = 50000, + batch_size: int = 32, + tau: float = 1.0, + gamma: float = 0.99, + train_freq: int = 4, + gradient_steps: int = 1, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + optimize_memory_usage: bool = False, + target_update_interval: int = 10000, + exploration_fraction: float = 0.1, + exploration_initial_eps: float = 1.0, + exploration_final_eps: float = 0.01, + max_grad_norm: Optional[float] = None, + stats_window_size: int = 100, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + super().__init__( + policy, + env, + learning_rate, + buffer_size, + learning_starts, + batch_size, + tau, + gamma, + train_freq, + gradient_steps, + action_noise=None, # No action noise + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, + policy_kwargs=policy_kwargs, + stats_window_size=stats_window_size, + tensorboard_log=tensorboard_log, + verbose=verbose, + device=device, + seed=seed, + sde_support=False, + optimize_memory_usage=optimize_memory_usage, + supported_action_spaces=(spaces.Discrete,), + support_multi_env=True, + ) + + self.exploration_initial_eps = exploration_initial_eps + self.exploration_final_eps = exploration_final_eps + self.exploration_fraction = exploration_fraction + self.target_update_interval = target_update_interval + # For updating the target network with multiple envs: + self._n_calls = 0 + self.max_grad_norm = max_grad_norm + # "epsilon" for the epsilon-greedy exploration + self.exploration_rate = 0.0 + + if "optimizer_class" not in self.policy_kwargs: + self.policy_kwargs["optimizer_class"] = th.optim.Adam + # Proposed in the C51 paper where `batch_size = 32` + self.policy_kwargs["optimizer_kwargs"] = dict(eps=0.01 / batch_size) + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + super()._setup_model() + self._create_aliases() + # Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996 + self.batch_norm_stats = get_parameters_by_name(self.categorical_net, ["running_"]) + self.batch_norm_stats_target = get_parameters_by_name(self.categorical_net_target, ["running_"]) + self.exploration_schedule = get_linear_fn( + self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction + ) + # Account for multiple environments + # each call to step() corresponds to n_envs transitions + if self.n_envs > 1: + if self.n_envs > self.target_update_interval: + warnings.warn( + "The number of environments used is greater than the target network " + f"update interval ({self.n_envs} > {self.target_update_interval}), " + "therefore the target network will be updated after each call to env.step() " + f"which corresponds to {self.n_envs} steps." + ) + + self.target_update_interval = max(self.target_update_interval // self.n_envs, 1) + + def _create_aliases(self) -> None: + self.categorical_net = self.policy.categorical_net + self.categorical_net_target = self.policy.categorical_net_target + self.support = self.categorical_net.support + + def _on_step(self) -> None: + """ + Update the exploration rate and target network if needed. + This method is called in ``collect_rollouts()`` after each step in the environment. + """ + self._n_calls += 1 + if self._n_calls % self.target_update_interval == 0: + polyak_update(self.categorical_net.parameters(), self.categorical_net_target.parameters(), self.tau) + # Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996 + polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) + + self.exploration_rate = self.exploration_schedule(self._current_progress_remaining) + self.logger.record("rollout/exploration_rate", self.exploration_rate) + + def train(self, gradient_steps: int, batch_size: int = 100) -> None: + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + # Update learning rate according to schedule + self._update_learning_rate(self.policy.optimizer) + + losses = [] + for _ in range(gradient_steps): + # Sample replay buffer + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] + + with th.no_grad(): + # Compute the next categorical probabilities using the target network + next_probabilities = th.softmax(self.categorical_net_target(replay_data.next_observations), dim=-1) + # Compute the greedy actions which maximize the next Q values + next_actions = (next_probabilities * self.support).mean(dim=-1).argmax(dim=-1) + # Follow greedy policy: use the one with the highest Q values + next_probabilities = next_probabilities[th.arange(self.batch_size), next_actions] + # 1-step TD target + target_support = replay_data.rewards + (1 - replay_data.dones) * self.gamma * self.support + # Project + targets = project(target_support, next_probabilities, self.support) + + # Get current estimated categorical logits + logits = self.categorical_net(replay_data.observations) + logits = logits[np.arange(self.batch_size), replay_data.actions.squeeze()] + + # Compute cross-entropy loss + loss = th.nn.functional.cross_entropy(logits, targets) + losses.append(loss.item()) + + # Optimize the policy + self.policy.optimizer.zero_grad() + loss.backward() + # Clip gradient norm + if self.max_grad_norm is not None: + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + # Increase update counter + self._n_updates += gradient_steps + + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/loss", np.mean(losses)) + + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + """ + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next state (used in recurrent policies) + """ + if not deterministic and np.random.rand() < self.exploration_rate: + if self.policy.is_vectorized_observation(observation): + if isinstance(observation, dict): + n_batch = observation[next(iter(observation.keys()))].shape[0] + else: + n_batch = observation.shape[0] + action = np.array([self.action_space.sample() for _ in range(n_batch)]) + else: + action = np.array(self.action_space.sample()) + else: + action, state = self.policy.predict(observation, state, episode_start, deterministic) + return action, state + + def learn( + self: SelfC51, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + tb_log_name: str = "C51", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ) -> SelfC51: + return super().learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + tb_log_name=tb_log_name, + reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, + ) + + def _excluded_save_params(self) -> List[str]: + return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"] # noqa: RUF005 + + def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + state_dicts = ["policy", "policy.optimizer"] + + return state_dicts, [] diff --git a/sb3_contrib/c51/policies.py b/sb3_contrib/c51/policies.py new file mode 100644 index 00000000..57100238 --- /dev/null +++ b/sb3_contrib/c51/policies.py @@ -0,0 +1,336 @@ +from typing import Any, Dict, List, Optional, Type + +import torch as th +from gymnasium import spaces +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + NatureCNN, + create_mlp, +) +from stable_baselines3.common.type_aliases import PyTorchObs, Schedule +from torch import nn + + +class CategoricalNetwork(BasePolicy): + """ + Categorical network for C51 + + :param observation_space: Observation space + :param action_space: Action space + :param n_atoms: Number of atoms + :param v_min: the value distribution support is [v_min, v_max]. If None, we set it to be -v_max. + :param v_max: the value distribution support is [v_min, v_max]. + :param net_arch: The specification of the network architecture. + :param activation_fn: Activation function + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + """ + + action_space: spaces.Discrete + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Discrete, + features_extractor: BaseFeaturesExtractor, + features_dim: int, + n_atoms: int = 51, + v_min: Optional[int] = None, + v_max: int = 10, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + normalize_images: bool = True, + ) -> None: + super().__init__( + observation_space, + action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, + ) + + if net_arch is None: + net_arch = [64, 64] + + self.net_arch = net_arch + self.activation_fn = activation_fn + self.features_dim = features_dim + self.n_atoms = n_atoms + v_min = v_min if v_min else -v_max + self.support = nn.parameter.Parameter(th.linspace(v_min, v_max, n_atoms), requires_grad=False) + action_dim = int(self.action_space.n) # number of actions + categorical_net = create_mlp(self.features_dim, action_dim * self.n_atoms, self.net_arch, self.activation_fn) + self.categorical_net = nn.Sequential(*categorical_net) + + def forward(self, obs: PyTorchObs) -> th.Tensor: + """ + Predict the categorical logits. + + :param obs: Observation + :return: The estimated categorical logits for each action. + """ + logits = self.categorical_net(self.extract_features(obs, self.features_extractor)) + return logits.view(-1, int(self.action_space.n), self.n_atoms) + + def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor: + q_values = (th.softmax(self(observation), dim=-1) * self.support).mean(dim=-1) + # Greedy action + action = q_values.argmax(dim=1).reshape(-1) + return action + + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() + + data.update( + dict( + net_arch=self.net_arch, + features_dim=self.features_dim, + n_atoms=self.n_atoms, + activation_fn=self.activation_fn, + features_extractor=self.features_extractor, + ) + ) + return data + + +class C51Policy(BasePolicy): + """ + Policy class with categorical and target networks for C51 + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param n_atoms: Number of atoms + :param v_min: the value distribution support is [v_min, v_max]. If None, we set it to be -v_max. + :param v_max: the value distribution support is [v_min, v_max]. + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + categorical_net: CategoricalNetwork + categorical_net_target: CategoricalNetwork + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Discrete, + lr_schedule: Schedule, + n_atoms: int = 51, + v_min: Optional[int] = None, + v_max: int = 10, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + normalize_images=normalize_images, + ) + + if net_arch is None: + if features_extractor_class == NatureCNN: + net_arch = [] + else: + net_arch = [64, 64] + + self.net_arch = net_arch + self.activation_fn = activation_fn + + self.net_args = { + "observation_space": self.observation_space, + "action_space": self.action_space, + "n_atoms": n_atoms, + "v_min": v_min, + "v_max": v_max, + "net_arch": self.net_arch, + "activation_fn": self.activation_fn, + "normalize_images": normalize_images, + } + self._build(lr_schedule) + + def _build(self, lr_schedule: Schedule) -> None: + """ + Create the network and the optimizer. + + :param lr_schedule: Learning rate schedule + lr_schedule(1) is the initial learning rate + """ + self.categorical_net = self.make_categorical_net() + self.categorical_net_target = self.make_categorical_net() + self.categorical_net_target.load_state_dict(self.categorical_net.state_dict()) + self.categorical_net_target.set_training_mode(False) + + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class( # type: ignore[call-arg] + self.parameters(), + lr=lr_schedule(1), + **self.optimizer_kwargs, + ) + + def make_categorical_net(self) -> CategoricalNetwork: + # Make sure we always have separate networks for features extractors etc + net_args = self._update_features_extractor(self.net_args, features_extractor=None) + return CategoricalNetwork(**net_args).to(self.device) + + def forward(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor: + return self._predict(obs, deterministic=deterministic) + + def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor: + return self.categorical_net._predict(obs, deterministic=deterministic) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() + + data.update( + dict( + n_atoms=self.net_args["n_atoms"], + net_arch=self.net_args["net_arch"], + activation_fn=self.net_args["activation_fn"], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) + return data + + def set_training_mode(self, mode: bool) -> None: + """ + Put the policy in either training or evaluation mode. + This affects certain modules, such as batch normalisation and dropout. + :param mode: if true, set to training mode, else set to evaluation mode + """ + self.categorical_net.set_training_mode(mode) + self.training = mode + + +MlpPolicy = C51Policy + + +class CnnPolicy(C51Policy): + """ + Policy class for C51 when using images as input. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param n_atoms: Number of atoms + :param v_min: the value distribution support is [v_min, v_max]. If None, we set it to be -v_max. + :param v_max: the value distribution support is [v_min, v_max]. + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Discrete, + lr_schedule: Schedule, + n_atoms: int = 51, + v_min: Optional[int] = None, + v_max: int = 10, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__( + observation_space, + action_space, + lr_schedule, + n_atoms, + v_min, + v_max, + net_arch, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + +class MultiInputPolicy(C51Policy): + """ + Policy class for C51 when using dict observations as input. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param n_atoms: Number of atoms + :param v_min: the value distribution support is [v_min, v_max]. If None, we set it to be -v_max. + :param v_max: the value distribution support is [v_min, v_max]. + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: spaces.Dict, + action_space: spaces.Discrete, + lr_schedule: Schedule, + n_atoms: int = 51, + v_min: Optional[int] = None, + v_max: int = 10, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__( + observation_space, + action_space, + lr_schedule, + n_atoms, + v_min, + v_max, + net_arch, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + )