diff --git a/README.md b/README.md index 8c39d88..5bc336b 100644 --- a/README.md +++ b/README.md @@ -14,33 +14,36 @@ Welcome to `actorch`, a deep reinforcement learning framework for fast prototypi - [REINFORCE](https://people.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf) - [Advantage Actor-Critic (A2C)](https://arxiv.org/abs/1602.01783) - [Actor-Critic Kronecker-Factored Trust Region (ACKTR)](https://arxiv.org/abs/1708.05144) +- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477) - [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347) - [Advantage-Weighted Regression (AWR)](https://arxiv.org/abs/1910.00177) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971) +- [Distributional Deep Deterministic Policy Gradient (D3PG)](https://arxiv.org/abs/1804.08617) - [Twin Delayed Deep Deterministic Policy Gradient (TD3)](https://arxiv.org/abs/1802.09477) +- [Soft Actor-Critic (SAC)](https://arxiv.org/abs/1801.01290) --------------------------------------------------------------------------------------------------------- ## 💡 Key features - Support for [OpenAI Gymnasium](https://gymnasium.farama.org/) environments -- Support for custom observation/action spaces -- Support for custom multimodal input multimodal output models -- Support for recurrent models (e.g. RNNs, LSTMs, GRUs, etc.) -- Support for custom policy/value distributions -- Support for custom preprocessing/postprocessing pipelines -- Support for custom exploration strategies +- Support for **custom observation/action spaces** +- Support for **custom multimodal input multimodal output models** +- Support for **recurrent models** (e.g. RNNs, LSTMs, GRUs, etc.) +- Support for **custom policy/value distributions** +- Support for **custom preprocessing/postprocessing pipelines** +- Support for **custom exploration strategies** - Support for [normalizing flows](https://arxiv.org/abs/1906.02771) - Batched environments (both for training and evaluation) -- Batched trajectory replay -- Batched and distributional value estimation (e.g. batched and distributional [Retrace](https://arxiv.org/abs/1606.02647) and [V-trace](https://arxiv.org/abs/1802.01561)) -- Data parallel and distributed data parallel multi-GPU training and evaluation -- Automatic mixed precision training -- Integration with [Ray Tune](https://docs.ray.io/en/releases-1.13.0/tune/index.html) for experiment execution and hyperparameter tuning at any scale -- Effortless experiment definition through Python-based configuration files -- Built-in visualization tool to plot performance metrics -- Modular object-oriented design -- Detailed API documentation +- Batched **trajectory replay** +- Batched and **distributional value estimation** (e.g. batched and distributional [Retrace](https://arxiv.org/abs/1606.02647) and [V-trace](https://arxiv.org/abs/1802.01561)) +- Data parallel and distributed data parallel **multi-GPU training and evaluation** +- Automatic **mixed precision training** +- Integration with [Ray Tune](https://docs.ray.io/en/releases-1.13.0/tune/index.html) for experiment execution and **hyperparameter tuning** at any scale +- Effortless experiment definition through **Python-based configuration files** +- Built-in **visualization tool** to plot performance metrics +- Modular **object-oriented** design +- Detailed **API documentation** --------------------------------------------------------------------------------------------------------- @@ -161,7 +164,7 @@ experiment_params = ExperimentParams( enable_amp=False, enable_reproducibility=True, log_sys_usage=True, - suppress_warnings=False, + suppress_warnings=True, ), ) ``` @@ -197,6 +200,8 @@ You can find the generated plots in `plots`. Congratulations, you ran your first experiment! +See `examples` for additional configuration file examples. + **HINT**: since a configuration file is a regular Python script, you can use all the features of the language (e.g. inheritance). @@ -217,6 +222,21 @@ features of the language (e.g. inheritance). --------------------------------------------------------------------------------------------------------- +## @ Citation + +``` +@misc{DellaLibera2022ACTorch, + author = {Luca Della Libera}, + title = {{ACTorch}: a Deep Reinforcement Learning Framework for Fast Prototyping}, + year = {2022}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/lucadellalib/actorch}}, +} +``` + +--------------------------------------------------------------------------------------------------------- + ## 📧 Contact [luca.dellalib@gmail.com](mailto:luca.dellalib@gmail.com) diff --git a/actorch/algorithms/__init__.py b/actorch/algorithms/__init__.py index bbb311d..3235bc7 100644 --- a/actorch/algorithms/__init__.py +++ b/actorch/algorithms/__init__.py @@ -20,7 +20,10 @@ from actorch.algorithms.acktr import * from actorch.algorithms.algorithm import * from actorch.algorithms.awr import * +from actorch.algorithms.d3pg import * from actorch.algorithms.ddpg import * from actorch.algorithms.ppo import * from actorch.algorithms.reinforce import * +from actorch.algorithms.sac import * from actorch.algorithms.td3 import * +from actorch.algorithms.trpo import * diff --git a/actorch/algorithms/a2c.py b/actorch/algorithms/a2c.py index 9cce312..4a3e6e4 100644 --- a/actorch/algorithms/a2c.py +++ b/actorch/algorithms/a2c.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -"""Advantage Actor-Critic.""" +"""Advantage Actor-Critic (A2C).""" import contextlib from typing import Any, Callable, Dict, Optional, Tuple, Type, Union @@ -38,7 +38,7 @@ DistributedDataParallelREINFORCE, LRScheduler, ) -from actorch.algorithms.utils import prepare_model +from actorch.algorithms.utils import normalize_, prepare_model from actorch.algorithms.value_estimation import n_step_return from actorch.distributions import Deterministic from actorch.envs import BatchedEnv @@ -65,7 +65,7 @@ class A2C(REINFORCE): - """Advantage Actor-Critic. + """Advantage Actor-Critic (A2C). References ---------- @@ -208,12 +208,8 @@ def setup(self, config: "Dict[str, Any]") -> "None": self.config = A2C.Config(**self.config) self.config["_accept_kwargs"] = True super().setup(config) - self._value_network = ( - self._build_value_network().train().to(self._device, non_blocking=True) - ) - self._value_network_loss = ( - self._build_value_network_loss().train().to(self._device, non_blocking=True) - ) + self._value_network = self._build_value_network() + self._value_network_loss = self._build_value_network_loss() self._value_network_optimizer = self._build_value_network_optimizer() self._value_network_optimizer_lr_scheduler = ( self._build_value_network_optimizer_lr_scheduler() @@ -324,16 +320,20 @@ def _build_value_network(self) -> "Network": self.value_network_normalizing_flows, ) self._log_graph(value_network.wrapped_model.model, "value_network_model") - return value_network + return value_network.train().to(self._device, non_blocking=True) def _build_value_network_loss(self) -> "Loss": if self.value_network_loss_builder is None: self.value_network_loss_builder = torch.nn.MSELoss if self.value_network_loss_config is None: self.value_network_loss_config: "Dict[str, Any]" = {} - return self.value_network_loss_builder( - reduction="none", - **self.value_network_loss_config, + return ( + self.value_network_loss_builder( + reduction="none", + **self.value_network_loss_config, + ) + .train() + .to(self._device, non_blocking=True) ) def _build_value_network_optimizer(self) -> "Optimizer": @@ -374,6 +374,8 @@ def _train_step(self) -> "Dict[str, Any]": result = super()._train_step() self.num_return_steps.step() result["num_return_steps"] = self.num_return_steps() + result["entropy_coeff"] = result.pop("entropy_coeff", None) + result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None) return result # override @@ -405,17 +407,7 @@ def _train_on_batch( self.num_return_steps(), ) if self.normalize_advantage: - length = mask.sum(dim=1, keepdim=True) - advantages_mean = advantages.sum(dim=1, keepdim=True) / length - advantages -= advantages_mean - advantages *= mask - advantages_stddev = ( - ((advantages**2).sum(dim=1, keepdim=True) / length) - .sqrt() - .clamp(min=1e-6) - ) - advantages /= advantages_stddev - advantages *= mask + normalize_(advantages, dim=-1, mask=mask) # Discard next state value state_values = state_values[:, :-1] @@ -449,10 +441,12 @@ def _train_on_batch_value_network( state_value = state_values[mask] target = targets[mask] loss = self._value_network_loss(state_value, target) - loss *= is_weight[:, None].expand_as(mask)[mask] + priority = None + if self._buffer.is_prioritized: + loss *= is_weight[:, None].expand_as(mask)[mask] + priority = loss.detach().abs().to("cpu").numpy() loss = loss.mean() optimize_result = self._optimize_value_network(loss) - priority = None result = { "state_value": state_value.mean().item(), "target": target.mean().item(), @@ -490,7 +484,7 @@ def _get_default_value_network_preprocessor( class DistributedDataParallelA2C(DistributedDataParallelREINFORCE): - """Distributed data parallel Advantage Actor-Critic. + """Distributed data parallel Advantage Actor-Critic (A2C). See Also -------- diff --git a/actorch/algorithms/acktr.py b/actorch/algorithms/acktr.py index 931b16c..4884914 100644 --- a/actorch/algorithms/acktr.py +++ b/actorch/algorithms/acktr.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -"""Actor-Critic Kronecker-Factored Trust Region.""" +"""Actor-Critic Kronecker-Factored Trust Region (ACKTR).""" from typing import Any, Callable, Dict, Optional, Union @@ -43,7 +43,7 @@ class ACKTR(A2C): - """Actor-Critic Kronecker-Factored Trust Region. + """Actor-Critic Kronecker-Factored Trust Region (ACKTR). References ---------- @@ -287,7 +287,7 @@ def _optimize_policy_network(self, loss: "Tensor") -> "Dict[str, Any]": class DistributedDataParallelACKTR(DistributedDataParallelA2C): - """Distributed data parallel Actor-Critic Kronecker-Factored Trust Region. + """Distributed data parallel Actor-Critic Kronecker-Factored Trust Region (ACKTR). See Also -------- diff --git a/actorch/algorithms/algorithm.py b/actorch/algorithms/algorithm.py index 1fcd34b..fdb568d 100644 --- a/actorch/algorithms/algorithm.py +++ b/actorch/algorithms/algorithm.py @@ -51,7 +51,6 @@ from ray.tune.syncer import NodeSyncer from ray.tune.trial import ExportFormat from torch import Tensor -from torch.cuda.amp import autocast from torch.distributions import Bernoulli, Categorical, Distribution, Normal from torch.profiler import profile, record_function, tensorboard_trace_handler from torch.utils.data import DataLoader @@ -113,7 +112,7 @@ class Algorithm(ABC, Trainable): _EXPORT_FORMATS = [ExportFormat.CHECKPOINT, ExportFormat.MODEL] - _UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH = True + _OFF_POLICY = True class Config(dict): """Keyword arguments expected in the configuration received by `setup`.""" @@ -692,12 +691,7 @@ def _build_train_env(self) -> "BatchedEnv": if self.train_env_config is None: self.train_env_config = {} - try: - train_env = self.train_env_builder( - **self.train_env_config, - ) - except TypeError: - train_env = self.train_env_builder(**self.train_env_config) + train_env = self.train_env_builder(**self.train_env_config) if not isinstance(train_env, BatchedEnv): train_env.close() train_env = SerialBatchedEnv(self.train_env_builder, self.train_env_config) @@ -866,7 +860,7 @@ def _build_policy_network(self) -> "PolicyNetwork": # noqa: C901 self.policy_network_postprocessors, ) self._log_graph(policy_network.wrapped_model.model, "policy_network_model") - return policy_network + return policy_network.train().to(self._device, non_blocking=True) def _build_train_agent(self) -> "Agent": if self.train_agent_builder is None: @@ -989,14 +983,18 @@ def _build_dataloader(self) -> "DataLoader": if self.dataloader_builder is None: self.dataloader_builder = DataLoader if self.dataloader_config is None: - fork = torch.multiprocessing.get_start_method() == "fork" + use_mp = ( + self._OFF_POLICY + and not self._buffer.is_prioritized + and torch.multiprocessing.get_start_method() == "fork" + ) self.dataloader_config = { - "num_workers": 1 if fork else 0, + "num_workers": 1 if use_mp else 0, "pin_memory": True, "timeout": 0, "worker_init_fn": None, "generator": None, - "prefetch_factor": 1 if fork else 2, + "prefetch_factor": 1 if use_mp else 2, "pin_memory_device": "", } if self.dataloader_config is None: @@ -1037,20 +1035,15 @@ def _train_step(self) -> "Dict[str, Any]": if self.train_num_episodes_per_iter: train_num_episodes_per_iter = self.train_num_episodes_per_iter() self.train_num_episodes_per_iter.step() - with ( - autocast(**self.enable_amp) - if self.enable_amp["enabled"] - else contextlib.suppress() + for experience, done in self._train_sampler.sample( + train_num_timesteps_per_iter, + train_num_episodes_per_iter, ): - for experience, done in self._train_sampler.sample( - train_num_timesteps_per_iter, - train_num_episodes_per_iter, - ): - self._buffer.add(experience, done) + self._buffer.add(experience, done) result = self._train_sampler.stats self._cumrewards += result["episode_cumreward"] - if not self._UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH: + if not self._OFF_POLICY: for schedule in self._buffer_dataset.schedules.values(): schedule.step() train_epoch_result = self._train_epoch() @@ -1060,7 +1053,7 @@ def _train_step(self) -> "Dict[str, Any]": schedule.step() for schedule in self._buffer.schedules.values(): schedule.step() - if self._UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH: + if self._OFF_POLICY: for schedule in self._buffer_dataset.schedules.values(): schedule.step() return result @@ -1113,16 +1106,11 @@ def _eval_step(self) -> "Dict[str, Any]": eval_num_episodes_per_iter = self.eval_num_episodes_per_iter() self.eval_num_episodes_per_iter.step() self._eval_sampler.reset() - with ( - autocast(**self.enable_amp) - if self.enable_amp["enabled"] - else contextlib.suppress() + for _ in self._eval_sampler.sample( + eval_num_timesteps_per_iter, + eval_num_episodes_per_iter, ): - for _ in self._eval_sampler.sample( - eval_num_timesteps_per_iter, - eval_num_episodes_per_iter, - ): - pass + pass for schedule in self._eval_agent.schedules.values(): schedule.step() return self._eval_sampler.stats @@ -1353,7 +1341,7 @@ def __init__( Default to ``{}``. placement_strategy: The placement strategy - (see https://docs.ray.io/en/latest/ray-core/placement-group.html). + (see https://docs.ray.io/en/releases-1.13.0/ray-core/placement-group.html for Ray 1.13.0). backend: The backend for distributed execution (see https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group). diff --git a/actorch/algorithms/awr.py b/actorch/algorithms/awr.py index f13d6db..1e887c3 100644 --- a/actorch/algorithms/awr.py +++ b/actorch/algorithms/awr.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -"""Advantage-Weighted Regression.""" +"""Advantage-Weighted Regression (AWR).""" import contextlib from typing import Any, Callable, Dict, Optional, Tuple, Union @@ -31,6 +31,7 @@ from actorch.agents import Agent from actorch.algorithms.a2c import A2C, DistributedDataParallelA2C, Loss, LRScheduler from actorch.algorithms.algorithm import RefOrFutureRef, Tunable +from actorch.algorithms.utils import normalize_ from actorch.algorithms.value_estimation import lambda_return from actorch.buffers import Buffer from actorch.datasets import BufferDataset @@ -48,7 +49,7 @@ class AWR(A2C): - """Advantage-Weighted Regression. + """Advantage-Weighted Regression (AWR). References ---------- @@ -59,9 +60,7 @@ class AWR(A2C): """ - _UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH = True # override - - _RESET_BUFFER = False # override + _OFF_POLICY = True # override # override class Config(dict): @@ -277,6 +276,8 @@ def _train_step(self) -> "Dict[str, Any]": ) result["weight_clip"] = self.weight_clip() result["temperature"] = self.temperature() + result["entropy_coeff"] = result.pop("entropy_coeff", None) + result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None) result["buffer_num_experiences"] = self._buffer.num_experiences result["buffer_num_full_trajectories"] = self._buffer.num_full_trajectories result.pop("num_return_steps", None) @@ -311,17 +312,7 @@ def _train_on_batch( self.trace_decay(), ) if self.normalize_advantage: - length = mask.sum(dim=1, keepdim=True) - advantages_mean = advantages.sum(dim=1, keepdim=True) / length - advantages -= advantages_mean - advantages *= mask - advantages_stddev = ( - ((advantages**2).sum(dim=1, keepdim=True) / length) - .sqrt() - .clamp(min=1e-6) - ) - advantages /= advantages_stddev - advantages *= mask + normalize_(advantages, dim=-1, mask=mask) # Discard next state value state_values = state_values[:, :-1] @@ -352,16 +343,19 @@ def _train_on_batch_policy_network( raise ValueError( f"`entropy_coeff` ({entropy_coeff}) must be in the interval [0, inf)" ) + weight_clip = self.weight_clip() if weight_clip <= 0.0: raise ValueError( f"`weight_clip` ({weight_clip}) must be in the interval (0, inf)" ) + temperature = self.temperature() if temperature <= 0.0: raise ValueError( f"`temperature` ({temperature}) must be in the interval (0, inf)" ) + with ( autocast(**self.enable_amp) if self.enable_amp["enabled"] @@ -392,7 +386,7 @@ def _train_on_batch_policy_network( class DistributedDataParallelAWR(DistributedDataParallelA2C): - """Distributed data parallel Advantage-Weighted Regression. + """Distributed data parallel Advantage-Weighted Regression (AWR). See Also -------- diff --git a/actorch/algorithms/d3pg.py b/actorch/algorithms/d3pg.py new file mode 100644 index 0000000..ef201d0 --- /dev/null +++ b/actorch/algorithms/d3pg.py @@ -0,0 +1,430 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Distributional Deep Deterministic Policy Gradient (D3PG).""" + +import contextlib +import logging +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from gymnasium import Env +from numpy import ndarray +from torch import Tensor +from torch.cuda.amp import autocast +from torch.distributions import Distribution, kl_divergence +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from actorch.agents import Agent +from actorch.algorithms.algorithm import RefOrFutureRef, Tunable +from actorch.algorithms.ddpg import DDPG, DistributedDataParallelDDPG, Loss, LRScheduler +from actorch.algorithms.utils import freeze_params, sync_polyak_ +from actorch.algorithms.value_estimation import n_step_return +from actorch.buffers import Buffer, ProportionalBuffer +from actorch.distributions import Finite +from actorch.envs import BatchedEnv +from actorch.models import Model +from actorch.networks import ( + DistributionParametrization, + Network, + NormalizingFlow, + Processor, +) +from actorch.samplers import Sampler +from actorch.schedules import Schedule + + +__all__ = [ + "D3PG", + "DistributedDataParallelD3PG", +] + + +_LOGGER = logging.getLogger(__name__) + + +class D3PG(DDPG): + """Distributional Deep Deterministic Policy Gradient (D3PG). + + References + ---------- + .. [1] G. Barth-Maron, M. W. Hoffman, D. Budden, W. Dabney, D. Horgan, D. TB, + A. Muldal, N. Heess, and T. Lillicrap. + "Distributed Distributional Deterministic Policy Gradients". + In: ICLR. 2018. + URL: https://arxiv.org/abs/1804.08617 + + """ + + # override + class Config(dict): + """Keyword arguments expected in the configuration received by `setup`.""" + + def __init__( + self, + train_env_builder: "Tunable[RefOrFutureRef[Callable[..., Union[Env, BatchedEnv]]]]", + train_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None, + train_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None, + train_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + train_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + eval_freq: "Tunable[RefOrFutureRef[Optional[int]]]" = 1, + eval_env_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Union[Env, BatchedEnv]]]]]" = None, + eval_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None, + eval_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None, + eval_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + eval_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + policy_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + policy_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None, + policy_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + policy_network_postprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + policy_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None, + policy_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + policy_network_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None, + policy_network_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + value_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None, + value_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_distribution_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Distribution]]]]" = None, + value_network_distribution_parametrization: "Tunable[RefOrFutureRef[Optional[DistributionParametrization]]]" = None, + value_network_distribution_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_normalizing_flow: "Tunable[RefOrFutureRef[Optional[NormalizingFlow]]]" = None, + value_network_loss_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Loss]]]]" = None, + value_network_loss_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None, + value_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None, + value_network_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + buffer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Buffer]]]]" = None, + buffer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + buffer_checkpoint: "Tunable[RefOrFutureRef[bool]]" = False, + dataloader_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., DataLoader]]]]" = None, + dataloader_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + discount: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.99, + num_return_steps: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 3, + num_updates_per_iter: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 1000, + batch_size: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 128, + max_trajectory_length: "Tunable[RefOrFutureRef[Union[int, float, Schedule]]]" = float( # noqa: B008 + "inf" + ), + sync_freq: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 1, + polyak_weight: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.001, + max_grad_l2_norm: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = float( # noqa: B008 + "inf" + ), + cumreward_window_size: "Tunable[RefOrFutureRef[int]]" = 100, + seed: "Tunable[RefOrFutureRef[int]]" = 0, + enable_amp: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False, + enable_reproducibility: "Tunable[RefOrFutureRef[bool]]" = False, + enable_anomaly_detection: "Tunable[RefOrFutureRef[bool]]" = False, + enable_profiling: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False, + log_sys_usage: "Tunable[RefOrFutureRef[bool]]" = False, + suppress_warnings: "Tunable[RefOrFutureRef[bool]]" = False, + _accept_kwargs: "bool" = False, + **kwargs: "Any", + ) -> "None": + if not _accept_kwargs and kwargs: + raise ValueError(f"Unexpected configuration arguments: {list(kwargs)}") + super().__init__( + train_env_builder=train_env_builder, + train_env_config=train_env_config, + train_agent_builder=train_agent_builder, + train_agent_config=train_agent_config, + train_sampler_builder=train_sampler_builder, + train_sampler_config=train_sampler_config, + train_num_timesteps_per_iter=train_num_timesteps_per_iter, + train_num_episodes_per_iter=train_num_episodes_per_iter, + eval_freq=eval_freq, + eval_env_builder=eval_env_builder, + eval_env_config=eval_env_config, + eval_agent_builder=eval_agent_builder, + eval_agent_config=eval_agent_config, + eval_sampler_builder=eval_sampler_builder, + eval_sampler_config=eval_sampler_config, + eval_num_timesteps_per_iter=eval_num_timesteps_per_iter, + eval_num_episodes_per_iter=eval_num_episodes_per_iter, + policy_network_preprocessors=policy_network_preprocessors, + policy_network_model_builder=policy_network_model_builder, + policy_network_model_config=policy_network_model_config, + policy_network_postprocessors=policy_network_postprocessors, + policy_network_optimizer_builder=policy_network_optimizer_builder, + policy_network_optimizer_config=policy_network_optimizer_config, + policy_network_optimizer_lr_scheduler_builder=policy_network_optimizer_lr_scheduler_builder, + policy_network_optimizer_lr_scheduler_config=policy_network_optimizer_lr_scheduler_config, + value_network_preprocessors=value_network_preprocessors, + value_network_model_builder=value_network_model_builder, + value_network_model_config=value_network_model_config, + value_network_distribution_builder=value_network_distribution_builder, + value_network_distribution_parametrization=value_network_distribution_parametrization, + value_network_distribution_config=value_network_distribution_config, + value_network_normalizing_flow=value_network_normalizing_flow, + value_network_loss_builder=value_network_loss_builder, + value_network_loss_config=value_network_loss_config, + value_network_optimizer_builder=value_network_optimizer_builder, + value_network_optimizer_config=value_network_optimizer_config, + value_network_optimizer_lr_scheduler_builder=value_network_optimizer_lr_scheduler_builder, + value_network_optimizer_lr_scheduler_config=value_network_optimizer_lr_scheduler_config, + buffer_builder=buffer_builder, + buffer_config=buffer_config, + buffer_checkpoint=buffer_checkpoint, + dataloader_builder=dataloader_builder, + dataloader_config=dataloader_config, + discount=discount, + num_return_steps=num_return_steps, + num_updates_per_iter=num_updates_per_iter, + batch_size=batch_size, + max_trajectory_length=max_trajectory_length, + sync_freq=sync_freq, + polyak_weight=polyak_weight, + max_grad_l2_norm=max_grad_l2_norm, + cumreward_window_size=cumreward_window_size, + seed=seed, + enable_amp=enable_amp, + enable_reproducibility=enable_reproducibility, + enable_anomaly_detection=enable_anomaly_detection, + enable_profiling=enable_profiling, + log_sys_usage=log_sys_usage, + suppress_warnings=suppress_warnings, + **kwargs, + ) + + # override + def setup(self, config: "Dict[str, Any]") -> "None": + self.config = D3PG.Config(**self.config) + self.config["_accept_kwargs"] = True + super().setup(config) + self._warn_failed_logging = True + + # override + def _build_buffer(self) -> "Buffer": + if self.buffer_builder is None: + self.buffer_builder = ProportionalBuffer + if self.buffer_config is None: + self.buffer_config = { + "capacity": int(1e5), + "prioritization": 1.0, + "bias_correction": 0.4, + "epsilon": 1e-5, + } + return super()._build_buffer() + + # override + def _build_value_network(self) -> "Network": + if self.value_network_distribution_builder is None: + self.value_network_distribution_builder = Finite + if self.value_network_distribution_parametrization is None: + self.value_network_distribution_parametrization = { + "logits": ( + {"logits": (51,)}, + lambda x: x["logits"], + ), + } + if self.value_network_distribution_config is None: + self.value_network_distribution_config = { + "atoms": torch.linspace(-10.0, 10.0, 51).to(self._device), + "validate_args": False, + } + if self.value_network_distribution_parametrization is None: + self.value_network_distribution_parametrization = {} + if self.value_network_distribution_config is None: + self.value_network_distribution_config = {} + + if self.value_network_normalizing_flow is not None: + self.value_network_normalizing_flows = { + "value": self.value_network_normalizing_flow, + } + else: + self.value_network_normalizing_flows: "Dict[str, NormalizingFlow]" = {} + + self.value_network_distribution_builders = { + "value": self.value_network_distribution_builder, + } + self.value_network_distribution_parametrizations = { + "value": self.value_network_distribution_parametrization, + } + self.value_network_distribution_configs = { + "value": self.value_network_distribution_config, + } + return super()._build_value_network() + + # override + def _train_on_batch( + self, + idx: "int", + experiences: "Dict[str, Tensor]", + is_weight: "Tensor", + mask: "Tensor", + ) -> "Tuple[Dict[str, Any], Optional[ndarray]]": + sync_freq = self.sync_freq() + if sync_freq < 1 or not float(sync_freq).is_integer(): + raise ValueError( + f"`sync_freq` ({sync_freq}) " + f"must be in the integer interval [1, inf)" + ) + sync_freq = int(sync_freq) + + result = {} + + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + target_actions, _ = self._target_policy_network( + experiences["observation"], mask=mask + ) + observations_target_actions = torch.cat( + [ + x[..., None] if x.shape == mask.shape else x + for x in [experiences["observation"], target_actions] + ], + dim=-1, + ) + self._target_value_network(observations_target_actions, mask=mask) + target_action_values = self._target_value_network.distribution + # Discard next observation + experiences["observation"] = experiences["observation"][:, :-1, ...] + mask = mask[:, 1:] + targets, _ = n_step_return( + target_action_values, + experiences["reward"], + experiences["terminal"], + mask, + self.discount(), + self.num_return_steps(), + return_advantage=False, + ) + + # Compute action values + observations_actions = torch.cat( + [ + x[..., None] if x.shape == mask.shape else x + for x in [experiences["observation"], experiences["action"]] + ], + dim=-1, + ) + self._value_network(observations_actions, mask=mask) + action_values = self._value_network.distribution + + result["value_network"], priority = self._train_on_batch_value_network( + action_values, + targets, + is_weight, + mask, + ) + result["policy_network"] = self._train_on_batch_policy_network( + experiences, + mask, + ) + + # Synchronize + if idx % sync_freq == 0: + sync_polyak_( + self._policy_network, + self._target_policy_network, + self.polyak_weight(), + ) + sync_polyak_( + self._value_network, + self._target_value_network, + self.polyak_weight(), + ) + + self._grad_scaler.update() + return result, priority + + # override + def _train_on_batch_policy_network( + self, + experiences: "Dict[str, Tensor]", + mask: "Tensor", + ) -> "Dict[str, Any]": + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + actions, _ = self._policy_network(experiences["observation"], mask=mask) + observations_actions = torch.cat( + [ + x[..., None] if x.shape == mask.shape else x + for x in [experiences["observation"], actions] + ], + dim=-1, + ) + with freeze_params(self._value_network): + self._value_network(observations_actions, mask=mask) + action_values = self._value_network.distribution + try: + action_values = action_values.mean + except Exception as e: + raise RuntimeError(f"Could not compute `action_values.mean`: {e}") + action_value = action_values[mask] + loss = -action_value.mean() + optimize_result = self._optimize_policy_network(loss) + result = {"loss": loss.item()} + result.update(optimize_result) + return result + + # override + def _train_on_batch_value_network( + self, + action_values: "Distribution", + targets: "Distribution", + is_weight: "Tensor", + mask: "Tensor", + ) -> "Tuple[Dict[str, Any], Optional[ndarray]]": + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + loss = kl_divergence(targets, action_values)[mask] + priority = None + if self._buffer.is_prioritized: + loss *= is_weight[:, None].expand_as(mask)[mask] + priority = loss.detach().abs().to("cpu").numpy() + loss = loss.mean() + optimize_result = self._optimize_value_network(loss) + result = {} + try: + result["action_value"] = action_values.mean[mask].mean().item() + result["target"] = targets.mean[mask].mean().item() + except Exception as e: + if self._warn_failed_logging: + _LOGGER.warning(f"Could not log `action_value` and/or `target`: {e}") + self._warn_failed_logging = False + result["loss"] = loss.item() + result.update(optimize_result) + return result, priority + + +class DistributedDataParallelD3PG(DistributedDataParallelDDPG): + """Distributed data parallel Distributional Deep Deterministic Policy Gradient (D3PG). + + See Also + -------- + actorch.algorithms.d3pg.D3PG + + """ + + _ALGORITHM_CLS = D3PG # override diff --git a/actorch/algorithms/ddpg.py b/actorch/algorithms/ddpg.py index 79cc4a6..17190cc 100644 --- a/actorch/algorithms/ddpg.py +++ b/actorch/algorithms/ddpg.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -"""Deep Deterministic Policy Gradient.""" +"""Deep Deterministic Policy Gradient (DDPG).""" import contextlib import copy @@ -33,7 +33,7 @@ from actorch.agents import Agent, GaussianNoiseAgent from actorch.algorithms.a2c import A2C, DistributedDataParallelA2C, Loss, LRScheduler from actorch.algorithms.algorithm import RefOrFutureRef, Tunable -from actorch.algorithms.utils import freeze_params, prepare_model, sync_polyak +from actorch.algorithms.utils import freeze_params, prepare_model, sync_polyak_ from actorch.algorithms.value_estimation import n_step_return from actorch.buffers import Buffer from actorch.datasets import BufferDataset @@ -53,7 +53,7 @@ class DDPG(A2C): - """Deep Deterministic Policy Gradient. + """Deep Deterministic Policy Gradient (DDPG). References ---------- @@ -65,9 +65,7 @@ class DDPG(A2C): """ - _UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH = True # override - - _RESET_BUFFER = False # override + _OFF_POLICY = True # override # override class Config(dict): @@ -203,12 +201,8 @@ def setup(self, config: "Dict[str, Any]") -> "None": self.config = DDPG.Config(**self.config) self.config["_accept_kwargs"] = True super().setup(config) - self._target_policy_network = copy.deepcopy(self._policy_network) - self._target_policy_network.eval().to(self._device, non_blocking=True) - self._target_policy_network.requires_grad_(False) - self._target_value_network = copy.deepcopy(self._value_network) - self._target_value_network.eval().to(self._device, non_blocking=True) - self._target_value_network.requires_grad_(False) + self._target_policy_network = self._build_target_policy_network() + self._target_value_network = self._build_target_value_network() if not isinstance(self.num_updates_per_iter, Schedule): self.num_updates_per_iter = ConstantSchedule(self.num_updates_per_iter) if not isinstance(self.batch_size, Schedule): @@ -300,6 +294,16 @@ def _build_value_network(self) -> "Network": } return super()._build_value_network() + def _build_target_policy_network(self) -> "Network": + target_policy_network = copy.deepcopy(self._policy_network) + target_policy_network.requires_grad_(False) + return target_policy_network.eval().to(self._device, non_blocking=True) + + def _build_target_value_network(self) -> "Network": + target_value_network = copy.deepcopy(self._value_network) + target_value_network.requires_grad_(False) + return target_value_network.eval().to(self._device, non_blocking=True) + # override def _train_step(self) -> "Dict[str, Any]": result = super()._train_step() @@ -314,6 +318,7 @@ def _train_step(self) -> "Dict[str, Any]": ) result["sync_freq"] = self.sync_freq() result["polyak_weight"] = self.polyak_weight() + result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None) result["buffer_num_experiences"] = self._buffer.num_experiences result["buffer_num_full_trajectories"] = self._buffer.num_full_trajectories result.pop("entropy_coeff", None) @@ -365,6 +370,7 @@ def _train_on_batch( mask, self.discount(), self.num_return_steps(), + return_advantage=False, ) # Compute action values @@ -390,12 +396,12 @@ def _train_on_batch( # Synchronize if idx % sync_freq == 0: - sync_polyak( + sync_polyak_( self._policy_network, self._target_policy_network, self.polyak_weight(), ) - sync_polyak( + sync_polyak_( self._value_network, self._target_value_network, self.polyak_weight(), @@ -425,7 +431,8 @@ def _train_on_batch_policy_network( ) with freeze_params(self._value_network): action_values, _ = self._value_network(observations_actions, mask=mask) - loss = -action_values[mask].mean() + action_value = action_values[mask] + loss = -action_value.mean() optimize_result = self._optimize_policy_network(loss) result = {"loss": loss.item()} result.update(optimize_result) @@ -482,7 +489,7 @@ def _get_default_policy_network_distribution_config( class DistributedDataParallelDDPG(DistributedDataParallelA2C): - """Distributed data parallel Deep Deterministic Policy Gradient. + """Distributed data parallel Deep Deterministic Policy Gradient (DDPG). See Also -------- diff --git a/actorch/algorithms/ppo.py b/actorch/algorithms/ppo.py index 2fd2963..eee2cf4 100644 --- a/actorch/algorithms/ppo.py +++ b/actorch/algorithms/ppo.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -"""Proximal Policy Optimization.""" +"""Proximal Policy Optimization (PPO).""" import contextlib from typing import Any, Callable, Dict, Optional, Tuple, Union @@ -31,6 +31,7 @@ from actorch.agents import Agent from actorch.algorithms.a2c import A2C, DistributedDataParallelA2C, Loss, LRScheduler from actorch.algorithms.algorithm import RefOrFutureRef, Tunable +from actorch.algorithms.utils import normalize_ from actorch.algorithms.value_estimation import lambda_return from actorch.envs import BatchedEnv from actorch.models import Model @@ -46,7 +47,7 @@ class PPO(A2C): - """Proximal Policy Optimization. + """Proximal Policy Optimization (PPO). References ---------- @@ -234,6 +235,8 @@ def _train_step(self) -> "Dict[str, Any]": result["num_epochs"] = self.num_epochs() result["minibatch_size"] = self.minibatch_size() result["ratio_clip"] = self.ratio_clip() + result["entropy_coeff"] = result.pop("entropy_coeff", None) + result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None) result.pop("num_return_steps", None) return result @@ -301,22 +304,8 @@ def _train_on_batch( else contextlib.suppress() ): if self.normalize_advantage: - length_batch = mask_batch.sum(dim=1, keepdim=True) - advantages_batch_mean = ( - advantages_batch.sum(dim=1, keepdim=True) / length_batch - ) - advantages_batch -= advantages_batch_mean - advantages_batch *= mask_batch - advantages_batch_stddev = ( - ( - (advantages_batch**2).sum(dim=1, keepdim=True) - / length_batch - ) - .sqrt() - .clamp(min=1e-6) - ) - advantages_batch /= advantages_batch_stddev - advantages_batch *= mask_batch + normalize_(advantages_batch, dim=-1, mask=mask_batch) + state_values_batch, _ = self._value_network( experiences_batch["observation"], mask=mask_batch ) @@ -387,7 +376,7 @@ def _train_on_batch_policy_network( class DistributedDataParallelPPO(DistributedDataParallelA2C): - """Distributed data parallel Proximal Policy Optimization. + """Distributed data parallel Proximal Policy Optimization (PPO). See Also -------- diff --git a/actorch/algorithms/reinforce.py b/actorch/algorithms/reinforce.py index 20db9d1..c0625ec 100644 --- a/actorch/algorithms/reinforce.py +++ b/actorch/algorithms/reinforce.py @@ -67,9 +67,7 @@ class REINFORCE(Algorithm): """ - _UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH = False # override - - _RESET_BUFFER = True + _OFF_POLICY = False # override # override class Config(dict): @@ -281,7 +279,7 @@ def _build_policy_network_optimizer_lr_scheduler( # override def _train_step(self) -> "Dict[str, Any]": result = super()._train_step() - if self._RESET_BUFFER: + if not self._OFF_POLICY: self._buffer.reset() self.discount.step() self.entropy_coeff.step() diff --git a/actorch/algorithms/sac.py b/actorch/algorithms/sac.py new file mode 100644 index 0000000..6999ff8 --- /dev/null +++ b/actorch/algorithms/sac.py @@ -0,0 +1,605 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Soft Actor-Critic (SAC).""" + +import contextlib +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from gymnasium import Env, spaces +from numpy import ndarray +from torch import Tensor +from torch.cuda.amp import autocast +from torch.distributions import Distribution, Normal +from torch.nn.utils import clip_grad_norm_ +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from actorch.agents import Agent +from actorch.algorithms.algorithm import RefOrFutureRef, Tunable +from actorch.algorithms.td3 import TD3, DistributedDataParallelTD3, Loss, LRScheduler +from actorch.algorithms.utils import freeze_params, sync_polyak_ +from actorch.algorithms.value_estimation import n_step_return +from actorch.buffers import Buffer +from actorch.envs import BatchedEnv +from actorch.models import Model +from actorch.networks import DistributionParametrization, NormalizingFlow, Processor +from actorch.samplers import Sampler +from actorch.schedules import ConstantSchedule, Schedule +from actorch.utils import singledispatchmethod + + +__all__ = [ + "DistributedDataParallelSAC", + "SAC", +] + + +class SAC(TD3): + """Soft Actor-Critic (SAC). + + References + ---------- + .. [1] T. Haarnoja, A. Zhou, P. Abbeel, and S. Levine. + "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor". + In: ICML. 2018, pp. 1861-1870. + URL: https://arxiv.org/abs/1801.01290 + .. [2] T. Haarnoja, A. Zhou, K. Hartikainen, G. Tucker, S. Ha, J. Tan, V. Kumar, + H. Zhu, A. Gupta, P. Abbeel, and S. Levine. + "Soft Actor-Critic Algorithms and Applications". + In: arXiv. 2018. + URL: https://arxiv.org/abs/1812.05905 + + """ + + # override + class Config(dict): + """Keyword arguments expected in the configuration received by `setup`.""" + + def __init__( + self, + train_env_builder: "Tunable[RefOrFutureRef[Callable[..., Union[Env, BatchedEnv]]]]", + train_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None, + train_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None, + train_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + train_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + eval_freq: "Tunable[RefOrFutureRef[Optional[int]]]" = 1, + eval_env_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Union[Env, BatchedEnv]]]]]" = None, + eval_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None, + eval_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None, + eval_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + eval_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + policy_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + policy_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None, + policy_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + policy_network_distribution_builders: "Tunable[RefOrFutureRef[Optional[Dict[str, Callable[..., Distribution]]]]]" = None, + policy_network_distribution_parametrizations: "Tunable[RefOrFutureRef[Optional[Dict[str, DistributionParametrization]]]]" = None, + policy_network_distribution_configs: "Tunable[RefOrFutureRef[Optional[Dict[str, Dict[str, Any]]]]]" = None, + policy_network_normalizing_flows: "Tunable[RefOrFutureRef[Optional[Dict[str, NormalizingFlow]]]]" = None, + policy_network_sample_fn: "Tunable[RefOrFutureRef[Optional[Callable[[Distribution], Tensor]]]]" = None, + policy_network_prediction_fn: "Tunable[RefOrFutureRef[Optional[Callable[[Tensor], Tensor]]]]" = None, + policy_network_postprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + policy_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None, + policy_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + policy_network_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None, + policy_network_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + value_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None, + value_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_loss_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Loss]]]]" = None, + value_network_loss_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None, + value_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None, + value_network_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + temperature_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None, + temperature_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + temperature_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None, + temperature_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + buffer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Buffer]]]]" = None, + buffer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + buffer_checkpoint: "Tunable[RefOrFutureRef[bool]]" = False, + dataloader_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., DataLoader]]]]" = None, + dataloader_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + discount: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.99, + num_return_steps: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 1, + num_updates_per_iter: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 1000, + batch_size: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 128, + max_trajectory_length: "Tunable[RefOrFutureRef[Union[int, float, Schedule]]]" = float( # noqa: B008 + "inf" + ), + sync_freq: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 1, + polyak_weight: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.001, + temperature: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.1, + max_grad_l2_norm: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = float( # noqa: B008 + "inf" + ), + cumreward_window_size: "Tunable[RefOrFutureRef[int]]" = 100, + seed: "Tunable[RefOrFutureRef[int]]" = 0, + enable_amp: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False, + enable_reproducibility: "Tunable[RefOrFutureRef[bool]]" = False, + enable_anomaly_detection: "Tunable[RefOrFutureRef[bool]]" = False, + enable_profiling: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False, + log_sys_usage: "Tunable[RefOrFutureRef[bool]]" = False, + suppress_warnings: "Tunable[RefOrFutureRef[bool]]" = False, + _accept_kwargs: "bool" = False, + **kwargs: "Any", + ) -> "None": + if not _accept_kwargs and kwargs: + raise ValueError(f"Unexpected configuration arguments: {list(kwargs)}") + super().__init__( + train_env_builder=train_env_builder, + train_env_config=train_env_config, + train_agent_builder=train_agent_builder, + train_agent_config=train_agent_config, + train_sampler_builder=train_sampler_builder, + train_sampler_config=train_sampler_config, + train_num_timesteps_per_iter=train_num_timesteps_per_iter, + train_num_episodes_per_iter=train_num_episodes_per_iter, + eval_freq=eval_freq, + eval_env_builder=eval_env_builder, + eval_env_config=eval_env_config, + eval_agent_builder=eval_agent_builder, + eval_agent_config=eval_agent_config, + eval_sampler_builder=eval_sampler_builder, + eval_sampler_config=eval_sampler_config, + eval_num_timesteps_per_iter=eval_num_timesteps_per_iter, + eval_num_episodes_per_iter=eval_num_episodes_per_iter, + policy_network_preprocessors=policy_network_preprocessors, + policy_network_model_builder=policy_network_model_builder, + policy_network_model_config=policy_network_model_config, + policy_network_distribution_builders=policy_network_distribution_builders, + policy_network_distribution_parametrizations=policy_network_distribution_parametrizations, + policy_network_distribution_configs=policy_network_distribution_configs, + policy_network_normalizing_flows=policy_network_normalizing_flows, + policy_network_sample_fn=policy_network_sample_fn, + policy_network_prediction_fn=policy_network_prediction_fn, + policy_network_postprocessors=policy_network_postprocessors, + policy_network_optimizer_builder=policy_network_optimizer_builder, + policy_network_optimizer_config=policy_network_optimizer_config, + policy_network_optimizer_lr_scheduler_builder=policy_network_optimizer_lr_scheduler_builder, + policy_network_optimizer_lr_scheduler_config=policy_network_optimizer_lr_scheduler_config, + value_network_preprocessors=value_network_preprocessors, + value_network_model_builder=value_network_model_builder, + value_network_model_config=value_network_model_config, + value_network_loss_builder=value_network_loss_builder, + value_network_loss_config=value_network_loss_config, + value_network_optimizer_builder=value_network_optimizer_builder, + value_network_optimizer_config=value_network_optimizer_config, + value_network_optimizer_lr_scheduler_builder=value_network_optimizer_lr_scheduler_builder, + value_network_optimizer_lr_scheduler_config=value_network_optimizer_lr_scheduler_config, + temperature_optimizer_builder=temperature_optimizer_builder, + temperature_optimizer_config=temperature_optimizer_config, + temperature_optimizer_lr_scheduler_builder=temperature_optimizer_lr_scheduler_builder, + temperature_optimizer_lr_scheduler_config=temperature_optimizer_lr_scheduler_config, + buffer_builder=buffer_builder, + buffer_config=buffer_config, + buffer_checkpoint=buffer_checkpoint, + dataloader_builder=dataloader_builder, + dataloader_config=dataloader_config, + discount=discount, + num_return_steps=num_return_steps, + num_updates_per_iter=num_updates_per_iter, + batch_size=batch_size, + max_trajectory_length=max_trajectory_length, + sync_freq=sync_freq, + polyak_weight=polyak_weight, + temperature=temperature, + max_grad_l2_norm=max_grad_l2_norm, + cumreward_window_size=cumreward_window_size, + seed=seed, + enable_amp=enable_amp, + enable_reproducibility=enable_reproducibility, + enable_anomaly_detection=enable_anomaly_detection, + enable_profiling=enable_profiling, + log_sys_usage=log_sys_usage, + suppress_warnings=suppress_warnings, + **kwargs, + ) + + # override + def setup(self, config: "Dict[str, Any]") -> "None": + self.config = SAC.Config(**self.config) + self.config["_accept_kwargs"] = True + super().setup(config) + if self.temperature_optimizer_builder is not None: + self._log_temperature = torch.zeros( + 1, device=self._device, requires_grad=True + ) + self._target_entropy = torch.as_tensor( + self._train_env.single_action_space.sample() + ).numel() + self._temperature_optimizer = self._build_temperature_optimizer() + self._temperature_optimizer_lr_scheduler = ( + self._build_temperature_optimizer_lr_scheduler() + ) + self.temperature = lambda *args: self._log_temperature.exp().item() + elif not isinstance(self.temperature, Schedule): + self.temperature = ConstantSchedule(self.temperature) + + # override + @property + def _checkpoint(self) -> "Dict[str, Any]": + checkpoint = super()._checkpoint + if self.temperature_optimizer_builder is not None: + checkpoint["log_temperature"] = self._log_temperature + checkpoint[ + "temperature_optimizer" + ] = self._temperature_optimizer.state_dict() + if self._temperature_optimizer_lr_scheduler is not None: + checkpoint[ + "temperature_optimizer_lr_scheduler" + ] = self._temperature_optimizer_lr_scheduler.state_dict() + checkpoint["temperature"] = self.temperature.state_dict() + return checkpoint + + # override + @_checkpoint.setter + def _checkpoint(self, value: "Dict[str, Any]") -> "None": + super()._checkpoint = value + if "log_temperature" in value: + self._log_temperature = value["log_temperature"] + if "temperature_optimizer" in value: + self._temperature_optimizer.load_state_dict(value["temperature_optimizer"]) + if "temperature_optimizer_lr_scheduler" in value: + self._temperature_optimizer_lr_scheduler.load_state_dict( + value["temperature_optimizer_lr_scheduler"] + ) + self.temperature.load_state_dict(value["temperature"]) + + def _build_temperature_optimizer(self) -> "Optimizer": + if self.temperature_optimizer_config is None: + self.temperature_optimizer_config: "Dict[str, Any]" = {} + return self.temperature_optimizer_builder( + [self._log_temperature], + **self.temperature_optimizer_config, + ) + + def _build_temperature_optimizer_lr_scheduler( + self, + ) -> "Optional[LRScheduler]": + if self.temperature_optimizer_lr_scheduler_builder is None: + return + if self.temperature_optimizer_lr_scheduler_config is None: + self.temperature_optimizer_lr_scheduler_config: "Dict[str, Any]" = {} + return self.temperature_optimizer_lr_scheduler_builder( + self._temperature_optimizer, + **self.temperature_optimizer_lr_scheduler_config, + ) + + # override + def _train_step(self) -> "Dict[str, Any]": + result = super()._train_step() + if self.temperature_optimizer_builder is None: + result["temperature"] = self.temperature() + result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None) + result["buffer_num_experiences"] = result.pop("buffer_num_experiences", None) + result["buffer_num_full_trajectories"] = result.pop( + "buffer_num_full_trajectories", None + ) + result.pop("delay", None) + result.pop("noise_stddev", None) + result.pop("noise_clip", None) + return result + + # override + def _train_on_batch( + self, + idx: "int", + experiences: "Dict[str, Tensor]", + is_weight: "Tensor", + mask: "Tensor", + ) -> "Tuple[Dict[str, Any], Optional[ndarray]]": + sync_freq = self.sync_freq() + if sync_freq < 1 or not float(sync_freq).is_integer(): + raise ValueError( + f"`sync_freq` ({sync_freq}) " + f"must be in the integer interval [1, inf)" + ) + sync_freq = int(sync_freq) + + temperature = self.temperature() + if temperature <= 0.0: + raise ValueError( + f"`temperature` ({temperature}) must be in the interval (0, inf)" + ) + + result = {} + + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + self._policy_network(experiences["observation"], mask=mask) + policy = self._policy_network.distribution + target_actions = policy.rsample() + target_log_probs = policy.log_prob(target_actions) + observations_target_actions = torch.cat( + [ + x[..., None] if x.shape == mask.shape else x + for x in [experiences["observation"], target_actions] + ], + dim=-1, + ) + with torch.no_grad(): + target_action_values, _ = self._target_value_network( + observations_target_actions, mask=mask + ) + target_twin_action_values, _ = self._target_twin_value_network( + observations_target_actions, mask=mask + ) + target_action_values = target_action_values.min(target_twin_action_values) + with torch.no_grad(): + target_action_values -= temperature * target_log_probs + + # Discard next observation + experiences["observation"] = experiences["observation"][:, :-1, ...] + observations_target_actions = observations_target_actions[:, :-1, ...] + target_log_probs = target_log_probs[:, :-1, ...] + mask = mask[:, 1:] + targets, _ = n_step_return( + target_action_values, + experiences["reward"], + experiences["terminal"], + mask, + self.discount(), + self.num_return_steps(), + return_advantage=False, + ) + + # Compute action values + observations_actions = torch.cat( + [ + x[..., None] if x.shape == mask.shape else x + for x in [experiences["observation"], experiences["action"]] + ], + dim=-1, + ) + action_values, _ = self._value_network(observations_actions, mask=mask) + twin_action_values, _ = self._twin_value_network( + observations_actions, mask=mask + ) + + result["value_network"], priority = self._train_on_batch_value_network( + action_values, + targets, + is_weight, + mask, + ) + + ( + result["twin_value_network"], + twin_priority, + ) = self._train_on_batch_twin_value_network( + twin_action_values, + targets, + is_weight, + mask, + ) + + if priority is not None: + priority += twin_priority + priority /= 2.0 + + result["policy_network"] = self._train_on_batch_policy_network( + observations_target_actions, + target_log_probs, + mask, + ) + + if self.temperature_optimizer_builder is not None: + result["temperature"] = self._train_on_batch_temperature( + target_log_probs, + mask, + ) + + # Synchronize + if idx % sync_freq == 0: + sync_polyak_( + self._value_network, + self._target_value_network, + self.polyak_weight(), + ) + sync_polyak_( + self._twin_value_network, + self._target_twin_value_network, + self.polyak_weight(), + ) + + self._grad_scaler.update() + return result, priority + + # override + def _train_on_batch_policy_network( + self, + observations_actions: "Tensor", + log_probs: "Tensor", + mask: "Tensor", + ) -> "Dict[str, Any]": + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + with freeze_params(self._value_network, self._twin_value_network): + action_values, _ = self._value_network(observations_actions, mask=mask) + twin_action_values, _ = self._twin_value_network( + observations_actions, mask=mask + ) + action_values = action_values.min(twin_action_values) + log_prob = log_probs[mask] + action_value = action_values[mask] + loss = self.temperature() * log_prob + loss -= action_value + loss = loss.mean() + optimize_result = self._optimize_policy_network(loss) + result = { + "log_prob": log_prob.mean().item(), + "action_value": action_value.mean().item(), + "loss": loss.item(), + } + result.update(optimize_result) + return result + + def _train_on_batch_temperature( + self, + log_probs: "Tensor", + mask: "Tensor", + ) -> "Dict[str, Any]": + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + with torch.no_grad(): + log_prob = log_probs[mask] + loss = log_prob + self._target_entropy + loss *= -self._log_temperature + loss = loss.mean() + optimize_result = self._optimize_temperature(loss) + result = { + "temperature": self.temperature(), + "log_prob": log_prob.mean().item(), + "loss": loss.item(), + } + result.update(optimize_result) + return result + + def _optimize_temperature(self, loss: "Tensor") -> "Dict[str, Any]": + max_grad_l2_norm = self.max_grad_l2_norm() + if max_grad_l2_norm <= 0.0: + raise ValueError( + f"`max_grad_l2_norm` ({max_grad_l2_norm}) must be in the interval (0, inf]" + ) + self._temperature_optimizer.zero_grad(set_to_none=True) + self._grad_scaler.scale(loss).backward() + self._grad_scaler.unscale_(self._temperature_optimizer) + grad_l2_norm = clip_grad_norm_(self._log_temperature, max_grad_l2_norm) + self._grad_scaler.step(self._temperature_optimizer) + result = { + "lr": self._temperature_optimizer.param_groups[0]["lr"], + "grad_l2_norm": min(grad_l2_norm.item(), max_grad_l2_norm), + } + if self._temperature_optimizer_lr_scheduler is not None: + self._temperature_optimizer_lr_scheduler.step() + return result + + @singledispatchmethod(use_weakrefs=False) + def _get_default_policy_network_distribution_builder( + self, + space: "spaces.Space", + ) -> "Callable[..., Distribution]": + raise NotImplementedError( + f"Unsupported space type: " + f"`{type(space).__module__}.{type(space).__name__}`. " + f"Register a custom space type through decorator " + f"`{type(self).__module__}.{type(self).__name__}." + f"_get_default_policy_network_distribution_builder.register`" + ) + + @singledispatchmethod(use_weakrefs=False) + def _get_default_policy_network_distribution_parametrization( + self, + space: "spaces.Space", + ) -> "Callable[..., DistributionParametrization]": + raise NotImplementedError( + f"Unsupported space type: " + f"`{type(space).__module__}.{type(space).__name__}`. " + f"Register a custom space type through decorator " + f"`{type(self).__module__}.{type(self).__name__}." + f"_get_default_policy_network_distribution_parametrization.register`" + ) + + @singledispatchmethod(use_weakrefs=False) + def _get_default_policy_network_distribution_config( + self, + space: "spaces.Space", + ) -> "Dict[str, Any]": + raise NotImplementedError( + f"Unsupported space type: " + f"`{type(space).__module__}.{type(space).__name__}`. " + f"Register a custom space type through decorator " + f"`{type(self).__module__}.{type(self).__name__}." + f"_get_default_policy_network_distribution_config.register`" + ) + + +class DistributedDataParallelSAC(DistributedDataParallelTD3): + """Distributed data parallel Soft Actor-Critic (SAC). + + See Also + -------- + actorch.algorithms.sac.SAC + + """ + + _ALGORITHM_CLS = SAC # override + + +##################################################################################################### +# SAC._get_default_policy_network_distribution_builder implementation +##################################################################################################### + + +@SAC._get_default_policy_network_distribution_builder.register(spaces.Box) +def _get_default_policy_network_distribution_builder_box( + self, + space: "spaces.Box", +) -> "Callable[..., Normal]": + return Normal + + +##################################################################################################### +# SAC._get_default_policy_network_distribution_parametrization implementation +##################################################################################################### + + +@SAC._get_default_policy_network_distribution_parametrization.register(spaces.Box) +def _get_default_policy_network_distribution_parametrization_box( + self, + space: "spaces.Box", +) -> "DistributionParametrization": + return { + "loc": ( + {"loc": space.shape}, + lambda x: x["loc"], + ), + "scale": ( + {"log_scale": space.shape}, + lambda x: x["log_scale"].exp(), + ), + } + + +##################################################################################################### +# SAC._get_default_policy_network_distribution_config implementation +##################################################################################################### + + +@SAC._get_default_policy_network_distribution_config.register(spaces.Box) +def _get_default_policy_network_distribution_config_box( + self, + space: "spaces.Box", +) -> "Dict[str, Any]": + return {"validate_args": False} diff --git a/actorch/algorithms/td3.py b/actorch/algorithms/td3.py index a922fb0..b25ec24 100644 --- a/actorch/algorithms/td3.py +++ b/actorch/algorithms/td3.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -"""Twin Delayed Deep Deterministic Policy Gradient.""" +"""Twin Delayed Deep Deterministic Policy Gradient (TD3).""" import contextlib import copy @@ -33,12 +33,12 @@ from actorch.agents import Agent from actorch.algorithms.algorithm import RefOrFutureRef, Tunable from actorch.algorithms.ddpg import DDPG, DistributedDataParallelDDPG, Loss, LRScheduler -from actorch.algorithms.utils import prepare_model, sync_polyak +from actorch.algorithms.utils import prepare_model, sync_polyak_ from actorch.algorithms.value_estimation import n_step_return from actorch.buffers import Buffer from actorch.envs import BatchedEnv from actorch.models import Model -from actorch.networks import Processor +from actorch.networks import Network, Processor from actorch.samplers import Sampler from actorch.schedules import ConstantSchedule, Schedule @@ -50,13 +50,13 @@ class TD3(DDPG): - """Twin Delayed Deep Deterministic Policy Gradient. + """Twin Delayed Deep Deterministic Policy Gradient (TD3). References ---------- .. [1] S. Fujimoto, H. van Hoof, and D. Meger. "Addressing Function Approximation Error in Actor-Critic Methods". - In: ICML. 2018, pp. 1587–1596. + In: ICML. 2018, pp. 1587-1596. URL: https://arxiv.org/abs/1802.09477 """ @@ -201,19 +201,13 @@ def setup(self, config: "Dict[str, Any]") -> "None": self.config = TD3.Config(**self.config) self.config["_accept_kwargs"] = True super().setup(config) - self._twin_value_network = ( - self._build_value_network().train().to(self._device, non_blocking=True) - ) - self._twin_value_network_loss = ( - self._build_value_network_loss().train().to(self._device, non_blocking=True) - ) + self._twin_value_network = self._build_value_network() + self._twin_value_network_loss = self._build_value_network_loss() self._twin_value_network_optimizer = self._build_twin_value_network_optimizer() self._twin_value_network_optimizer_lr_scheduler = ( self._build_twin_value_network_optimizer_lr_scheduler() ) - self._target_twin_value_network = copy.deepcopy(self._twin_value_network) - self._target_twin_value_network.eval().to(self._device, non_blocking=True) - self._target_twin_value_network.requires_grad_(False) + self._target_twin_value_network = self._build_target_twin_value_network() if not isinstance(self.delay, Schedule): self.delay = ConstantSchedule(self.delay) if not isinstance(self.noise_stddev, Schedule): @@ -297,6 +291,11 @@ def _build_twin_value_network_optimizer_lr_scheduler( **self.value_network_optimizer_lr_scheduler_config, ) + def _build_target_twin_value_network(self) -> "Network": + target_twin_value_network = copy.deepcopy(self._twin_value_network) + target_twin_value_network.requires_grad_(False) + return target_twin_value_network.eval().to(self._device, non_blocking=True) + # override def _train_step(self) -> "Dict[str, Any]": result = super()._train_step() @@ -306,9 +305,10 @@ def _train_step(self) -> "Dict[str, Any]": result["delay"] = self.delay() result["noise_stddev"] = self.noise_stddev() result["noise_clip"] = self.noise_clip() - result["buffer_num_experiences"] = result.pop("buffer_num_experiences") + result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None) + result["buffer_num_experiences"] = result.pop("buffer_num_experiences", None) result["buffer_num_full_trajectories"] = result.pop( - "buffer_num_full_trajectories" + "buffer_num_full_trajectories", None ) return result @@ -398,6 +398,7 @@ def _train_on_batch( mask, self.discount(), self.num_return_steps(), + return_advantage=False, ) # Compute action values @@ -441,17 +442,17 @@ def _train_on_batch( ) # Synchronize if idx % sync_freq == 0: - sync_polyak( + sync_polyak_( self._policy_network, self._target_policy_network, self.polyak_weight(), ) - sync_polyak( + sync_polyak_( self._value_network, self._target_value_network, self.polyak_weight(), ) - sync_polyak( + sync_polyak_( self._twin_value_network, self._target_twin_value_network, self.polyak_weight(), @@ -475,10 +476,12 @@ def _train_on_batch_twin_value_network( action_values = action_values[mask] target = targets[mask] loss = self._twin_value_network_loss(action_values, target) - loss *= is_weight[:, None].expand_as(mask)[mask] + priority = None + if self._buffer.is_prioritized: + loss *= is_weight[:, None].expand_as(mask)[mask] + priority = loss.detach().abs().to("cpu").numpy() loss = loss.mean() optimize_result = self._optimize_twin_value_network(loss) - priority = None result = { "action_value": action_values.mean().item(), "target": target.mean().item(), @@ -510,7 +513,7 @@ def _optimize_twin_value_network(self, loss: "Tensor") -> "Dict[str, Any]": class DistributedDataParallelTD3(DistributedDataParallelDDPG): - """Distributed data parallel Twin Delayed Deep Deterministic Policy Gradient. + """Distributed data parallel Twin Delayed Deep Deterministic Policy Gradient (TD3). See Also -------- diff --git a/actorch/algorithms/trpo.py b/actorch/algorithms/trpo.py new file mode 100644 index 0000000..c4b749d --- /dev/null +++ b/actorch/algorithms/trpo.py @@ -0,0 +1,394 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Trust Region Policy Optimization (TRPO).""" + +import contextlib +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from gymnasium import Env +from numpy import ndarray +from torch import Tensor +from torch.cuda.amp import autocast +from torch.distributions import Distribution, kl_divergence +from torch.nn.utils import clip_grad_norm_ +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from actorch.agents import Agent +from actorch.algorithms.a2c import A2C, DistributedDataParallelA2C, Loss, LRScheduler +from actorch.algorithms.algorithm import RefOrFutureRef, Tunable +from actorch.algorithms.utils import normalize_ +from actorch.algorithms.value_estimation import lambda_return +from actorch.envs import BatchedEnv +from actorch.models import Model +from actorch.networks import DistributionParametrization, NormalizingFlow, Processor +from actorch.optimizers import CGBLS +from actorch.samplers import Sampler +from actorch.schedules import ConstantSchedule, Schedule + + +__all__ = [ + "DistributedDataParallelTRPO", + "TRPO", +] + + +class TRPO(A2C): + """Trust Region Policy Optimization (TRPO). + + References + ---------- + .. [1] J. Schulman, S. Levine, P. Abbeel, M. Jordan, and P. Moritz. + "Trust Region Policy Optimization". + In: ICML. 2015, pp. 1889-1897. + URL: https://arxiv.org/abs/1502.05477 + + """ + + # override + class Config(dict): + """Keyword arguments expected in the configuration received by `setup`.""" + + def __init__( + self, + train_env_builder: "Tunable[RefOrFutureRef[Callable[..., Union[Env, BatchedEnv]]]]", + train_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None, + train_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None, + train_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + train_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + eval_freq: "Tunable[RefOrFutureRef[Optional[int]]]" = 1, + eval_env_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Union[Env, BatchedEnv]]]]]" = None, + eval_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None, + eval_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None, + eval_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + eval_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + policy_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + policy_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None, + policy_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + policy_network_distribution_builders: "Tunable[RefOrFutureRef[Optional[Dict[str, Callable[..., Distribution]]]]]" = None, + policy_network_distribution_parametrizations: "Tunable[RefOrFutureRef[Optional[Dict[str, DistributionParametrization]]]]" = None, + policy_network_distribution_configs: "Tunable[RefOrFutureRef[Optional[Dict[str, Dict[str, Any]]]]]" = None, + policy_network_normalizing_flows: "Tunable[RefOrFutureRef[Optional[Dict[str, NormalizingFlow]]]]" = None, + policy_network_sample_fn: "Tunable[RefOrFutureRef[Optional[Callable[[Distribution], Tensor]]]]" = None, + policy_network_prediction_fn: "Tunable[RefOrFutureRef[Optional[Callable[[Tensor], Tensor]]]]" = None, + policy_network_postprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + policy_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None, + policy_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + value_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None, + value_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_loss_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Loss]]]]" = None, + value_network_loss_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None, + value_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None, + value_network_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + dataloader_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., DataLoader]]]]" = None, + dataloader_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + discount: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.99, + trace_decay: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.95, + normalize_advantage: "Tunable[RefOrFutureRef[bool]]" = False, + entropy_coeff: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.01, + max_grad_l2_norm: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = float( # noqa: B008 + "inf" + ), + cumreward_window_size: "Tunable[RefOrFutureRef[int]]" = 100, + seed: "Tunable[RefOrFutureRef[int]]" = 0, + enable_amp: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False, + enable_reproducibility: "Tunable[RefOrFutureRef[bool]]" = False, + enable_anomaly_detection: "Tunable[RefOrFutureRef[bool]]" = False, + enable_profiling: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False, + log_sys_usage: "Tunable[RefOrFutureRef[bool]]" = False, + suppress_warnings: "Tunable[RefOrFutureRef[bool]]" = False, + _accept_kwargs: "bool" = False, + **kwargs: "Any", + ) -> "None": + if not _accept_kwargs and kwargs: + raise ValueError(f"Unexpected configuration arguments: {list(kwargs)}") + super().__init__( + train_env_builder=train_env_builder, + train_env_config=train_env_config, + train_agent_builder=train_agent_builder, + train_agent_config=train_agent_config, + train_sampler_builder=train_sampler_builder, + train_sampler_config=train_sampler_config, + train_num_timesteps_per_iter=train_num_timesteps_per_iter, + train_num_episodes_per_iter=train_num_episodes_per_iter, + eval_freq=eval_freq, + eval_env_builder=eval_env_builder, + eval_env_config=eval_env_config, + eval_agent_builder=eval_agent_builder, + eval_agent_config=eval_agent_config, + eval_sampler_builder=eval_sampler_builder, + eval_sampler_config=eval_sampler_config, + eval_num_timesteps_per_iter=eval_num_timesteps_per_iter, + eval_num_episodes_per_iter=eval_num_episodes_per_iter, + policy_network_preprocessors=policy_network_preprocessors, + policy_network_model_builder=policy_network_model_builder, + policy_network_model_config=policy_network_model_config, + policy_network_distribution_builders=policy_network_distribution_builders, + policy_network_distribution_parametrizations=policy_network_distribution_parametrizations, + policy_network_distribution_configs=policy_network_distribution_configs, + policy_network_normalizing_flows=policy_network_normalizing_flows, + policy_network_sample_fn=policy_network_sample_fn, + policy_network_prediction_fn=policy_network_prediction_fn, + policy_network_postprocessors=policy_network_postprocessors, + policy_network_optimizer_builder=policy_network_optimizer_builder, + policy_network_optimizer_config=policy_network_optimizer_config, + value_network_preprocessors=value_network_preprocessors, + value_network_model_builder=value_network_model_builder, + value_network_model_config=value_network_model_config, + value_network_loss_builder=value_network_loss_builder, + value_network_loss_config=value_network_loss_config, + value_network_optimizer_builder=value_network_optimizer_builder, + value_network_optimizer_config=value_network_optimizer_config, + value_network_optimizer_lr_scheduler_builder=value_network_optimizer_lr_scheduler_builder, + value_network_optimizer_lr_scheduler_config=value_network_optimizer_lr_scheduler_config, + dataloader_builder=dataloader_builder, + dataloader_config=dataloader_config, + discount=discount, + trace_decay=trace_decay, + normalize_advantage=normalize_advantage, + entropy_coeff=entropy_coeff, + max_grad_l2_norm=max_grad_l2_norm, + cumreward_window_size=cumreward_window_size, + seed=seed, + enable_amp=enable_amp, + enable_reproducibility=enable_reproducibility, + enable_anomaly_detection=enable_anomaly_detection, + enable_profiling=enable_profiling, + log_sys_usage=log_sys_usage, + suppress_warnings=suppress_warnings, + **kwargs, + ) + + # override + def setup(self, config: "Dict[str, Any]") -> "None": + self.config = TRPO.Config(**self.config) + self.config["_accept_kwargs"] = True + super().setup(config) + if not isinstance(self.trace_decay, Schedule): + self.trace_decay = ConstantSchedule(self.trace_decay) + + # override + @property + def _checkpoint(self) -> "Dict[str, Any]": + checkpoint = super()._checkpoint + checkpoint["trace_decay"] = self.trace_decay.state_dict() + return checkpoint + + # override + @_checkpoint.setter + def _checkpoint(self, value: "Dict[str, Any]") -> "None": + super()._checkpoint = value + self.trace_decay.load_state_dict(value["trace_decay"]) + + # override + def _build_policy_network_optimizer(self) -> "Optimizer": + if self.policy_network_optimizer_builder is None: + self.policy_network_optimizer_builder = CGBLS + if self.policy_network_optimizer_config is None: + self.policy_network_optimizer_config = { + "max_constraint": 0.01, + "num_cg_iters": 10, + "max_backtracks": 15, + "backtrack_ratio": 0.8, + "hvp_reg_coeff": 1e-5, + "accept_violation": False, + "epsilon": 1e-8, + } + if self.policy_network_optimizer_config is None: + self.policy_network_optimizer_config = {} + return self.policy_network_optimizer_builder( + self._policy_network.parameters(), + **self.policy_network_optimizer_config, + ) + + # override + def _train_step(self) -> "Dict[str, Any]": + result = super()._train_step() + self.trace_decay.step() + result["trace_decay"] = self.trace_decay() + result["entropy_coeff"] = result.pop("entropy_coeff", None) + result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None) + result.pop("num_return_steps", None) + return result + + # override + def _train_on_batch( + self, + idx: "int", + experiences: "Dict[str, Tensor]", + is_weight: "Tensor", + mask: "Tensor", + ) -> "Tuple[Dict[str, Any], Optional[ndarray]]": + result = {} + + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + state_values, _ = self._value_network(experiences["observation"], mask=mask) + # Discard next observation + experiences["observation"] = experiences["observation"][:, :-1, ...] + mask = mask[:, 1:] + with torch.no_grad(): + targets, advantages = lambda_return( + state_values, + experiences["reward"], + experiences["terminal"], + mask, + self.discount(), + self.trace_decay(), + ) + if self.normalize_advantage: + normalize_(advantages, dim=-1, mask=mask) + + # Discard next state value + state_values = state_values[:, :-1] + + result["value_network"], priority = self._train_on_batch_value_network( + state_values, + targets, + is_weight, + mask, + ) + result["policy_network"] = self._train_on_batch_policy_network( + experiences, + advantages, + mask, + ) + self._grad_scaler.update() + return result, priority + + # override + def _train_on_batch_policy_network( + self, + experiences: "Dict[str, Tensor]", + advantages: "Tensor", + mask: "Tensor", + ) -> "Dict[str, Any]": + entropy_coeff = self.entropy_coeff() + if entropy_coeff < 0.0: + raise ValueError( + f"`entropy_coeff` ({entropy_coeff}) must be in the interval [0, inf)" + ) + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + advantage = advantages[mask] + old_log_prob = experiences["log_prob"][mask] + with torch.no_grad(): + self._policy_network(experiences["observation"], mask=mask) + old_policy = self._policy_network.distribution + policy = loss = log_prob = entropy_bonus = kl_div = None + + def compute_loss() -> "Tensor": + nonlocal policy, loss, log_prob, entropy_bonus + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + self._policy_network(experiences["observation"], mask=mask) + policy = self._policy_network.distribution + log_prob = policy.log_prob(experiences["action"])[mask] + ratio = (log_prob - old_log_prob).exp() + loss = -advantage * ratio + entropy_bonus = None + if entropy_coeff != 0.0: + entropy_bonus = -entropy_coeff * policy.entropy()[mask] + loss += entropy_bonus + loss = loss.mean() + return loss + + def compute_kl_div() -> "Tensor": + nonlocal policy, kl_div + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + if policy is None: + self._policy_network(experiences["observation"], mask=mask) + policy = self._policy_network.distribution + kl_div = kl_divergence(old_policy, policy)[mask].mean() + policy = None + return kl_div + + loss = compute_loss() + optimize_result = self._optimize_policy_network( + loss, compute_loss, compute_kl_div + ) + result = { + "advantage": advantage.mean().item(), + "log_prob": log_prob.mean().item(), + "old_log_prob": old_log_prob.mean().item(), + } + if kl_div is not None: + result["kl_div"] = kl_div.item() + result["loss"] = loss.item() + if entropy_bonus is not None: + result["entropy_bonus"] = entropy_bonus.mean().item() + result.update(optimize_result) + return result + + # override + def _optimize_policy_network( + self, + loss: "Tensor", + loss_fn: "Callable[[], Tensor]", + constraint_fn: "Callable[[], Tensor]", + ) -> "Dict[str, Any]": + max_grad_l2_norm = self.max_grad_l2_norm() + if max_grad_l2_norm <= 0.0: + raise ValueError( + f"`max_grad_l2_norm` ({max_grad_l2_norm}) must be in the interval (0, inf]" + ) + self._policy_network_optimizer.zero_grad(set_to_none=True) + self._grad_scaler.scale(loss).backward(retain_graph=True) + self._grad_scaler.unscale_(self._policy_network_optimizer) + grad_l2_norm = clip_grad_norm_( + self._policy_network.parameters(), max_grad_l2_norm + ) + self._grad_scaler.step(self._policy_network_optimizer, loss_fn, constraint_fn) + result = { + "grad_l2_norm": min(grad_l2_norm.item(), max_grad_l2_norm), + } + return result + + +class DistributedDataParallelTRPO(DistributedDataParallelA2C): + """Distributed data parallel Trust Region Policy Optimization (TRPO). + + See Also + -------- + actorch.algorithms.trpo.TRPO + + """ + + _ALGORITHM_CLS = TRPO # override diff --git a/actorch/algorithms/utils.py b/actorch/algorithms/utils.py index 916d79b..456ee80 100644 --- a/actorch/algorithms/utils.py +++ b/actorch/algorithms/utils.py @@ -23,6 +23,7 @@ import ray.train.torch # Fix missing train.torch attribute import torch from ray import train +from torch import Tensor from torch import distributed as dist from torch.nn import Module from torch.nn.parallel import DataParallel, DistributedDataParallel @@ -32,11 +33,59 @@ "count_params", "freeze_params", "init_mock_train_session", + "normalize_", "prepare_model", - "sync_polyak", + "sync_polyak_", ] +def normalize_(input, dim: "int" = 0, mask: "Optional[Tensor]" = None) -> "Tensor": + """Normalize a tensor along a dimension, by subtracting the mean + and then dividing by the standard deviation (in-place). + + Parameters + ---------- + input: + The tensor. + dim: + The dimension. + mask: + The boolean tensor indicating which elements + are valid (True) and which are not (False). + Default to ``torch.ones_like(input, dtype=torch.bool)``. + + Returns + ------- + The normalized tensor. + + Examples + -------- + >>> import torch + >>> + >>> from actorch.algorithms.utils import normalize_ + >>> + >>> + >>> input = torch.randn(2,3) + >>> mask = torch.rand(2, 3) > 0.5 + >>> output = normalize_(input, mask=mask) + + """ + if mask is None: + mask = torch.ones_like(input, dtype=torch.bool) + else: + input[~mask] = 0.0 + length = mask.sum(dim=dim, keepdim=True) + input_mean = input.sum(dim=dim, keepdim=True) / length + input -= input_mean + input *= mask + input_stddev = ( + ((input**2).sum(dim=dim, keepdim=True) / length).sqrt().clamp(min=1e-6) + ) + input /= input_stddev + input *= mask + return input + + def count_params(module: "Module") -> "Tuple[int, int]": """Return the number of trainable and non-trainable parameters in `module`. @@ -51,16 +100,26 @@ def count_params(module: "Module") -> "Tuple[int, int]": - The number of trainable parameters; - the number of non-trainable parameters. + Examples + -------- + >>> import torch + >>> + >>> from actorch.algorithms.utils import count_params + >>> + >>> + >>> model = torch.nn.Linear(4, 2) + >>> trainable_count, non_trainable_count = count_params(model) + """ - num_trainable_params, num_non_trainable_params = 0, 0 + trainable_count, non_trainable_count = 0, 0 for param in module.parameters(): if param.requires_grad: - num_trainable_params += param.numel() + trainable_count += param.numel() else: - num_non_trainable_params += param.numel() + non_trainable_count += param.numel() for buffer in module.buffers(): - num_non_trainable_params += buffer.numel() - return num_trainable_params, num_non_trainable_params + non_trainable_count += buffer.numel() + return trainable_count, non_trainable_count @contextmanager @@ -73,6 +132,21 @@ def freeze_params(*modules: "Module") -> "Iterator[None]": modules: The modules. + Examples + -------- + >>> import torch + >>> + >>> from actorch.algorithms.utils import freeze_params + >>> + >>> + >>> policy_model = torch.nn.Linear(4, 2) + >>> value_model = torch.nn.Linear(2, 1) + >>> input = torch.randn(3, 4) + >>> with freeze_params(value_model): + ... action = policy_model(input) + ... loss = -value_model(action).mean() + >>> loss.backward() + """ params = [ param @@ -89,19 +163,19 @@ def freeze_params(*modules: "Module") -> "Iterator[None]": param.requires_grad = True -def sync_polyak( +def sync_polyak_( source_module: "Module", target_module: "Module", polyak_weight: "float" = 0.001, -) -> "None": - """Synchronize `source_module` with `target_module` - through Polyak averaging. +) -> "Module": + """Synchronize a source module with a target module + through Polyak averaging (in-place). - For each `target_param` in `target_module`, - for each `source_param` in `source_module`: - `target_param` = - (1 - `polyak_weight`) * `target_param` - + `polyak_weight` * `source_param`. + For each `target_parameter` in `target_module`, + for each `source_parameter` in `source_module`: + `target_parameter` = + (1 - `polyak_weight`) * `target_parameter` + + `polyak_weight` * `source_parameter`. Parameters ---------- @@ -112,6 +186,10 @@ def sync_polyak( polyak_weight: The Polyak weight. + Returns + ------- + The synchronized target module. + Raises ------ ValueError @@ -124,16 +202,27 @@ def sync_polyak( In: SIAM Journal on Control and Optimization. 1992, pp. 838-855. URL: https://doi.org/10.1137/0330046 + Examples + -------- + >>> import torch + >>> + >>> from actorch.algorithms.utils import sync_polyak_ + >>> + >>> + >>> model = torch.nn.Linear(4, 2) + >>> target_model = torch.nn.Linear(4, 2) + >>> sync_polyak_(model, target_model, polyak_weight=0.001) + """ if polyak_weight < 0.0 or polyak_weight > 1.0: raise ValueError( f"`polyak_weight` ({polyak_weight}) must be in the interval [0, 1]" ) if polyak_weight == 0.0: - return + return target_module if polyak_weight == 1.0: target_module.load_state_dict(source_module.state_dict()) - return + return target_module # Update parameters target_params = target_module.parameters() source_params = source_module.parameters() @@ -151,6 +240,7 @@ def sync_polyak( target_buffer *= (1 - polyak_weight) / polyak_weight target_buffer += source_buffer target_buffer *= polyak_weight + return target_module def init_mock_train_session() -> "None": diff --git a/actorch/algorithms/value_estimation/generalized_estimator.py b/actorch/algorithms/value_estimation/generalized_estimator.py index ceb15a7..4c6c1b2 100644 --- a/actorch/algorithms/value_estimation/generalized_estimator.py +++ b/actorch/algorithms/value_estimation/generalized_estimator.py @@ -32,7 +32,7 @@ ] -def generalized_estimator( +def generalized_estimator( # noqa: C901 state_values: "Union[Tensor, Distribution]", rewards: "Tensor", terminals: "Tensor", @@ -44,7 +44,8 @@ def generalized_estimator( discount: "float" = 0.99, num_return_steps: "int" = 1, trace_decay: "float" = 1.0, -) -> "Tuple[Union[Tensor, Distribution], Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]": """Compute the (possibly distributional) generalized estimator targets and the corresponding advantages of a trajectory. @@ -81,12 +82,14 @@ def generalized_estimator( The number of return steps (`n` in the literature). trace_decay: The trace-decay parameter (`lambda` in the literature). + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The (possibly distributional) generalized estimator targets, shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``; - - the corresponding advantages, ``[B, T]``. + - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``. Raises ------ @@ -199,6 +202,8 @@ def generalized_estimator( f"be equal to the shape of `rewards` ({rewards.shape})" ) targets, state_values, next_state_values = _compute_targets(*compute_targets_args) + if not return_advantage: + return targets, None advantages = _compute_advantages( targets, state_values, @@ -308,7 +313,7 @@ def _compute_distributional_targets( B, T = rewards.shape num_return_steps = min(num_return_steps, T) length = mask.sum(dim=1, keepdim=True) - idx = torch.arange(1, T + 1).expand(B, T) + idx = torch.arange(1, T + 1, device=length.device).expand(B, T) if state_values.batch_shape == (B, T + 1): next_state_values = distributional_gather( state_values, @@ -316,7 +321,7 @@ def _compute_distributional_targets( idx.clamp(max=length), mask * (~terminals), ) - idx = torch.arange(0, T).expand(B, T) + idx = torch.arange(0, T, device=length.device).expand(B, T) state_values = distributional_gather( state_values, 1, @@ -361,7 +366,9 @@ def _compute_distributional_targets( next_state_value_coeffs *= ~terminals next_state_value_coeffs *= mask # Gather - idx = torch.arange(num_return_steps - 1, T + num_return_steps - 1).expand(B, T) + idx = torch.arange( + num_return_steps - 1, T + num_return_steps - 1, device=length.device + ).expand(B, T) next_state_values = distributional_gather( next_state_values, 1, @@ -371,7 +378,7 @@ def _compute_distributional_targets( targets = TransformedDistribution( next_state_values, AffineTransform(offsets, next_state_value_coeffs), - next_state_values._validate_args, + validate_args=False, ).reduced_dist return targets, state_values, next_state_values coeffs = torch.stack( @@ -403,14 +410,11 @@ def _compute_distributional_targets( next_state_values, (B, T, num_return_steps) ) # Transform - validate_args = ( - state_or_action_values._validate_args or next_state_values._validate_args - ) targets = TransformedDistribution( CatDistribution( [state_or_action_values, next_state_values], dim=-1, - validate_args=validate_args, + validate_args=False, ), [ AffineTransform( @@ -420,7 +424,7 @@ def _compute_distributional_targets( SumTransform((2,)), SumTransform((num_return_steps,)), ], - validate_args=validate_args, + validate_args=False, ).reduced_dist return targets, state_values, next_state_values @@ -462,9 +466,8 @@ def _compute_advantages( trace_decay * targets + (1 - trace_decay) * state_values, [0, 1], )[:, 1:] - next_targets[torch.arange(B), length - 1] = next_state_values[ - torch.arange(B), length - 1 - ] + batch_idx = torch.arange(B) + next_targets[batch_idx, length - 1] = next_state_values[batch_idx, length - 1] action_values = rewards + discount * next_targets advantages = advantage_weights * (action_values - state_values) advantages *= mask diff --git a/actorch/algorithms/value_estimation/importance_sampling.py b/actorch/algorithms/value_estimation/importance_sampling.py index 91d1f53..754050d 100644 --- a/actorch/algorithms/value_estimation/importance_sampling.py +++ b/actorch/algorithms/value_estimation/importance_sampling.py @@ -40,7 +40,8 @@ def importance_sampling( log_is_weights: "Tensor", mask: "Optional[Tensor]" = None, discount: "float" = 0.99, -) -> "Tuple[Union[Tensor, Distribution], Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]": """Compute the (possibly distributional) importance sampling targets, a.k.a. IS, and the corresponding advantages of a trajectory. @@ -73,12 +74,14 @@ def importance_sampling( Default to ``torch.ones_like(rewards, dtype=torch.bool)``. discount: The discount factor (`gamma` in the literature). + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The (possibly distributional) importance sampling targets, shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``; - - the corresponding advantages, shape: ``[B, T]``. + - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``. References ---------- @@ -108,4 +111,5 @@ def importance_sampling( discount=discount, num_return_steps=rewards.shape[1], trace_decay=1.0, + return_advantage=return_advantage, ) diff --git a/actorch/algorithms/value_estimation/lambda_return.py b/actorch/algorithms/value_estimation/lambda_return.py index 8950462..095d7c0 100644 --- a/actorch/algorithms/value_estimation/lambda_return.py +++ b/actorch/algorithms/value_estimation/lambda_return.py @@ -37,7 +37,8 @@ def lambda_return( mask: "Optional[Tensor]" = None, discount: "float" = 0.99, trace_decay: "float" = 1.0, -) -> "Tuple[Union[Tensor, Distribution], Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]": """Compute the (possibly distributional) lambda returns, a.k.a. TD(lambda), and the corresponding advantages, a.k.a. GAE(lambda), of a trajectory. @@ -63,12 +64,14 @@ def lambda_return( The discount factor (`gamma` in the literature). trace_decay: The trace-decay parameter (`lambda` in the literature). + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The (possibly distributional) lambda returns, shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``; - - the corresponding advantages, shape: ``[B, T]``. + - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``. References ---------- @@ -94,4 +97,5 @@ def lambda_return( max_is_weight_trace=1.0, max_is_weight_delta=1.0, max_is_weight_advantage=1.0, + return_advantage=return_advantage, ) diff --git a/actorch/algorithms/value_estimation/monte_carlo_return.py b/actorch/algorithms/value_estimation/monte_carlo_return.py index 516d9e5..68d4c33 100644 --- a/actorch/algorithms/value_estimation/monte_carlo_return.py +++ b/actorch/algorithms/value_estimation/monte_carlo_return.py @@ -33,7 +33,8 @@ def monte_carlo_return( rewards: "Tensor", mask: "Optional[Tensor]" = None, discount: "float" = 0.99, -) -> "Tuple[Tensor, Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Tensor, Optional[Tensor]]": """Compute the Monte Carlo returns and the corresponding advantages of a trajectory. @@ -51,11 +52,14 @@ def monte_carlo_return( Default to ``torch.ones_like(rewards, dtype=torch.bool)``. discount: The discount factor (`gamma` in the literature). + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The Monte Carlo returns, shape: ``[B, T]``; - - the corresponding advantages, shape: ``[B, T]``. + - the corresponding advantages if `return_advantage` is True + None otherwise, shape: ``[B, T]``. References ---------- @@ -72,4 +76,5 @@ def monte_carlo_return( mask=mask, discount=discount, num_return_steps=rewards.shape[1], + return_advantage=return_advantage, ) diff --git a/actorch/algorithms/value_estimation/n_step_return.py b/actorch/algorithms/value_estimation/n_step_return.py index 1846b43..71b472e 100644 --- a/actorch/algorithms/value_estimation/n_step_return.py +++ b/actorch/algorithms/value_estimation/n_step_return.py @@ -37,7 +37,8 @@ def n_step_return( mask: "Optional[Tensor]" = None, discount: "float" = 0.99, num_return_steps: "int" = 1, -) -> "Tuple[Union[Tensor, Distribution], Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]": """Compute the (possibly distributional) n-step returns, a.k.a. TD(n), and the corresponding advantages of a trajectory. @@ -63,12 +64,14 @@ def n_step_return( The discount factor (`gamma` in the literature). num_return_steps: The number of return steps (`n` in the literature). + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The (possibly distributional) n-step returns, shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``; - - the corresponding advantages, shape: ``[B, T]``. + - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``. References ---------- @@ -90,4 +93,5 @@ def n_step_return( max_is_weight_trace=1.0, max_is_weight_delta=1.0, max_is_weight_advantage=1.0, + return_advantage=return_advantage, ) diff --git a/actorch/algorithms/value_estimation/off_policy_lambda_return.py b/actorch/algorithms/value_estimation/off_policy_lambda_return.py index 7d7cf0b..c603e94 100644 --- a/actorch/algorithms/value_estimation/off_policy_lambda_return.py +++ b/actorch/algorithms/value_estimation/off_policy_lambda_return.py @@ -40,7 +40,8 @@ def off_policy_lambda_return( mask: "Optional[Tensor]" = None, discount: "float" = 0.99, trace_decay: "float" = 1.0, -) -> "Tuple[Union[Tensor, Distribution], Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]": """Compute the (possibly distributional) off-policy lambda returns, a.k.a. Harutyunyan's et al. Q(lambda), and the corresponding advantages of a trajectory. @@ -70,12 +71,14 @@ def off_policy_lambda_return( The discount factor (`gamma` in the literature). trace_decay: The trace-decay parameter (`lambda` in the literature). + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The (possibly distributional) off-policy lambda returns, shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``; - - the corresponding advantages, shape: ``[B, T]``. + - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``. References ---------- @@ -100,4 +103,5 @@ def off_policy_lambda_return( discount=discount, num_return_steps=rewards.shape[1], trace_decay=trace_decay, + return_advantage=return_advantage, ) diff --git a/actorch/algorithms/value_estimation/retrace.py b/actorch/algorithms/value_estimation/retrace.py index b910f45..51a2605 100644 --- a/actorch/algorithms/value_estimation/retrace.py +++ b/actorch/algorithms/value_estimation/retrace.py @@ -43,7 +43,8 @@ def retrace( trace_decay: "float" = 1.0, max_is_weight_trace: "float" = 1.0, max_is_weight_advantage: "float" = 1.0, -) -> "Tuple[Union[Tensor, Distribution], Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]": """Compute the (possibly distributional) Retrace targets, a.k.a. Retrace(lambda), and the corresponding advantages of a trajectory. @@ -82,12 +83,14 @@ def retrace( The maximum importance sampling weight for trace computation (`c_bar` in the literature). max_is_weight_advantage: The maximum importance sampling weight for advantage computation. + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The (possibly distributional) Retrace targets, shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``; - - the corresponding advantages, shape: ``[B, T]``. + - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``. Raises ------ @@ -126,4 +129,5 @@ def retrace( discount=discount, num_return_steps=rewards.shape[1], trace_decay=trace_decay, + return_advantage=return_advantage, ) diff --git a/actorch/algorithms/value_estimation/tree_backup.py b/actorch/algorithms/value_estimation/tree_backup.py index 4ea1fa1..d8a3bda 100644 --- a/actorch/algorithms/value_estimation/tree_backup.py +++ b/actorch/algorithms/value_estimation/tree_backup.py @@ -41,7 +41,8 @@ def tree_backup( mask: "Optional[Tensor]" = None, discount: "float" = 0.99, trace_decay: "float" = 1.0, -) -> "Tuple[Union[Tensor, Distribution], Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]": """Compute the (possibly distributional) tree-backup targets, a.k.a. TB(lambda), and the corresponding advantages of a trajectory. @@ -74,12 +75,14 @@ def tree_backup( The discount factor (`gamma` in the literature). trace_decay: The trace-decay parameter (`lambda` in the literature). + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The (possibly distributional) tree-backup targets, shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``; - - the corresponding advantages, shape: ``[B, T]``. + - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``. References ---------- @@ -105,4 +108,5 @@ def tree_backup( discount=discount, num_return_steps=rewards.shape[1], trace_decay=trace_decay, + return_advantage=return_advantage, ) diff --git a/actorch/algorithms/value_estimation/vtrace.py b/actorch/algorithms/value_estimation/vtrace.py index f6471bb..7c54aef 100644 --- a/actorch/algorithms/value_estimation/vtrace.py +++ b/actorch/algorithms/value_estimation/vtrace.py @@ -44,7 +44,8 @@ def vtrace( max_is_weight_trace: "float" = 1.0, max_is_weight_delta: "float" = 1.0, max_is_weight_advantage: "float" = 1.0, -) -> "Tuple[Union[Tensor, Distribution], Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]": """Compute the (possibly distributional) (leaky) V-trace targets, a.k.a. V-trace(n), and the corresponding advantages of a trajectory. @@ -86,12 +87,14 @@ def vtrace( The maximum importance sampling weight for delta computation (`rho_bar` in the literature). max_is_weight_advantage: The maximum importance sampling weight for advantage computation. + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The (possibly distributional) (leaky) V-trace targets shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``; - - the corresponding advantages, shape: ``[B, T]``. + - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``. Raises ------ @@ -148,4 +151,5 @@ def vtrace( discount=discount, num_return_steps=num_return_steps, trace_decay=trace_decay, + return_advantage=return_advantage, ) diff --git a/actorch/buffers/buffer.py b/actorch/buffers/buffer.py index 69ff10d..c1c9157 100644 --- a/actorch/buffers/buffer.py +++ b/actorch/buffers/buffer.py @@ -35,6 +35,9 @@ class Buffer(ABC, CheckpointableMixin): """Replay buffer that stores and samples batched experience trajectories.""" + is_prioritized = False + """Whether a priority is assigned to each trajectory.""" + _STATE_VARS = ["capacity", "spec", "_sampled_idx"] # override def __init__( diff --git a/actorch/buffers/proportional_buffer.py b/actorch/buffers/proportional_buffer.py index b8286db..b63e7fb 100644 --- a/actorch/buffers/proportional_buffer.py +++ b/actorch/buffers/proportional_buffer.py @@ -44,6 +44,8 @@ class ProportionalBuffer(UniformBuffer): """ + is_prioritized = True # override + _STATE_VARS = UniformBuffer._STATE_VARS + [ "prioritization", "bias_correction", diff --git a/actorch/buffers/rank_based_buffer.py b/actorch/buffers/rank_based_buffer.py index e54276a..58b994f 100644 --- a/actorch/buffers/rank_based_buffer.py +++ b/actorch/buffers/rank_based_buffer.py @@ -43,6 +43,8 @@ class RankBasedBuffer(ProportionalBuffer): """ + is_prioritized = True # override + _STATE_VARS = ProportionalBuffer._STATE_VARS # override _STATE_VARS.remove("epsilon") diff --git a/actorch/buffers/uniform_buffer.py b/actorch/buffers/uniform_buffer.py index 0f40362..6ed77de 100644 --- a/actorch/buffers/uniform_buffer.py +++ b/actorch/buffers/uniform_buffer.py @@ -55,6 +55,7 @@ def num_experiences(self) -> "int": # override @property def num_full_trajectories(self) -> "int": + # print(len(self._full_trajectory_start_idx)) return len(self._full_trajectory_start_idx) if self._num_experiences > 0 else 0 # override diff --git a/actorch/distributed/distributed_trainable.py b/actorch/distributed/distributed_trainable.py index ef807f4..0fd013f 100644 --- a/actorch/distributed/distributed_trainable.py +++ b/actorch/distributed/distributed_trainable.py @@ -45,7 +45,7 @@ class DistributedTrainable(ABC, TuneDistributedTrainable): """Distributed Ray Tune trainable with configurable resource requirements. Derived classes must implement `step`, `save_checkpoint` and `load_checkpoint` - (see https://docs.ray.io/en/latest/tune/api_docs/trainable.html#tune-trainable-class-api). + (see https://docs.ray.io/en/releases-1.13.0/tune/api_docs/trainable.html#tune-trainable-class-api for Ray 1.13.0). """ @@ -67,7 +67,7 @@ def __init__( Default to ``[{"CPU": 1}]``. placement_strategy: The placement strategy - (see https://docs.ray.io/en/latest/ray-core/placement-group.html). + (see https://docs.ray.io/en/releases-1.13.0/ray-core/placement-group.html for Ray 1.13.0). """ super().__init__( @@ -98,7 +98,7 @@ def resource_help(cls, config: "Dict[str, Any]") -> "str": " Default to ``[{}]``." + "\n" " placement_strategy:" + "\n" " The placement strategy." + "\n" - " (see https://docs.ray.io/en/latest/ray-core/placement-group.html)." + f" (see https://docs.ray.io/en/releases-{ray.__version__}/ray-core/placement-group.html)." ) # override diff --git a/actorch/distributed/sync_distributed_trainable.py b/actorch/distributed/sync_distributed_trainable.py index 3b4e805..fa3f231 100644 --- a/actorch/distributed/sync_distributed_trainable.py +++ b/actorch/distributed/sync_distributed_trainable.py @@ -100,7 +100,7 @@ def __init__( Default to ``{}``. placement_strategy: The placement strategy - (see https://docs.ray.io/en/latest/ray-core/placement-group.html). + (see https://docs.ray.io/en/releases-1.13.0/ray-core/placement-group.html for Ray 1.13.0). reduction_mode: The reduction mode for worker results. Must be one of the following: @@ -225,7 +225,7 @@ def resource_help(cls, config: "Dict[str, Any]") -> "str": " Default to ``{}``." + "\n" " placement_strategy:" + "\n" " The placement strategy" + "\n" - " (see https://docs.ray.io/en/latest/ray-core/placement-group.html)." + f" (see https://docs.ray.io/en/releases-{ray.__version__}/ray-core/placement-group.html)." ) # override diff --git a/actorch/distributions/finite.py b/actorch/distributions/finite.py index e48d8e4..16dfca7 100644 --- a/actorch/distributions/finite.py +++ b/actorch/distributions/finite.py @@ -48,6 +48,8 @@ class Finite(Distribution): Examples -------- + >>> import torch + >>> >>> from actorch.distributions import Finite >>> >>> diff --git a/actorch/optimizers/cgbls.py b/actorch/optimizers/cgbls.py index c8d1abe..908c742 100644 --- a/actorch/optimizers/cgbls.py +++ b/actorch/optimizers/cgbls.py @@ -300,7 +300,7 @@ def _conjugate_gradient( References ---------- - .. [1] M.R. Hestenes, and E. Stiefel. + .. [1] M.R. Hestenes and E. Stiefel. "Methods of Conjugate Gradients for Solving Linear Systems". In: Journal of Research of the National Bureau of Standards. 1952, pp. 409-435. URL: http://dx.doi.org/10.6028/jres.049.044 @@ -389,7 +389,7 @@ def _backtracking_line_search( loss, constraint = 0.0, 0.0 for ratio in ratios: for i, step in enumerate(descent_steps): - params[i] -= ratio * step + params[i].copy_(prev_params[i] - ratio * step) loss = loss_fn() constraint = constraint_fn() @@ -411,6 +411,6 @@ def _backtracking_line_search( warning_msg = " and".join(warning_msg.rsplit(",", 1)) _LOGGER.warning(warning_msg) for i, prev_param in enumerate(prev_params): - params[i] = prev_param + params[i].copy_(prev_param) return loss, constraint diff --git a/actorch/preconditioners/kfac.py b/actorch/preconditioners/kfac.py index ee90ce9..98734be 100644 --- a/actorch/preconditioners/kfac.py +++ b/actorch/preconditioners/kfac.py @@ -93,8 +93,8 @@ def update_AG(self, decay: "float") -> "None": ) self._dA, self._dG = A.new_zeros(A.shape[0]), G.new_zeros(G.shape[0]) self._QA, self._QG = A.new_zeros(A.shape), G.new_zeros(G.shape) - self._update_exp_moving_average(self._A, A, decay) - self._update_exp_moving_average(self._G, G, decay) + self._update_exp_moving_average_(self._A, A, decay) + self._update_exp_moving_average_(self._G, G, decay) def update_eigen_AG(self, epsilon: "float") -> "None": """Update eigenvalues and eigenvectors of A and G. @@ -128,14 +128,15 @@ def get_preconditioned_grad(self, damping: "float") -> "Tensor": v = self._QG @ v2 @ self._QA.t() return v - def _update_exp_moving_average( + def _update_exp_moving_average_( self, current: "Tensor", new: "Tensor", weight: "float" - ) -> "None": + ) -> "Tensor": if weight == 1.0: - return + return current current *= weight / (1 - weight) current += new current *= 1 - weight + return current @property @abstractmethod diff --git a/actorch/version.py b/actorch/version.py index 76f0825..cf4b122 100644 --- a/actorch/version.py +++ b/actorch/version.py @@ -28,7 +28,7 @@ "0" # Minor version to increment in case of backward compatible new functionality ) -_PATCH = "4" # Patch version to increment in case of backward compatible bug fixes +_PATCH = "5" # Patch version to increment in case of backward compatible bug fixes VERSION = f"{_MAJOR}.{_MINOR}.{_PATCH}" """The package version.""" diff --git a/docs/_static/images/actorch-overview.png b/docs/_static/images/actorch-overview.png new file mode 100644 index 0000000..7c65783 Binary files /dev/null and b/docs/_static/images/actorch-overview.png differ diff --git a/docs/index.rst b/docs/index.rst index 5281006..e9746dd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -5,3 +5,6 @@ Welcome to ACTorch! `actorch` is a deep reinforcement learning framework for fast prototyping based on `PyTorch `_. Visit the GitHub repository: https://github.com/lucadellalib/actorch + +.. image:: docs/_static/images/actorch-overview.png + :width: 600 \ No newline at end of file diff --git a/examples/A2C-AsyncHyperBand_LunarLander-v2.py b/examples/A2C-AsyncHyperBand_LunarLander-v2.py index add1b62..c21725a 100644 --- a/examples/A2C-AsyncHyperBand_LunarLander-v2.py +++ b/examples/A2C-AsyncHyperBand_LunarLander-v2.py @@ -14,12 +14,13 @@ # limitations under the License. # ============================================================================== -"""Train Advantage Actor-Critic on LunarLander-v2. Tune the value network learning +"""Train Advantage Actor-Critic (A2C) on LunarLander-v2. Tune the value network learning rate using the asynchronous version of HyperBand (https://arxiv.org/abs/1810.05934). """ # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] # actorch run A2C-AsyncHyperBand_LunarLander-v2.py import gymnasium as gym diff --git a/examples/A2C_LunarLander-v2.py b/examples/A2C_LunarLander-v2.py index d92bfdb..85fb2e6 100644 --- a/examples/A2C_LunarLander-v2.py +++ b/examples/A2C_LunarLander-v2.py @@ -14,9 +14,10 @@ # limitations under the License. # ============================================================================== -"""Train Advantage Actor-Critic on LunarLander-v2.""" +"""Train Advantage Actor-Critic (A2C) on LunarLander-v2.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] # actorch run A2C_LunarLander-v2.py import gymnasium as gym diff --git a/examples/ACKTR_LunarLander-v2.py b/examples/ACKTR_LunarLander-v2.py index cdfa571..de0d291 100644 --- a/examples/ACKTR_LunarLander-v2.py +++ b/examples/ACKTR_LunarLander-v2.py @@ -14,9 +14,10 @@ # limitations under the License. # ============================================================================== -"""Train Actor-Critic Kronecker-Factored Trust Region on LunarLander-v2.""" +"""Train Actor-Critic Kronecker-Factored (ACKTR) Trust Region on LunarLander-v2.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] # actorch run ACKTR_LunarLander-v2.py import gymnasium as gym diff --git a/examples/AWR-NormalizingFlow_Pendulum-v1.py b/examples/AWR-AffineFlow_Pendulum-v1.py similarity index 95% rename from examples/AWR-NormalizingFlow_Pendulum-v1.py rename to examples/AWR-AffineFlow_Pendulum-v1.py index b63c0df..848db20 100644 --- a/examples/AWR-NormalizingFlow_Pendulum-v1.py +++ b/examples/AWR-AffineFlow_Pendulum-v1.py @@ -14,10 +14,11 @@ # limitations under the License. # ============================================================================== -"""Train Advantage-Weighted Regression equipped with a normalizing flow policy on Pendulum-v1.""" +"""Train Advantage-Weighted Regression (AWR) equipped with an affine flow policy on Pendulum-v1.""" # Navigate to `/examples`, open a terminal and run: -# actorch run AWR-NormalizingFlow_Pendulum-v1.py +# pip install gymnasium[classic_control] +# actorch run AWR-AffineFlow_Pendulum-v1.py import gymnasium as gym import torch diff --git a/examples/AWR_Pendulum-v1.py b/examples/AWR_Pendulum-v1.py index 0c42e92..3c89595 100644 --- a/examples/AWR_Pendulum-v1.py +++ b/examples/AWR_Pendulum-v1.py @@ -14,9 +14,10 @@ # limitations under the License. # ============================================================================== -"""Train Advantage-Weighted Regression on Pendulum-v1.""" +"""Train Advantage-Weighted Regression (AWR) on Pendulum-v1.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[classic_control] # actorch run AWR_Pendulum-v1.py import gymnasium as gym diff --git a/examples/D3PG-Finite_LunarLanderContinuous-v2.py b/examples/D3PG-Finite_LunarLanderContinuous-v2.py new file mode 100644 index 0000000..b2cd13e --- /dev/null +++ b/examples/D3PG-Finite_LunarLanderContinuous-v2.py @@ -0,0 +1,129 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Train Distributional Deep Deterministic Policy Gradient (D3PG) on LunarLanderContinuous-v2 +assuming a finite value distribution with 51 atoms (see https://arxiv.org/abs/1804.08617). + +""" + +# Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] +# actorch run D3PG-Finite_LunarLanderContinuous-v2.py + +import gymnasium as gym +import torch +from torch import nn +from torch.optim import Adam + +from actorch import * + + +# Define custom model +class LayerNormFCNet(FCNet): + # override + def _setup_torso(self, in_shape): + super()._setup_torso(in_shape) + torso = nn.Sequential() + for module in self.torso: + torso.append(module) + if isinstance(module, nn.Linear): + torso.append( + nn.LayerNorm(module.out_features, elementwise_affine=False) + ) + self.torso = torso + + +experiment_params = ExperimentParams( + run_or_experiment=D3PG, + stop={"timesteps_total": int(4e5)}, + resources_per_trial={"cpu": 1, "gpu": 0}, + checkpoint_freq=10, + checkpoint_at_end=True, + log_to_file=True, + export_formats=["checkpoint", "model"], + config=D3PG.Config( + train_env_builder=lambda **config: ParallelBatchedEnv( + lambda **kwargs: gym.make("LunarLanderContinuous-v2", **kwargs), + config, + num_workers=1, + ), + train_agent_builder=OUNoiseAgent, + train_agent_config={ + "clip_action": True, + "device": "cpu", + "num_random_timesteps": 1000, + "mean": 0.0, + "volatility": 0.1, + "reversion_speed": 0.15, + }, + train_num_timesteps_per_iter=1024, + eval_freq=10, + eval_env_config={"render_mode": None}, + eval_num_episodes_per_iter=10, + policy_network_model_builder=LayerNormFCNet, + policy_network_model_config={ + "torso_fc_configs": [ + {"out_features": 400, "bias": True}, + {"out_features": 300, "bias": True}, + ], + "head_activation_builder": nn.Tanh, + }, + policy_network_optimizer_builder=Adam, + policy_network_optimizer_config={"lr": 1e-3}, + value_network_model_builder=LayerNormFCNet, + value_network_model_config={ + "torso_fc_configs": [ + {"out_features": 400, "bias": True}, + {"out_features": 300, "bias": True}, + ], + }, + value_network_distribution_builder=Finite, + value_network_distribution_parametrization={ + "logits": ( + {"logits": (51,)}, + lambda x: x["logits"], + ), + }, + value_network_distribution_config={ + "atoms": torch.linspace(-10.0, 10.0, 51), + }, + value_network_optimizer_builder=Adam, + value_network_optimizer_config={"lr": 1e-3}, + buffer_config={ + "capacity": int(1e5), + "prioritization": 1.0, + "bias_correction": 0.4, + "epsilon": 1e-5, + }, + discount=0.995, + num_return_steps=5, + num_updates_per_iter=LambdaSchedule( + lambda iter: ( + 200 if iter >= 10 else 0 + ), # Fill buffer with some trajectories before training + ), + batch_size=128, + max_trajectory_length=10, + sync_freq=1, + polyak_weight=0.001, + max_grad_l2_norm=5.0, + seed=0, + enable_amp=False, + enable_reproducibility=True, + log_sys_usage=True, + suppress_warnings=True, + ), +) diff --git a/examples/D3PG-Normal_LunarLanderContinuous-v2.py b/examples/D3PG-Normal_LunarLanderContinuous-v2.py new file mode 100644 index 0000000..d6c7227 --- /dev/null +++ b/examples/D3PG-Normal_LunarLanderContinuous-v2.py @@ -0,0 +1,130 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Train Distributional Deep Deterministic Policy Gradient (D3PG) on LunarLanderContinuous-v2 +assuming a normal value distribution. + +""" + +# Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] +# actorch run D3PG-Normal_LunarLanderContinuous-v2.py + +import gymnasium as gym +from torch import nn +from torch.distributions import Normal +from torch.optim import Adam + +from actorch import * + + +# Define custom model +class LayerNormFCNet(FCNet): + # override + def _setup_torso(self, in_shape): + super()._setup_torso(in_shape) + torso = nn.Sequential() + for module in self.torso: + torso.append(module) + if isinstance(module, nn.Linear): + torso.append( + nn.LayerNorm(module.out_features, elementwise_affine=False) + ) + self.torso = torso + + +experiment_params = ExperimentParams( + run_or_experiment=D3PG, + stop={"timesteps_total": int(4e5)}, + resources_per_trial={"cpu": 1, "gpu": 0}, + checkpoint_freq=10, + checkpoint_at_end=True, + log_to_file=True, + export_formats=["checkpoint", "model"], + config=D3PG.Config( + train_env_builder=lambda **config: ParallelBatchedEnv( + lambda **kwargs: gym.make("LunarLanderContinuous-v2", **kwargs), + config, + num_workers=1, + ), + train_agent_builder=OUNoiseAgent, + train_agent_config={ + "clip_action": True, + "device": "cpu", + "num_random_timesteps": 1000, + "mean": 0.0, + "volatility": 0.1, + "reversion_speed": 0.15, + }, + train_num_timesteps_per_iter=1024, + eval_freq=10, + eval_env_config={"render_mode": None}, + eval_num_episodes_per_iter=10, + policy_network_model_builder=LayerNormFCNet, + policy_network_model_config={ + "torso_fc_configs": [ + {"out_features": 400, "bias": True}, + {"out_features": 300, "bias": True}, + ], + "head_activation_builder": nn.Tanh, + }, + policy_network_optimizer_builder=Adam, + policy_network_optimizer_config={"lr": 1e-3}, + value_network_model_builder=LayerNormFCNet, + value_network_model_config={ + "torso_fc_configs": [ + {"out_features": 400, "bias": True}, + {"out_features": 300, "bias": True}, + ], + }, + value_network_distribution_builder=Normal, + value_network_distribution_parametrization={ + "loc": ( + {"loc": ()}, + lambda x: x["loc"], + ), + "scale": ( + {"log_scale": ()}, + lambda x: x["log_scale"].clamp(-20.0, 2.0).exp(), + ), + }, + value_network_optimizer_builder=Adam, + value_network_optimizer_config={"lr": 1e-3}, + buffer_config={ + "capacity": int(1e5), + "prioritization": 1.0, + "bias_correction": 0.4, + "epsilon": 1e-5, + }, + discount=0.995, + num_return_steps=5, + num_updates_per_iter=LambdaSchedule( + lambda iter: ( + 200 if iter >= 10 else 0 + ), # Fill buffer with some trajectories before training + ), + batch_size=128, + max_trajectory_length=10, + sync_freq=1, + polyak_weight=0.001, + max_grad_l2_norm=5.0, + seed=0, + enable_amp=False, + enable_reproducibility=True, + log_sys_usage=True, + suppress_warnings=True, + ), +) diff --git a/examples/DDPG_LunarLanderContinuous-v2.py b/examples/DDPG_LunarLanderContinuous-v2.py index de4912a..c44c4fd 100644 --- a/examples/DDPG_LunarLanderContinuous-v2.py +++ b/examples/DDPG_LunarLanderContinuous-v2.py @@ -14,9 +14,10 @@ # limitations under the License. # ============================================================================== -"""Train Deep Deterministic Policy Gradient on LunarLanderContinuous-v2.""" +"""Train Deep Deterministic Policy Gradient (DDPG) on LunarLanderContinuous-v2.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] # actorch run DDPG_LunarLanderContinuous-v2.py import gymnasium as gym diff --git a/examples/DistributedDataParallelDDPG_LunarLanderContinuous-v2.py b/examples/DistributedDataParallelDDPG_LunarLanderContinuous-v2.py index f9319a5..f3a3df7 100644 --- a/examples/DistributedDataParallelDDPG_LunarLanderContinuous-v2.py +++ b/examples/DistributedDataParallelDDPG_LunarLanderContinuous-v2.py @@ -14,9 +14,10 @@ # limitations under the License. # ============================================================================== -"""Train distributed data parallel Deep Deterministic Policy Gradient on LunarLanderContinuous-v2.""" +"""Train distributed data parallel Deep Deterministic Policy Gradient (DDPG) on LunarLanderContinuous-v2.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] # actorch run DistributedDataParallelDDPG_LunarLanderContinuous-v2.py import gymnasium as gym diff --git a/examples/DistributedDataParallelREINFORCE_CartPole-v1.py b/examples/DistributedDataParallelREINFORCE_CartPole-v1.py index ff470ab..b07ad36 100644 --- a/examples/DistributedDataParallelREINFORCE_CartPole-v1.py +++ b/examples/DistributedDataParallelREINFORCE_CartPole-v1.py @@ -17,6 +17,7 @@ """Train distributed data parallel REINFORCE on CartPole-v1.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[classic_control] # actorch run DistributedDataParallelREINFORCE_CartPole-v1.py import gymnasium as gym diff --git a/examples/PPO-Laplace_Pendulum-v1.py b/examples/PPO-Laplace_Pendulum-v1.py new file mode 100644 index 0000000..32ad43e --- /dev/null +++ b/examples/PPO-Laplace_Pendulum-v1.py @@ -0,0 +1,94 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Train Proximal Policy Optimization (PPO) on Pendulum-v1 assuming a Laplace policy distribution.""" + +# Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[classic_control] +# actorch run PPO-Laplace_Pendulum-v1.py + +import gymnasium as gym +from torch.distributions import Laplace +from torch.optim import Adam + +from actorch import * + + +experiment_params = ExperimentParams( + run_or_experiment=PPO, + stop={"timesteps_total": int(1e5)}, + resources_per_trial={"cpu": 1, "gpu": 0}, + checkpoint_freq=10, + checkpoint_at_end=True, + log_to_file=True, + export_formats=["checkpoint", "model"], + config=PPO.Config( + train_env_builder=lambda **config: ParallelBatchedEnv( + lambda **kwargs: gym.make("Pendulum-v1", **kwargs), + config, + num_workers=2, + ), + train_num_timesteps_per_iter=2048, + eval_freq=10, + eval_env_config={"render_mode": None}, + eval_num_episodes_per_iter=10, + policy_network_model_builder=FCNet, + policy_network_model_config={ + "torso_fc_configs": [ + {"out_features": 256, "bias": True}, + {"out_features": 256, "bias": True}, + ], + "independent_heads": ["action/log_scale"], + }, + policy_network_distribution_builders={"action": Laplace}, + policy_network_distribution_parametrizations={ + "action": { + "loc": ( + {"loc": (1,)}, + lambda x: x["loc"], + ), + "scale": ( + {"log_scale": (1,)}, + lambda x: x["log_scale"].exp(), + ), + }, + }, + policy_network_optimizer_builder=Adam, + policy_network_optimizer_config={"lr": 5e-5}, + value_network_model_builder=FCNet, + value_network_model_config={ + "torso_fc_configs": [ + {"out_features": 256, "bias": True}, + {"out_features": 256, "bias": True}, + ], + }, + value_network_optimizer_builder=Adam, + value_network_optimizer_config={"lr": 3e-3}, + discount=0.99, + trace_decay=0.95, + num_epochs=20, + minibatch_size=16, + ratio_clip=0.2, + normalize_advantage=True, + entropy_coeff=0.01, + max_grad_l2_norm=0.5, + seed=0, + enable_amp=False, + enable_reproducibility=True, + log_sys_usage=True, + suppress_warnings=True, + ), +) diff --git a/examples/PPO_Pendulum-v1.py b/examples/PPO_Pendulum-v1.py index 0089d72..6cb22ea 100644 --- a/examples/PPO_Pendulum-v1.py +++ b/examples/PPO_Pendulum-v1.py @@ -14,9 +14,10 @@ # limitations under the License. # ============================================================================== -"""Train Proximal Policy Optimization on Pendulum-v1.""" +"""Train Proximal Policy Optimization (PPO) on Pendulum-v1.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[classic_control] # actorch run PPO_Pendulum-v1.py import gymnasium as gym diff --git a/examples/REINFORCE_CartPole-v1.py b/examples/REINFORCE_CartPole-v1.py index 2ab5c13..342ed36 100644 --- a/examples/REINFORCE_CartPole-v1.py +++ b/examples/REINFORCE_CartPole-v1.py @@ -17,6 +17,7 @@ """Train REINFORCE on CartPole-v1.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[classic_control] # actorch run REINFORCE_CartPole-v1.py import gymnasium as gym diff --git a/examples/SAC_BipedalWalker-v3.py b/examples/SAC_BipedalWalker-v3.py new file mode 100644 index 0000000..55c768e --- /dev/null +++ b/examples/SAC_BipedalWalker-v3.py @@ -0,0 +1,129 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Train Soft Actor-Critic (SAC) on BipedalWalker-v3.""" + +# Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] +# actorch run SAC_BipedalWalker-v3.py + +import gymnasium as gym +import numpy as np +import torch.nn +from torch.distributions import Normal, TanhTransform +from torch.optim import Adam + +from actorch import * + + +class Wrapper(gym.Wrapper): + def __init__(self, env, action_noise=0.3, action_repeat=3, reward_scale=5): + super().__init__(env) + self.action_noise = action_noise + self.action_repeat = action_repeat + self.reward_scale = reward_scale + + def step(self, action): + action += self.action_noise * ( + 1 - 2 * np.random.random(self.action_space.shape) + ) + cumreward = 0.0 + for _ in range(self.action_repeat): + observation, reward, terminated, truncated, info = super().step(action) + cumreward += reward + if terminated: + return observation, 0.0, terminated, truncated, info + return observation, self.reward_scale * cumreward, terminated, truncated, info + + +experiment_params = ExperimentParams( + run_or_experiment=SAC, + stop={"timesteps_total": int(5e5)}, + resources_per_trial={"cpu": 1, "gpu": 0}, + checkpoint_freq=10, + checkpoint_at_end=True, + log_to_file=True, + export_formats=["checkpoint", "model"], + config=SAC.Config( + train_env_builder=lambda **config: ParallelBatchedEnv( + lambda **kwargs: Wrapper( + gym.wrappers.TimeLimit( + gym.make("BipedalWalker-v3", **kwargs), + max_episode_steps=200, + ), + ), + config, + num_workers=1, + ), + train_num_episodes_per_iter=1, + eval_freq=10, + eval_env_config={"render_mode": None}, + eval_num_episodes_per_iter=10, + policy_network_model_builder=FCNet, + policy_network_model_config={ + "torso_fc_configs": [ + {"out_features": 400, "bias": True}, + {"out_features": 300, "bias": True}, + ], + "head_activation_builder": torch.nn.Tanh, + }, + policy_network_distribution_builders={ + "action": lambda loc, scale: TransformedDistribution( + Normal(loc, scale), + TanhTransform(cache_size=1), # Use tanh normal to enforce action bounds + ), + }, + policy_network_distribution_parametrizations={ + "action": { + "loc": ( + {"loc": (4,)}, + lambda x: x["loc"], + ), + "scale": ( + {"log_scale": (4,)}, + lambda x: x["log_scale"].clamp(-20.0, 2.0).exp(), + ), + }, + }, + policy_network_sample_fn=lambda d: d.sample(), # `mode` does not exist in closed-form for tanh normal + policy_network_optimizer_builder=Adam, + policy_network_optimizer_config={"lr": 1e-3}, + value_network_model_builder=FCNet, + value_network_model_config={ + "torso_fc_configs": [ + {"out_features": 400, "bias": True}, + {"out_features": 300, "bias": True}, + ], + }, + value_network_optimizer_builder=Adam, + value_network_optimizer_config={"lr": 1e-3}, + buffer_config={"capacity": int(3e5)}, + discount=0.98, + num_return_steps=5, + num_updates_per_iter=10, + batch_size=1024, + max_trajectory_length=20, + sync_freq=1, + polyak_weight=0.001, + temperature=0.2, + max_grad_l2_norm=1.0, + seed=0, + enable_amp=False, + enable_reproducibility=True, + log_sys_usage=True, + suppress_warnings=True, + ), +) diff --git a/examples/SAC_Pendulum-v1.py b/examples/SAC_Pendulum-v1.py new file mode 100644 index 0000000..b3efe24 --- /dev/null +++ b/examples/SAC_Pendulum-v1.py @@ -0,0 +1,82 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Train Soft Actor-Critic (SAC) on Pendulum-v1.""" + +# Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[classic_control] +# actorch run SAC_Pendulum-v1.py + +import gymnasium as gym +from torch.optim import Adam + +from actorch import * + + +experiment_params = ExperimentParams( + run_or_experiment=SAC, + stop={"timesteps_total": int(1e5)}, + resources_per_trial={"cpu": 1, "gpu": 0}, + checkpoint_freq=10, + checkpoint_at_end=True, + log_to_file=True, + export_formats=["checkpoint", "model"], + config=SAC.Config( + train_env_builder=lambda **config: ParallelBatchedEnv( + lambda **kwargs: gym.make("Pendulum-v1", **kwargs), + config, + num_workers=2, + ), + train_num_timesteps_per_iter=500, + eval_freq=10, + eval_env_config={"render_mode": None}, + eval_num_episodes_per_iter=10, + policy_network_model_builder=FCNet, + policy_network_model_config={ + "torso_fc_configs": [ + {"out_features": 256, "bias": True}, + {"out_features": 256, "bias": True}, + ], + }, + policy_network_optimizer_builder=Adam, + policy_network_optimizer_config={"lr": 5e-3}, + value_network_model_builder=FCNet, + value_network_model_config={ + "torso_fc_configs": [ + {"out_features": 256, "bias": True}, + {"out_features": 256, "bias": True}, + ], + }, + value_network_optimizer_builder=Adam, + value_network_optimizer_config={"lr": 5e-3}, + temperature_optimizer_builder=Adam, + temperature_optimizer_config={"lr": 1e-5}, + buffer_config={"capacity": int(1e5)}, + discount=0.99, + num_return_steps=1, + num_updates_per_iter=500, + batch_size=32, + max_trajectory_length=10, + sync_freq=1, + polyak_weight=0.001, + max_grad_l2_norm=2.0, + seed=0, + enable_amp=False, + enable_reproducibility=True, + log_sys_usage=True, + suppress_warnings=True, + ), +) diff --git a/examples/TD3_LunarLanderContinuous-v2.py b/examples/TD3_LunarLanderContinuous-v2.py index 258e7c1..dee26b4 100644 --- a/examples/TD3_LunarLanderContinuous-v2.py +++ b/examples/TD3_LunarLanderContinuous-v2.py @@ -14,9 +14,10 @@ # limitations under the License. # ============================================================================== -"""Train Twin Delayed Deep Deterministic Policy Gradient on LunarLanderContinuous-v2.""" +"""Train Twin Delayed Deep Deterministic Policy Gradient (TD3) on LunarLanderContinuous-v2.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] # actorch run TD3_LunarLanderContinuous-v2.py import gymnasium as gym diff --git a/examples/TRPO_Pendulum-v1.py b/examples/TRPO_Pendulum-v1.py new file mode 100644 index 0000000..df1fbd0 --- /dev/null +++ b/examples/TRPO_Pendulum-v1.py @@ -0,0 +1,84 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Train Trust Region Policy Optimization (TRPO) on Pendulum-v1.""" + +# Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[classic_control] +# actorch run TRPO_Pendulum-v1.py + +import gymnasium as gym +from torch.optim import Adam + +from actorch import * + + +experiment_params = ExperimentParams( + run_or_experiment=TRPO, + stop={"timesteps_total": int(2e5)}, + resources_per_trial={"cpu": 1, "gpu": 0}, + checkpoint_freq=10, + checkpoint_at_end=True, + log_to_file=True, + export_formats=["checkpoint", "model"], + config=TRPO.Config( + train_env_builder=lambda **config: ParallelBatchedEnv( + lambda **kwargs: gym.make("Pendulum-v1", **kwargs), + config, + num_workers=2, + ), + train_num_timesteps_per_iter=512, + eval_freq=10, + eval_env_config={"render_mode": None}, + eval_num_episodes_per_iter=10, + policy_network_model_builder=FCNet, + policy_network_model_config={ + "torso_fc_configs": [ + {"out_features": 256, "bias": True}, + {"out_features": 256, "bias": True}, + ], + "independent_heads": ["action/log_scale"], + }, + policy_network_optimizer_config={ + "max_constraint": 0.01, + "num_cg_iters": 10, + "max_backtracks": 15, + "backtrack_ratio": 0.8, + "hvp_reg_coeff": 1e-5, + "accept_violation": False, + "epsilon": 1e-8, + }, + value_network_model_builder=FCNet, + value_network_model_config={ + "torso_fc_configs": [ + {"out_features": 256, "bias": True}, + {"out_features": 256, "bias": True}, + ], + }, + value_network_optimizer_builder=Adam, + value_network_optimizer_config={"lr": 3e-3}, + discount=0.9, + trace_decay=0.95, + normalize_advantage=False, + entropy_coeff=0.01, + max_grad_l2_norm=2.0, + seed=0, + enable_amp=False, + enable_reproducibility=True, + log_sys_usage=True, + suppress_warnings=True, + ), +) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 0b6e578..9d8c59d 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -123,6 +123,7 @@ def _cleanup(): algorithms.AWR, algorithms.PPO, algorithms.REINFORCE, + algorithms.TRPO, ], ) @pytest.mark.parametrize( @@ -198,7 +199,9 @@ def test_algorithm_mixed(algorithm_cls, model_cls, space): @pytest.mark.parametrize( "algorithm_cls", [ + algorithms.D3PG, algorithms.DDPG, + algorithms.SAC, algorithms.TD3, ], )