From 2eb60648ce352354d8003c81b01bd046f3498806 Mon Sep 17 00:00:00 2001 From: Luca Della Libera Date: Sun, 12 Mar 2023 15:46:17 -0400 Subject: [PATCH] Add TRPO, D3PG and SAC, minor improvements and bug fixes --- README.md | 52 +- actorch/algorithms/__init__.py | 3 + actorch/algorithms/a2c.py | 48 +- actorch/algorithms/acktr.py | 6 +- actorch/algorithms/algorithm.py | 54 +- actorch/algorithms/awr.py | 28 +- actorch/algorithms/d3pg.py | 430 +++++++++++++ actorch/algorithms/ddpg.py | 39 +- actorch/algorithms/ppo.py | 27 +- actorch/algorithms/reinforce.py | 6 +- actorch/algorithms/sac.py | 605 ++++++++++++++++++ actorch/algorithms/td3.py | 47 +- actorch/algorithms/trpo.py | 394 ++++++++++++ actorch/algorithms/utils.py | 124 +++- .../value_estimation/generalized_estimator.py | 33 +- .../value_estimation/importance_sampling.py | 8 +- .../value_estimation/lambda_return.py | 8 +- .../value_estimation/monte_carlo_return.py | 9 +- .../value_estimation/n_step_return.py | 8 +- .../off_policy_lambda_return.py | 8 +- .../algorithms/value_estimation/retrace.py | 8 +- .../value_estimation/tree_backup.py | 8 +- actorch/algorithms/value_estimation/vtrace.py | 8 +- actorch/buffers/buffer.py | 3 + actorch/buffers/proportional_buffer.py | 2 + actorch/buffers/rank_based_buffer.py | 2 + actorch/buffers/uniform_buffer.py | 1 + actorch/distributed/distributed_trainable.py | 6 +- .../distributed/sync_distributed_trainable.py | 4 +- actorch/distributions/finite.py | 2 + actorch/optimizers/cgbls.py | 6 +- actorch/preconditioners/kfac.py | 11 +- actorch/version.py | 2 +- docs/_static/images/actorch-overview.png | Bin 0 -> 39825 bytes docs/index.rst | 3 + examples/A2C-AsyncHyperBand_LunarLander-v2.py | 3 +- examples/A2C_LunarLander-v2.py | 3 +- examples/ACKTR_LunarLander-v2.py | 3 +- ...um-v1.py => AWR-AffineFlow_Pendulum-v1.py} | 5 +- examples/AWR_Pendulum-v1.py | 3 +- .../D3PG-Finite_LunarLanderContinuous-v2.py | 129 ++++ .../D3PG-Normal_LunarLanderContinuous-v2.py | 130 ++++ examples/DDPG_LunarLanderContinuous-v2.py | 3 +- ...taParallelDDPG_LunarLanderContinuous-v2.py | 3 +- ...ibutedDataParallelREINFORCE_CartPole-v1.py | 1 + examples/PPO-Laplace_Pendulum-v1.py | 94 +++ examples/PPO_Pendulum-v1.py | 3 +- examples/REINFORCE_CartPole-v1.py | 1 + examples/SAC_BipedalWalker-v3.py | 129 ++++ examples/SAC_Pendulum-v1.py | 82 +++ examples/TD3_LunarLanderContinuous-v2.py | 3 +- examples/TRPO_Pendulum-v1.py | 84 +++ tests/test_algorithms.py | 3 + 53 files changed, 2456 insertions(+), 229 deletions(-) create mode 100644 actorch/algorithms/d3pg.py create mode 100644 actorch/algorithms/sac.py create mode 100644 actorch/algorithms/trpo.py create mode 100644 docs/_static/images/actorch-overview.png rename examples/{AWR-NormalizingFlow_Pendulum-v1.py => AWR-AffineFlow_Pendulum-v1.py} (95%) create mode 100644 examples/D3PG-Finite_LunarLanderContinuous-v2.py create mode 100644 examples/D3PG-Normal_LunarLanderContinuous-v2.py create mode 100644 examples/PPO-Laplace_Pendulum-v1.py create mode 100644 examples/SAC_BipedalWalker-v3.py create mode 100644 examples/SAC_Pendulum-v1.py create mode 100644 examples/TRPO_Pendulum-v1.py diff --git a/README.md b/README.md index 8c39d88..5bc336b 100644 --- a/README.md +++ b/README.md @@ -14,33 +14,36 @@ Welcome to `actorch`, a deep reinforcement learning framework for fast prototypi - [REINFORCE](https://people.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf) - [Advantage Actor-Critic (A2C)](https://arxiv.org/abs/1602.01783) - [Actor-Critic Kronecker-Factored Trust Region (ACKTR)](https://arxiv.org/abs/1708.05144) +- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477) - [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347) - [Advantage-Weighted Regression (AWR)](https://arxiv.org/abs/1910.00177) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971) +- [Distributional Deep Deterministic Policy Gradient (D3PG)](https://arxiv.org/abs/1804.08617) - [Twin Delayed Deep Deterministic Policy Gradient (TD3)](https://arxiv.org/abs/1802.09477) +- [Soft Actor-Critic (SAC)](https://arxiv.org/abs/1801.01290) --------------------------------------------------------------------------------------------------------- ## 💡 Key features - Support for [OpenAI Gymnasium](https://gymnasium.farama.org/) environments -- Support for custom observation/action spaces -- Support for custom multimodal input multimodal output models -- Support for recurrent models (e.g. RNNs, LSTMs, GRUs, etc.) -- Support for custom policy/value distributions -- Support for custom preprocessing/postprocessing pipelines -- Support for custom exploration strategies +- Support for **custom observation/action spaces** +- Support for **custom multimodal input multimodal output models** +- Support for **recurrent models** (e.g. RNNs, LSTMs, GRUs, etc.) +- Support for **custom policy/value distributions** +- Support for **custom preprocessing/postprocessing pipelines** +- Support for **custom exploration strategies** - Support for [normalizing flows](https://arxiv.org/abs/1906.02771) - Batched environments (both for training and evaluation) -- Batched trajectory replay -- Batched and distributional value estimation (e.g. batched and distributional [Retrace](https://arxiv.org/abs/1606.02647) and [V-trace](https://arxiv.org/abs/1802.01561)) -- Data parallel and distributed data parallel multi-GPU training and evaluation -- Automatic mixed precision training -- Integration with [Ray Tune](https://docs.ray.io/en/releases-1.13.0/tune/index.html) for experiment execution and hyperparameter tuning at any scale -- Effortless experiment definition through Python-based configuration files -- Built-in visualization tool to plot performance metrics -- Modular object-oriented design -- Detailed API documentation +- Batched **trajectory replay** +- Batched and **distributional value estimation** (e.g. batched and distributional [Retrace](https://arxiv.org/abs/1606.02647) and [V-trace](https://arxiv.org/abs/1802.01561)) +- Data parallel and distributed data parallel **multi-GPU training and evaluation** +- Automatic **mixed precision training** +- Integration with [Ray Tune](https://docs.ray.io/en/releases-1.13.0/tune/index.html) for experiment execution and **hyperparameter tuning** at any scale +- Effortless experiment definition through **Python-based configuration files** +- Built-in **visualization tool** to plot performance metrics +- Modular **object-oriented** design +- Detailed **API documentation** --------------------------------------------------------------------------------------------------------- @@ -161,7 +164,7 @@ experiment_params = ExperimentParams( enable_amp=False, enable_reproducibility=True, log_sys_usage=True, - suppress_warnings=False, + suppress_warnings=True, ), ) ``` @@ -197,6 +200,8 @@ You can find the generated plots in `plots`. Congratulations, you ran your first experiment! +See `examples` for additional configuration file examples. + **HINT**: since a configuration file is a regular Python script, you can use all the features of the language (e.g. inheritance). @@ -217,6 +222,21 @@ features of the language (e.g. inheritance). --------------------------------------------------------------------------------------------------------- +## @ Citation + +``` +@misc{DellaLibera2022ACTorch, + author = {Luca Della Libera}, + title = {{ACTorch}: a Deep Reinforcement Learning Framework for Fast Prototyping}, + year = {2022}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/lucadellalib/actorch}}, +} +``` + +--------------------------------------------------------------------------------------------------------- + ## 📧 Contact [luca.dellalib@gmail.com](mailto:luca.dellalib@gmail.com) diff --git a/actorch/algorithms/__init__.py b/actorch/algorithms/__init__.py index bbb311d..3235bc7 100644 --- a/actorch/algorithms/__init__.py +++ b/actorch/algorithms/__init__.py @@ -20,7 +20,10 @@ from actorch.algorithms.acktr import * from actorch.algorithms.algorithm import * from actorch.algorithms.awr import * +from actorch.algorithms.d3pg import * from actorch.algorithms.ddpg import * from actorch.algorithms.ppo import * from actorch.algorithms.reinforce import * +from actorch.algorithms.sac import * from actorch.algorithms.td3 import * +from actorch.algorithms.trpo import * diff --git a/actorch/algorithms/a2c.py b/actorch/algorithms/a2c.py index 9cce312..4a3e6e4 100644 --- a/actorch/algorithms/a2c.py +++ b/actorch/algorithms/a2c.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -"""Advantage Actor-Critic.""" +"""Advantage Actor-Critic (A2C).""" import contextlib from typing import Any, Callable, Dict, Optional, Tuple, Type, Union @@ -38,7 +38,7 @@ DistributedDataParallelREINFORCE, LRScheduler, ) -from actorch.algorithms.utils import prepare_model +from actorch.algorithms.utils import normalize_, prepare_model from actorch.algorithms.value_estimation import n_step_return from actorch.distributions import Deterministic from actorch.envs import BatchedEnv @@ -65,7 +65,7 @@ class A2C(REINFORCE): - """Advantage Actor-Critic. + """Advantage Actor-Critic (A2C). References ---------- @@ -208,12 +208,8 @@ def setup(self, config: "Dict[str, Any]") -> "None": self.config = A2C.Config(**self.config) self.config["_accept_kwargs"] = True super().setup(config) - self._value_network = ( - self._build_value_network().train().to(self._device, non_blocking=True) - ) - self._value_network_loss = ( - self._build_value_network_loss().train().to(self._device, non_blocking=True) - ) + self._value_network = self._build_value_network() + self._value_network_loss = self._build_value_network_loss() self._value_network_optimizer = self._build_value_network_optimizer() self._value_network_optimizer_lr_scheduler = ( self._build_value_network_optimizer_lr_scheduler() @@ -324,16 +320,20 @@ def _build_value_network(self) -> "Network": self.value_network_normalizing_flows, ) self._log_graph(value_network.wrapped_model.model, "value_network_model") - return value_network + return value_network.train().to(self._device, non_blocking=True) def _build_value_network_loss(self) -> "Loss": if self.value_network_loss_builder is None: self.value_network_loss_builder = torch.nn.MSELoss if self.value_network_loss_config is None: self.value_network_loss_config: "Dict[str, Any]" = {} - return self.value_network_loss_builder( - reduction="none", - **self.value_network_loss_config, + return ( + self.value_network_loss_builder( + reduction="none", + **self.value_network_loss_config, + ) + .train() + .to(self._device, non_blocking=True) ) def _build_value_network_optimizer(self) -> "Optimizer": @@ -374,6 +374,8 @@ def _train_step(self) -> "Dict[str, Any]": result = super()._train_step() self.num_return_steps.step() result["num_return_steps"] = self.num_return_steps() + result["entropy_coeff"] = result.pop("entropy_coeff", None) + result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None) return result # override @@ -405,17 +407,7 @@ def _train_on_batch( self.num_return_steps(), ) if self.normalize_advantage: - length = mask.sum(dim=1, keepdim=True) - advantages_mean = advantages.sum(dim=1, keepdim=True) / length - advantages -= advantages_mean - advantages *= mask - advantages_stddev = ( - ((advantages**2).sum(dim=1, keepdim=True) / length) - .sqrt() - .clamp(min=1e-6) - ) - advantages /= advantages_stddev - advantages *= mask + normalize_(advantages, dim=-1, mask=mask) # Discard next state value state_values = state_values[:, :-1] @@ -449,10 +441,12 @@ def _train_on_batch_value_network( state_value = state_values[mask] target = targets[mask] loss = self._value_network_loss(state_value, target) - loss *= is_weight[:, None].expand_as(mask)[mask] + priority = None + if self._buffer.is_prioritized: + loss *= is_weight[:, None].expand_as(mask)[mask] + priority = loss.detach().abs().to("cpu").numpy() loss = loss.mean() optimize_result = self._optimize_value_network(loss) - priority = None result = { "state_value": state_value.mean().item(), "target": target.mean().item(), @@ -490,7 +484,7 @@ def _get_default_value_network_preprocessor( class DistributedDataParallelA2C(DistributedDataParallelREINFORCE): - """Distributed data parallel Advantage Actor-Critic. + """Distributed data parallel Advantage Actor-Critic (A2C). See Also -------- diff --git a/actorch/algorithms/acktr.py b/actorch/algorithms/acktr.py index 931b16c..4884914 100644 --- a/actorch/algorithms/acktr.py +++ b/actorch/algorithms/acktr.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -"""Actor-Critic Kronecker-Factored Trust Region.""" +"""Actor-Critic Kronecker-Factored Trust Region (ACKTR).""" from typing import Any, Callable, Dict, Optional, Union @@ -43,7 +43,7 @@ class ACKTR(A2C): - """Actor-Critic Kronecker-Factored Trust Region. + """Actor-Critic Kronecker-Factored Trust Region (ACKTR). References ---------- @@ -287,7 +287,7 @@ def _optimize_policy_network(self, loss: "Tensor") -> "Dict[str, Any]": class DistributedDataParallelACKTR(DistributedDataParallelA2C): - """Distributed data parallel Actor-Critic Kronecker-Factored Trust Region. + """Distributed data parallel Actor-Critic Kronecker-Factored Trust Region (ACKTR). See Also -------- diff --git a/actorch/algorithms/algorithm.py b/actorch/algorithms/algorithm.py index 1fcd34b..fdb568d 100644 --- a/actorch/algorithms/algorithm.py +++ b/actorch/algorithms/algorithm.py @@ -51,7 +51,6 @@ from ray.tune.syncer import NodeSyncer from ray.tune.trial import ExportFormat from torch import Tensor -from torch.cuda.amp import autocast from torch.distributions import Bernoulli, Categorical, Distribution, Normal from torch.profiler import profile, record_function, tensorboard_trace_handler from torch.utils.data import DataLoader @@ -113,7 +112,7 @@ class Algorithm(ABC, Trainable): _EXPORT_FORMATS = [ExportFormat.CHECKPOINT, ExportFormat.MODEL] - _UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH = True + _OFF_POLICY = True class Config(dict): """Keyword arguments expected in the configuration received by `setup`.""" @@ -692,12 +691,7 @@ def _build_train_env(self) -> "BatchedEnv": if self.train_env_config is None: self.train_env_config = {} - try: - train_env = self.train_env_builder( - **self.train_env_config, - ) - except TypeError: - train_env = self.train_env_builder(**self.train_env_config) + train_env = self.train_env_builder(**self.train_env_config) if not isinstance(train_env, BatchedEnv): train_env.close() train_env = SerialBatchedEnv(self.train_env_builder, self.train_env_config) @@ -866,7 +860,7 @@ def _build_policy_network(self) -> "PolicyNetwork": # noqa: C901 self.policy_network_postprocessors, ) self._log_graph(policy_network.wrapped_model.model, "policy_network_model") - return policy_network + return policy_network.train().to(self._device, non_blocking=True) def _build_train_agent(self) -> "Agent": if self.train_agent_builder is None: @@ -989,14 +983,18 @@ def _build_dataloader(self) -> "DataLoader": if self.dataloader_builder is None: self.dataloader_builder = DataLoader if self.dataloader_config is None: - fork = torch.multiprocessing.get_start_method() == "fork" + use_mp = ( + self._OFF_POLICY + and not self._buffer.is_prioritized + and torch.multiprocessing.get_start_method() == "fork" + ) self.dataloader_config = { - "num_workers": 1 if fork else 0, + "num_workers": 1 if use_mp else 0, "pin_memory": True, "timeout": 0, "worker_init_fn": None, "generator": None, - "prefetch_factor": 1 if fork else 2, + "prefetch_factor": 1 if use_mp else 2, "pin_memory_device": "", } if self.dataloader_config is None: @@ -1037,20 +1035,15 @@ def _train_step(self) -> "Dict[str, Any]": if self.train_num_episodes_per_iter: train_num_episodes_per_iter = self.train_num_episodes_per_iter() self.train_num_episodes_per_iter.step() - with ( - autocast(**self.enable_amp) - if self.enable_amp["enabled"] - else contextlib.suppress() + for experience, done in self._train_sampler.sample( + train_num_timesteps_per_iter, + train_num_episodes_per_iter, ): - for experience, done in self._train_sampler.sample( - train_num_timesteps_per_iter, - train_num_episodes_per_iter, - ): - self._buffer.add(experience, done) + self._buffer.add(experience, done) result = self._train_sampler.stats self._cumrewards += result["episode_cumreward"] - if not self._UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH: + if not self._OFF_POLICY: for schedule in self._buffer_dataset.schedules.values(): schedule.step() train_epoch_result = self._train_epoch() @@ -1060,7 +1053,7 @@ def _train_step(self) -> "Dict[str, Any]": schedule.step() for schedule in self._buffer.schedules.values(): schedule.step() - if self._UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH: + if self._OFF_POLICY: for schedule in self._buffer_dataset.schedules.values(): schedule.step() return result @@ -1113,16 +1106,11 @@ def _eval_step(self) -> "Dict[str, Any]": eval_num_episodes_per_iter = self.eval_num_episodes_per_iter() self.eval_num_episodes_per_iter.step() self._eval_sampler.reset() - with ( - autocast(**self.enable_amp) - if self.enable_amp["enabled"] - else contextlib.suppress() + for _ in self._eval_sampler.sample( + eval_num_timesteps_per_iter, + eval_num_episodes_per_iter, ): - for _ in self._eval_sampler.sample( - eval_num_timesteps_per_iter, - eval_num_episodes_per_iter, - ): - pass + pass for schedule in self._eval_agent.schedules.values(): schedule.step() return self._eval_sampler.stats @@ -1353,7 +1341,7 @@ def __init__( Default to ``{}``. placement_strategy: The placement strategy - (see https://docs.ray.io/en/latest/ray-core/placement-group.html). + (see https://docs.ray.io/en/releases-1.13.0/ray-core/placement-group.html for Ray 1.13.0). backend: The backend for distributed execution (see https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group). diff --git a/actorch/algorithms/awr.py b/actorch/algorithms/awr.py index f13d6db..1e887c3 100644 --- a/actorch/algorithms/awr.py +++ b/actorch/algorithms/awr.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -"""Advantage-Weighted Regression.""" +"""Advantage-Weighted Regression (AWR).""" import contextlib from typing import Any, Callable, Dict, Optional, Tuple, Union @@ -31,6 +31,7 @@ from actorch.agents import Agent from actorch.algorithms.a2c import A2C, DistributedDataParallelA2C, Loss, LRScheduler from actorch.algorithms.algorithm import RefOrFutureRef, Tunable +from actorch.algorithms.utils import normalize_ from actorch.algorithms.value_estimation import lambda_return from actorch.buffers import Buffer from actorch.datasets import BufferDataset @@ -48,7 +49,7 @@ class AWR(A2C): - """Advantage-Weighted Regression. + """Advantage-Weighted Regression (AWR). References ---------- @@ -59,9 +60,7 @@ class AWR(A2C): """ - _UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH = True # override - - _RESET_BUFFER = False # override + _OFF_POLICY = True # override # override class Config(dict): @@ -277,6 +276,8 @@ def _train_step(self) -> "Dict[str, Any]": ) result["weight_clip"] = self.weight_clip() result["temperature"] = self.temperature() + result["entropy_coeff"] = result.pop("entropy_coeff", None) + result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None) result["buffer_num_experiences"] = self._buffer.num_experiences result["buffer_num_full_trajectories"] = self._buffer.num_full_trajectories result.pop("num_return_steps", None) @@ -311,17 +312,7 @@ def _train_on_batch( self.trace_decay(), ) if self.normalize_advantage: - length = mask.sum(dim=1, keepdim=True) - advantages_mean = advantages.sum(dim=1, keepdim=True) / length - advantages -= advantages_mean - advantages *= mask - advantages_stddev = ( - ((advantages**2).sum(dim=1, keepdim=True) / length) - .sqrt() - .clamp(min=1e-6) - ) - advantages /= advantages_stddev - advantages *= mask + normalize_(advantages, dim=-1, mask=mask) # Discard next state value state_values = state_values[:, :-1] @@ -352,16 +343,19 @@ def _train_on_batch_policy_network( raise ValueError( f"`entropy_coeff` ({entropy_coeff}) must be in the interval [0, inf)" ) + weight_clip = self.weight_clip() if weight_clip <= 0.0: raise ValueError( f"`weight_clip` ({weight_clip}) must be in the interval (0, inf)" ) + temperature = self.temperature() if temperature <= 0.0: raise ValueError( f"`temperature` ({temperature}) must be in the interval (0, inf)" ) + with ( autocast(**self.enable_amp) if self.enable_amp["enabled"] @@ -392,7 +386,7 @@ def _train_on_batch_policy_network( class DistributedDataParallelAWR(DistributedDataParallelA2C): - """Distributed data parallel Advantage-Weighted Regression. + """Distributed data parallel Advantage-Weighted Regression (AWR). See Also -------- diff --git a/actorch/algorithms/d3pg.py b/actorch/algorithms/d3pg.py new file mode 100644 index 0000000..ef201d0 --- /dev/null +++ b/actorch/algorithms/d3pg.py @@ -0,0 +1,430 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Distributional Deep Deterministic Policy Gradient (D3PG).""" + +import contextlib +import logging +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from gymnasium import Env +from numpy import ndarray +from torch import Tensor +from torch.cuda.amp import autocast +from torch.distributions import Distribution, kl_divergence +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from actorch.agents import Agent +from actorch.algorithms.algorithm import RefOrFutureRef, Tunable +from actorch.algorithms.ddpg import DDPG, DistributedDataParallelDDPG, Loss, LRScheduler +from actorch.algorithms.utils import freeze_params, sync_polyak_ +from actorch.algorithms.value_estimation import n_step_return +from actorch.buffers import Buffer, ProportionalBuffer +from actorch.distributions import Finite +from actorch.envs import BatchedEnv +from actorch.models import Model +from actorch.networks import ( + DistributionParametrization, + Network, + NormalizingFlow, + Processor, +) +from actorch.samplers import Sampler +from actorch.schedules import Schedule + + +__all__ = [ + "D3PG", + "DistributedDataParallelD3PG", +] + + +_LOGGER = logging.getLogger(__name__) + + +class D3PG(DDPG): + """Distributional Deep Deterministic Policy Gradient (D3PG). + + References + ---------- + .. [1] G. Barth-Maron, M. W. Hoffman, D. Budden, W. Dabney, D. Horgan, D. TB, + A. Muldal, N. Heess, and T. Lillicrap. + "Distributed Distributional Deterministic Policy Gradients". + In: ICLR. 2018. + URL: https://arxiv.org/abs/1804.08617 + + """ + + # override + class Config(dict): + """Keyword arguments expected in the configuration received by `setup`.""" + + def __init__( + self, + train_env_builder: "Tunable[RefOrFutureRef[Callable[..., Union[Env, BatchedEnv]]]]", + train_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None, + train_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None, + train_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + train_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + eval_freq: "Tunable[RefOrFutureRef[Optional[int]]]" = 1, + eval_env_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Union[Env, BatchedEnv]]]]]" = None, + eval_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None, + eval_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None, + eval_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + eval_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + policy_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + policy_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None, + policy_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + policy_network_postprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + policy_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None, + policy_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + policy_network_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None, + policy_network_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + value_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None, + value_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_distribution_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Distribution]]]]" = None, + value_network_distribution_parametrization: "Tunable[RefOrFutureRef[Optional[DistributionParametrization]]]" = None, + value_network_distribution_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_normalizing_flow: "Tunable[RefOrFutureRef[Optional[NormalizingFlow]]]" = None, + value_network_loss_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Loss]]]]" = None, + value_network_loss_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None, + value_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None, + value_network_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + buffer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Buffer]]]]" = None, + buffer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + buffer_checkpoint: "Tunable[RefOrFutureRef[bool]]" = False, + dataloader_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., DataLoader]]]]" = None, + dataloader_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + discount: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.99, + num_return_steps: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 3, + num_updates_per_iter: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 1000, + batch_size: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 128, + max_trajectory_length: "Tunable[RefOrFutureRef[Union[int, float, Schedule]]]" = float( # noqa: B008 + "inf" + ), + sync_freq: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 1, + polyak_weight: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.001, + max_grad_l2_norm: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = float( # noqa: B008 + "inf" + ), + cumreward_window_size: "Tunable[RefOrFutureRef[int]]" = 100, + seed: "Tunable[RefOrFutureRef[int]]" = 0, + enable_amp: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False, + enable_reproducibility: "Tunable[RefOrFutureRef[bool]]" = False, + enable_anomaly_detection: "Tunable[RefOrFutureRef[bool]]" = False, + enable_profiling: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False, + log_sys_usage: "Tunable[RefOrFutureRef[bool]]" = False, + suppress_warnings: "Tunable[RefOrFutureRef[bool]]" = False, + _accept_kwargs: "bool" = False, + **kwargs: "Any", + ) -> "None": + if not _accept_kwargs and kwargs: + raise ValueError(f"Unexpected configuration arguments: {list(kwargs)}") + super().__init__( + train_env_builder=train_env_builder, + train_env_config=train_env_config, + train_agent_builder=train_agent_builder, + train_agent_config=train_agent_config, + train_sampler_builder=train_sampler_builder, + train_sampler_config=train_sampler_config, + train_num_timesteps_per_iter=train_num_timesteps_per_iter, + train_num_episodes_per_iter=train_num_episodes_per_iter, + eval_freq=eval_freq, + eval_env_builder=eval_env_builder, + eval_env_config=eval_env_config, + eval_agent_builder=eval_agent_builder, + eval_agent_config=eval_agent_config, + eval_sampler_builder=eval_sampler_builder, + eval_sampler_config=eval_sampler_config, + eval_num_timesteps_per_iter=eval_num_timesteps_per_iter, + eval_num_episodes_per_iter=eval_num_episodes_per_iter, + policy_network_preprocessors=policy_network_preprocessors, + policy_network_model_builder=policy_network_model_builder, + policy_network_model_config=policy_network_model_config, + policy_network_postprocessors=policy_network_postprocessors, + policy_network_optimizer_builder=policy_network_optimizer_builder, + policy_network_optimizer_config=policy_network_optimizer_config, + policy_network_optimizer_lr_scheduler_builder=policy_network_optimizer_lr_scheduler_builder, + policy_network_optimizer_lr_scheduler_config=policy_network_optimizer_lr_scheduler_config, + value_network_preprocessors=value_network_preprocessors, + value_network_model_builder=value_network_model_builder, + value_network_model_config=value_network_model_config, + value_network_distribution_builder=value_network_distribution_builder, + value_network_distribution_parametrization=value_network_distribution_parametrization, + value_network_distribution_config=value_network_distribution_config, + value_network_normalizing_flow=value_network_normalizing_flow, + value_network_loss_builder=value_network_loss_builder, + value_network_loss_config=value_network_loss_config, + value_network_optimizer_builder=value_network_optimizer_builder, + value_network_optimizer_config=value_network_optimizer_config, + value_network_optimizer_lr_scheduler_builder=value_network_optimizer_lr_scheduler_builder, + value_network_optimizer_lr_scheduler_config=value_network_optimizer_lr_scheduler_config, + buffer_builder=buffer_builder, + buffer_config=buffer_config, + buffer_checkpoint=buffer_checkpoint, + dataloader_builder=dataloader_builder, + dataloader_config=dataloader_config, + discount=discount, + num_return_steps=num_return_steps, + num_updates_per_iter=num_updates_per_iter, + batch_size=batch_size, + max_trajectory_length=max_trajectory_length, + sync_freq=sync_freq, + polyak_weight=polyak_weight, + max_grad_l2_norm=max_grad_l2_norm, + cumreward_window_size=cumreward_window_size, + seed=seed, + enable_amp=enable_amp, + enable_reproducibility=enable_reproducibility, + enable_anomaly_detection=enable_anomaly_detection, + enable_profiling=enable_profiling, + log_sys_usage=log_sys_usage, + suppress_warnings=suppress_warnings, + **kwargs, + ) + + # override + def setup(self, config: "Dict[str, Any]") -> "None": + self.config = D3PG.Config(**self.config) + self.config["_accept_kwargs"] = True + super().setup(config) + self._warn_failed_logging = True + + # override + def _build_buffer(self) -> "Buffer": + if self.buffer_builder is None: + self.buffer_builder = ProportionalBuffer + if self.buffer_config is None: + self.buffer_config = { + "capacity": int(1e5), + "prioritization": 1.0, + "bias_correction": 0.4, + "epsilon": 1e-5, + } + return super()._build_buffer() + + # override + def _build_value_network(self) -> "Network": + if self.value_network_distribution_builder is None: + self.value_network_distribution_builder = Finite + if self.value_network_distribution_parametrization is None: + self.value_network_distribution_parametrization = { + "logits": ( + {"logits": (51,)}, + lambda x: x["logits"], + ), + } + if self.value_network_distribution_config is None: + self.value_network_distribution_config = { + "atoms": torch.linspace(-10.0, 10.0, 51).to(self._device), + "validate_args": False, + } + if self.value_network_distribution_parametrization is None: + self.value_network_distribution_parametrization = {} + if self.value_network_distribution_config is None: + self.value_network_distribution_config = {} + + if self.value_network_normalizing_flow is not None: + self.value_network_normalizing_flows = { + "value": self.value_network_normalizing_flow, + } + else: + self.value_network_normalizing_flows: "Dict[str, NormalizingFlow]" = {} + + self.value_network_distribution_builders = { + "value": self.value_network_distribution_builder, + } + self.value_network_distribution_parametrizations = { + "value": self.value_network_distribution_parametrization, + } + self.value_network_distribution_configs = { + "value": self.value_network_distribution_config, + } + return super()._build_value_network() + + # override + def _train_on_batch( + self, + idx: "int", + experiences: "Dict[str, Tensor]", + is_weight: "Tensor", + mask: "Tensor", + ) -> "Tuple[Dict[str, Any], Optional[ndarray]]": + sync_freq = self.sync_freq() + if sync_freq < 1 or not float(sync_freq).is_integer(): + raise ValueError( + f"`sync_freq` ({sync_freq}) " + f"must be in the integer interval [1, inf)" + ) + sync_freq = int(sync_freq) + + result = {} + + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + target_actions, _ = self._target_policy_network( + experiences["observation"], mask=mask + ) + observations_target_actions = torch.cat( + [ + x[..., None] if x.shape == mask.shape else x + for x in [experiences["observation"], target_actions] + ], + dim=-1, + ) + self._target_value_network(observations_target_actions, mask=mask) + target_action_values = self._target_value_network.distribution + # Discard next observation + experiences["observation"] = experiences["observation"][:, :-1, ...] + mask = mask[:, 1:] + targets, _ = n_step_return( + target_action_values, + experiences["reward"], + experiences["terminal"], + mask, + self.discount(), + self.num_return_steps(), + return_advantage=False, + ) + + # Compute action values + observations_actions = torch.cat( + [ + x[..., None] if x.shape == mask.shape else x + for x in [experiences["observation"], experiences["action"]] + ], + dim=-1, + ) + self._value_network(observations_actions, mask=mask) + action_values = self._value_network.distribution + + result["value_network"], priority = self._train_on_batch_value_network( + action_values, + targets, + is_weight, + mask, + ) + result["policy_network"] = self._train_on_batch_policy_network( + experiences, + mask, + ) + + # Synchronize + if idx % sync_freq == 0: + sync_polyak_( + self._policy_network, + self._target_policy_network, + self.polyak_weight(), + ) + sync_polyak_( + self._value_network, + self._target_value_network, + self.polyak_weight(), + ) + + self._grad_scaler.update() + return result, priority + + # override + def _train_on_batch_policy_network( + self, + experiences: "Dict[str, Tensor]", + mask: "Tensor", + ) -> "Dict[str, Any]": + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + actions, _ = self._policy_network(experiences["observation"], mask=mask) + observations_actions = torch.cat( + [ + x[..., None] if x.shape == mask.shape else x + for x in [experiences["observation"], actions] + ], + dim=-1, + ) + with freeze_params(self._value_network): + self._value_network(observations_actions, mask=mask) + action_values = self._value_network.distribution + try: + action_values = action_values.mean + except Exception as e: + raise RuntimeError(f"Could not compute `action_values.mean`: {e}") + action_value = action_values[mask] + loss = -action_value.mean() + optimize_result = self._optimize_policy_network(loss) + result = {"loss": loss.item()} + result.update(optimize_result) + return result + + # override + def _train_on_batch_value_network( + self, + action_values: "Distribution", + targets: "Distribution", + is_weight: "Tensor", + mask: "Tensor", + ) -> "Tuple[Dict[str, Any], Optional[ndarray]]": + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + loss = kl_divergence(targets, action_values)[mask] + priority = None + if self._buffer.is_prioritized: + loss *= is_weight[:, None].expand_as(mask)[mask] + priority = loss.detach().abs().to("cpu").numpy() + loss = loss.mean() + optimize_result = self._optimize_value_network(loss) + result = {} + try: + result["action_value"] = action_values.mean[mask].mean().item() + result["target"] = targets.mean[mask].mean().item() + except Exception as e: + if self._warn_failed_logging: + _LOGGER.warning(f"Could not log `action_value` and/or `target`: {e}") + self._warn_failed_logging = False + result["loss"] = loss.item() + result.update(optimize_result) + return result, priority + + +class DistributedDataParallelD3PG(DistributedDataParallelDDPG): + """Distributed data parallel Distributional Deep Deterministic Policy Gradient (D3PG). + + See Also + -------- + actorch.algorithms.d3pg.D3PG + + """ + + _ALGORITHM_CLS = D3PG # override diff --git a/actorch/algorithms/ddpg.py b/actorch/algorithms/ddpg.py index 79cc4a6..17190cc 100644 --- a/actorch/algorithms/ddpg.py +++ b/actorch/algorithms/ddpg.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -"""Deep Deterministic Policy Gradient.""" +"""Deep Deterministic Policy Gradient (DDPG).""" import contextlib import copy @@ -33,7 +33,7 @@ from actorch.agents import Agent, GaussianNoiseAgent from actorch.algorithms.a2c import A2C, DistributedDataParallelA2C, Loss, LRScheduler from actorch.algorithms.algorithm import RefOrFutureRef, Tunable -from actorch.algorithms.utils import freeze_params, prepare_model, sync_polyak +from actorch.algorithms.utils import freeze_params, prepare_model, sync_polyak_ from actorch.algorithms.value_estimation import n_step_return from actorch.buffers import Buffer from actorch.datasets import BufferDataset @@ -53,7 +53,7 @@ class DDPG(A2C): - """Deep Deterministic Policy Gradient. + """Deep Deterministic Policy Gradient (DDPG). References ---------- @@ -65,9 +65,7 @@ class DDPG(A2C): """ - _UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH = True # override - - _RESET_BUFFER = False # override + _OFF_POLICY = True # override # override class Config(dict): @@ -203,12 +201,8 @@ def setup(self, config: "Dict[str, Any]") -> "None": self.config = DDPG.Config(**self.config) self.config["_accept_kwargs"] = True super().setup(config) - self._target_policy_network = copy.deepcopy(self._policy_network) - self._target_policy_network.eval().to(self._device, non_blocking=True) - self._target_policy_network.requires_grad_(False) - self._target_value_network = copy.deepcopy(self._value_network) - self._target_value_network.eval().to(self._device, non_blocking=True) - self._target_value_network.requires_grad_(False) + self._target_policy_network = self._build_target_policy_network() + self._target_value_network = self._build_target_value_network() if not isinstance(self.num_updates_per_iter, Schedule): self.num_updates_per_iter = ConstantSchedule(self.num_updates_per_iter) if not isinstance(self.batch_size, Schedule): @@ -300,6 +294,16 @@ def _build_value_network(self) -> "Network": } return super()._build_value_network() + def _build_target_policy_network(self) -> "Network": + target_policy_network = copy.deepcopy(self._policy_network) + target_policy_network.requires_grad_(False) + return target_policy_network.eval().to(self._device, non_blocking=True) + + def _build_target_value_network(self) -> "Network": + target_value_network = copy.deepcopy(self._value_network) + target_value_network.requires_grad_(False) + return target_value_network.eval().to(self._device, non_blocking=True) + # override def _train_step(self) -> "Dict[str, Any]": result = super()._train_step() @@ -314,6 +318,7 @@ def _train_step(self) -> "Dict[str, Any]": ) result["sync_freq"] = self.sync_freq() result["polyak_weight"] = self.polyak_weight() + result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None) result["buffer_num_experiences"] = self._buffer.num_experiences result["buffer_num_full_trajectories"] = self._buffer.num_full_trajectories result.pop("entropy_coeff", None) @@ -365,6 +370,7 @@ def _train_on_batch( mask, self.discount(), self.num_return_steps(), + return_advantage=False, ) # Compute action values @@ -390,12 +396,12 @@ def _train_on_batch( # Synchronize if idx % sync_freq == 0: - sync_polyak( + sync_polyak_( self._policy_network, self._target_policy_network, self.polyak_weight(), ) - sync_polyak( + sync_polyak_( self._value_network, self._target_value_network, self.polyak_weight(), @@ -425,7 +431,8 @@ def _train_on_batch_policy_network( ) with freeze_params(self._value_network): action_values, _ = self._value_network(observations_actions, mask=mask) - loss = -action_values[mask].mean() + action_value = action_values[mask] + loss = -action_value.mean() optimize_result = self._optimize_policy_network(loss) result = {"loss": loss.item()} result.update(optimize_result) @@ -482,7 +489,7 @@ def _get_default_policy_network_distribution_config( class DistributedDataParallelDDPG(DistributedDataParallelA2C): - """Distributed data parallel Deep Deterministic Policy Gradient. + """Distributed data parallel Deep Deterministic Policy Gradient (DDPG). See Also -------- diff --git a/actorch/algorithms/ppo.py b/actorch/algorithms/ppo.py index 2fd2963..eee2cf4 100644 --- a/actorch/algorithms/ppo.py +++ b/actorch/algorithms/ppo.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -"""Proximal Policy Optimization.""" +"""Proximal Policy Optimization (PPO).""" import contextlib from typing import Any, Callable, Dict, Optional, Tuple, Union @@ -31,6 +31,7 @@ from actorch.agents import Agent from actorch.algorithms.a2c import A2C, DistributedDataParallelA2C, Loss, LRScheduler from actorch.algorithms.algorithm import RefOrFutureRef, Tunable +from actorch.algorithms.utils import normalize_ from actorch.algorithms.value_estimation import lambda_return from actorch.envs import BatchedEnv from actorch.models import Model @@ -46,7 +47,7 @@ class PPO(A2C): - """Proximal Policy Optimization. + """Proximal Policy Optimization (PPO). References ---------- @@ -234,6 +235,8 @@ def _train_step(self) -> "Dict[str, Any]": result["num_epochs"] = self.num_epochs() result["minibatch_size"] = self.minibatch_size() result["ratio_clip"] = self.ratio_clip() + result["entropy_coeff"] = result.pop("entropy_coeff", None) + result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None) result.pop("num_return_steps", None) return result @@ -301,22 +304,8 @@ def _train_on_batch( else contextlib.suppress() ): if self.normalize_advantage: - length_batch = mask_batch.sum(dim=1, keepdim=True) - advantages_batch_mean = ( - advantages_batch.sum(dim=1, keepdim=True) / length_batch - ) - advantages_batch -= advantages_batch_mean - advantages_batch *= mask_batch - advantages_batch_stddev = ( - ( - (advantages_batch**2).sum(dim=1, keepdim=True) - / length_batch - ) - .sqrt() - .clamp(min=1e-6) - ) - advantages_batch /= advantages_batch_stddev - advantages_batch *= mask_batch + normalize_(advantages_batch, dim=-1, mask=mask_batch) + state_values_batch, _ = self._value_network( experiences_batch["observation"], mask=mask_batch ) @@ -387,7 +376,7 @@ def _train_on_batch_policy_network( class DistributedDataParallelPPO(DistributedDataParallelA2C): - """Distributed data parallel Proximal Policy Optimization. + """Distributed data parallel Proximal Policy Optimization (PPO). See Also -------- diff --git a/actorch/algorithms/reinforce.py b/actorch/algorithms/reinforce.py index 20db9d1..c0625ec 100644 --- a/actorch/algorithms/reinforce.py +++ b/actorch/algorithms/reinforce.py @@ -67,9 +67,7 @@ class REINFORCE(Algorithm): """ - _UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH = False # override - - _RESET_BUFFER = True + _OFF_POLICY = False # override # override class Config(dict): @@ -281,7 +279,7 @@ def _build_policy_network_optimizer_lr_scheduler( # override def _train_step(self) -> "Dict[str, Any]": result = super()._train_step() - if self._RESET_BUFFER: + if not self._OFF_POLICY: self._buffer.reset() self.discount.step() self.entropy_coeff.step() diff --git a/actorch/algorithms/sac.py b/actorch/algorithms/sac.py new file mode 100644 index 0000000..6999ff8 --- /dev/null +++ b/actorch/algorithms/sac.py @@ -0,0 +1,605 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Soft Actor-Critic (SAC).""" + +import contextlib +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from gymnasium import Env, spaces +from numpy import ndarray +from torch import Tensor +from torch.cuda.amp import autocast +from torch.distributions import Distribution, Normal +from torch.nn.utils import clip_grad_norm_ +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from actorch.agents import Agent +from actorch.algorithms.algorithm import RefOrFutureRef, Tunable +from actorch.algorithms.td3 import TD3, DistributedDataParallelTD3, Loss, LRScheduler +from actorch.algorithms.utils import freeze_params, sync_polyak_ +from actorch.algorithms.value_estimation import n_step_return +from actorch.buffers import Buffer +from actorch.envs import BatchedEnv +from actorch.models import Model +from actorch.networks import DistributionParametrization, NormalizingFlow, Processor +from actorch.samplers import Sampler +from actorch.schedules import ConstantSchedule, Schedule +from actorch.utils import singledispatchmethod + + +__all__ = [ + "DistributedDataParallelSAC", + "SAC", +] + + +class SAC(TD3): + """Soft Actor-Critic (SAC). + + References + ---------- + .. [1] T. Haarnoja, A. Zhou, P. Abbeel, and S. Levine. + "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor". + In: ICML. 2018, pp. 1861-1870. + URL: https://arxiv.org/abs/1801.01290 + .. [2] T. Haarnoja, A. Zhou, K. Hartikainen, G. Tucker, S. Ha, J. Tan, V. Kumar, + H. Zhu, A. Gupta, P. Abbeel, and S. Levine. + "Soft Actor-Critic Algorithms and Applications". + In: arXiv. 2018. + URL: https://arxiv.org/abs/1812.05905 + + """ + + # override + class Config(dict): + """Keyword arguments expected in the configuration received by `setup`.""" + + def __init__( + self, + train_env_builder: "Tunable[RefOrFutureRef[Callable[..., Union[Env, BatchedEnv]]]]", + train_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None, + train_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None, + train_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + train_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + eval_freq: "Tunable[RefOrFutureRef[Optional[int]]]" = 1, + eval_env_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Union[Env, BatchedEnv]]]]]" = None, + eval_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None, + eval_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None, + eval_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + eval_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + policy_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + policy_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None, + policy_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + policy_network_distribution_builders: "Tunable[RefOrFutureRef[Optional[Dict[str, Callable[..., Distribution]]]]]" = None, + policy_network_distribution_parametrizations: "Tunable[RefOrFutureRef[Optional[Dict[str, DistributionParametrization]]]]" = None, + policy_network_distribution_configs: "Tunable[RefOrFutureRef[Optional[Dict[str, Dict[str, Any]]]]]" = None, + policy_network_normalizing_flows: "Tunable[RefOrFutureRef[Optional[Dict[str, NormalizingFlow]]]]" = None, + policy_network_sample_fn: "Tunable[RefOrFutureRef[Optional[Callable[[Distribution], Tensor]]]]" = None, + policy_network_prediction_fn: "Tunable[RefOrFutureRef[Optional[Callable[[Tensor], Tensor]]]]" = None, + policy_network_postprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + policy_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None, + policy_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + policy_network_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None, + policy_network_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + value_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None, + value_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_loss_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Loss]]]]" = None, + value_network_loss_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None, + value_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None, + value_network_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + temperature_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None, + temperature_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + temperature_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None, + temperature_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + buffer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Buffer]]]]" = None, + buffer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + buffer_checkpoint: "Tunable[RefOrFutureRef[bool]]" = False, + dataloader_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., DataLoader]]]]" = None, + dataloader_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + discount: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.99, + num_return_steps: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 1, + num_updates_per_iter: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 1000, + batch_size: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 128, + max_trajectory_length: "Tunable[RefOrFutureRef[Union[int, float, Schedule]]]" = float( # noqa: B008 + "inf" + ), + sync_freq: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 1, + polyak_weight: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.001, + temperature: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.1, + max_grad_l2_norm: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = float( # noqa: B008 + "inf" + ), + cumreward_window_size: "Tunable[RefOrFutureRef[int]]" = 100, + seed: "Tunable[RefOrFutureRef[int]]" = 0, + enable_amp: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False, + enable_reproducibility: "Tunable[RefOrFutureRef[bool]]" = False, + enable_anomaly_detection: "Tunable[RefOrFutureRef[bool]]" = False, + enable_profiling: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False, + log_sys_usage: "Tunable[RefOrFutureRef[bool]]" = False, + suppress_warnings: "Tunable[RefOrFutureRef[bool]]" = False, + _accept_kwargs: "bool" = False, + **kwargs: "Any", + ) -> "None": + if not _accept_kwargs and kwargs: + raise ValueError(f"Unexpected configuration arguments: {list(kwargs)}") + super().__init__( + train_env_builder=train_env_builder, + train_env_config=train_env_config, + train_agent_builder=train_agent_builder, + train_agent_config=train_agent_config, + train_sampler_builder=train_sampler_builder, + train_sampler_config=train_sampler_config, + train_num_timesteps_per_iter=train_num_timesteps_per_iter, + train_num_episodes_per_iter=train_num_episodes_per_iter, + eval_freq=eval_freq, + eval_env_builder=eval_env_builder, + eval_env_config=eval_env_config, + eval_agent_builder=eval_agent_builder, + eval_agent_config=eval_agent_config, + eval_sampler_builder=eval_sampler_builder, + eval_sampler_config=eval_sampler_config, + eval_num_timesteps_per_iter=eval_num_timesteps_per_iter, + eval_num_episodes_per_iter=eval_num_episodes_per_iter, + policy_network_preprocessors=policy_network_preprocessors, + policy_network_model_builder=policy_network_model_builder, + policy_network_model_config=policy_network_model_config, + policy_network_distribution_builders=policy_network_distribution_builders, + policy_network_distribution_parametrizations=policy_network_distribution_parametrizations, + policy_network_distribution_configs=policy_network_distribution_configs, + policy_network_normalizing_flows=policy_network_normalizing_flows, + policy_network_sample_fn=policy_network_sample_fn, + policy_network_prediction_fn=policy_network_prediction_fn, + policy_network_postprocessors=policy_network_postprocessors, + policy_network_optimizer_builder=policy_network_optimizer_builder, + policy_network_optimizer_config=policy_network_optimizer_config, + policy_network_optimizer_lr_scheduler_builder=policy_network_optimizer_lr_scheduler_builder, + policy_network_optimizer_lr_scheduler_config=policy_network_optimizer_lr_scheduler_config, + value_network_preprocessors=value_network_preprocessors, + value_network_model_builder=value_network_model_builder, + value_network_model_config=value_network_model_config, + value_network_loss_builder=value_network_loss_builder, + value_network_loss_config=value_network_loss_config, + value_network_optimizer_builder=value_network_optimizer_builder, + value_network_optimizer_config=value_network_optimizer_config, + value_network_optimizer_lr_scheduler_builder=value_network_optimizer_lr_scheduler_builder, + value_network_optimizer_lr_scheduler_config=value_network_optimizer_lr_scheduler_config, + temperature_optimizer_builder=temperature_optimizer_builder, + temperature_optimizer_config=temperature_optimizer_config, + temperature_optimizer_lr_scheduler_builder=temperature_optimizer_lr_scheduler_builder, + temperature_optimizer_lr_scheduler_config=temperature_optimizer_lr_scheduler_config, + buffer_builder=buffer_builder, + buffer_config=buffer_config, + buffer_checkpoint=buffer_checkpoint, + dataloader_builder=dataloader_builder, + dataloader_config=dataloader_config, + discount=discount, + num_return_steps=num_return_steps, + num_updates_per_iter=num_updates_per_iter, + batch_size=batch_size, + max_trajectory_length=max_trajectory_length, + sync_freq=sync_freq, + polyak_weight=polyak_weight, + temperature=temperature, + max_grad_l2_norm=max_grad_l2_norm, + cumreward_window_size=cumreward_window_size, + seed=seed, + enable_amp=enable_amp, + enable_reproducibility=enable_reproducibility, + enable_anomaly_detection=enable_anomaly_detection, + enable_profiling=enable_profiling, + log_sys_usage=log_sys_usage, + suppress_warnings=suppress_warnings, + **kwargs, + ) + + # override + def setup(self, config: "Dict[str, Any]") -> "None": + self.config = SAC.Config(**self.config) + self.config["_accept_kwargs"] = True + super().setup(config) + if self.temperature_optimizer_builder is not None: + self._log_temperature = torch.zeros( + 1, device=self._device, requires_grad=True + ) + self._target_entropy = torch.as_tensor( + self._train_env.single_action_space.sample() + ).numel() + self._temperature_optimizer = self._build_temperature_optimizer() + self._temperature_optimizer_lr_scheduler = ( + self._build_temperature_optimizer_lr_scheduler() + ) + self.temperature = lambda *args: self._log_temperature.exp().item() + elif not isinstance(self.temperature, Schedule): + self.temperature = ConstantSchedule(self.temperature) + + # override + @property + def _checkpoint(self) -> "Dict[str, Any]": + checkpoint = super()._checkpoint + if self.temperature_optimizer_builder is not None: + checkpoint["log_temperature"] = self._log_temperature + checkpoint[ + "temperature_optimizer" + ] = self._temperature_optimizer.state_dict() + if self._temperature_optimizer_lr_scheduler is not None: + checkpoint[ + "temperature_optimizer_lr_scheduler" + ] = self._temperature_optimizer_lr_scheduler.state_dict() + checkpoint["temperature"] = self.temperature.state_dict() + return checkpoint + + # override + @_checkpoint.setter + def _checkpoint(self, value: "Dict[str, Any]") -> "None": + super()._checkpoint = value + if "log_temperature" in value: + self._log_temperature = value["log_temperature"] + if "temperature_optimizer" in value: + self._temperature_optimizer.load_state_dict(value["temperature_optimizer"]) + if "temperature_optimizer_lr_scheduler" in value: + self._temperature_optimizer_lr_scheduler.load_state_dict( + value["temperature_optimizer_lr_scheduler"] + ) + self.temperature.load_state_dict(value["temperature"]) + + def _build_temperature_optimizer(self) -> "Optimizer": + if self.temperature_optimizer_config is None: + self.temperature_optimizer_config: "Dict[str, Any]" = {} + return self.temperature_optimizer_builder( + [self._log_temperature], + **self.temperature_optimizer_config, + ) + + def _build_temperature_optimizer_lr_scheduler( + self, + ) -> "Optional[LRScheduler]": + if self.temperature_optimizer_lr_scheduler_builder is None: + return + if self.temperature_optimizer_lr_scheduler_config is None: + self.temperature_optimizer_lr_scheduler_config: "Dict[str, Any]" = {} + return self.temperature_optimizer_lr_scheduler_builder( + self._temperature_optimizer, + **self.temperature_optimizer_lr_scheduler_config, + ) + + # override + def _train_step(self) -> "Dict[str, Any]": + result = super()._train_step() + if self.temperature_optimizer_builder is None: + result["temperature"] = self.temperature() + result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None) + result["buffer_num_experiences"] = result.pop("buffer_num_experiences", None) + result["buffer_num_full_trajectories"] = result.pop( + "buffer_num_full_trajectories", None + ) + result.pop("delay", None) + result.pop("noise_stddev", None) + result.pop("noise_clip", None) + return result + + # override + def _train_on_batch( + self, + idx: "int", + experiences: "Dict[str, Tensor]", + is_weight: "Tensor", + mask: "Tensor", + ) -> "Tuple[Dict[str, Any], Optional[ndarray]]": + sync_freq = self.sync_freq() + if sync_freq < 1 or not float(sync_freq).is_integer(): + raise ValueError( + f"`sync_freq` ({sync_freq}) " + f"must be in the integer interval [1, inf)" + ) + sync_freq = int(sync_freq) + + temperature = self.temperature() + if temperature <= 0.0: + raise ValueError( + f"`temperature` ({temperature}) must be in the interval (0, inf)" + ) + + result = {} + + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + self._policy_network(experiences["observation"], mask=mask) + policy = self._policy_network.distribution + target_actions = policy.rsample() + target_log_probs = policy.log_prob(target_actions) + observations_target_actions = torch.cat( + [ + x[..., None] if x.shape == mask.shape else x + for x in [experiences["observation"], target_actions] + ], + dim=-1, + ) + with torch.no_grad(): + target_action_values, _ = self._target_value_network( + observations_target_actions, mask=mask + ) + target_twin_action_values, _ = self._target_twin_value_network( + observations_target_actions, mask=mask + ) + target_action_values = target_action_values.min(target_twin_action_values) + with torch.no_grad(): + target_action_values -= temperature * target_log_probs + + # Discard next observation + experiences["observation"] = experiences["observation"][:, :-1, ...] + observations_target_actions = observations_target_actions[:, :-1, ...] + target_log_probs = target_log_probs[:, :-1, ...] + mask = mask[:, 1:] + targets, _ = n_step_return( + target_action_values, + experiences["reward"], + experiences["terminal"], + mask, + self.discount(), + self.num_return_steps(), + return_advantage=False, + ) + + # Compute action values + observations_actions = torch.cat( + [ + x[..., None] if x.shape == mask.shape else x + for x in [experiences["observation"], experiences["action"]] + ], + dim=-1, + ) + action_values, _ = self._value_network(observations_actions, mask=mask) + twin_action_values, _ = self._twin_value_network( + observations_actions, mask=mask + ) + + result["value_network"], priority = self._train_on_batch_value_network( + action_values, + targets, + is_weight, + mask, + ) + + ( + result["twin_value_network"], + twin_priority, + ) = self._train_on_batch_twin_value_network( + twin_action_values, + targets, + is_weight, + mask, + ) + + if priority is not None: + priority += twin_priority + priority /= 2.0 + + result["policy_network"] = self._train_on_batch_policy_network( + observations_target_actions, + target_log_probs, + mask, + ) + + if self.temperature_optimizer_builder is not None: + result["temperature"] = self._train_on_batch_temperature( + target_log_probs, + mask, + ) + + # Synchronize + if idx % sync_freq == 0: + sync_polyak_( + self._value_network, + self._target_value_network, + self.polyak_weight(), + ) + sync_polyak_( + self._twin_value_network, + self._target_twin_value_network, + self.polyak_weight(), + ) + + self._grad_scaler.update() + return result, priority + + # override + def _train_on_batch_policy_network( + self, + observations_actions: "Tensor", + log_probs: "Tensor", + mask: "Tensor", + ) -> "Dict[str, Any]": + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + with freeze_params(self._value_network, self._twin_value_network): + action_values, _ = self._value_network(observations_actions, mask=mask) + twin_action_values, _ = self._twin_value_network( + observations_actions, mask=mask + ) + action_values = action_values.min(twin_action_values) + log_prob = log_probs[mask] + action_value = action_values[mask] + loss = self.temperature() * log_prob + loss -= action_value + loss = loss.mean() + optimize_result = self._optimize_policy_network(loss) + result = { + "log_prob": log_prob.mean().item(), + "action_value": action_value.mean().item(), + "loss": loss.item(), + } + result.update(optimize_result) + return result + + def _train_on_batch_temperature( + self, + log_probs: "Tensor", + mask: "Tensor", + ) -> "Dict[str, Any]": + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + with torch.no_grad(): + log_prob = log_probs[mask] + loss = log_prob + self._target_entropy + loss *= -self._log_temperature + loss = loss.mean() + optimize_result = self._optimize_temperature(loss) + result = { + "temperature": self.temperature(), + "log_prob": log_prob.mean().item(), + "loss": loss.item(), + } + result.update(optimize_result) + return result + + def _optimize_temperature(self, loss: "Tensor") -> "Dict[str, Any]": + max_grad_l2_norm = self.max_grad_l2_norm() + if max_grad_l2_norm <= 0.0: + raise ValueError( + f"`max_grad_l2_norm` ({max_grad_l2_norm}) must be in the interval (0, inf]" + ) + self._temperature_optimizer.zero_grad(set_to_none=True) + self._grad_scaler.scale(loss).backward() + self._grad_scaler.unscale_(self._temperature_optimizer) + grad_l2_norm = clip_grad_norm_(self._log_temperature, max_grad_l2_norm) + self._grad_scaler.step(self._temperature_optimizer) + result = { + "lr": self._temperature_optimizer.param_groups[0]["lr"], + "grad_l2_norm": min(grad_l2_norm.item(), max_grad_l2_norm), + } + if self._temperature_optimizer_lr_scheduler is not None: + self._temperature_optimizer_lr_scheduler.step() + return result + + @singledispatchmethod(use_weakrefs=False) + def _get_default_policy_network_distribution_builder( + self, + space: "spaces.Space", + ) -> "Callable[..., Distribution]": + raise NotImplementedError( + f"Unsupported space type: " + f"`{type(space).__module__}.{type(space).__name__}`. " + f"Register a custom space type through decorator " + f"`{type(self).__module__}.{type(self).__name__}." + f"_get_default_policy_network_distribution_builder.register`" + ) + + @singledispatchmethod(use_weakrefs=False) + def _get_default_policy_network_distribution_parametrization( + self, + space: "spaces.Space", + ) -> "Callable[..., DistributionParametrization]": + raise NotImplementedError( + f"Unsupported space type: " + f"`{type(space).__module__}.{type(space).__name__}`. " + f"Register a custom space type through decorator " + f"`{type(self).__module__}.{type(self).__name__}." + f"_get_default_policy_network_distribution_parametrization.register`" + ) + + @singledispatchmethod(use_weakrefs=False) + def _get_default_policy_network_distribution_config( + self, + space: "spaces.Space", + ) -> "Dict[str, Any]": + raise NotImplementedError( + f"Unsupported space type: " + f"`{type(space).__module__}.{type(space).__name__}`. " + f"Register a custom space type through decorator " + f"`{type(self).__module__}.{type(self).__name__}." + f"_get_default_policy_network_distribution_config.register`" + ) + + +class DistributedDataParallelSAC(DistributedDataParallelTD3): + """Distributed data parallel Soft Actor-Critic (SAC). + + See Also + -------- + actorch.algorithms.sac.SAC + + """ + + _ALGORITHM_CLS = SAC # override + + +##################################################################################################### +# SAC._get_default_policy_network_distribution_builder implementation +##################################################################################################### + + +@SAC._get_default_policy_network_distribution_builder.register(spaces.Box) +def _get_default_policy_network_distribution_builder_box( + self, + space: "spaces.Box", +) -> "Callable[..., Normal]": + return Normal + + +##################################################################################################### +# SAC._get_default_policy_network_distribution_parametrization implementation +##################################################################################################### + + +@SAC._get_default_policy_network_distribution_parametrization.register(spaces.Box) +def _get_default_policy_network_distribution_parametrization_box( + self, + space: "spaces.Box", +) -> "DistributionParametrization": + return { + "loc": ( + {"loc": space.shape}, + lambda x: x["loc"], + ), + "scale": ( + {"log_scale": space.shape}, + lambda x: x["log_scale"].exp(), + ), + } + + +##################################################################################################### +# SAC._get_default_policy_network_distribution_config implementation +##################################################################################################### + + +@SAC._get_default_policy_network_distribution_config.register(spaces.Box) +def _get_default_policy_network_distribution_config_box( + self, + space: "spaces.Box", +) -> "Dict[str, Any]": + return {"validate_args": False} diff --git a/actorch/algorithms/td3.py b/actorch/algorithms/td3.py index a922fb0..b25ec24 100644 --- a/actorch/algorithms/td3.py +++ b/actorch/algorithms/td3.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -"""Twin Delayed Deep Deterministic Policy Gradient.""" +"""Twin Delayed Deep Deterministic Policy Gradient (TD3).""" import contextlib import copy @@ -33,12 +33,12 @@ from actorch.agents import Agent from actorch.algorithms.algorithm import RefOrFutureRef, Tunable from actorch.algorithms.ddpg import DDPG, DistributedDataParallelDDPG, Loss, LRScheduler -from actorch.algorithms.utils import prepare_model, sync_polyak +from actorch.algorithms.utils import prepare_model, sync_polyak_ from actorch.algorithms.value_estimation import n_step_return from actorch.buffers import Buffer from actorch.envs import BatchedEnv from actorch.models import Model -from actorch.networks import Processor +from actorch.networks import Network, Processor from actorch.samplers import Sampler from actorch.schedules import ConstantSchedule, Schedule @@ -50,13 +50,13 @@ class TD3(DDPG): - """Twin Delayed Deep Deterministic Policy Gradient. + """Twin Delayed Deep Deterministic Policy Gradient (TD3). References ---------- .. [1] S. Fujimoto, H. van Hoof, and D. Meger. "Addressing Function Approximation Error in Actor-Critic Methods". - In: ICML. 2018, pp. 1587–1596. + In: ICML. 2018, pp. 1587-1596. URL: https://arxiv.org/abs/1802.09477 """ @@ -201,19 +201,13 @@ def setup(self, config: "Dict[str, Any]") -> "None": self.config = TD3.Config(**self.config) self.config["_accept_kwargs"] = True super().setup(config) - self._twin_value_network = ( - self._build_value_network().train().to(self._device, non_blocking=True) - ) - self._twin_value_network_loss = ( - self._build_value_network_loss().train().to(self._device, non_blocking=True) - ) + self._twin_value_network = self._build_value_network() + self._twin_value_network_loss = self._build_value_network_loss() self._twin_value_network_optimizer = self._build_twin_value_network_optimizer() self._twin_value_network_optimizer_lr_scheduler = ( self._build_twin_value_network_optimizer_lr_scheduler() ) - self._target_twin_value_network = copy.deepcopy(self._twin_value_network) - self._target_twin_value_network.eval().to(self._device, non_blocking=True) - self._target_twin_value_network.requires_grad_(False) + self._target_twin_value_network = self._build_target_twin_value_network() if not isinstance(self.delay, Schedule): self.delay = ConstantSchedule(self.delay) if not isinstance(self.noise_stddev, Schedule): @@ -297,6 +291,11 @@ def _build_twin_value_network_optimizer_lr_scheduler( **self.value_network_optimizer_lr_scheduler_config, ) + def _build_target_twin_value_network(self) -> "Network": + target_twin_value_network = copy.deepcopy(self._twin_value_network) + target_twin_value_network.requires_grad_(False) + return target_twin_value_network.eval().to(self._device, non_blocking=True) + # override def _train_step(self) -> "Dict[str, Any]": result = super()._train_step() @@ -306,9 +305,10 @@ def _train_step(self) -> "Dict[str, Any]": result["delay"] = self.delay() result["noise_stddev"] = self.noise_stddev() result["noise_clip"] = self.noise_clip() - result["buffer_num_experiences"] = result.pop("buffer_num_experiences") + result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None) + result["buffer_num_experiences"] = result.pop("buffer_num_experiences", None) result["buffer_num_full_trajectories"] = result.pop( - "buffer_num_full_trajectories" + "buffer_num_full_trajectories", None ) return result @@ -398,6 +398,7 @@ def _train_on_batch( mask, self.discount(), self.num_return_steps(), + return_advantage=False, ) # Compute action values @@ -441,17 +442,17 @@ def _train_on_batch( ) # Synchronize if idx % sync_freq == 0: - sync_polyak( + sync_polyak_( self._policy_network, self._target_policy_network, self.polyak_weight(), ) - sync_polyak( + sync_polyak_( self._value_network, self._target_value_network, self.polyak_weight(), ) - sync_polyak( + sync_polyak_( self._twin_value_network, self._target_twin_value_network, self.polyak_weight(), @@ -475,10 +476,12 @@ def _train_on_batch_twin_value_network( action_values = action_values[mask] target = targets[mask] loss = self._twin_value_network_loss(action_values, target) - loss *= is_weight[:, None].expand_as(mask)[mask] + priority = None + if self._buffer.is_prioritized: + loss *= is_weight[:, None].expand_as(mask)[mask] + priority = loss.detach().abs().to("cpu").numpy() loss = loss.mean() optimize_result = self._optimize_twin_value_network(loss) - priority = None result = { "action_value": action_values.mean().item(), "target": target.mean().item(), @@ -510,7 +513,7 @@ def _optimize_twin_value_network(self, loss: "Tensor") -> "Dict[str, Any]": class DistributedDataParallelTD3(DistributedDataParallelDDPG): - """Distributed data parallel Twin Delayed Deep Deterministic Policy Gradient. + """Distributed data parallel Twin Delayed Deep Deterministic Policy Gradient (TD3). See Also -------- diff --git a/actorch/algorithms/trpo.py b/actorch/algorithms/trpo.py new file mode 100644 index 0000000..c4b749d --- /dev/null +++ b/actorch/algorithms/trpo.py @@ -0,0 +1,394 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Trust Region Policy Optimization (TRPO).""" + +import contextlib +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from gymnasium import Env +from numpy import ndarray +from torch import Tensor +from torch.cuda.amp import autocast +from torch.distributions import Distribution, kl_divergence +from torch.nn.utils import clip_grad_norm_ +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from actorch.agents import Agent +from actorch.algorithms.a2c import A2C, DistributedDataParallelA2C, Loss, LRScheduler +from actorch.algorithms.algorithm import RefOrFutureRef, Tunable +from actorch.algorithms.utils import normalize_ +from actorch.algorithms.value_estimation import lambda_return +from actorch.envs import BatchedEnv +from actorch.models import Model +from actorch.networks import DistributionParametrization, NormalizingFlow, Processor +from actorch.optimizers import CGBLS +from actorch.samplers import Sampler +from actorch.schedules import ConstantSchedule, Schedule + + +__all__ = [ + "DistributedDataParallelTRPO", + "TRPO", +] + + +class TRPO(A2C): + """Trust Region Policy Optimization (TRPO). + + References + ---------- + .. [1] J. Schulman, S. Levine, P. Abbeel, M. Jordan, and P. Moritz. + "Trust Region Policy Optimization". + In: ICML. 2015, pp. 1889-1897. + URL: https://arxiv.org/abs/1502.05477 + + """ + + # override + class Config(dict): + """Keyword arguments expected in the configuration received by `setup`.""" + + def __init__( + self, + train_env_builder: "Tunable[RefOrFutureRef[Callable[..., Union[Env, BatchedEnv]]]]", + train_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None, + train_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None, + train_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + train_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + train_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + eval_freq: "Tunable[RefOrFutureRef[Optional[int]]]" = 1, + eval_env_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Union[Env, BatchedEnv]]]]]" = None, + eval_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None, + eval_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None, + eval_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + eval_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + eval_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None, + policy_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + policy_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None, + policy_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + policy_network_distribution_builders: "Tunable[RefOrFutureRef[Optional[Dict[str, Callable[..., Distribution]]]]]" = None, + policy_network_distribution_parametrizations: "Tunable[RefOrFutureRef[Optional[Dict[str, DistributionParametrization]]]]" = None, + policy_network_distribution_configs: "Tunable[RefOrFutureRef[Optional[Dict[str, Dict[str, Any]]]]]" = None, + policy_network_normalizing_flows: "Tunable[RefOrFutureRef[Optional[Dict[str, NormalizingFlow]]]]" = None, + policy_network_sample_fn: "Tunable[RefOrFutureRef[Optional[Callable[[Distribution], Tensor]]]]" = None, + policy_network_prediction_fn: "Tunable[RefOrFutureRef[Optional[Callable[[Tensor], Tensor]]]]" = None, + policy_network_postprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + policy_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None, + policy_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None, + value_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None, + value_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_loss_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Loss]]]]" = None, + value_network_loss_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None, + value_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + value_network_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None, + value_network_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + dataloader_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., DataLoader]]]]" = None, + dataloader_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None, + discount: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.99, + trace_decay: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.95, + normalize_advantage: "Tunable[RefOrFutureRef[bool]]" = False, + entropy_coeff: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.01, + max_grad_l2_norm: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = float( # noqa: B008 + "inf" + ), + cumreward_window_size: "Tunable[RefOrFutureRef[int]]" = 100, + seed: "Tunable[RefOrFutureRef[int]]" = 0, + enable_amp: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False, + enable_reproducibility: "Tunable[RefOrFutureRef[bool]]" = False, + enable_anomaly_detection: "Tunable[RefOrFutureRef[bool]]" = False, + enable_profiling: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False, + log_sys_usage: "Tunable[RefOrFutureRef[bool]]" = False, + suppress_warnings: "Tunable[RefOrFutureRef[bool]]" = False, + _accept_kwargs: "bool" = False, + **kwargs: "Any", + ) -> "None": + if not _accept_kwargs and kwargs: + raise ValueError(f"Unexpected configuration arguments: {list(kwargs)}") + super().__init__( + train_env_builder=train_env_builder, + train_env_config=train_env_config, + train_agent_builder=train_agent_builder, + train_agent_config=train_agent_config, + train_sampler_builder=train_sampler_builder, + train_sampler_config=train_sampler_config, + train_num_timesteps_per_iter=train_num_timesteps_per_iter, + train_num_episodes_per_iter=train_num_episodes_per_iter, + eval_freq=eval_freq, + eval_env_builder=eval_env_builder, + eval_env_config=eval_env_config, + eval_agent_builder=eval_agent_builder, + eval_agent_config=eval_agent_config, + eval_sampler_builder=eval_sampler_builder, + eval_sampler_config=eval_sampler_config, + eval_num_timesteps_per_iter=eval_num_timesteps_per_iter, + eval_num_episodes_per_iter=eval_num_episodes_per_iter, + policy_network_preprocessors=policy_network_preprocessors, + policy_network_model_builder=policy_network_model_builder, + policy_network_model_config=policy_network_model_config, + policy_network_distribution_builders=policy_network_distribution_builders, + policy_network_distribution_parametrizations=policy_network_distribution_parametrizations, + policy_network_distribution_configs=policy_network_distribution_configs, + policy_network_normalizing_flows=policy_network_normalizing_flows, + policy_network_sample_fn=policy_network_sample_fn, + policy_network_prediction_fn=policy_network_prediction_fn, + policy_network_postprocessors=policy_network_postprocessors, + policy_network_optimizer_builder=policy_network_optimizer_builder, + policy_network_optimizer_config=policy_network_optimizer_config, + value_network_preprocessors=value_network_preprocessors, + value_network_model_builder=value_network_model_builder, + value_network_model_config=value_network_model_config, + value_network_loss_builder=value_network_loss_builder, + value_network_loss_config=value_network_loss_config, + value_network_optimizer_builder=value_network_optimizer_builder, + value_network_optimizer_config=value_network_optimizer_config, + value_network_optimizer_lr_scheduler_builder=value_network_optimizer_lr_scheduler_builder, + value_network_optimizer_lr_scheduler_config=value_network_optimizer_lr_scheduler_config, + dataloader_builder=dataloader_builder, + dataloader_config=dataloader_config, + discount=discount, + trace_decay=trace_decay, + normalize_advantage=normalize_advantage, + entropy_coeff=entropy_coeff, + max_grad_l2_norm=max_grad_l2_norm, + cumreward_window_size=cumreward_window_size, + seed=seed, + enable_amp=enable_amp, + enable_reproducibility=enable_reproducibility, + enable_anomaly_detection=enable_anomaly_detection, + enable_profiling=enable_profiling, + log_sys_usage=log_sys_usage, + suppress_warnings=suppress_warnings, + **kwargs, + ) + + # override + def setup(self, config: "Dict[str, Any]") -> "None": + self.config = TRPO.Config(**self.config) + self.config["_accept_kwargs"] = True + super().setup(config) + if not isinstance(self.trace_decay, Schedule): + self.trace_decay = ConstantSchedule(self.trace_decay) + + # override + @property + def _checkpoint(self) -> "Dict[str, Any]": + checkpoint = super()._checkpoint + checkpoint["trace_decay"] = self.trace_decay.state_dict() + return checkpoint + + # override + @_checkpoint.setter + def _checkpoint(self, value: "Dict[str, Any]") -> "None": + super()._checkpoint = value + self.trace_decay.load_state_dict(value["trace_decay"]) + + # override + def _build_policy_network_optimizer(self) -> "Optimizer": + if self.policy_network_optimizer_builder is None: + self.policy_network_optimizer_builder = CGBLS + if self.policy_network_optimizer_config is None: + self.policy_network_optimizer_config = { + "max_constraint": 0.01, + "num_cg_iters": 10, + "max_backtracks": 15, + "backtrack_ratio": 0.8, + "hvp_reg_coeff": 1e-5, + "accept_violation": False, + "epsilon": 1e-8, + } + if self.policy_network_optimizer_config is None: + self.policy_network_optimizer_config = {} + return self.policy_network_optimizer_builder( + self._policy_network.parameters(), + **self.policy_network_optimizer_config, + ) + + # override + def _train_step(self) -> "Dict[str, Any]": + result = super()._train_step() + self.trace_decay.step() + result["trace_decay"] = self.trace_decay() + result["entropy_coeff"] = result.pop("entropy_coeff", None) + result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None) + result.pop("num_return_steps", None) + return result + + # override + def _train_on_batch( + self, + idx: "int", + experiences: "Dict[str, Tensor]", + is_weight: "Tensor", + mask: "Tensor", + ) -> "Tuple[Dict[str, Any], Optional[ndarray]]": + result = {} + + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + state_values, _ = self._value_network(experiences["observation"], mask=mask) + # Discard next observation + experiences["observation"] = experiences["observation"][:, :-1, ...] + mask = mask[:, 1:] + with torch.no_grad(): + targets, advantages = lambda_return( + state_values, + experiences["reward"], + experiences["terminal"], + mask, + self.discount(), + self.trace_decay(), + ) + if self.normalize_advantage: + normalize_(advantages, dim=-1, mask=mask) + + # Discard next state value + state_values = state_values[:, :-1] + + result["value_network"], priority = self._train_on_batch_value_network( + state_values, + targets, + is_weight, + mask, + ) + result["policy_network"] = self._train_on_batch_policy_network( + experiences, + advantages, + mask, + ) + self._grad_scaler.update() + return result, priority + + # override + def _train_on_batch_policy_network( + self, + experiences: "Dict[str, Tensor]", + advantages: "Tensor", + mask: "Tensor", + ) -> "Dict[str, Any]": + entropy_coeff = self.entropy_coeff() + if entropy_coeff < 0.0: + raise ValueError( + f"`entropy_coeff` ({entropy_coeff}) must be in the interval [0, inf)" + ) + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + advantage = advantages[mask] + old_log_prob = experiences["log_prob"][mask] + with torch.no_grad(): + self._policy_network(experiences["observation"], mask=mask) + old_policy = self._policy_network.distribution + policy = loss = log_prob = entropy_bonus = kl_div = None + + def compute_loss() -> "Tensor": + nonlocal policy, loss, log_prob, entropy_bonus + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + self._policy_network(experiences["observation"], mask=mask) + policy = self._policy_network.distribution + log_prob = policy.log_prob(experiences["action"])[mask] + ratio = (log_prob - old_log_prob).exp() + loss = -advantage * ratio + entropy_bonus = None + if entropy_coeff != 0.0: + entropy_bonus = -entropy_coeff * policy.entropy()[mask] + loss += entropy_bonus + loss = loss.mean() + return loss + + def compute_kl_div() -> "Tensor": + nonlocal policy, kl_div + with ( + autocast(**self.enable_amp) + if self.enable_amp["enabled"] + else contextlib.suppress() + ): + if policy is None: + self._policy_network(experiences["observation"], mask=mask) + policy = self._policy_network.distribution + kl_div = kl_divergence(old_policy, policy)[mask].mean() + policy = None + return kl_div + + loss = compute_loss() + optimize_result = self._optimize_policy_network( + loss, compute_loss, compute_kl_div + ) + result = { + "advantage": advantage.mean().item(), + "log_prob": log_prob.mean().item(), + "old_log_prob": old_log_prob.mean().item(), + } + if kl_div is not None: + result["kl_div"] = kl_div.item() + result["loss"] = loss.item() + if entropy_bonus is not None: + result["entropy_bonus"] = entropy_bonus.mean().item() + result.update(optimize_result) + return result + + # override + def _optimize_policy_network( + self, + loss: "Tensor", + loss_fn: "Callable[[], Tensor]", + constraint_fn: "Callable[[], Tensor]", + ) -> "Dict[str, Any]": + max_grad_l2_norm = self.max_grad_l2_norm() + if max_grad_l2_norm <= 0.0: + raise ValueError( + f"`max_grad_l2_norm` ({max_grad_l2_norm}) must be in the interval (0, inf]" + ) + self._policy_network_optimizer.zero_grad(set_to_none=True) + self._grad_scaler.scale(loss).backward(retain_graph=True) + self._grad_scaler.unscale_(self._policy_network_optimizer) + grad_l2_norm = clip_grad_norm_( + self._policy_network.parameters(), max_grad_l2_norm + ) + self._grad_scaler.step(self._policy_network_optimizer, loss_fn, constraint_fn) + result = { + "grad_l2_norm": min(grad_l2_norm.item(), max_grad_l2_norm), + } + return result + + +class DistributedDataParallelTRPO(DistributedDataParallelA2C): + """Distributed data parallel Trust Region Policy Optimization (TRPO). + + See Also + -------- + actorch.algorithms.trpo.TRPO + + """ + + _ALGORITHM_CLS = TRPO # override diff --git a/actorch/algorithms/utils.py b/actorch/algorithms/utils.py index 916d79b..456ee80 100644 --- a/actorch/algorithms/utils.py +++ b/actorch/algorithms/utils.py @@ -23,6 +23,7 @@ import ray.train.torch # Fix missing train.torch attribute import torch from ray import train +from torch import Tensor from torch import distributed as dist from torch.nn import Module from torch.nn.parallel import DataParallel, DistributedDataParallel @@ -32,11 +33,59 @@ "count_params", "freeze_params", "init_mock_train_session", + "normalize_", "prepare_model", - "sync_polyak", + "sync_polyak_", ] +def normalize_(input, dim: "int" = 0, mask: "Optional[Tensor]" = None) -> "Tensor": + """Normalize a tensor along a dimension, by subtracting the mean + and then dividing by the standard deviation (in-place). + + Parameters + ---------- + input: + The tensor. + dim: + The dimension. + mask: + The boolean tensor indicating which elements + are valid (True) and which are not (False). + Default to ``torch.ones_like(input, dtype=torch.bool)``. + + Returns + ------- + The normalized tensor. + + Examples + -------- + >>> import torch + >>> + >>> from actorch.algorithms.utils import normalize_ + >>> + >>> + >>> input = torch.randn(2,3) + >>> mask = torch.rand(2, 3) > 0.5 + >>> output = normalize_(input, mask=mask) + + """ + if mask is None: + mask = torch.ones_like(input, dtype=torch.bool) + else: + input[~mask] = 0.0 + length = mask.sum(dim=dim, keepdim=True) + input_mean = input.sum(dim=dim, keepdim=True) / length + input -= input_mean + input *= mask + input_stddev = ( + ((input**2).sum(dim=dim, keepdim=True) / length).sqrt().clamp(min=1e-6) + ) + input /= input_stddev + input *= mask + return input + + def count_params(module: "Module") -> "Tuple[int, int]": """Return the number of trainable and non-trainable parameters in `module`. @@ -51,16 +100,26 @@ def count_params(module: "Module") -> "Tuple[int, int]": - The number of trainable parameters; - the number of non-trainable parameters. + Examples + -------- + >>> import torch + >>> + >>> from actorch.algorithms.utils import count_params + >>> + >>> + >>> model = torch.nn.Linear(4, 2) + >>> trainable_count, non_trainable_count = count_params(model) + """ - num_trainable_params, num_non_trainable_params = 0, 0 + trainable_count, non_trainable_count = 0, 0 for param in module.parameters(): if param.requires_grad: - num_trainable_params += param.numel() + trainable_count += param.numel() else: - num_non_trainable_params += param.numel() + non_trainable_count += param.numel() for buffer in module.buffers(): - num_non_trainable_params += buffer.numel() - return num_trainable_params, num_non_trainable_params + non_trainable_count += buffer.numel() + return trainable_count, non_trainable_count @contextmanager @@ -73,6 +132,21 @@ def freeze_params(*modules: "Module") -> "Iterator[None]": modules: The modules. + Examples + -------- + >>> import torch + >>> + >>> from actorch.algorithms.utils import freeze_params + >>> + >>> + >>> policy_model = torch.nn.Linear(4, 2) + >>> value_model = torch.nn.Linear(2, 1) + >>> input = torch.randn(3, 4) + >>> with freeze_params(value_model): + ... action = policy_model(input) + ... loss = -value_model(action).mean() + >>> loss.backward() + """ params = [ param @@ -89,19 +163,19 @@ def freeze_params(*modules: "Module") -> "Iterator[None]": param.requires_grad = True -def sync_polyak( +def sync_polyak_( source_module: "Module", target_module: "Module", polyak_weight: "float" = 0.001, -) -> "None": - """Synchronize `source_module` with `target_module` - through Polyak averaging. +) -> "Module": + """Synchronize a source module with a target module + through Polyak averaging (in-place). - For each `target_param` in `target_module`, - for each `source_param` in `source_module`: - `target_param` = - (1 - `polyak_weight`) * `target_param` - + `polyak_weight` * `source_param`. + For each `target_parameter` in `target_module`, + for each `source_parameter` in `source_module`: + `target_parameter` = + (1 - `polyak_weight`) * `target_parameter` + + `polyak_weight` * `source_parameter`. Parameters ---------- @@ -112,6 +186,10 @@ def sync_polyak( polyak_weight: The Polyak weight. + Returns + ------- + The synchronized target module. + Raises ------ ValueError @@ -124,16 +202,27 @@ def sync_polyak( In: SIAM Journal on Control and Optimization. 1992, pp. 838-855. URL: https://doi.org/10.1137/0330046 + Examples + -------- + >>> import torch + >>> + >>> from actorch.algorithms.utils import sync_polyak_ + >>> + >>> + >>> model = torch.nn.Linear(4, 2) + >>> target_model = torch.nn.Linear(4, 2) + >>> sync_polyak_(model, target_model, polyak_weight=0.001) + """ if polyak_weight < 0.0 or polyak_weight > 1.0: raise ValueError( f"`polyak_weight` ({polyak_weight}) must be in the interval [0, 1]" ) if polyak_weight == 0.0: - return + return target_module if polyak_weight == 1.0: target_module.load_state_dict(source_module.state_dict()) - return + return target_module # Update parameters target_params = target_module.parameters() source_params = source_module.parameters() @@ -151,6 +240,7 @@ def sync_polyak( target_buffer *= (1 - polyak_weight) / polyak_weight target_buffer += source_buffer target_buffer *= polyak_weight + return target_module def init_mock_train_session() -> "None": diff --git a/actorch/algorithms/value_estimation/generalized_estimator.py b/actorch/algorithms/value_estimation/generalized_estimator.py index ceb15a7..4c6c1b2 100644 --- a/actorch/algorithms/value_estimation/generalized_estimator.py +++ b/actorch/algorithms/value_estimation/generalized_estimator.py @@ -32,7 +32,7 @@ ] -def generalized_estimator( +def generalized_estimator( # noqa: C901 state_values: "Union[Tensor, Distribution]", rewards: "Tensor", terminals: "Tensor", @@ -44,7 +44,8 @@ def generalized_estimator( discount: "float" = 0.99, num_return_steps: "int" = 1, trace_decay: "float" = 1.0, -) -> "Tuple[Union[Tensor, Distribution], Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]": """Compute the (possibly distributional) generalized estimator targets and the corresponding advantages of a trajectory. @@ -81,12 +82,14 @@ def generalized_estimator( The number of return steps (`n` in the literature). trace_decay: The trace-decay parameter (`lambda` in the literature). + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The (possibly distributional) generalized estimator targets, shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``; - - the corresponding advantages, ``[B, T]``. + - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``. Raises ------ @@ -199,6 +202,8 @@ def generalized_estimator( f"be equal to the shape of `rewards` ({rewards.shape})" ) targets, state_values, next_state_values = _compute_targets(*compute_targets_args) + if not return_advantage: + return targets, None advantages = _compute_advantages( targets, state_values, @@ -308,7 +313,7 @@ def _compute_distributional_targets( B, T = rewards.shape num_return_steps = min(num_return_steps, T) length = mask.sum(dim=1, keepdim=True) - idx = torch.arange(1, T + 1).expand(B, T) + idx = torch.arange(1, T + 1, device=length.device).expand(B, T) if state_values.batch_shape == (B, T + 1): next_state_values = distributional_gather( state_values, @@ -316,7 +321,7 @@ def _compute_distributional_targets( idx.clamp(max=length), mask * (~terminals), ) - idx = torch.arange(0, T).expand(B, T) + idx = torch.arange(0, T, device=length.device).expand(B, T) state_values = distributional_gather( state_values, 1, @@ -361,7 +366,9 @@ def _compute_distributional_targets( next_state_value_coeffs *= ~terminals next_state_value_coeffs *= mask # Gather - idx = torch.arange(num_return_steps - 1, T + num_return_steps - 1).expand(B, T) + idx = torch.arange( + num_return_steps - 1, T + num_return_steps - 1, device=length.device + ).expand(B, T) next_state_values = distributional_gather( next_state_values, 1, @@ -371,7 +378,7 @@ def _compute_distributional_targets( targets = TransformedDistribution( next_state_values, AffineTransform(offsets, next_state_value_coeffs), - next_state_values._validate_args, + validate_args=False, ).reduced_dist return targets, state_values, next_state_values coeffs = torch.stack( @@ -403,14 +410,11 @@ def _compute_distributional_targets( next_state_values, (B, T, num_return_steps) ) # Transform - validate_args = ( - state_or_action_values._validate_args or next_state_values._validate_args - ) targets = TransformedDistribution( CatDistribution( [state_or_action_values, next_state_values], dim=-1, - validate_args=validate_args, + validate_args=False, ), [ AffineTransform( @@ -420,7 +424,7 @@ def _compute_distributional_targets( SumTransform((2,)), SumTransform((num_return_steps,)), ], - validate_args=validate_args, + validate_args=False, ).reduced_dist return targets, state_values, next_state_values @@ -462,9 +466,8 @@ def _compute_advantages( trace_decay * targets + (1 - trace_decay) * state_values, [0, 1], )[:, 1:] - next_targets[torch.arange(B), length - 1] = next_state_values[ - torch.arange(B), length - 1 - ] + batch_idx = torch.arange(B) + next_targets[batch_idx, length - 1] = next_state_values[batch_idx, length - 1] action_values = rewards + discount * next_targets advantages = advantage_weights * (action_values - state_values) advantages *= mask diff --git a/actorch/algorithms/value_estimation/importance_sampling.py b/actorch/algorithms/value_estimation/importance_sampling.py index 91d1f53..754050d 100644 --- a/actorch/algorithms/value_estimation/importance_sampling.py +++ b/actorch/algorithms/value_estimation/importance_sampling.py @@ -40,7 +40,8 @@ def importance_sampling( log_is_weights: "Tensor", mask: "Optional[Tensor]" = None, discount: "float" = 0.99, -) -> "Tuple[Union[Tensor, Distribution], Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]": """Compute the (possibly distributional) importance sampling targets, a.k.a. IS, and the corresponding advantages of a trajectory. @@ -73,12 +74,14 @@ def importance_sampling( Default to ``torch.ones_like(rewards, dtype=torch.bool)``. discount: The discount factor (`gamma` in the literature). + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The (possibly distributional) importance sampling targets, shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``; - - the corresponding advantages, shape: ``[B, T]``. + - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``. References ---------- @@ -108,4 +111,5 @@ def importance_sampling( discount=discount, num_return_steps=rewards.shape[1], trace_decay=1.0, + return_advantage=return_advantage, ) diff --git a/actorch/algorithms/value_estimation/lambda_return.py b/actorch/algorithms/value_estimation/lambda_return.py index 8950462..095d7c0 100644 --- a/actorch/algorithms/value_estimation/lambda_return.py +++ b/actorch/algorithms/value_estimation/lambda_return.py @@ -37,7 +37,8 @@ def lambda_return( mask: "Optional[Tensor]" = None, discount: "float" = 0.99, trace_decay: "float" = 1.0, -) -> "Tuple[Union[Tensor, Distribution], Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]": """Compute the (possibly distributional) lambda returns, a.k.a. TD(lambda), and the corresponding advantages, a.k.a. GAE(lambda), of a trajectory. @@ -63,12 +64,14 @@ def lambda_return( The discount factor (`gamma` in the literature). trace_decay: The trace-decay parameter (`lambda` in the literature). + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The (possibly distributional) lambda returns, shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``; - - the corresponding advantages, shape: ``[B, T]``. + - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``. References ---------- @@ -94,4 +97,5 @@ def lambda_return( max_is_weight_trace=1.0, max_is_weight_delta=1.0, max_is_weight_advantage=1.0, + return_advantage=return_advantage, ) diff --git a/actorch/algorithms/value_estimation/monte_carlo_return.py b/actorch/algorithms/value_estimation/monte_carlo_return.py index 516d9e5..68d4c33 100644 --- a/actorch/algorithms/value_estimation/monte_carlo_return.py +++ b/actorch/algorithms/value_estimation/monte_carlo_return.py @@ -33,7 +33,8 @@ def monte_carlo_return( rewards: "Tensor", mask: "Optional[Tensor]" = None, discount: "float" = 0.99, -) -> "Tuple[Tensor, Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Tensor, Optional[Tensor]]": """Compute the Monte Carlo returns and the corresponding advantages of a trajectory. @@ -51,11 +52,14 @@ def monte_carlo_return( Default to ``torch.ones_like(rewards, dtype=torch.bool)``. discount: The discount factor (`gamma` in the literature). + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The Monte Carlo returns, shape: ``[B, T]``; - - the corresponding advantages, shape: ``[B, T]``. + - the corresponding advantages if `return_advantage` is True + None otherwise, shape: ``[B, T]``. References ---------- @@ -72,4 +76,5 @@ def monte_carlo_return( mask=mask, discount=discount, num_return_steps=rewards.shape[1], + return_advantage=return_advantage, ) diff --git a/actorch/algorithms/value_estimation/n_step_return.py b/actorch/algorithms/value_estimation/n_step_return.py index 1846b43..71b472e 100644 --- a/actorch/algorithms/value_estimation/n_step_return.py +++ b/actorch/algorithms/value_estimation/n_step_return.py @@ -37,7 +37,8 @@ def n_step_return( mask: "Optional[Tensor]" = None, discount: "float" = 0.99, num_return_steps: "int" = 1, -) -> "Tuple[Union[Tensor, Distribution], Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]": """Compute the (possibly distributional) n-step returns, a.k.a. TD(n), and the corresponding advantages of a trajectory. @@ -63,12 +64,14 @@ def n_step_return( The discount factor (`gamma` in the literature). num_return_steps: The number of return steps (`n` in the literature). + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The (possibly distributional) n-step returns, shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``; - - the corresponding advantages, shape: ``[B, T]``. + - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``. References ---------- @@ -90,4 +93,5 @@ def n_step_return( max_is_weight_trace=1.0, max_is_weight_delta=1.0, max_is_weight_advantage=1.0, + return_advantage=return_advantage, ) diff --git a/actorch/algorithms/value_estimation/off_policy_lambda_return.py b/actorch/algorithms/value_estimation/off_policy_lambda_return.py index 7d7cf0b..c603e94 100644 --- a/actorch/algorithms/value_estimation/off_policy_lambda_return.py +++ b/actorch/algorithms/value_estimation/off_policy_lambda_return.py @@ -40,7 +40,8 @@ def off_policy_lambda_return( mask: "Optional[Tensor]" = None, discount: "float" = 0.99, trace_decay: "float" = 1.0, -) -> "Tuple[Union[Tensor, Distribution], Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]": """Compute the (possibly distributional) off-policy lambda returns, a.k.a. Harutyunyan's et al. Q(lambda), and the corresponding advantages of a trajectory. @@ -70,12 +71,14 @@ def off_policy_lambda_return( The discount factor (`gamma` in the literature). trace_decay: The trace-decay parameter (`lambda` in the literature). + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The (possibly distributional) off-policy lambda returns, shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``; - - the corresponding advantages, shape: ``[B, T]``. + - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``. References ---------- @@ -100,4 +103,5 @@ def off_policy_lambda_return( discount=discount, num_return_steps=rewards.shape[1], trace_decay=trace_decay, + return_advantage=return_advantage, ) diff --git a/actorch/algorithms/value_estimation/retrace.py b/actorch/algorithms/value_estimation/retrace.py index b910f45..51a2605 100644 --- a/actorch/algorithms/value_estimation/retrace.py +++ b/actorch/algorithms/value_estimation/retrace.py @@ -43,7 +43,8 @@ def retrace( trace_decay: "float" = 1.0, max_is_weight_trace: "float" = 1.0, max_is_weight_advantage: "float" = 1.0, -) -> "Tuple[Union[Tensor, Distribution], Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]": """Compute the (possibly distributional) Retrace targets, a.k.a. Retrace(lambda), and the corresponding advantages of a trajectory. @@ -82,12 +83,14 @@ def retrace( The maximum importance sampling weight for trace computation (`c_bar` in the literature). max_is_weight_advantage: The maximum importance sampling weight for advantage computation. + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The (possibly distributional) Retrace targets, shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``; - - the corresponding advantages, shape: ``[B, T]``. + - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``. Raises ------ @@ -126,4 +129,5 @@ def retrace( discount=discount, num_return_steps=rewards.shape[1], trace_decay=trace_decay, + return_advantage=return_advantage, ) diff --git a/actorch/algorithms/value_estimation/tree_backup.py b/actorch/algorithms/value_estimation/tree_backup.py index 4ea1fa1..d8a3bda 100644 --- a/actorch/algorithms/value_estimation/tree_backup.py +++ b/actorch/algorithms/value_estimation/tree_backup.py @@ -41,7 +41,8 @@ def tree_backup( mask: "Optional[Tensor]" = None, discount: "float" = 0.99, trace_decay: "float" = 1.0, -) -> "Tuple[Union[Tensor, Distribution], Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]": """Compute the (possibly distributional) tree-backup targets, a.k.a. TB(lambda), and the corresponding advantages of a trajectory. @@ -74,12 +75,14 @@ def tree_backup( The discount factor (`gamma` in the literature). trace_decay: The trace-decay parameter (`lambda` in the literature). + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The (possibly distributional) tree-backup targets, shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``; - - the corresponding advantages, shape: ``[B, T]``. + - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``. References ---------- @@ -105,4 +108,5 @@ def tree_backup( discount=discount, num_return_steps=rewards.shape[1], trace_decay=trace_decay, + return_advantage=return_advantage, ) diff --git a/actorch/algorithms/value_estimation/vtrace.py b/actorch/algorithms/value_estimation/vtrace.py index f6471bb..7c54aef 100644 --- a/actorch/algorithms/value_estimation/vtrace.py +++ b/actorch/algorithms/value_estimation/vtrace.py @@ -44,7 +44,8 @@ def vtrace( max_is_weight_trace: "float" = 1.0, max_is_weight_delta: "float" = 1.0, max_is_weight_advantage: "float" = 1.0, -) -> "Tuple[Union[Tensor, Distribution], Tensor]": + return_advantage: "bool" = True, +) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]": """Compute the (possibly distributional) (leaky) V-trace targets, a.k.a. V-trace(n), and the corresponding advantages of a trajectory. @@ -86,12 +87,14 @@ def vtrace( The maximum importance sampling weight for delta computation (`rho_bar` in the literature). max_is_weight_advantage: The maximum importance sampling weight for advantage computation. + return_advantage: + True to additionally return the advantages, False otherwise. Returns ------- - The (possibly distributional) (leaky) V-trace targets shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``; - - the corresponding advantages, shape: ``[B, T]``. + - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``. Raises ------ @@ -148,4 +151,5 @@ def vtrace( discount=discount, num_return_steps=num_return_steps, trace_decay=trace_decay, + return_advantage=return_advantage, ) diff --git a/actorch/buffers/buffer.py b/actorch/buffers/buffer.py index 69ff10d..c1c9157 100644 --- a/actorch/buffers/buffer.py +++ b/actorch/buffers/buffer.py @@ -35,6 +35,9 @@ class Buffer(ABC, CheckpointableMixin): """Replay buffer that stores and samples batched experience trajectories.""" + is_prioritized = False + """Whether a priority is assigned to each trajectory.""" + _STATE_VARS = ["capacity", "spec", "_sampled_idx"] # override def __init__( diff --git a/actorch/buffers/proportional_buffer.py b/actorch/buffers/proportional_buffer.py index b8286db..b63e7fb 100644 --- a/actorch/buffers/proportional_buffer.py +++ b/actorch/buffers/proportional_buffer.py @@ -44,6 +44,8 @@ class ProportionalBuffer(UniformBuffer): """ + is_prioritized = True # override + _STATE_VARS = UniformBuffer._STATE_VARS + [ "prioritization", "bias_correction", diff --git a/actorch/buffers/rank_based_buffer.py b/actorch/buffers/rank_based_buffer.py index e54276a..58b994f 100644 --- a/actorch/buffers/rank_based_buffer.py +++ b/actorch/buffers/rank_based_buffer.py @@ -43,6 +43,8 @@ class RankBasedBuffer(ProportionalBuffer): """ + is_prioritized = True # override + _STATE_VARS = ProportionalBuffer._STATE_VARS # override _STATE_VARS.remove("epsilon") diff --git a/actorch/buffers/uniform_buffer.py b/actorch/buffers/uniform_buffer.py index 0f40362..6ed77de 100644 --- a/actorch/buffers/uniform_buffer.py +++ b/actorch/buffers/uniform_buffer.py @@ -55,6 +55,7 @@ def num_experiences(self) -> "int": # override @property def num_full_trajectories(self) -> "int": + # print(len(self._full_trajectory_start_idx)) return len(self._full_trajectory_start_idx) if self._num_experiences > 0 else 0 # override diff --git a/actorch/distributed/distributed_trainable.py b/actorch/distributed/distributed_trainable.py index ef807f4..0fd013f 100644 --- a/actorch/distributed/distributed_trainable.py +++ b/actorch/distributed/distributed_trainable.py @@ -45,7 +45,7 @@ class DistributedTrainable(ABC, TuneDistributedTrainable): """Distributed Ray Tune trainable with configurable resource requirements. Derived classes must implement `step`, `save_checkpoint` and `load_checkpoint` - (see https://docs.ray.io/en/latest/tune/api_docs/trainable.html#tune-trainable-class-api). + (see https://docs.ray.io/en/releases-1.13.0/tune/api_docs/trainable.html#tune-trainable-class-api for Ray 1.13.0). """ @@ -67,7 +67,7 @@ def __init__( Default to ``[{"CPU": 1}]``. placement_strategy: The placement strategy - (see https://docs.ray.io/en/latest/ray-core/placement-group.html). + (see https://docs.ray.io/en/releases-1.13.0/ray-core/placement-group.html for Ray 1.13.0). """ super().__init__( @@ -98,7 +98,7 @@ def resource_help(cls, config: "Dict[str, Any]") -> "str": " Default to ``[{}]``." + "\n" " placement_strategy:" + "\n" " The placement strategy." + "\n" - " (see https://docs.ray.io/en/latest/ray-core/placement-group.html)." + f" (see https://docs.ray.io/en/releases-{ray.__version__}/ray-core/placement-group.html)." ) # override diff --git a/actorch/distributed/sync_distributed_trainable.py b/actorch/distributed/sync_distributed_trainable.py index 3b4e805..fa3f231 100644 --- a/actorch/distributed/sync_distributed_trainable.py +++ b/actorch/distributed/sync_distributed_trainable.py @@ -100,7 +100,7 @@ def __init__( Default to ``{}``. placement_strategy: The placement strategy - (see https://docs.ray.io/en/latest/ray-core/placement-group.html). + (see https://docs.ray.io/en/releases-1.13.0/ray-core/placement-group.html for Ray 1.13.0). reduction_mode: The reduction mode for worker results. Must be one of the following: @@ -225,7 +225,7 @@ def resource_help(cls, config: "Dict[str, Any]") -> "str": " Default to ``{}``." + "\n" " placement_strategy:" + "\n" " The placement strategy" + "\n" - " (see https://docs.ray.io/en/latest/ray-core/placement-group.html)." + f" (see https://docs.ray.io/en/releases-{ray.__version__}/ray-core/placement-group.html)." ) # override diff --git a/actorch/distributions/finite.py b/actorch/distributions/finite.py index e48d8e4..16dfca7 100644 --- a/actorch/distributions/finite.py +++ b/actorch/distributions/finite.py @@ -48,6 +48,8 @@ class Finite(Distribution): Examples -------- + >>> import torch + >>> >>> from actorch.distributions import Finite >>> >>> diff --git a/actorch/optimizers/cgbls.py b/actorch/optimizers/cgbls.py index c8d1abe..908c742 100644 --- a/actorch/optimizers/cgbls.py +++ b/actorch/optimizers/cgbls.py @@ -300,7 +300,7 @@ def _conjugate_gradient( References ---------- - .. [1] M.R. Hestenes, and E. Stiefel. + .. [1] M.R. Hestenes and E. Stiefel. "Methods of Conjugate Gradients for Solving Linear Systems". In: Journal of Research of the National Bureau of Standards. 1952, pp. 409-435. URL: http://dx.doi.org/10.6028/jres.049.044 @@ -389,7 +389,7 @@ def _backtracking_line_search( loss, constraint = 0.0, 0.0 for ratio in ratios: for i, step in enumerate(descent_steps): - params[i] -= ratio * step + params[i].copy_(prev_params[i] - ratio * step) loss = loss_fn() constraint = constraint_fn() @@ -411,6 +411,6 @@ def _backtracking_line_search( warning_msg = " and".join(warning_msg.rsplit(",", 1)) _LOGGER.warning(warning_msg) for i, prev_param in enumerate(prev_params): - params[i] = prev_param + params[i].copy_(prev_param) return loss, constraint diff --git a/actorch/preconditioners/kfac.py b/actorch/preconditioners/kfac.py index ee90ce9..98734be 100644 --- a/actorch/preconditioners/kfac.py +++ b/actorch/preconditioners/kfac.py @@ -93,8 +93,8 @@ def update_AG(self, decay: "float") -> "None": ) self._dA, self._dG = A.new_zeros(A.shape[0]), G.new_zeros(G.shape[0]) self._QA, self._QG = A.new_zeros(A.shape), G.new_zeros(G.shape) - self._update_exp_moving_average(self._A, A, decay) - self._update_exp_moving_average(self._G, G, decay) + self._update_exp_moving_average_(self._A, A, decay) + self._update_exp_moving_average_(self._G, G, decay) def update_eigen_AG(self, epsilon: "float") -> "None": """Update eigenvalues and eigenvectors of A and G. @@ -128,14 +128,15 @@ def get_preconditioned_grad(self, damping: "float") -> "Tensor": v = self._QG @ v2 @ self._QA.t() return v - def _update_exp_moving_average( + def _update_exp_moving_average_( self, current: "Tensor", new: "Tensor", weight: "float" - ) -> "None": + ) -> "Tensor": if weight == 1.0: - return + return current current *= weight / (1 - weight) current += new current *= 1 - weight + return current @property @abstractmethod diff --git a/actorch/version.py b/actorch/version.py index 76f0825..cf4b122 100644 --- a/actorch/version.py +++ b/actorch/version.py @@ -28,7 +28,7 @@ "0" # Minor version to increment in case of backward compatible new functionality ) -_PATCH = "4" # Patch version to increment in case of backward compatible bug fixes +_PATCH = "5" # Patch version to increment in case of backward compatible bug fixes VERSION = f"{_MAJOR}.{_MINOR}.{_PATCH}" """The package version.""" diff --git a/docs/_static/images/actorch-overview.png b/docs/_static/images/actorch-overview.png new file mode 100644 index 0000000000000000000000000000000000000000..7c65783674e1ba6bc959badce6c09c6dc2df61ab GIT binary patch literal 39825 zcmZ^~1z3~c{{~DcARS6ccW=}Hk=*ExF%W?bMh+M~2GS`Yic%t=CAB5EPJl&%S=+_kXYVy{=t4wrA&@=bTU6_vijR8P=9YOec6xP*6}XnHa-uDJZB& z6cm(QbjQIb((~n86clt>VF*-M0Kv-_houl#Hu&>STnXY66dES33>Q~YLgVoYp1x@B z5OhGOLLfE_d<4D^2=esx^2K`oxkd@1q%03nmWM*^Auw@eJtzeH15r~@S66lZb3NJ{ z8~FEtstOP=fD{S}mHzWx1U4kpHz@G$p`Z#71+_n4V8hVfe?E1Nw0G19gZqYQD8oFA z?RSbhO;SBcAFv?S3P0tGNWfuiChuc~ct#LT; zsfw?Sc?8kSi{N7kSJ&{cS64FhH8KwgGqN@d^fw}y`J%%ycssbaKUzHsW`iQCfuZz+ zeN2Px%zdrAR6`ssyh99~Yz=XqYUT(C-dQau%Em7cV~Da23(zNk3D7~#297>(3u{jk zWm9|bV+8on-xrC{@Ifm_s5%pZ44qI&913ZHH$eMGnL!X(1Rh)#2KToQP*Jy2S5=8J zRrmFAR@SgmL7PVfqdf3Y#^B`i{b2!W77(bNhXX_puYwOm*nyFep$K&=Q)f>E%vw3v z(FbBka5M?P`guE;n+B`<8beG&krtsU5vu4=yAVIqFy&wv&K3u;fQF+8NP9JcEz%i< zMEZL|9c>{{q#;C2CD_AX!@*NO)YuQ<;cIUegf_Kx#^Jpj%uxGLAAJLe=yP;IO|hB7dQ0 z?`2|Yt*++}-bXrmSQD(kC)Q2|w$}gtR*yWarMfq`#4yspUKM7dhxSr2w?LWv_e*ez zm!Ucs!w(r0Vq_W_=Af$X=wxT0W^brw=#PdI{fvDK)q;ZU?ZIaId)Nd!Iyjm-7^nsr zdf40h!R>85!QhCHD6b%UXB&T{mw^>ASkDaRZ|a4#4loUH3`H6G=y};%;GB@Q{+13n zgj!^$TL-akOlu=fecvFjj2y&TPLprQUD z8jgWV!66P{Y)`BY3Wq}D>}-`oJe571or43N2qCtrAz(KQ!z~HUM#?r$HhRGpmOnX5s@&0hp~e(#3R(m z*#Slb>-G#+w!`AQBdol_Oksw87U8D$s-FHPKqQ`EVvP@Uw$gxF z!9u}gjyR{tASGy^D#2G@#q|H@JjFRR59_j47Im0BO0lMc)}dRjZkKmmIi)AHHdKp#2kw@B`DkD zO^o%Bc!;y1nWJw2$~)BD+!r2cr>qu?Fts+NPs zo|a)K6%Uk+lPcMh0*#bW22n=p`Y;o;ikG#q1qeOH7$V9B834Y8g_^&iYNF!COfrFi|N~Ea;ysEAI zI%=_~=u4=pqT4k1-cHr9)hr!=wjCsvO_( z#=V_>y)al{=lU!}e5ApDEZ2~h{b??Zv5+mz2`+{Gy_)mSe;_Jj^3ABkSovvkB`k>Y zQ9tInm_CO}J3aD@J6d}kHFp%pbN$QfBwCglRpYOa!>=lbr`ONq6~~OXi=>=5>6fpB z9Z-rp_~}?-ch3^rw!2tddrIDUYbsaV?5@&bt90Dqey)szj7gdHg`u*r<+l&w_^3P5 zjK{^;g(&}eCHRQCb8pRm{qQ*C`-H~N#@#QoJ(-Up*IYleUf64E+y61xQ*tJ9GVJ%( zvS*PJR`Z6qS$;T|g3EZ5_v_MQ$7|2WViswKi~E)?U;psfJ8o%zcXQios1O=;@cXCT zXqBU)?_^6;aC^KuKY2dsQ;F&R{c@(?KiHmn5F*f%~4Sq5p$*=mQ{-y2RvTWmcY z$d^CI(p7%05`1iur6bK)aI6;UeZl0vp+ae54|l!6y$@jLUbp$ayL+(l;)bnF?8`&$)0%Ehvv;B7HvX3I0xUA(TA%iTdO7MFH`*e+PM@zb0~CG!s7SdT#k+< zw)&A3WJ1JMl{aeqB~$q0C3$eui?Bnm9nC>)2N45Cf)`XC)UUpIDdzGJ7VcV2o!cTAq%xX(YP}Gh${ITN z0$uXo)(FM7gKeyBUGHT0l3iCf#a%ot6x-NDf4LYQq$z6Ot1cRb`i$;M{XEeU6?#4V zvx+T>>%SGyCANc=*8Ez2*Y}3xozELr3JKgjlm4kQHbWL(%VxLzAr5TTdbQKz8dd+% ztc{+|$!~K1?qxhpr;Dj$}+ z`j7d51-D00y}M`C_7=G9Z z`j&Z0UK0oWyr|2@chTfOzZRR;p~4R4pRJGH{{G$2e2S;}+`aQ#r{qv|B(})NaS&jn z3E|tPHbwuNrHsSp&CQMs5NaZhzBNugIpOxOyiq^p#QD%KuMQRyk1@R6Ytb3k`ky^Z zfUxDEAv&me#jHqSH zt+pSnc5+<%@HkC@>DI5~Ym$`gZ%n6wQ|I?)^Hpit-8Gl)wJn;63H{n5H78|mQlL=M zyqFD}ybNNHt3Q9@OWg9?3e4&0+X`;aBOcmUqiha$zX&eBU1$?s;ZyY!W6@f7+yA9~ z;r9z*En*Ip_L$@y`locRM&$i16}v2HRT{dty;kdPuYCvab?f5fIc3imbw^vv@1)O~ zT*;H5WM5ZJSbSML*b=pg`KXQ^;EeiO--l`V@UndSV}m~?InMs)h24b$>~wnqC#hAS z9(S7lvaUW2qv)#m+>6=My=wu2?rgbBs&W#2j0{0u0$Lq?MR5kMyC-y+3=n%D^3{S2afTSi(Wt1b3hv~U^fIaPW zVkr@!MRzTVB|zZK0|6iLjDI4y zkN+)5JrBS~L#l!?>j_VSrd$u=#%j(h#qJY|)Pu|S*SlC&juO5zv95wtorB=~sU5R} zm)lMyMQ@ve=rms()OzLp13RL4wzxW11(+~2L#zOeTzFZSSf#ta{=js}m4 z&*f_h=bcMqfBQPsV^Un6g&;<@OCl$)oC_k2AV~d{Z2HpJmWcbC0_;KwbsYDd9^b8YZZI9VCM~+wcGNl%^`)e8=!&S(fbwgt ztM`|_O|`YP{QBr$zfz>=ro9W&MJ)#@N9VAKy=6mueq%OMvC{3>MOV6)ky95O4`TLy zeZzF$&JriicrzxPX!8{7>J-kCaj;z|)F$CCKrpulFY%wjoXl=+`@K@N;(d3?;)+O$ zOKXhAXt-3D`ft0ZfuDI-zpN)6F5S-QoHY6q$gA0X-rarq0RWpepVQ~}r)d0yKd_My z*&efM+{`^k-}c=_$bMbG7O%>MVdm?YjM=n^`u^-ox4Mea(|}KRlW((aL+|!(I1z*S zzp*}9ibYrNtK z#rxiNeg5$M%SYh2sPAQ)K$C24=7kXC$rq^`x-Y%C1-0R>+h<{rPo9aVq_*eJ5>?0> zhhMbXYW*=FJAj~I6J!?9k}6U@ifk975v38yl_q|A8XyZYOM#rzqjWZ;(c3#Fw_|hS zx=t|6Tzuq4T&_TsNA9f8n8J&tTbGzR(p}%LJ^d7D_d)CDd%FRC>Yr*CiY$g1#%OoC zJDM4|=r>D2ZWsXQZR7xCe4A0M*baB`w}|RF6f1d=G&Ix<)NxD&9faH(XCRLpatl(T1JJiKFuA>RkoyhTf1F+ z_Fh@+Mc;0a{O>qaA?x9aY9#^Vb$TmfwTi4Y2@e3k8*t=0HOiyx)y+=2FSw60U?U~W zp4x2zIAh=~VmiKA!_sk?+R6Sv(%e3NaV1O#dFfZNzeM@&)6$d=O;Z6XItS zBKp&PlwU$1Va&hOWXl_Kp?Zoii512UuUD(&JlpO|RSXNdv@vCIAJ?Rb`|t>Hps79V ze}hV_EcW+M%6(>4o`vaF!Ha&O?Yq1wes`15owa;Ru$q8A#vUGKY_!y#!o z&v4wq`mbF3nkUx3wq<)V`axp70=EMS0FyAxkse z8!c56`z>DN)!p8~+Zi)``C=mS=v(9@!YY=~PI%Vqdu4zm6quMY_JksC`*Y;L5lnkH zGuu7Asm^CuL?e9Vnoxv%bz6D_tK-(}QTnkgo%NH3g-WOhwvt8G1qZ`7QEeOKV6E_26) zDYHbbdr~|3lCjjfZ~{2eGMJiNfy(u#u+~L!pAJ>aStkf?ysr1#r^YWD*R)S_YROlN@}!oODl;!WAcWf z%VbGkj^NRnmW9OsU^j5J2Ibt7eumx$SI2u|!~z}WPLfjdB9ns^Vql_-EDu{0)?ARl z5C3Hj08w6*v}(nD?cu)gb-dnHFJn^7B-iD6*s?7#6j>p!@>u7rRp%f> zRbwTJY-JUV4z~@nd*xhP)Nga~_G#vVqxkm~Z_=(VzIiR!*~@$GXp`!JizHDQJMe@E zQcfNy)C z#kVO8T#xv4bg+jC`4wa9Gv7V~^yFxi&5IZw+yFS~A3iKLZ9>&cX~31m>UEGMx?t2T zOAI3Y91Sph^K({xswCro_9LFA#n>HC2~5+MVmmbM7pZt;u#$S4R>tFph zC>Y(H9rEiziB9;6J=@$-+WKSN?bKiL=bpxlsXs3)FPKaDpMl%o(Og&ckMQb9JZ8m{ zUC3s(#Vd)DgTzD?KlOdr`;N(T;XgV7PRD)a{e!p9BiEt~IcdY?g0m{d(m&!ceV6!) zLKcTiO#ja+rhTaZm`>x~dw;3yY_yu$%_aFrj1oggp}t zkl)+hxb4FH$(o$Azs8FJz&5e^e+^7a74YflSu!VaH9RLFGaNlA- zq=h>{)P2YPXSn!UI&QWBg_fhky|aNzuf$6s_~~}b&(9;b9jHHWgHW>0fqLwfjr;h- zBzAwp2vl~Sb5pi2Om`%aYd+BZd!25S`G1V?W{U35m62n`t#2#LB=ck)^MS*g4woY; z02i~ZFg@8{diHsKh{%=x^_i}Pg#x$OwYvbw#Hjeb6%%#l>B>Wti-Md%$ZiK=1q((} ziu&Ox?O1a|#pjK#MQYE*lUKE(e$c6PW8Rtrs5y4M_cz%Ub?rIgJE-E^+sB^<6mv`d zlriW{6O5H|dhF$}Nn+bZzT9>-r91blo5M=3Svqua$lv?rb%%P(5>3}iomDb95u%(^ z;p0)m6}D`?I`acydoy5E?+fApj^?jEmiZ?JH$h`6{^aHvtY=4rObtdzOqjsSnqeKX1Q;P0hE?T{`1(9*w*Nz>5Jp0*s&O}sm+pymOJ6Kb- zH%W?Kcy(rNfAy~Y1KW}__-S3|u9qF*d7ddCz}^sFU#)F31-^y_=WN>;c;1`A+xPFm zB;wBs0`U0!(P6VR4{e)$1EY2!GKV<^w{GF4d$2h2W&Z`R4nIA^#+xHzpKm&ebl@Y$~(6*FKp3>r=hchGq#%}x+_u)?D+9X?m5+Vq`@fgpZE=QRxzB)-db6W-eln>R6 zph~B8uE!ugH24?%Fi^C+n+2-HCZi9}NzxT3PhB|pIaZ@}ytsS}rJr8|)@0XM~N}aEG*FgN@l9z#3#$mT<~w z$=x$+Eu!n_BT(nK<1<`t$_WejsItgA_e7$3)t(jB__zl}aDCKKJW zu=`!l*IKJ465X-AQ{zvPn9i_H?&!!l1D2iLkjq}S^!Zuny}w9=?S;erDC~pY;i>J3 zP%HD#-$~o9()RZ+YYK0GO0i?WO~KYNLxt7M+h2vXXqiuQvvG5h*g9V6n+QFkagEtP zbkxsjIjcW9Xn*^SZR!1ktL#tKmoF?9dA~Q?GNd=FBzylOB~{$GwNtRLc_Z#adGg(7 z6DNkBqq~lS@;RxUmW>{Xr5gfFzH;?)AIKqjd?~&bY6!({CAdGxcn$)Oyl2gP9fvjW5=;b@8|E_(Q7 ziqek$tCF&UZudA4wePx}1nA*~?SLLtc=qRpfbF9jOYXTEf0k-};h`(3+5H)(K1b^^ zYxNFKx5tEigP%i!>dC&#m^q2Jfc91-e+AW}`48S@4mt(`;KcJ^7Qq|kV?#IeUNCNA4)3xg*y(<{-S^xWY;m1_qPIG`(S+TW;nOuIO?q`p3bW?j z%`XK&<b! zL+wHkia!pP@0DRH>D$n#yV|?+96uY2>Y@}tO{g-)ZfkZta9;>+-QRD`gF&Sw`Pj{K*vLy_@!>zt3Ovfe+i*3v2%1>;s&1X z2(!c46{l{!6?=5BUB4eOx0-2DtoC4ula+)Yj6ZXngPu@A5%U=!&pZ3?O3XP)mT$eB z#8}Z(kG)?b6=QdJm9E5d(YO0O2Rufc20fBpEW{|xI#%nZ2=bCG8*`mB7ZW(wu_{y*eA}J{l)kpPv`|J$Bp~o zRIZZobjuRi0Ao;$$4)c_ zZ!G@UdS@jotGGW_|A;bEZq|lHP+Mz}2!|x3zn3P?oZAOy*jxDy)Y;!>q_NJ7({cGM zro!a;;pJIQ0;}+pdNu(yy=%Q<%^&1K`>=UDh(}S5Yxd}PK;CzZG)FRP6RNpCC+ga1`lpA5PFKH*$#dN%8i|uba_*KmQ3)JY!mwx_m$A0Dzi}z-|a(cGrQZfe*B1nB%DTp1< zMk%*F%PE=cjuD}(<%n;JnahT~H#&_O>%4v)AcO@z2FHY$-5)(FY;9m~?}b2{7V^*0 zj5>6Wr>s)CFbG6{g+}dtYcx1TN;=MDtKFS0yuY{B#;_k8yaWfIU{=_tW;oaVyL$z? z=~_b;F0#+ow)`;7ny#oATfLq4Gbe|l^uylw=T;YPYJyZzBJwiCSg4amGi>Q*M48Uv z&Kz3}eWoO392>+S|5&r>PE`tyi9R{4J)rUPU6(}GpiddV*=Vr&dcfHzIvA>u)5CWa z?w_gxqE060S3LXsU(%W=)=IDJz*MW;BrkqD$*7i&Ib!e87-=p((m zhpI@4d)GaVjPf!10U*Typ@(I!?dWhd@p|2fzp0CmUkc4y1xU^x*iLhjF1dRNi?K#g zagwgAQIG!NveQTO)>|PefM6u2P#}U(bxL-B&}1=<6fq!iE&z-Klh7R=WcwqOPD_BF z2d*%UrqIN|+kJ=kK&5>Uq*w*8(YH@)h8Z&pYF8g>M}EJNbN`X_jL2n$^O}0oM4E?v zQ$T{r(+pp^Z}+MY;Jkgk`0_EqgBQrl@vC&|!o;}W%aTQ!;gx_Bqd;k03NWZ_Hu<}R zghI4PLO0#?oYcC|-^@8*yTez}d?E2ypn(CZ$@>eqLN>W~bvw$mCK4Xvo16U}>Pg=R@*3 zaeyp`vE5u6LFDd&BkvZuX|PqQ6=i$kl$;dEzC!2O@$w0bJW6!sXYW{*O*Vvk_JAn} zjL<$ZD=MU0VkJEbS;z$&Vgh9GE_*XCGpbu0J3wu^10r&wF$=l&_FY`ptTuWtRaKPI zii6H!>nRXB0Fskb^?SdrLW|Vaq$8tZub0=w0e8;0bo_Ci*t8z7DuJ&m$Re5+T_Mil z=k6U;MQ5X*4wdOlF<4;vO?24^>m9TU^?d?%!MBURnkI~Z`CPcr$$KSUkuF(;$iOKj zB29?3>)R(sCjvkmyDwnzszQ{-G{;#4**}yPQgLCiF%Zmb3CY^uyTni!C>50!3PFZ9YjY<1s8i4mj9y&MJwl4p54#L1P zOIzJ7jN2V)sxMAb(=yXxHbo|-$*ElzBpmGGP!XtjFxSFTeo9hKvEE$zY2o-tSZrE> z3i$^uf~v$LfHxymlN*F7`Dn;85Lwm&*#1*faU(fsB!QrzeqYH71dYM(OlJ9VS3q@P zP_isWpSZzF08xDbfWABc8T)S;q|J|>3kpAW4-}zakij{qFkCH1WHy&u2HuBvr7~6c zkJZQok}DqZXUH3QTM)&-zRpisl|Ko*h_0H%+G;+=A$9I)b-Ux%K%VT6$>_CyJFBw` zC(NfVY?lBsiFih?_zs>74>1tut{r0{wQFA1#ep+C%?1n;E$r^F1(Xz9(AVI$44^D( zWp2T|G2oMgxaqaafRzpcZgC9}4z}i!+VYQu!CVo<#npj3OWAFFfOy?Po@tI;`|hJt z4*@ldWOCdL`zvn7I}ERGd2;3(rb)l!awX_Z66To5MGZziHFF7z>tevKPn{k+UC1iI z(8*mRVvNEBO*Ga2B`LYSJD46SQhrol08ae7hr-^&hsjAqvA2M`y-R%Hi%r1Dxxep;)(N%4zPfJ46TcH`-(-y~_yKZg< zg|T|cnh2SKWv(m}hVp>$aP?E*a|`~D8-T6J%w%xa0lgGsF5FZD1nL3Ae7|JbwR3^b z5R-977s!Qaa`c3jZv&ks(ErB?7;CxZGZ}Cs2F91 zUjcsXUmCtWa-_5ORav}HxXwmkp9sMItL$?}t8Dk&YKVq>p~94L9QfbQ?kxS>CSX>L zg;FUHwWS^FU@}=~21jctVU`}yr#Lb1S2tEzS(?rk`hwEm+iE9erxrrJir;(VgB_rj z+uku|BNsfg{*oXW1+`r=Jh~BT9GyMF-J{?n#X%be6zPJqjs`_{j~ag<(#lXHv0Kp3 ze6wrb|KCfAFL&JJUK+2z4cdha$Z0HDf!&2o2#$_Q89DU;WUm4C{XW6C?n+9v7J&05 z47c)S986D~lo10oBrCi`+VvigeFHu|$zUU~jwZ|!-rUqva&M;qT?&wlq1Bx0ZI!CX&fFTV6Vfy{+>A6h% zTl&ZPFAFFAc*kTuE|mT6z{tPT063qQV=j8&qE95(mxlt5&=QK z@;&_b0`G4h4`#9HI4Vn>0#4GQpdC(_UJ7``EtX`!h6X_}yWwug+O=L~kt>!#)^}wh zf+je=xwEBW$0uGyY`%nWM%78cgIt9D(M$UUZEj*GS}sfk6-v!~rADM<|x-2z64wi=(*_`z-g&Ut+_JSOsY! z#~O#L9#yLlxJ74}h%>c*!pqXL=K^dq6@{{sNUQ_tuQOi;>SHQ=K? zYZ6!S=yY?uIw^f3-XpUf$|S-S$Y2a=NY}oUMq9T2UXfkUq@2mz1B{#B?o!k`98=B( z(_IH;*gnu*qOlOz0UsWO0xeMt-V^yqP>5avp@vN{D{W8T3 zqF)NtFlt{es0x+$FcF(RPd0aCOa__&+0s|nSxS^M&r+;SMPkN2DamOSek{`9e|+sQMHnKU|6 z~z=T-@O(`1id-m#bDC@3H(g-UEG(~Wd$JbSH63D z@uaMTwC$aAb3}QIJ@9**OT{lG6K>L}AGp|V9UqZzNvCPC0#IV_!i!VmlGDK2VlAR1 z>ESGYTN}?gobNSYZP`C4Sr=YLf>{@uKE$xoUHmhv_2u-3?zCIG2Imt}-#0{kRlTvMQM~8Z6=~|3+C>7IPaliFZwIDGd$$bcz5136Z<^cB} z3JJg&$UPMAp|3zuDn}FZ40O(sJ01IY7r?F6r|#ao337TrIj+nos+*_0ZaH&C^?Y-f zZEH+ac&_K>;!6kPUhAvxUSqpKAxTaqq2rQWmW=&#EMpm(9JyCP^IU@5^uClkK_NKx&#~YBSuj5ds zXCW2)z@LH8cgm>=e@ipqZD7dnXIb}`E&vg9WWGSmg#CF7q0ERL_vxucDkF~wkVLk9 z7+M?y8ZWxWd{qNlCI3{W!Rf;Cw06)5Q!#r=(fNrnFD&r%XzH1R5GH?*nA}|sE3I3X zCVJ~9%qK-DM2~6PyT9o?K`NQw3guxN&eZ*exX8&CGc$wj<1C>8J*C1x-?;*4Pg1=Z zKQ~n%C#i6~?(13qQ8`+=IyxA8yPu%z4N%g^1;JPv=Rh{tr>Q$)9(1nNi|qr0^8r-& zs6u7$OPn(HFG37Jd=V8PqM6g&xC-Rub6BQ6uLFZ003==v^vt&U7k}cNR~ycB5I8+@ z28bJtw=fq(bZ-LD`t5xi(J0Ve&}rNJPI0VGrBkw~b2|Lzii0!Z!rL;C2Csl7xqjdV za*Z#RzbJ{6k13s;1mX&-5+y8SvfiVwQoKHRt~X^u!E0KO3WWOH=ii&n)*HXiZw)#` z9vtKfsy)O4y0qvM*!D#B1cBXO;Z}b~y%DE-VeQ*7c#9>U2flow9g7xFRxo zdiz+JwgkIzx-RJ3BA2#*v{SSAB$2o2T+fAkwoX=&iIdw~JJlZ9i=aWlDW#-gwiMXT za{5wF2gmp=x4Y&%M?aO$!Knmn{pnYm;Y}PHBQAix86n&nk*hveWTV!qQ!< z!s+-Jy}tTUjejW94EO7(!JqF3vAKr0ColMlGIn+NC@81Id+62KTda4Mm{J7RMl7cR z#tqru2`(MITk4DUVBWr5e4 zE1r@$0IHij^6aPi?mkXC)8H8{@WXkp{8?*6DfG;Ca_2bPFoVt{XarNgb9+VKS8VMk zGi6hUeHGGz-(lcRLHtNT_V<1|T5GxU&Gw*c^g$zf=w$&7ppX?4?s&0=OLwBRR_|Xe z(IU4ZHpFDv4V~bWa+l0ZpY#Y9xafm*Y7(dVYbI2s{f$h1Hqryz3R|?CE!bhx^HD*Z ztUI87v;fFP-sYFrMtc4w0p)*YWaA^_x=Vasf0Y!B{gNBIE#vr2wh5y)l2uXev7iW^ zJ@UyVoHP>k!XXbl+h^SVO4;j6sdaM)+wh>CfGHPZio{AvpL;>Ri7(dMo_EgZkirU! zKhZ4-C`N2{(W+lEGMr}{GCL`Yyaw=qX$HGEnST%PW~W_O3;h+KljT z8Z~?n=ve|{?wMym^sN9&;E%=PpkD5Ty2@i{mdInMU_KR}o8`rxFD~mX&`x>5wL1D{ zK_8+gXmTBJkbXm=tTDJh2bwF%Oxka|CF+jyN-C@E`04QqGp-KlLUo0R_xEiQEwFzf z0D-f_7a*p9(jA!fmJAD4R2*p}>(RHOG@nCv@tUtegB)2Y z@e|2V{2aUYEd<13jj*N5Kf+DrmG+~XGDq%b@hG2&5IXna!VOf^%ncT6wG|>@ zf6}+uA0$&7$v9NrflrhB5G#0ZWfknmxWjbtDdRmRcC6R%uXlkGiv{d%Whbm#B4wOh z3mRDiPZEfpYYcR_cS&z6k-j`%>tO4hUu$0r3!335f@Vt*SzWlIlkD4dL1!Jjhb&c;FxJ;WAZ zf7|(Q1^)l+D^MpE7(3i0k`4C1K7ReaZvhl+6hQw9Gq1JFFassk>gPelHaULHFgne$ zZUq&<=ohKyK|f-MIOaGL?=_S|=db9ti<6+C&4dRh#QYd!;Vu>`!A%JKjz*wB)I3{} zif=A&66-T#RuaMJ+r%^FUZKaJ<2ec>XR@rX^Hb*Cb7+F7?pz z&rTFeTW^)iH6>&ToYh?1X8qX2Rh%ds7~Fh*nl@vXSqsz=AL-6Hr!SDH>jLFeH0p6M zi(F^M$nZe>qyIKcV-ke5CvSh2bGeNZSkXlX^H;5E{AJ^3#8YTt<5yd5u#*bKHb%&` z|4Q-NcF(c5e7?h(*81sn@i|xUPDH3ppEo}X_U7i~3!@-9Y=|;K(e)6uOlQu0wJxmV zZR+Ze`F;-j$i9;#z0aO=B6eG;9@==In=N%2r3uQD>LxnzY;K*$&vdgJCaQOz794v) zIh^_IE3m_yeky%s02CgDESQGtB!|TK0+C{XpRzm>fW9AZ3|M+(iS=(O(R^+nLB-&v zR=Vp@RpUfRW;5fc+dgAk5~@@!;5z6k6ra?ZmbrrBf2ehelZ5n%T6I1X+ENm&e`^+` z7kd76gGa@XVBPzjT>UXVFI(IfdXmW1#Ge(6SF$Pw0N{28^`(0*jQB?^`B~<|Cmpvy zrOWXfYcCDkk`jgBeNc^i&`)tnNKEe7Ohsn@368Ut1sz=4mh4@&(z(02(E!(ta4F?i zt~Znc^_rVzBV)4rhuX2$BKrK6>9J!HsL2u~tf~1SzDRcu=&YQfQt2NAP+ZPWe1r+h zET`42GeTk(fIQ){Kt0%81c7?=GhSUGYcTUnRcB(}&t7zuWAHG7eq<7e;E+T-pAU>F zbEs*SylwsC`IRL^)l}?WKu516>cU1sRPsb*60aep4_t=gJXz*+|YQ0#+R|KRi2kk<_(>k$-(c0~&xwe~LgJRZ$ zU1q-HGbgZjLA_8KO>IdqyfV$6uj-E!ukO9X&&R2(?2<1%TNiUlgFU|!$RmPFWYIMC zHcBE5@zL}Iwx!$gQpvu%Pr~K1#I2=pCkc~WvAY)T485p~_8VDK!ZV+b!DyS53h9f7 zA+I3FnNA8>d`ud?-GN1cmX!pwd>gsOE4pTia>Lw737b!yPEA}`XNDZB(fCD;Q0zT3 zp$eY%`9gS3``Q5${^V^f@{6!+)aJ(*>~))2AD&5GLjo9Jz&}qA5`IhWADC6UPOm7u zv#IYslR_0IwJ&J7OBRxye{`reu_#jE;JPOeTb(@^AV|riioqo%PM8zwR5}(Yu9ip> zp+(m~LBnsf`aukyI>$BJYD*#?mfrao3~`8s1z=X^KZs=oq*q z=j!t--n#bfOZ3B#k+VE0xo&`l@+GXBw|`{VNYS87WNgh{hxQopRXeagB+MAZ9}^`C zvu>RB(BPFd(b->5`aP{$<~?La2ozGXxRtcvn{G1XfD&OIi;2%UAotv->~!_X3>bXc zNXW_gpj!nV#9T?Ru;1B1BwkKb?^Dt|WKJ!TCU#Jl^Bq=?zer*Er8}JI;Uyd~m@i+Y zU1^aNpZ7)f+6KFuM#PUKXoY)6V&#PAL}D^OiMSSFdgI%Bd9dtwBn z@N*m_a(_w)ze+az1$ETfjc$(0Nx~F4Eg$%W3p~n&)n?04o_Spe9No7<60RsGqkfN% zibauIxg#y0J)Y8Uw(rg+xtymnrt5#aCZ$a=S=!amI08{yq|^FQcz=kdidI6Pg>j>Y ze4IXCI^fI49`^?I`h`Gsn`q%>I+3is>7==%7eQ5r%H`dYjzl{WX>T3R zRAfBX0RfR3E~lMn;h3vi=DN>lNI@4sSJ{42L00t4p99=Xs^D7Vk7;*rU=`F`Iv4tw zSG#gOv^ri-sB<=Rm*ZO+sGD*1QiY20^>R$waTnuHb{$uSbUkM*mL@;evY!Us77>>q zqQ?`0j|mD96!laL?##0MFg4#QSKJ6pLGZ5GXI%koK$@CeJ3aBi?Y!8A|MJpbaQjap z-#Y#{m~gW`CZ~27G@U#UoYON(=X%FQ-+F;?1)U3ds8FZd&UHf3aMRfAc{$ghQQ`gd zTnLQ+rT~G>9h0`6)ReuK+m*?~LDxCruHShH==`wp7GOYimF?x2*7?g^G%|7=bt~>1 z$M?Sroe;VNt)D@V01f+j<{eeS)AreT{(jKaej z$x4u@+E^;2o7!JN*9^!@CiH@)+^Z;3yUa1Td@^ms&0_jc%Ib)^UzppLz+F$-eZ{01 zCog~|U>-8jdUU9dFUQgG+HrTCNM{K~(O>D*5EMEB5p$Dmg~nfGQIw7gbqBqiD%Bb4 z463&YpLl3kw0o1Vr+k)Gb?(~G0+5nKHQ~P7>w>9K|G@K-4BrMG7Yp{%FIqJDOdHUi z3c$H{?a@vMAwl)Y2cEdV7pk4C)Jw(P?m!tTUaE)s1Z}%M(_WXDwviOu9?ig~3e7i< zZ0Uw84s=ITp7P3ydQZ&<6?VBojy9uiS;5>s>(DgKa%O@N8o?@Euj#z(R<(IKuO%wL zlYTs^_v-|QsVpEDx9s1vjBUAfd3yZj?oeVIi{o?~sB6;Sn!zUZ)0ycMhUuRq?Ubjn zAob}vPUqj5ZJo!Tl`uDtAY|x=^`0bkEb^Bv^$ecQr@WTK#V6N{a)uZ3qPXjxFfeed zIlE(eYCMt?lXiNK<|ND$>o~h!zvd(b!{&<0xwJiY=o0)9F?_$mw$5(8F=+qSzYsy~ zX8eM_gxDc271u*TYn?vlSPrLRXqE%r7`WeZy=gI%nE6SmjgZ95r@FY zAKLW0OSzn@lUWfKb1*wN@}%>o$9kxk-8g_{wE=P;+AFCuC0qomFS@=f0JY}hHuP-S z7V;b7E%0e})+w`BF>EvFP6>$I-m%_gCbhNXa#;I47DWNAJNIexKh$++WJloJ;=^f2 zol!FDloYZoa&&d59*gujy85hw+tNgIL;TuJa32h4iUi6*wt-`^vh}ueb1#;dZfU)x zpYmV8Cb@)J8Qz*ND6|T4ud3Hi_;OZzn1PL|g8^CyBGe>jq{0SR-jR^NUO%7yIeREb@ zy(wRIzvUgNKe2zA#RwjBf0%FRxQGmOwIpzPTgpr^&0cfXE9*8X?aFgZvytomiBtZ- zuoVq<9E<49%_VFfJNaFg^* z@X&3@N?L(}>-)Nuc3$QqO-R=xF1c66^dLKk(;q3VcizncRuv&)pyk1H@@jIF+x;YV zK@D>vJu2$MP0!C)ZKQUc*hq`Eq~%pXV>To#nUGk91#9VKraWR};1&$?N>r5G*Mv?@ z?{*rg#drmic+}(}`*!-}j(sulnmv-lnd7N+3ao8-v&05Qea>S0`sC~{f;DugV@VBK zC@+vZLmOx}nevZ>joUw30?^fj*h&mdpgN{MIpLpHc5ba+iKVD~?3J)FJ~DaLRo&(L z(3_HooaEH9FifwomOp;#3GdHtzcxvGu8Xc!V$?N_N-u)!86|9$Jz3S5rXeSAw#pp& zqIE7TN1~yP5P?&2UqlD=^~=Xx1|X48Q5_csHip3(U{WVj8aVVig;6hQrlj|W!uXAs z#M4&m>6u4uQ1@mwy)RC>Z>rIX$o0^5cuMTN(OWWjGMA7%!be}cB;+#lJd;Z`VqQe} zH!kDJ7G0B7S|Cal>C?hhc7x7c<6qi!M?` z*_^xDc1IUxr~ZZdT0(({wsr7zq&pS$md2t5oLg6gFO*ugTXPYXtO!4m?*txTUZj>&i~WccYssf|NS4w!I5=P z2w4%?qwKx6WGg$AWQD9`WlOS)NF^()%p935O3Dry;UFtSBLDZdx_{5_|9_s}bN%n@ zzOL@ObIx~s#{2Uguh;8+wy9#Rx&0?=ba8#p3AcnT@mY!{fhsh*j^`tu^WOX1?V`x> zFmFoU4TlxVy8_!`WX*YE8C5><{Z4(v0tS{fFZKdNoHQ6-H&F|-Y+H7cnw)3*`myyH z>t%{>EoURs&vvVY^O_Ze5NP`tNd>!!>KUDHAJ1x>U=HcHLw+gZn6}_>pv%%|s-=R* zMzEG3-kdbhF)6&CNiTLgqJ4sL)IOs1MtDSPya8+F`<4hIm1vzK#^kHG*5^X2OpcTs zmSL5>lr%9H8loOzGrC)(@vG_W(U_)F03s-30embW|A!<=>09b-k#S*b2l zh1*S7PIW}aKK$5XvmUp3Vf3>^P^6h-tBvQTOdsO><{j##U=AR6ct{%h)#a6V6UQoD zC0#Y8vDK>Jq*eRS3;ec{u*%1yd8XHq2l-^VqnPzoLU#7=^!K0yEAv^i&zYD_ zjVg~&Azi7Dx|JUI>IUJu^N|6)#A~eRh`EFJ203&mNyD!k*&rTIcZs-%6}TVpg0`HB z#Rhy1Z?!8^3I{UT@HSx-%R zqMpO?@=;&%yAJ1R-SxHbobT2+6Lv+PJNS{1b5a^Jh9uq>w!GS=#kWK3IDD|G70xB~ zkV#{DmcwxhPX2ehI^I0p3D#1txIv(DX%|I_W{I78e=nSTilMpL(k;f0&)=i{5^9Rt zA=-W^EUYsRe0g~e)iuB9&nmAb7s<;HGCc(#W+bcLHmVx>!FG-D#TrH$5?w0MRA7)Qv#)zoL22%!cn+Q%kj3rYNy_dT^)kz$)$br2n z#U^=KhtE>4BD>-^oqfX#!=PhZxXN$UV!UqhCgFCw_ny9Z7j-_v)of!*QJxIv@1@mp zIbNn^Ve;U4rLuHOD{(bnheRh^^0WROqv%#lm#AtqNugH{8TKIwd$4>@meE9N^10@E zF(s1FsHt|w!R&$0Ese)sbjK4knuqpLR|MS0JELU;j{o*=$!Ex_5@<;43zee?@)0At z!FxVJiCVZ^!ez^!)5iqXZypayu~6;W4PY~vX8iUdn5n1auuMndoTo?7%aG_Q+_v`t zhu9yj^7Y3#ds(D>dB>k@W2431Z0vPjj?Eg!eC+mxc*T?1Ag42+^$}u;<*1>-C@7>M2D{~ z!X6bY>Kv530hv?xY0Fi@rv~51AJpH7ZoyK#Vi~ZsJVQ(`_{gQ+5Ua(M%Yj>~ij3+v zz}kMj9)8|Xd*cPw82Ls1M7P;+5tnS2uc4Rk7ac!QHATy({q_l|{=U*;5NYfK*WMLgRd}_mb zk+@!?H5WZ#@=>YBnz~6=INU4FwmzkG$|c?=Ey39f%N#3M_@2B)<1OVT*%pJzu46g_ zOMS#xk^4(L(XS=TBr1JxsbHimK4zKHu|V>@y(O2Zv_l(MYAuyHg-j>gZ<1!XE|ZB0 zy)1kmrA#HC`#w<>(l)pEweEUvv1QHJJ2U>geUmpJLZYmjoM-)gIr5+^ZI|-}g4D2P z6OY-r?9Flli!RVUE-@>gq&w0$%ENVF{o3{Rlh|@&oPI13>)_OMv}{lf^(6DPfFE}r zUB2x;L5F5PckC*r^7Mg=71@=7G_(l{qLT;lYjN+vk0*;)_G+u9=n|1Ea5T~bGL0u+ z5b;)RR)JKrf%j}U$LgX6LCxemX}s0J2YMS^BK7Gl8pJQ1US>{wVkeB>w6x(1QVU=x zYj#)Ti?7=WL%q}sknQCk^opZainGP?sT&jl+x!3#Lw%K`Jf2CRA^PGj=UHw+yd^`- z8DcuaKK1?H52daloi!V(yKQfYGHhc%YbV&cQ%1_6LVj_cs?s9;N!pDUB)1Bh_8>$D z6BnpCBpca_OWkaL&TgEPgToJXL}P`9!WgpN3$vzDKexB6i+oI>;fd{u=+C@_)mSrU z$JuX0ulg4B&xxKbXg4a5l2_W|t(bV(=K)tbv`dGGLN zzt(!G{PFFguw+re0o;`Kc702RMs)>st6;Y}I=A2;_H+7aG7{J(00> zOZfyn7#VL+vb$_oxfDnLeotJVPswDL>kC&1_hCHeT6071p`kF6&!@4K%2h1!1ibME z>Gt$(hM;cm)_B``o)$c+O;l8se&KK#LF2rLrDY0^nEet^-)eVm#0I0#U> z$?8Y~owOGztq@Kar}%2yX=nZRdXT7JgK$h(Pj<(GHWyt^xk{mRS32j|^5nr*?Auf7Fr-B^*t`Q$l-mMo+ifoyLpxm#417pf zq(6|X+siT{(e_M-g8|u-Jz+2~DpYjzL;Iq-sMvNR7|t;MFak)09SeXJ$%*#XscQ1@MVk)Zza8n8DDt zGzNZraVcSTXXJN4c)$sYBV>Vpe%BQBsONVAYyO>xd|FoN8Kel*xPF;}x1zs&LHaV`hro%d?E&(7_^&5YRXWCanQ;xpAmCKPI+(dZSLrxkh zh}=?(H1tyHpP#K>R^Ny|t-`yK3z&ifRdfO3M)}U&lxLG<1J~xDoS&}67N~yR+#z1^ zPuPz<$$he_`@ul#Lo|>!6P*~J=Fu7EwCPyI8pj9b1BkEu(%t2?0y(9@h|7VUW$)C4 zJ*i6>mH*br@=Z2K->Zel=j^Kw!m8>w*9L2MvIw@qo;cPve0bX)$vbfj}+} zh%X=om7cgX(wE}%n4#eJySoJk>|Mja(|nmd^0$Vi+Ckg9I7^_C)*x}j&6#WIsVYNN z;_|cF26BUVNP`^095lKk`8Y?~yRByP%EeAFN@a$c2=J!+_DQ|CPAG`m@PB}6(wEo0| zv~?MWx>E_*!6pE-pDs`*w>W9H70{cTS85+vBo4`6nv+5crEb61F<<_D^Np#q?w0pC z3JXrEe07qM0?=H!Y8ZhE&=t3?j0f^+enljlN^T&`Mc3OBR=oBw>aOB#wB-o&)=VbW zF2^XD>M^cEI{TFYx!$gqtJnRSkK~5HSxEsMlG39AV`*O#pR@Ga!P0b`dDw_7&A!_6 z6{UaHm=`R~KijA|*jqC&g-hcBWDzF?I$YiXTbof(W}5PfL~A^DQC^anbMiI_pp;y@ z20-40ecPP|gA;f0RE~p51$CN9%+*jX`nLEvoIdI+i^j1Llg8RaU2UEOyo)3;)Eu9(3 zPMmtO0eB5#s~{*%e6j9DpkH^6_OsWiY(Gm1&t1E6fBb!Jo>I{G4|90GoC;38r+UBF zG~2Vy_N&GMLBMLMux+_A=4ZjTD}Qxm{^O%V#IbZz59@nbCg3EnqqMXP`UGICLjE%) zdjH%MeKkm|d_uLaK!c+5rD14B%Sh7@KjWRw`W;Cr@jnJ;f!fb|W>k9MY+{#**^}Q1 z{eyfrVr``~J-wKH`En0l+B(L7=LDRe&$ll+J&tjG4?JtmmiA}^)(rvgE*#}Q_4W1< z=aWj0WZ~M67G}=Lc%1uV5nlgVGqLhuy%;*h^1AkKHj4$aSstuk3?r!@rWP!Xcf9Cg zv$7kvg88Sz~XWpM2PCobYj=fW1CT!T>_1b6jmGbZ9 zn7TLps2ZMpw?pdIrK&%r2uaPyggvCb_yN+Pn`)DD^TAp-Sc{o}G2L#qq`X1uUcPoQ z{`V`4rKT+J`X8n`SwNdqd_e=xwfW!^(iK)iV$<#|xDS05{^vxOD2d*6w?-00@DBwQqA)10-we z9_VEnPLO)t|FU8Kqz}t($9g>6q2}z^wAz)>#%OKPv%B&%VNRlyzH8Ky?7!`IYi;CU z=dqgrszw-KJOF@FpO!!a6pi8hJ>UnEjEO(W>nehzpr)s&oj>6Z!z`WO&&G8bSVsWE zFA81e>0L}GP^l;>ur-LDMz^~){5~b-7nbeow*%P#ynpgmmXRpN>#WS#1&$k8yBZU4 zt4t_g>D0LOM|fm?Tp@vADf7P6gfrRxXRzAj&5pL9$+sxGj_5vdszqx9ndAXrCGS>k zkSg#WX%=Ooo$^Y5%D`Cs-Ac&fH>klTmyVM0c@J@`==)B&p%C9>#<3qHpQvpa@3a7u zudnzvP>dORw_{O>=K<@L4uv!;xAZ?F-9=v#c-k1K6X(EWSA*u(qU=+^%b%`3r*mi& z4@96U^%dasn}(dm%xV=0g`g%Jdh(R6m4$c}FirKeT`y3Nqg&8p`qAm@E2!k(d}=JctctSkC~;s|IrF1ghCCYOSdO z5Dp22^0j4`I52chYyAlQ_L>*XKC0aN4EW69mcoE0wbvng6mM`>-U|LF{~N;GU}mkP z+X5;uY##wYL#G4e{uN-#xj{$G4VSHmAHu~#(y_s(ll7gb(+f(T4r6mbT$zS~e#=%uk0<=@keh5=|EO`CNjUT_S&JAGX=4D{on++veQy7$b+MdRG-%^I zPf+mYf`+^zM_6;nL^#36BqUZpyb?Rqx_V4Owiy2m`}!$Pep?ZW}}(StzXL}0HAnRxXgQ{e5ic9ToI z9k`Nd&KPo*PaM7u*>7PZ0nQmEWZP7q$b3>4RIcVNh1>v5)_q6HsDYP59Pp+9B4i&p|1i@Oe=5da2rP;uqu)jKQo6)f2ee2Y z!wL!^yfb3F|8H9@*$fRt1rya`U{afV&TIlnEVBnNHv)AC_jaTt<>ShA*wN+%QV6Z( zNCs+u6u1dpfW++oOdx=qjm;(BQUI6;_WlI;fi4i*Fk5XxT%xUywi@Q{roz~p95c+} z19q83o1_Ppv0gR=7?}%Kw&mV?N;UVAJa^b;BohD|*g7D={LP>>w(KTF94~R{%);G7 z-?>K##-eA@8Wo>s0dH$_H15X`rzi-M1%4lPO%JvQMZCV?4G?U4l6q)`Tp}FM7bX1{ z2+ZnF+Su*`won(KlJ1#bH-Z7)w75y2f{N|tp=b0gQXYc9bSWRnkABW+OFPE-{sW?H zasgmNBfI86E4Cy?FXOhW$DUt%-ones^6@5FNC2@3i1zeNL=dj$chaHD8p*E{hxZ?5 z!GMcjt92pElTtG(s^NgcTgB5$Pr!J;1`oDT78ZAS!C>;b^8af@c??m@VU~w9c5Ry^ zs*Gk!E1Gq8HEk#PhzJP?J`q9P(?BKIk5oY}{g3Mv6sSrSGJH?%C*XMyK`TJ(@DGw& zhVbfVb)1|dM_{E)mES}?}Bpm{5vCyWRmvImH+_4K6O?vPzxQV|c@j zDKe~6^i5~nLpb`)Qy35=fodj6v|up6{~vmOZstE|4YQu+$2j2B@vC$8ghUYW?oD8J zl#g=U1YBdbDfTHqGz^4DfPnD&7GXw@L)^sp0VcZN$1j;??8p29-oC!$5I*!a8DEfJ z689{K%-E-gAqN%YZ8_`bceZW)hXSHcIeRt5&^KqJ$5^`a96o+RP)SOA8>t5*qT9JW ze5E@9j+Uma*+Eu+BA?Y{gR2J~mp{A-7)poQ85tZYnN3!wZrhKsCX^15f)@=1!TbEbvExpJs#tw07DPYe<)E^V5B!eRRsW(1sCXVB>#}8vG` zw;&HI3=Tgj3%<3K_;^;3M39Bj2vbM_#h~uc;;JlB-%Q^U)Dh z68SZrMEToK2&~Gpa;5WP0SGm-hI@N+ZYFzwXw^w0(MrA&YbJ}cI>UZnfCNB8^R zZx7b5v|UUbMvKCdH%_~z=U-kC{r@^Yn(a{w{X}ywf}VLd5xyM!H~;3UnK+b4(T9fp*=OI7bTXurHJ`wl=;D%=QvYeHC^UER^^m26|xP1U#!g5)@ z(M4yokcBpT@GXB&q<67dX$GS2)%^YFbK27TTFGyYCd+dPn3ui3ZmRdkrFnYdR`yNh zG=xkA3sEyLud((AP()J@xfh?gcS|F(J5`S6osWb`|9L>h%tv>FbW6-ij{&!6F#X}9 zZwSW{VVmOWJ{bTC7UIl1mj5*hLO}?jY3Q5<5D+SEFn-!>2K||{^aq`t{Yx0nKo$@&jiO_wfWsSIUHJS0em3i`1uI)?puG%@vxrrznJe1PamDOM;Dk z;kQ2Vq(@Nz^bt(2-v5Iu?vU38X7iuLiXgVYhxD2G=73NYJ7Q^1mzWk89S%mwEZ|Ny zRO7)hy9M~fR{&+Ru~IscUd^xQmlh}g+Z=*}neh+0xKwB48^Rj!@ka3Wwet9z#&9qq zx`E#yQg1A8EV2O1`a@uNSuYd@%U>)$D8U)d2V{`n*zJz~dun7C`ifE(8yQ&|*a9Yq zfP)fO!t#rTFI3;|>Sj79I-jjH*3aBjmwJTQa7Z*b1HP*9iN8T8GAQxLposM|U1JA9 z7=6H1!P2GXkuK-=$$W2X-Z0IdFTP*X|+{@;uA2*%N>LDW@a&^Qtr%m25B z((C;>ak8j&gAEYPeR!2=!FMHui8tK|3+G+|xWW2HZoFdRaIq6m(6Eq&61g)H3 z-Q-kPZyKw>fQwJouvGhr8wNl0Qh^(cOECo;TQne(2^ow*wtkf2ZO`T(g^0wf#mB7V zWGupQ5eLnprE;F3r-^UVfYh}`DYFq#Zyi-P&mx5 zjt@5lKB<+@rUe!n0TUt=6dpsy6{V=Xo6!zz(e!>eoUevnv~F6eI;!XwR%PPxJams6 z@TT9|C`1TFs`_k0++=GBShPNft_6UvVfK`C_pyWny@)w8tX;F*VQ~>4v*O6NZ8zH9 zP6BM6QShW#3JS|A=da&MWqA9<1yj`DOD{M2Ek1y7ZlObXjJ(Plm@iI&;Hmn#zTK7A zMsdD@Em!EczF-IxXNx=H2oGNF;w;Dn|0QGe54Wl?@pJYpD1(I7M$x3-#zpx9jD~xak$*Wn~dt*+_MnFj5 z=J^)D^=Y1sheJ;X6LsGNB`eABB09YE`VlcP+Lv{ZI;_37rs+u=Hq<%EggvqWO0+?x z-;&NB6$PC|O|H->pY?FJj_iL9 z1S*H(bHIO{S$?+S?-%l|@_16VIYB4i_o zo3CkN-sLHYNE029D}4K_2@Erac<*lTa-42}f2V6vCGZuF$7bhpX&mUXI>c|P%@AH0 z^bZ!`I8SDPG3hj~-sU4cU^{K*~wi z5Y8dzm6{Tz456&BdZHQk0JzfUY3vTYlN4Q9B%k1}hghwIhAHH$Ij}yg{Fg$)tR@pI zpG${RX}v`f+4afFV4O^SIapbz)qRo(eg#z^Vlk@B0Zw@94^FA36b-5*Y_ygO z+1f2j7({lwZAnZu_ET{DW`YY^ntBfy;C_N>LK>bYk|k>uxG-EiL4K?UQb3Hjb&_ay zK~Ei3eTq&ys%3fX83~K%wVz8;)reS`SUBjI@jVYMBb7-3o?(8sxHd%RjxORAHT8%>~7T}q+AMMBczK{2B4+E*3L**FB8NfbobG;xyULhf_^X(+ovh5Q zAwo@}+&u;d{`nF zSa%~82WOoYr+>5!ZUMZvv-h^ihHU(u>nY0TG(@(il~FlqnZQ<98msbKG!NgTD!>RI zgn-GU%R)L06`HXv2L^GZ%Q^bj(Xfxi-qPfX4_3E;Mw3R-4z{Xa%M3z4(^J8SAlnLU zOUnAT<@KunQedL3-6$Xyea-RCqu!_`9QC5A4LqKB9$OBVENu=k^tSYYT}ytj*dy~- z!5HxR3tTe-;Rqiz;){Vn@672~H1z`|UqqH>;u?l)emsny;|r7EIWL_*%G7XMUh+hc zB>+Kg0U&vKRa=@dXjDwUva)|lAN!O*0wA-W0+R_&*Ad|hH_(Vq|4y9#xBAl z6h;pv9&plwi?^An7yLNAaV#E$gTrjm9(4HgR*k;%2z}_5O zp~I$!fs)e51{nI1)N;NMwIzpfm@jxRRMk%)!y_v9XNe;V22a|w#;a2h0e$r2RKhf; z_bukDA@`ltt8=4YY=n}A50PC~&zcq`p83Z-q#t4}&D{6ta-wogeEn}=+`A3Ce) zS2@zqbQgJvf*fVgczg)%2%-rvm;#c*Hkp;Gp}Z6Ry%)b}NlWA;D?;Mt`4#d{zoqrL zaa`d@IvQee5iq40AS4l$btb^$5rK-vVpbPTPPYnx5MnRWF9?|jE^R`~_ODsI*#frm zOS*+^3Iyo~Mwx-L{l!}OAijKw*}S^~)5IP0RGbrOK+DzXJ|>GW9P?>Fh&g#*-23Sb7NdaWIUilK$tJ6<%4_P zM&;Sea?ky@)^~>y%6tq?i{k$0!OIO=h_W2wMZ0<#Zv+*VS{gsct(Fj|z9>QI?~ps| zG7V^OoMIV=I|_qiU=rFQJtG7%lOVxy;JGi$ClzFlj8!HebpuMShg&~GMlBsA)L((D zNZ;%%G#Ai?qQgVjaHoq-arUf%mdV-7sy18R39??6A+p%^dhcn?ACQ1M2h=mSYI(1T zu$`Zf4z^tp9?1JWKhcQNXnzRxeGN!k&p^86^waak$L*S%5Pbk3YB3l+fs6WXReiJN zo{;KY>{PVR)d68{9OdR!JO&8+vXnWUDJT%^0(ss$s19MtRnYM1|3u}m zq{HzBlub0PVxBalfVdRP4@G+_MBN+r3?Y|w;gNkcj_9`JHai+zRRv-??~&Io>Ow@Z zW01EmBJljjFT{@*Ed@!APD80oVLZI7c?<3nRMBJUM2F=@QLI4!fMWkTUiNNL@B=V-R~k4}`j z2n+(!M_{>e5ja0B8L44pFxEkpA1$3dfuX`GjyS>x(<2iXZx1HFq|_?97LY*=cdc^ssT2T=~|q8=hUy$!T~c*#va0f z)PMv3^Ef;*$`b~LIy7kb595G)20ZLESp0OPAj%ViEUBU=mS)!wsUR(->O;MHD90<` zzYZf+1oi*;vhlVzK-&8lGP>VcE-pDBD`Q#SJQqU=laBDvbkq1$*#7gRLxJfvNVv1% zaYH_Is~m708h_@}ua}VNSymdj3p2_nXvqO}dR|J{fixg+wRy_5s^k|4bu1=_o)aj@ zM0G3!&+>s~m18eOxDYQ94Kir;>_JN_1Fmq3>_3jf7D(Dd0X+Uck3;MjMD4nhZP&^F^Q8ZT5MJk=>h%If^ch1m z3`*}}tD#f+zqU;3U$X*#R!IF!s`)O84969=QbRN;k!3)ph@CK;N!&F5c^r_EFT~@q z|I4KR71E;rUnU*suD3+&n*ZZI1jRUz5gA9G|81z#V+DN}*7n?&C8=R5D28!_kO-P= zdjIpJ!-7TDYyz@db`olk=0D*@+?Ch%Lv?axYoMUn>Hm`*@-Ph`WDL}3&f0ZR&nBV}^U+uJjEXD+?xK&LXDZdTty4 z^*oea{-kR+owhP#I21Ep!GUY)4>lT7v#;bzd~oyEqY5yrwQV^fS;{$s(cw?7_DeES^{l0nUTRXBlmSmrly~e2Jy&0F}EmSpHay#2W>m z4wq2;N=}}K1#e6~mBb=dvI~-GWoMagPL*EHCtQUz?iZvpL|Y^12UokiWu70b$V2cy zGe3ZF{R(#bIpF*0KDfW^)er7KW}R$W4No{0tVf2I8igTi@Kz--lB7$-+7s# z_OR$?ALSxW?#p}4&7Kb~HT~8+35NXK2jt$#+1%lldAfg5u8doKNZ90w;6Q~Z&?|3LtCn%MknyqS zn+u0T<_a$>`~;!zS9s8=9TC_gek?3DROt|D!dB_OWX(%&okO_vK0h4rX7??*8s$|o zaor%NBndi3UlryCHr$9FOLR9rxJUSqZ9SGg2|RK-cJ`HXK|M7M+8YZ#u^<<5NY28? zr}~_;^J|b-%IU7XTMF>88`y!vs|3RN$fF7KIP4K3bx(q&L}~(dmFNvwQ9{Pi!foaJ z`mL{!#MW@N+#*&UE~Xk6X~p#}Z0n#})@GluDcK zIla7LNaJ90?TqjI$OIUGF8F4Sh>iZnLSm)jcmT8qG8FRaeYkIPUz5sJ?Gt$a8kx?7 z5tpm;l-0FD3JNwLZ|8AVUyKQT51y3jr1$J)5+X(AA1J?0)edYPIU(^eLwci!)bZo< z@Mz^=Vn;xw95U7yh}8Y(1;c5q~P9fK>ne8LxH0a1vE7knJiV`shB zISJ?(R-%$m$@7Bq*}T>Bbk`xxC@a+Uw0>{N^V%ZzQ-Og)qf_K^Q~sH>c^bX34E>&a zdI#X2oN}7HtKYLt_Dd!3vT9{={Dj*RU^C>b9LB2kl@=2S{ab`2cf?L3X1rkWMop}+ zn^obHca5xSu9fTS>qqQ-r6?4n7%?2D#9_t0PONW_#B8Gu%?;LugavtdPk$gGCs3Ao z$S}k@#kNusVY9SQy%%k}D~EXCYkoQ$qQq!L+%}l|+;1^po#i&G&f>ztu3~Wb@13m4 z#L8hb?@`Lpqp&hXW&_S@qC>(mXgQMAI4Ci` zkLbn*Kr^Wgm3N#Ipy9E$nHaQ0H&+W;k=0t8HC7pg#)Y8@GEbt2{{AJD0(^=UrP{ ziw+7K##N@GfX`Jw?Z-sU5!8;&E9S(Z(AHmk&^la-a#>EhEoi&oH3{ur(D#_fYb?wc6625(y|Jp=uF$1TXU-2x>gYki_Y zPJ1@^Bv7o~u^p~XMZ+59`14%d(_ZUAHMPR zt9x8hQcV`RT&+bO25a3hx-%LqCo`LJcimlht{ol}N)5Rt6|2h~J-CF&*Gg9O?XY*N zY|>uJj*gBfG!d2wx%R0_(S8=V_IT`3+aW!;HXvJde0`tdajUFoQds)TOegdu@4z*` zZaXMoo(?HMQBGY$gNPW0FBbkf-`=4_3Wg6)tJPGQ3Bw0?y%Y$FKT#0%aN2Xg3x=z- zpL+ri^Y#Y&fEuc|i#aST8}8`xf~u4XCi0|oKo&KThO%ow@)5aXDJdzdVKA3m@ByD? z)&dp0o@WQG$XA$LkI>L&b#?Wj|M|-B>=6-YXw1PX%7G$^gh!$}z)%=>7Fz>kEgPUOm zC&ekaa658g3Z~iHMlgDzlxr;>=%$#Z8%pEhhdIY{XfWjp5z<5deGI{XB=fx@Q9@v}n$c4_q z$Eo*YmNb!>JS>?Q4XwuDnJE9KXL`e8NsJTCfM*?g?@(V1Q&gx){ftvNkoY1jW>{Xg zSry!+Zo%i^R-VA{dD_Dbyx<1y#IZkSk;Tp^Ay|8-1y(mV503_-S_8hxN zfBEu78|M88Zj}!G-{%I&Y=LA)N?V(1TtY%BQam|#?g1jdBMu=v3Ty=_4&EHct-rwl zH>{hAllqQ}O-f2{YHIpOt)-=vx9J(IJ#ie<%?$%6;bD166u#El+U{FQPHzbTXbt=y_M~n;-zU9{!4D*1pEQwectkN#Dh0-=qsT!56UuqY^b?}L zb92VoSh+}v9`XM~}7v|cKd2y{$M)9@Vsw(*2nh{i5ohI&i&g^$OE zp^*$UWH|+@Vm7rZfy`hzm2Z-p0$S!m`atcaGZzNRvcV_T00E&RU#t$ZK@teowxQtK z0}Q$`d>5=_uD=?hmkZvJ_UY5NxOsWoso^4Fi{xKsXJr*4w!RJ@1+lBl?rwc8Ofrj< zVTnb>a?npP9C{KlxK}V)O57NbI-34kKfudBPeTqH35B-h$nO$_e!&7r&qfd}8>SGR zB*sK&qIy|QJf-}~fJSghfSuEYQv0YQoj+pov3%i>^)mr+6ai2_IzXtWBqb6-7Q*q{ z+dgbWLIFrD4h?R@s4=kRjOgEIkAoi)|LZLS9By8UMCu}?ZBH;e-%QAVO8MxZiFrC-hYAz zm>LF$;XtCKJA?2m!~wT&dCcr{+8=_&WHwpbJ`ACg4&Vcv!GH3XPRxWe z5A%-m=FOWN`#*ZFA>m0S9AJhp0@VlQwI)EoD+3h&UdvR)!YLINmV`Nd`t$?4=b;J>I?FV0|4 zJj-J7@RK4h2SE%64ud`ihCH1!2`+Y0F-qD?7?!|AG#Xc#3t#$vC~UtI89QNlaA+1A zZcf3BPVmAL<+aq09e~>jw7F-d#@arcN~(ip@dIvLb^x0L7QCxul@%Ex)p#`+=REiz z6g}-P0I&HN`^3hOD`+WRYFrfsx3?n3tzL(lG3t&wL4~Qpc(TnByb~qzqiD=la1I*aAxM z!5-$tk~vs2Yg(6QIqDlDhdi6Equ}<{q#P3za5Is#`j>q0llyJgtH3niV5w4}O`}gB z4#75V3VbDaZE~0ZfB4cT3)C-07_F|xzXQO0FsAy+lZOy@&L!D^R50TIf{MF31o_;w zw6vi;&mp><`vJD$2Nc4QT!1(fCk~%$3y+b0kFrulG{Q%uv^3zl!7St#?64g|nhPI6JBP;@tS)KyFq^jenyNpqHAim zw_t%tzat2-(UA4G-zH<&bwT^zFku{W6n4!&W?6c#DXz8<`v92Qno)xFJF=Uvn2%l{ z``vqMq|S4HkEf(;auK}xt=lY~O}*4TYac$s8oNvF*^(nbhqumgODPe_4O9R@L8O&v z!FL&LP0i@Dm*1XLOc;k|MRHp^r^i=&wb*5bN8k{W<*J2Lz$g?$<%aEYt%Toflu%v| z0d=e!xVZ>r9k?!T2HuWDa)j3wS> z5RIsqrvVrkqp_g~xfDdAuGi1S)x%BBdnyXzxGq+p40&7s?wl(Q<*oNE4h(;`0Vg#k zs|8Xxh*95fGB)S_h$$<-kb^TsK%KHUHy43F0^Zj(FO39#2TuC=d z!Hg@WOKqplzR&*9kAXI2 zf^O5_3iM3P9TosECl>(ic*`Z3rvGI)9z95MZ+z(igtHu~OCXatLZYz(sURL6p3dP{ zv&n@1g6{|s4>(|lSP~(fDLs8CT?)yKj?`Y#o8XwB2m51`RsWg`jg&3myjXA^6ZFxx zwjNQhF_9VI+V}3i4};4_SiH9x9)IEID&y=Bzvh*n9d!vSZLuHwVeZ{!mA|e&qwFV<0ldmq zOFx$>JFEAwb7roj?$*10GOhFBrCm5Ag9?7gWm%y#%?-Ek$21l1+bp&w9zzaQB11Uy zeSx(dth>kLB5$Jzo2O5#JuJg)RIT#Ir965w;CQ6B^x!IWzf|=vcr=TFsz5fh1H>Iv z0-*xO1|7{@LpRCYc2M9I)+fO}zC_!4*Ng7G!n^+7)lY~3T`6>%fCV<#U(rtr-A_*D z_=`zl-sq!x&$vWULa5JV2uElDN=zvqvoUMmU+!MAkf7KH#|U$of2~4=Az3^HL7$Q| z90#6rG?Vm4aCpk$Z1uCz7UfIC^ck--tbj)7(Kk`^8U=OXvJa*{FIf1ex1VP;QOoa! z#pkT~3I&J!Kk+)TF30iX$1Q4N+)GHuthj?70#;@i($@y>?;Nk^u>V&erOa^(ejLvA zTDp+_JsA~9r2Fmz|AoFrZ*o8;D$7< z8@qV5d88x5Cxc+NPsqG|d*=;Qn+bzmZeoAqEoty_IAG7C!p|{HlOqDXt{|AKmVWf| zyF>U)i&FDT2_s|C;QVxdmxVlw)b{{tBJ3`>N?p3hJxmX^)A~_s5N{5aUDl}J>wTtk zQq6M|Ivxb5v}nUK9c_>91)l%~XD~4~^R;~RS;)?>+i2>Z%aAXw< z8fi)-hv|;80msPdc6Eq@PJsoGXe_caGZ~FRntWZp)3FrZ|1q`FJNOOa(Phw|;-NYX zCsIUPTU*g(v>|cY6WrpHQNQ8?T?Z#QpPu6*i*JWMSQtP{ z$SJz@6=n8?AKYI|C74O|o8sU4vr2n~2~!JG3l`@7b}zxvcZXx6<|@h1((^apW{aI- zE(#p@8FC1JD=Vws3r_hlQ%qTJ9%!{XxTK$8EHBv+`R~{vfgRxcYD~vHg&q@- z{!D^Flh+v=w}7;17$0!y|8kdcc$Kg>H_MORyn9!;#Hl@^ySw|Rh|77MY}-Be>qSLH zYT%fqlz)e_DIXruiWGLz670X)3e?ZtS+2g`kaa2|7Vw9_v&0;VQM$~j z`!UKx*r(89s85#a)2=>bMwK~6F@2=IGJ)S7Br=I{jTVY3VIp+xA)n#%=~A7ui>vG4 zx-(avMg?;j(uLU^eAjG)OU*ZD?pETEvKm1u+^r5Y5wii75vE*EPRV>i8`kAm6%CE5U8}HQ$LcgVg3xD0+@E8*6K}Rm|Dh~C#K{lay>ev%-1Bs62 zH=fW2#!j`_5l_Esq?>0pN?I$LHa_S*-vvL8e@B9LOO@jw1{@!?z~zRybAY(Hmvuu1 zRtB~{z6G*3=MWJ9=*3|%d|eE&24-ym8<73{z+2(GBjIsDy@LDLF~tz5`pe4!$5r~; zv1wauR74B#OZlJ~N?MojGr@x&#mJBhz##4v?`a)ai4Os;T3$u2#(~pD!&&mi1~%`? z*7DfTlIPPIx*H2RebM)3zyb{X_xjrs^lr@Dk~;JBOi&PfWrwvhe2hb{aGZPQqGaDOFa`##&)x>&(vZlk`{+O zfmC-7LK#(1B`LW6sZ*xe#)1cRQlDY35wz?2sr&F`#2gX{{l)XWd0aQnTy~s^5m|bc zn_sdGa=QZTCB!V(9T4}TM0CMKoAx`c@_wmEuVffi{2jb ziL}EwIG;|0*VjwW#cV+JEB;N04pgtV?nin`kRzg{Z+x|fgy<;s7nSx6Aoh^>e1Tza7fE=2C)x|-rBlNUI4(3IC_MWDRi zG5+)%){Z@*MF?4ee}7t#%&1rilEMzq+ck^vJevwFeni_$ZrHG2U)ZekW#2ZbtXm@4 zzwRQhseD8vCp9X;oBCPELcaE-o(+Wm%^0637(&{Y_XaW}}tf z&1rURcX#*2o1`PHK0HGO7N2a)HouKMur4(PvhEf4N>+#8OYiTKJ;6c&Tfh@Sgr9)Y z?bEihU5MljS(80jFj%E_#6IO7^ZB}LDNi=P=4SQi;QzW;4Ecix4_GZ>M_QYGcrSF& zBU*pM&9#r#Cvtjv+HM5j53aL_f%`ikSWcq;7v`r;OfG->_H7pFTB@chy*Xdpb}Y!q>b+{Z z+ygkX?n8?6qd*KX>M|sL{&`ePGDzhT`#>ZG{Xoq1jGu(iHf0Md`heDyQ|_DX8+?3x zS)x#Erw4e!HQ#!BQIUV$G)&nZ+np=7RmAU{bR#}KXk)0#Nrcp;??dv`f!^RZyIJ54 z4W^-%#U0wBN+7Wzh_{d@|JS5*;26;SamN84Azvg}B4j)~*7_!P+Cl6~j?Nf&P=fx& z2RFpn7>^toa{wa3p${KEaE+o4u>AAFu-1h2Cm^&G*==Y4<)xon^!*;)GovE8y2iLn z77b6@yr8(axE_F`?Uq0PwOL^dMl~L(?9t1)Lq_uNWz6YMIa3sLnSj!CZ^WxtuR3A9 zLMcm5O4a)bL1-!+nle0ADmuCl4e8$>+kiI_K!O*Kb8!a$44-n#0rcU+hYN1wSg*pU z-v^JRa2X86;hAkkDLKywf<+0W0fT{C`0bRQ(}mtC5NNdGk!@+$UV zoxoxyP9xpP6)$Xk2HwIn_so#0NsffU27~wG0?8(-Q!O7&pVu{lVM;tzQ3Ub1HiX3b z4#~Oy^)xk@jn@~DCfGo{lD`H=*S?u^VKx~JTTY=a^SG%dyU*iEM$vB5-rz#HOigw5 z(E{-ng+AfdJN?EldrK&}8bPdJrjLjTu_$pPOs>{vD-m=1XT+is!*+xkA>koaKz-vT zWF=NQAk@Dh;C<(jZ$EQ-Z~yRUa-V>}IZ_e_E|J*r%T;IZseAD91bB`>gP#4OJZ-m< zvenms=DF+s;ek#fLn~y^KjOoWh?U(?_qWx!W`X|_5thw~7)m6W9;SSryvPMr&{-z? zZ7}fAo{gvCC(t}N>LcYbn)T>clls*qx=ogiGRQLTEh1d<%Rn=jc{7rGuSUB53A74l z&=Rz|y5_cb+HqSm{ji~Vj01N`26K8}-l0ALbD>1bP1o)?dfxLX45KGs5^$Awm4SaQ zUI*^dU=R)k?;N44JfBO?pPgv|ZM~fy*=m`_Tt8YO^i}a;FoMC}6mf2mTBtF=f5-GQ cF;f28ryQvmIdYn90|o!+Xc(%OtJ+8WAH-6Zy8r+H literal 0 HcmV?d00001 diff --git a/docs/index.rst b/docs/index.rst index 5281006..e9746dd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -5,3 +5,6 @@ Welcome to ACTorch! `actorch` is a deep reinforcement learning framework for fast prototyping based on `PyTorch `_. Visit the GitHub repository: https://github.com/lucadellalib/actorch + +.. image:: docs/_static/images/actorch-overview.png + :width: 600 \ No newline at end of file diff --git a/examples/A2C-AsyncHyperBand_LunarLander-v2.py b/examples/A2C-AsyncHyperBand_LunarLander-v2.py index add1b62..c21725a 100644 --- a/examples/A2C-AsyncHyperBand_LunarLander-v2.py +++ b/examples/A2C-AsyncHyperBand_LunarLander-v2.py @@ -14,12 +14,13 @@ # limitations under the License. # ============================================================================== -"""Train Advantage Actor-Critic on LunarLander-v2. Tune the value network learning +"""Train Advantage Actor-Critic (A2C) on LunarLander-v2. Tune the value network learning rate using the asynchronous version of HyperBand (https://arxiv.org/abs/1810.05934). """ # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] # actorch run A2C-AsyncHyperBand_LunarLander-v2.py import gymnasium as gym diff --git a/examples/A2C_LunarLander-v2.py b/examples/A2C_LunarLander-v2.py index d92bfdb..85fb2e6 100644 --- a/examples/A2C_LunarLander-v2.py +++ b/examples/A2C_LunarLander-v2.py @@ -14,9 +14,10 @@ # limitations under the License. # ============================================================================== -"""Train Advantage Actor-Critic on LunarLander-v2.""" +"""Train Advantage Actor-Critic (A2C) on LunarLander-v2.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] # actorch run A2C_LunarLander-v2.py import gymnasium as gym diff --git a/examples/ACKTR_LunarLander-v2.py b/examples/ACKTR_LunarLander-v2.py index cdfa571..de0d291 100644 --- a/examples/ACKTR_LunarLander-v2.py +++ b/examples/ACKTR_LunarLander-v2.py @@ -14,9 +14,10 @@ # limitations under the License. # ============================================================================== -"""Train Actor-Critic Kronecker-Factored Trust Region on LunarLander-v2.""" +"""Train Actor-Critic Kronecker-Factored (ACKTR) Trust Region on LunarLander-v2.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] # actorch run ACKTR_LunarLander-v2.py import gymnasium as gym diff --git a/examples/AWR-NormalizingFlow_Pendulum-v1.py b/examples/AWR-AffineFlow_Pendulum-v1.py similarity index 95% rename from examples/AWR-NormalizingFlow_Pendulum-v1.py rename to examples/AWR-AffineFlow_Pendulum-v1.py index b63c0df..848db20 100644 --- a/examples/AWR-NormalizingFlow_Pendulum-v1.py +++ b/examples/AWR-AffineFlow_Pendulum-v1.py @@ -14,10 +14,11 @@ # limitations under the License. # ============================================================================== -"""Train Advantage-Weighted Regression equipped with a normalizing flow policy on Pendulum-v1.""" +"""Train Advantage-Weighted Regression (AWR) equipped with an affine flow policy on Pendulum-v1.""" # Navigate to `/examples`, open a terminal and run: -# actorch run AWR-NormalizingFlow_Pendulum-v1.py +# pip install gymnasium[classic_control] +# actorch run AWR-AffineFlow_Pendulum-v1.py import gymnasium as gym import torch diff --git a/examples/AWR_Pendulum-v1.py b/examples/AWR_Pendulum-v1.py index 0c42e92..3c89595 100644 --- a/examples/AWR_Pendulum-v1.py +++ b/examples/AWR_Pendulum-v1.py @@ -14,9 +14,10 @@ # limitations under the License. # ============================================================================== -"""Train Advantage-Weighted Regression on Pendulum-v1.""" +"""Train Advantage-Weighted Regression (AWR) on Pendulum-v1.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[classic_control] # actorch run AWR_Pendulum-v1.py import gymnasium as gym diff --git a/examples/D3PG-Finite_LunarLanderContinuous-v2.py b/examples/D3PG-Finite_LunarLanderContinuous-v2.py new file mode 100644 index 0000000..b2cd13e --- /dev/null +++ b/examples/D3PG-Finite_LunarLanderContinuous-v2.py @@ -0,0 +1,129 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Train Distributional Deep Deterministic Policy Gradient (D3PG) on LunarLanderContinuous-v2 +assuming a finite value distribution with 51 atoms (see https://arxiv.org/abs/1804.08617). + +""" + +# Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] +# actorch run D3PG-Finite_LunarLanderContinuous-v2.py + +import gymnasium as gym +import torch +from torch import nn +from torch.optim import Adam + +from actorch import * + + +# Define custom model +class LayerNormFCNet(FCNet): + # override + def _setup_torso(self, in_shape): + super()._setup_torso(in_shape) + torso = nn.Sequential() + for module in self.torso: + torso.append(module) + if isinstance(module, nn.Linear): + torso.append( + nn.LayerNorm(module.out_features, elementwise_affine=False) + ) + self.torso = torso + + +experiment_params = ExperimentParams( + run_or_experiment=D3PG, + stop={"timesteps_total": int(4e5)}, + resources_per_trial={"cpu": 1, "gpu": 0}, + checkpoint_freq=10, + checkpoint_at_end=True, + log_to_file=True, + export_formats=["checkpoint", "model"], + config=D3PG.Config( + train_env_builder=lambda **config: ParallelBatchedEnv( + lambda **kwargs: gym.make("LunarLanderContinuous-v2", **kwargs), + config, + num_workers=1, + ), + train_agent_builder=OUNoiseAgent, + train_agent_config={ + "clip_action": True, + "device": "cpu", + "num_random_timesteps": 1000, + "mean": 0.0, + "volatility": 0.1, + "reversion_speed": 0.15, + }, + train_num_timesteps_per_iter=1024, + eval_freq=10, + eval_env_config={"render_mode": None}, + eval_num_episodes_per_iter=10, + policy_network_model_builder=LayerNormFCNet, + policy_network_model_config={ + "torso_fc_configs": [ + {"out_features": 400, "bias": True}, + {"out_features": 300, "bias": True}, + ], + "head_activation_builder": nn.Tanh, + }, + policy_network_optimizer_builder=Adam, + policy_network_optimizer_config={"lr": 1e-3}, + value_network_model_builder=LayerNormFCNet, + value_network_model_config={ + "torso_fc_configs": [ + {"out_features": 400, "bias": True}, + {"out_features": 300, "bias": True}, + ], + }, + value_network_distribution_builder=Finite, + value_network_distribution_parametrization={ + "logits": ( + {"logits": (51,)}, + lambda x: x["logits"], + ), + }, + value_network_distribution_config={ + "atoms": torch.linspace(-10.0, 10.0, 51), + }, + value_network_optimizer_builder=Adam, + value_network_optimizer_config={"lr": 1e-3}, + buffer_config={ + "capacity": int(1e5), + "prioritization": 1.0, + "bias_correction": 0.4, + "epsilon": 1e-5, + }, + discount=0.995, + num_return_steps=5, + num_updates_per_iter=LambdaSchedule( + lambda iter: ( + 200 if iter >= 10 else 0 + ), # Fill buffer with some trajectories before training + ), + batch_size=128, + max_trajectory_length=10, + sync_freq=1, + polyak_weight=0.001, + max_grad_l2_norm=5.0, + seed=0, + enable_amp=False, + enable_reproducibility=True, + log_sys_usage=True, + suppress_warnings=True, + ), +) diff --git a/examples/D3PG-Normal_LunarLanderContinuous-v2.py b/examples/D3PG-Normal_LunarLanderContinuous-v2.py new file mode 100644 index 0000000..d6c7227 --- /dev/null +++ b/examples/D3PG-Normal_LunarLanderContinuous-v2.py @@ -0,0 +1,130 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Train Distributional Deep Deterministic Policy Gradient (D3PG) on LunarLanderContinuous-v2 +assuming a normal value distribution. + +""" + +# Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] +# actorch run D3PG-Normal_LunarLanderContinuous-v2.py + +import gymnasium as gym +from torch import nn +from torch.distributions import Normal +from torch.optim import Adam + +from actorch import * + + +# Define custom model +class LayerNormFCNet(FCNet): + # override + def _setup_torso(self, in_shape): + super()._setup_torso(in_shape) + torso = nn.Sequential() + for module in self.torso: + torso.append(module) + if isinstance(module, nn.Linear): + torso.append( + nn.LayerNorm(module.out_features, elementwise_affine=False) + ) + self.torso = torso + + +experiment_params = ExperimentParams( + run_or_experiment=D3PG, + stop={"timesteps_total": int(4e5)}, + resources_per_trial={"cpu": 1, "gpu": 0}, + checkpoint_freq=10, + checkpoint_at_end=True, + log_to_file=True, + export_formats=["checkpoint", "model"], + config=D3PG.Config( + train_env_builder=lambda **config: ParallelBatchedEnv( + lambda **kwargs: gym.make("LunarLanderContinuous-v2", **kwargs), + config, + num_workers=1, + ), + train_agent_builder=OUNoiseAgent, + train_agent_config={ + "clip_action": True, + "device": "cpu", + "num_random_timesteps": 1000, + "mean": 0.0, + "volatility": 0.1, + "reversion_speed": 0.15, + }, + train_num_timesteps_per_iter=1024, + eval_freq=10, + eval_env_config={"render_mode": None}, + eval_num_episodes_per_iter=10, + policy_network_model_builder=LayerNormFCNet, + policy_network_model_config={ + "torso_fc_configs": [ + {"out_features": 400, "bias": True}, + {"out_features": 300, "bias": True}, + ], + "head_activation_builder": nn.Tanh, + }, + policy_network_optimizer_builder=Adam, + policy_network_optimizer_config={"lr": 1e-3}, + value_network_model_builder=LayerNormFCNet, + value_network_model_config={ + "torso_fc_configs": [ + {"out_features": 400, "bias": True}, + {"out_features": 300, "bias": True}, + ], + }, + value_network_distribution_builder=Normal, + value_network_distribution_parametrization={ + "loc": ( + {"loc": ()}, + lambda x: x["loc"], + ), + "scale": ( + {"log_scale": ()}, + lambda x: x["log_scale"].clamp(-20.0, 2.0).exp(), + ), + }, + value_network_optimizer_builder=Adam, + value_network_optimizer_config={"lr": 1e-3}, + buffer_config={ + "capacity": int(1e5), + "prioritization": 1.0, + "bias_correction": 0.4, + "epsilon": 1e-5, + }, + discount=0.995, + num_return_steps=5, + num_updates_per_iter=LambdaSchedule( + lambda iter: ( + 200 if iter >= 10 else 0 + ), # Fill buffer with some trajectories before training + ), + batch_size=128, + max_trajectory_length=10, + sync_freq=1, + polyak_weight=0.001, + max_grad_l2_norm=5.0, + seed=0, + enable_amp=False, + enable_reproducibility=True, + log_sys_usage=True, + suppress_warnings=True, + ), +) diff --git a/examples/DDPG_LunarLanderContinuous-v2.py b/examples/DDPG_LunarLanderContinuous-v2.py index de4912a..c44c4fd 100644 --- a/examples/DDPG_LunarLanderContinuous-v2.py +++ b/examples/DDPG_LunarLanderContinuous-v2.py @@ -14,9 +14,10 @@ # limitations under the License. # ============================================================================== -"""Train Deep Deterministic Policy Gradient on LunarLanderContinuous-v2.""" +"""Train Deep Deterministic Policy Gradient (DDPG) on LunarLanderContinuous-v2.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] # actorch run DDPG_LunarLanderContinuous-v2.py import gymnasium as gym diff --git a/examples/DistributedDataParallelDDPG_LunarLanderContinuous-v2.py b/examples/DistributedDataParallelDDPG_LunarLanderContinuous-v2.py index f9319a5..f3a3df7 100644 --- a/examples/DistributedDataParallelDDPG_LunarLanderContinuous-v2.py +++ b/examples/DistributedDataParallelDDPG_LunarLanderContinuous-v2.py @@ -14,9 +14,10 @@ # limitations under the License. # ============================================================================== -"""Train distributed data parallel Deep Deterministic Policy Gradient on LunarLanderContinuous-v2.""" +"""Train distributed data parallel Deep Deterministic Policy Gradient (DDPG) on LunarLanderContinuous-v2.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] # actorch run DistributedDataParallelDDPG_LunarLanderContinuous-v2.py import gymnasium as gym diff --git a/examples/DistributedDataParallelREINFORCE_CartPole-v1.py b/examples/DistributedDataParallelREINFORCE_CartPole-v1.py index ff470ab..b07ad36 100644 --- a/examples/DistributedDataParallelREINFORCE_CartPole-v1.py +++ b/examples/DistributedDataParallelREINFORCE_CartPole-v1.py @@ -17,6 +17,7 @@ """Train distributed data parallel REINFORCE on CartPole-v1.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[classic_control] # actorch run DistributedDataParallelREINFORCE_CartPole-v1.py import gymnasium as gym diff --git a/examples/PPO-Laplace_Pendulum-v1.py b/examples/PPO-Laplace_Pendulum-v1.py new file mode 100644 index 0000000..32ad43e --- /dev/null +++ b/examples/PPO-Laplace_Pendulum-v1.py @@ -0,0 +1,94 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Train Proximal Policy Optimization (PPO) on Pendulum-v1 assuming a Laplace policy distribution.""" + +# Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[classic_control] +# actorch run PPO-Laplace_Pendulum-v1.py + +import gymnasium as gym +from torch.distributions import Laplace +from torch.optim import Adam + +from actorch import * + + +experiment_params = ExperimentParams( + run_or_experiment=PPO, + stop={"timesteps_total": int(1e5)}, + resources_per_trial={"cpu": 1, "gpu": 0}, + checkpoint_freq=10, + checkpoint_at_end=True, + log_to_file=True, + export_formats=["checkpoint", "model"], + config=PPO.Config( + train_env_builder=lambda **config: ParallelBatchedEnv( + lambda **kwargs: gym.make("Pendulum-v1", **kwargs), + config, + num_workers=2, + ), + train_num_timesteps_per_iter=2048, + eval_freq=10, + eval_env_config={"render_mode": None}, + eval_num_episodes_per_iter=10, + policy_network_model_builder=FCNet, + policy_network_model_config={ + "torso_fc_configs": [ + {"out_features": 256, "bias": True}, + {"out_features": 256, "bias": True}, + ], + "independent_heads": ["action/log_scale"], + }, + policy_network_distribution_builders={"action": Laplace}, + policy_network_distribution_parametrizations={ + "action": { + "loc": ( + {"loc": (1,)}, + lambda x: x["loc"], + ), + "scale": ( + {"log_scale": (1,)}, + lambda x: x["log_scale"].exp(), + ), + }, + }, + policy_network_optimizer_builder=Adam, + policy_network_optimizer_config={"lr": 5e-5}, + value_network_model_builder=FCNet, + value_network_model_config={ + "torso_fc_configs": [ + {"out_features": 256, "bias": True}, + {"out_features": 256, "bias": True}, + ], + }, + value_network_optimizer_builder=Adam, + value_network_optimizer_config={"lr": 3e-3}, + discount=0.99, + trace_decay=0.95, + num_epochs=20, + minibatch_size=16, + ratio_clip=0.2, + normalize_advantage=True, + entropy_coeff=0.01, + max_grad_l2_norm=0.5, + seed=0, + enable_amp=False, + enable_reproducibility=True, + log_sys_usage=True, + suppress_warnings=True, + ), +) diff --git a/examples/PPO_Pendulum-v1.py b/examples/PPO_Pendulum-v1.py index 0089d72..6cb22ea 100644 --- a/examples/PPO_Pendulum-v1.py +++ b/examples/PPO_Pendulum-v1.py @@ -14,9 +14,10 @@ # limitations under the License. # ============================================================================== -"""Train Proximal Policy Optimization on Pendulum-v1.""" +"""Train Proximal Policy Optimization (PPO) on Pendulum-v1.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[classic_control] # actorch run PPO_Pendulum-v1.py import gymnasium as gym diff --git a/examples/REINFORCE_CartPole-v1.py b/examples/REINFORCE_CartPole-v1.py index 2ab5c13..342ed36 100644 --- a/examples/REINFORCE_CartPole-v1.py +++ b/examples/REINFORCE_CartPole-v1.py @@ -17,6 +17,7 @@ """Train REINFORCE on CartPole-v1.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[classic_control] # actorch run REINFORCE_CartPole-v1.py import gymnasium as gym diff --git a/examples/SAC_BipedalWalker-v3.py b/examples/SAC_BipedalWalker-v3.py new file mode 100644 index 0000000..55c768e --- /dev/null +++ b/examples/SAC_BipedalWalker-v3.py @@ -0,0 +1,129 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Train Soft Actor-Critic (SAC) on BipedalWalker-v3.""" + +# Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] +# actorch run SAC_BipedalWalker-v3.py + +import gymnasium as gym +import numpy as np +import torch.nn +from torch.distributions import Normal, TanhTransform +from torch.optim import Adam + +from actorch import * + + +class Wrapper(gym.Wrapper): + def __init__(self, env, action_noise=0.3, action_repeat=3, reward_scale=5): + super().__init__(env) + self.action_noise = action_noise + self.action_repeat = action_repeat + self.reward_scale = reward_scale + + def step(self, action): + action += self.action_noise * ( + 1 - 2 * np.random.random(self.action_space.shape) + ) + cumreward = 0.0 + for _ in range(self.action_repeat): + observation, reward, terminated, truncated, info = super().step(action) + cumreward += reward + if terminated: + return observation, 0.0, terminated, truncated, info + return observation, self.reward_scale * cumreward, terminated, truncated, info + + +experiment_params = ExperimentParams( + run_or_experiment=SAC, + stop={"timesteps_total": int(5e5)}, + resources_per_trial={"cpu": 1, "gpu": 0}, + checkpoint_freq=10, + checkpoint_at_end=True, + log_to_file=True, + export_formats=["checkpoint", "model"], + config=SAC.Config( + train_env_builder=lambda **config: ParallelBatchedEnv( + lambda **kwargs: Wrapper( + gym.wrappers.TimeLimit( + gym.make("BipedalWalker-v3", **kwargs), + max_episode_steps=200, + ), + ), + config, + num_workers=1, + ), + train_num_episodes_per_iter=1, + eval_freq=10, + eval_env_config={"render_mode": None}, + eval_num_episodes_per_iter=10, + policy_network_model_builder=FCNet, + policy_network_model_config={ + "torso_fc_configs": [ + {"out_features": 400, "bias": True}, + {"out_features": 300, "bias": True}, + ], + "head_activation_builder": torch.nn.Tanh, + }, + policy_network_distribution_builders={ + "action": lambda loc, scale: TransformedDistribution( + Normal(loc, scale), + TanhTransform(cache_size=1), # Use tanh normal to enforce action bounds + ), + }, + policy_network_distribution_parametrizations={ + "action": { + "loc": ( + {"loc": (4,)}, + lambda x: x["loc"], + ), + "scale": ( + {"log_scale": (4,)}, + lambda x: x["log_scale"].clamp(-20.0, 2.0).exp(), + ), + }, + }, + policy_network_sample_fn=lambda d: d.sample(), # `mode` does not exist in closed-form for tanh normal + policy_network_optimizer_builder=Adam, + policy_network_optimizer_config={"lr": 1e-3}, + value_network_model_builder=FCNet, + value_network_model_config={ + "torso_fc_configs": [ + {"out_features": 400, "bias": True}, + {"out_features": 300, "bias": True}, + ], + }, + value_network_optimizer_builder=Adam, + value_network_optimizer_config={"lr": 1e-3}, + buffer_config={"capacity": int(3e5)}, + discount=0.98, + num_return_steps=5, + num_updates_per_iter=10, + batch_size=1024, + max_trajectory_length=20, + sync_freq=1, + polyak_weight=0.001, + temperature=0.2, + max_grad_l2_norm=1.0, + seed=0, + enable_amp=False, + enable_reproducibility=True, + log_sys_usage=True, + suppress_warnings=True, + ), +) diff --git a/examples/SAC_Pendulum-v1.py b/examples/SAC_Pendulum-v1.py new file mode 100644 index 0000000..b3efe24 --- /dev/null +++ b/examples/SAC_Pendulum-v1.py @@ -0,0 +1,82 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Train Soft Actor-Critic (SAC) on Pendulum-v1.""" + +# Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[classic_control] +# actorch run SAC_Pendulum-v1.py + +import gymnasium as gym +from torch.optim import Adam + +from actorch import * + + +experiment_params = ExperimentParams( + run_or_experiment=SAC, + stop={"timesteps_total": int(1e5)}, + resources_per_trial={"cpu": 1, "gpu": 0}, + checkpoint_freq=10, + checkpoint_at_end=True, + log_to_file=True, + export_formats=["checkpoint", "model"], + config=SAC.Config( + train_env_builder=lambda **config: ParallelBatchedEnv( + lambda **kwargs: gym.make("Pendulum-v1", **kwargs), + config, + num_workers=2, + ), + train_num_timesteps_per_iter=500, + eval_freq=10, + eval_env_config={"render_mode": None}, + eval_num_episodes_per_iter=10, + policy_network_model_builder=FCNet, + policy_network_model_config={ + "torso_fc_configs": [ + {"out_features": 256, "bias": True}, + {"out_features": 256, "bias": True}, + ], + }, + policy_network_optimizer_builder=Adam, + policy_network_optimizer_config={"lr": 5e-3}, + value_network_model_builder=FCNet, + value_network_model_config={ + "torso_fc_configs": [ + {"out_features": 256, "bias": True}, + {"out_features": 256, "bias": True}, + ], + }, + value_network_optimizer_builder=Adam, + value_network_optimizer_config={"lr": 5e-3}, + temperature_optimizer_builder=Adam, + temperature_optimizer_config={"lr": 1e-5}, + buffer_config={"capacity": int(1e5)}, + discount=0.99, + num_return_steps=1, + num_updates_per_iter=500, + batch_size=32, + max_trajectory_length=10, + sync_freq=1, + polyak_weight=0.001, + max_grad_l2_norm=2.0, + seed=0, + enable_amp=False, + enable_reproducibility=True, + log_sys_usage=True, + suppress_warnings=True, + ), +) diff --git a/examples/TD3_LunarLanderContinuous-v2.py b/examples/TD3_LunarLanderContinuous-v2.py index 258e7c1..dee26b4 100644 --- a/examples/TD3_LunarLanderContinuous-v2.py +++ b/examples/TD3_LunarLanderContinuous-v2.py @@ -14,9 +14,10 @@ # limitations under the License. # ============================================================================== -"""Train Twin Delayed Deep Deterministic Policy Gradient on LunarLanderContinuous-v2.""" +"""Train Twin Delayed Deep Deterministic Policy Gradient (TD3) on LunarLanderContinuous-v2.""" # Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[box2d] # actorch run TD3_LunarLanderContinuous-v2.py import gymnasium as gym diff --git a/examples/TRPO_Pendulum-v1.py b/examples/TRPO_Pendulum-v1.py new file mode 100644 index 0000000..df1fbd0 --- /dev/null +++ b/examples/TRPO_Pendulum-v1.py @@ -0,0 +1,84 @@ +# ============================================================================== +# Copyright 2022 Luca Della Libera. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Train Trust Region Policy Optimization (TRPO) on Pendulum-v1.""" + +# Navigate to `/examples`, open a terminal and run: +# pip install gymnasium[classic_control] +# actorch run TRPO_Pendulum-v1.py + +import gymnasium as gym +from torch.optim import Adam + +from actorch import * + + +experiment_params = ExperimentParams( + run_or_experiment=TRPO, + stop={"timesteps_total": int(2e5)}, + resources_per_trial={"cpu": 1, "gpu": 0}, + checkpoint_freq=10, + checkpoint_at_end=True, + log_to_file=True, + export_formats=["checkpoint", "model"], + config=TRPO.Config( + train_env_builder=lambda **config: ParallelBatchedEnv( + lambda **kwargs: gym.make("Pendulum-v1", **kwargs), + config, + num_workers=2, + ), + train_num_timesteps_per_iter=512, + eval_freq=10, + eval_env_config={"render_mode": None}, + eval_num_episodes_per_iter=10, + policy_network_model_builder=FCNet, + policy_network_model_config={ + "torso_fc_configs": [ + {"out_features": 256, "bias": True}, + {"out_features": 256, "bias": True}, + ], + "independent_heads": ["action/log_scale"], + }, + policy_network_optimizer_config={ + "max_constraint": 0.01, + "num_cg_iters": 10, + "max_backtracks": 15, + "backtrack_ratio": 0.8, + "hvp_reg_coeff": 1e-5, + "accept_violation": False, + "epsilon": 1e-8, + }, + value_network_model_builder=FCNet, + value_network_model_config={ + "torso_fc_configs": [ + {"out_features": 256, "bias": True}, + {"out_features": 256, "bias": True}, + ], + }, + value_network_optimizer_builder=Adam, + value_network_optimizer_config={"lr": 3e-3}, + discount=0.9, + trace_decay=0.95, + normalize_advantage=False, + entropy_coeff=0.01, + max_grad_l2_norm=2.0, + seed=0, + enable_amp=False, + enable_reproducibility=True, + log_sys_usage=True, + suppress_warnings=True, + ), +) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 0b6e578..9d8c59d 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -123,6 +123,7 @@ def _cleanup(): algorithms.AWR, algorithms.PPO, algorithms.REINFORCE, + algorithms.TRPO, ], ) @pytest.mark.parametrize( @@ -198,7 +199,9 @@ def test_algorithm_mixed(algorithm_cls, model_cls, space): @pytest.mark.parametrize( "algorithm_cls", [ + algorithms.D3PG, algorithms.DDPG, + algorithms.SAC, algorithms.TD3, ], )