diff --git a/README.md b/README.md
index 8c39d88..5bc336b 100644
--- a/README.md
+++ b/README.md
@@ -14,33 +14,36 @@ Welcome to `actorch`, a deep reinforcement learning framework for fast prototypi
- [REINFORCE](https://people.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf)
- [Advantage Actor-Critic (A2C)](https://arxiv.org/abs/1602.01783)
- [Actor-Critic Kronecker-Factored Trust Region (ACKTR)](https://arxiv.org/abs/1708.05144)
+- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347)
- [Advantage-Weighted Regression (AWR)](https://arxiv.org/abs/1910.00177)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971)
+- [Distributional Deep Deterministic Policy Gradient (D3PG)](https://arxiv.org/abs/1804.08617)
- [Twin Delayed Deep Deterministic Policy Gradient (TD3)](https://arxiv.org/abs/1802.09477)
+- [Soft Actor-Critic (SAC)](https://arxiv.org/abs/1801.01290)
---------------------------------------------------------------------------------------------------------
## 💡 Key features
- Support for [OpenAI Gymnasium](https://gymnasium.farama.org/) environments
-- Support for custom observation/action spaces
-- Support for custom multimodal input multimodal output models
-- Support for recurrent models (e.g. RNNs, LSTMs, GRUs, etc.)
-- Support for custom policy/value distributions
-- Support for custom preprocessing/postprocessing pipelines
-- Support for custom exploration strategies
+- Support for **custom observation/action spaces**
+- Support for **custom multimodal input multimodal output models**
+- Support for **recurrent models** (e.g. RNNs, LSTMs, GRUs, etc.)
+- Support for **custom policy/value distributions**
+- Support for **custom preprocessing/postprocessing pipelines**
+- Support for **custom exploration strategies**
- Support for [normalizing flows](https://arxiv.org/abs/1906.02771)
- Batched environments (both for training and evaluation)
-- Batched trajectory replay
-- Batched and distributional value estimation (e.g. batched and distributional [Retrace](https://arxiv.org/abs/1606.02647) and [V-trace](https://arxiv.org/abs/1802.01561))
-- Data parallel and distributed data parallel multi-GPU training and evaluation
-- Automatic mixed precision training
-- Integration with [Ray Tune](https://docs.ray.io/en/releases-1.13.0/tune/index.html) for experiment execution and hyperparameter tuning at any scale
-- Effortless experiment definition through Python-based configuration files
-- Built-in visualization tool to plot performance metrics
-- Modular object-oriented design
-- Detailed API documentation
+- Batched **trajectory replay**
+- Batched and **distributional value estimation** (e.g. batched and distributional [Retrace](https://arxiv.org/abs/1606.02647) and [V-trace](https://arxiv.org/abs/1802.01561))
+- Data parallel and distributed data parallel **multi-GPU training and evaluation**
+- Automatic **mixed precision training**
+- Integration with [Ray Tune](https://docs.ray.io/en/releases-1.13.0/tune/index.html) for experiment execution and **hyperparameter tuning** at any scale
+- Effortless experiment definition through **Python-based configuration files**
+- Built-in **visualization tool** to plot performance metrics
+- Modular **object-oriented** design
+- Detailed **API documentation**
---------------------------------------------------------------------------------------------------------
@@ -161,7 +164,7 @@ experiment_params = ExperimentParams(
enable_amp=False,
enable_reproducibility=True,
log_sys_usage=True,
- suppress_warnings=False,
+ suppress_warnings=True,
),
)
```
@@ -197,6 +200,8 @@ You can find the generated plots in `plots`.
Congratulations, you ran your first experiment!
+See `examples` for additional configuration file examples.
+
**HINT**: since a configuration file is a regular Python script, you can use all the
features of the language (e.g. inheritance).
@@ -217,6 +222,21 @@ features of the language (e.g. inheritance).
---------------------------------------------------------------------------------------------------------
+## @ Citation
+
+```
+@misc{DellaLibera2022ACTorch,
+ author = {Luca Della Libera},
+ title = {{ACTorch}: a Deep Reinforcement Learning Framework for Fast Prototyping},
+ year = {2022},
+ publisher = {GitHub},
+ journal = {GitHub repository},
+ howpublished = {\url{https://github.com/lucadellalib/actorch}},
+}
+```
+
+---------------------------------------------------------------------------------------------------------
+
## 📧 Contact
[luca.dellalib@gmail.com](mailto:luca.dellalib@gmail.com)
diff --git a/actorch/algorithms/__init__.py b/actorch/algorithms/__init__.py
index bbb311d..3235bc7 100644
--- a/actorch/algorithms/__init__.py
+++ b/actorch/algorithms/__init__.py
@@ -20,7 +20,10 @@
from actorch.algorithms.acktr import *
from actorch.algorithms.algorithm import *
from actorch.algorithms.awr import *
+from actorch.algorithms.d3pg import *
from actorch.algorithms.ddpg import *
from actorch.algorithms.ppo import *
from actorch.algorithms.reinforce import *
+from actorch.algorithms.sac import *
from actorch.algorithms.td3 import *
+from actorch.algorithms.trpo import *
diff --git a/actorch/algorithms/a2c.py b/actorch/algorithms/a2c.py
index 9cce312..4a3e6e4 100644
--- a/actorch/algorithms/a2c.py
+++ b/actorch/algorithms/a2c.py
@@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================
-"""Advantage Actor-Critic."""
+"""Advantage Actor-Critic (A2C)."""
import contextlib
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
@@ -38,7 +38,7 @@
DistributedDataParallelREINFORCE,
LRScheduler,
)
-from actorch.algorithms.utils import prepare_model
+from actorch.algorithms.utils import normalize_, prepare_model
from actorch.algorithms.value_estimation import n_step_return
from actorch.distributions import Deterministic
from actorch.envs import BatchedEnv
@@ -65,7 +65,7 @@
class A2C(REINFORCE):
- """Advantage Actor-Critic.
+ """Advantage Actor-Critic (A2C).
References
----------
@@ -208,12 +208,8 @@ def setup(self, config: "Dict[str, Any]") -> "None":
self.config = A2C.Config(**self.config)
self.config["_accept_kwargs"] = True
super().setup(config)
- self._value_network = (
- self._build_value_network().train().to(self._device, non_blocking=True)
- )
- self._value_network_loss = (
- self._build_value_network_loss().train().to(self._device, non_blocking=True)
- )
+ self._value_network = self._build_value_network()
+ self._value_network_loss = self._build_value_network_loss()
self._value_network_optimizer = self._build_value_network_optimizer()
self._value_network_optimizer_lr_scheduler = (
self._build_value_network_optimizer_lr_scheduler()
@@ -324,16 +320,20 @@ def _build_value_network(self) -> "Network":
self.value_network_normalizing_flows,
)
self._log_graph(value_network.wrapped_model.model, "value_network_model")
- return value_network
+ return value_network.train().to(self._device, non_blocking=True)
def _build_value_network_loss(self) -> "Loss":
if self.value_network_loss_builder is None:
self.value_network_loss_builder = torch.nn.MSELoss
if self.value_network_loss_config is None:
self.value_network_loss_config: "Dict[str, Any]" = {}
- return self.value_network_loss_builder(
- reduction="none",
- **self.value_network_loss_config,
+ return (
+ self.value_network_loss_builder(
+ reduction="none",
+ **self.value_network_loss_config,
+ )
+ .train()
+ .to(self._device, non_blocking=True)
)
def _build_value_network_optimizer(self) -> "Optimizer":
@@ -374,6 +374,8 @@ def _train_step(self) -> "Dict[str, Any]":
result = super()._train_step()
self.num_return_steps.step()
result["num_return_steps"] = self.num_return_steps()
+ result["entropy_coeff"] = result.pop("entropy_coeff", None)
+ result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None)
return result
# override
@@ -405,17 +407,7 @@ def _train_on_batch(
self.num_return_steps(),
)
if self.normalize_advantage:
- length = mask.sum(dim=1, keepdim=True)
- advantages_mean = advantages.sum(dim=1, keepdim=True) / length
- advantages -= advantages_mean
- advantages *= mask
- advantages_stddev = (
- ((advantages**2).sum(dim=1, keepdim=True) / length)
- .sqrt()
- .clamp(min=1e-6)
- )
- advantages /= advantages_stddev
- advantages *= mask
+ normalize_(advantages, dim=-1, mask=mask)
# Discard next state value
state_values = state_values[:, :-1]
@@ -449,10 +441,12 @@ def _train_on_batch_value_network(
state_value = state_values[mask]
target = targets[mask]
loss = self._value_network_loss(state_value, target)
- loss *= is_weight[:, None].expand_as(mask)[mask]
+ priority = None
+ if self._buffer.is_prioritized:
+ loss *= is_weight[:, None].expand_as(mask)[mask]
+ priority = loss.detach().abs().to("cpu").numpy()
loss = loss.mean()
optimize_result = self._optimize_value_network(loss)
- priority = None
result = {
"state_value": state_value.mean().item(),
"target": target.mean().item(),
@@ -490,7 +484,7 @@ def _get_default_value_network_preprocessor(
class DistributedDataParallelA2C(DistributedDataParallelREINFORCE):
- """Distributed data parallel Advantage Actor-Critic.
+ """Distributed data parallel Advantage Actor-Critic (A2C).
See Also
--------
diff --git a/actorch/algorithms/acktr.py b/actorch/algorithms/acktr.py
index 931b16c..4884914 100644
--- a/actorch/algorithms/acktr.py
+++ b/actorch/algorithms/acktr.py
@@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================
-"""Actor-Critic Kronecker-Factored Trust Region."""
+"""Actor-Critic Kronecker-Factored Trust Region (ACKTR)."""
from typing import Any, Callable, Dict, Optional, Union
@@ -43,7 +43,7 @@
class ACKTR(A2C):
- """Actor-Critic Kronecker-Factored Trust Region.
+ """Actor-Critic Kronecker-Factored Trust Region (ACKTR).
References
----------
@@ -287,7 +287,7 @@ def _optimize_policy_network(self, loss: "Tensor") -> "Dict[str, Any]":
class DistributedDataParallelACKTR(DistributedDataParallelA2C):
- """Distributed data parallel Actor-Critic Kronecker-Factored Trust Region.
+ """Distributed data parallel Actor-Critic Kronecker-Factored Trust Region (ACKTR).
See Also
--------
diff --git a/actorch/algorithms/algorithm.py b/actorch/algorithms/algorithm.py
index 1fcd34b..fdb568d 100644
--- a/actorch/algorithms/algorithm.py
+++ b/actorch/algorithms/algorithm.py
@@ -51,7 +51,6 @@
from ray.tune.syncer import NodeSyncer
from ray.tune.trial import ExportFormat
from torch import Tensor
-from torch.cuda.amp import autocast
from torch.distributions import Bernoulli, Categorical, Distribution, Normal
from torch.profiler import profile, record_function, tensorboard_trace_handler
from torch.utils.data import DataLoader
@@ -113,7 +112,7 @@ class Algorithm(ABC, Trainable):
_EXPORT_FORMATS = [ExportFormat.CHECKPOINT, ExportFormat.MODEL]
- _UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH = True
+ _OFF_POLICY = True
class Config(dict):
"""Keyword arguments expected in the configuration received by `setup`."""
@@ -692,12 +691,7 @@ def _build_train_env(self) -> "BatchedEnv":
if self.train_env_config is None:
self.train_env_config = {}
- try:
- train_env = self.train_env_builder(
- **self.train_env_config,
- )
- except TypeError:
- train_env = self.train_env_builder(**self.train_env_config)
+ train_env = self.train_env_builder(**self.train_env_config)
if not isinstance(train_env, BatchedEnv):
train_env.close()
train_env = SerialBatchedEnv(self.train_env_builder, self.train_env_config)
@@ -866,7 +860,7 @@ def _build_policy_network(self) -> "PolicyNetwork": # noqa: C901
self.policy_network_postprocessors,
)
self._log_graph(policy_network.wrapped_model.model, "policy_network_model")
- return policy_network
+ return policy_network.train().to(self._device, non_blocking=True)
def _build_train_agent(self) -> "Agent":
if self.train_agent_builder is None:
@@ -989,14 +983,18 @@ def _build_dataloader(self) -> "DataLoader":
if self.dataloader_builder is None:
self.dataloader_builder = DataLoader
if self.dataloader_config is None:
- fork = torch.multiprocessing.get_start_method() == "fork"
+ use_mp = (
+ self._OFF_POLICY
+ and not self._buffer.is_prioritized
+ and torch.multiprocessing.get_start_method() == "fork"
+ )
self.dataloader_config = {
- "num_workers": 1 if fork else 0,
+ "num_workers": 1 if use_mp else 0,
"pin_memory": True,
"timeout": 0,
"worker_init_fn": None,
"generator": None,
- "prefetch_factor": 1 if fork else 2,
+ "prefetch_factor": 1 if use_mp else 2,
"pin_memory_device": "",
}
if self.dataloader_config is None:
@@ -1037,20 +1035,15 @@ def _train_step(self) -> "Dict[str, Any]":
if self.train_num_episodes_per_iter:
train_num_episodes_per_iter = self.train_num_episodes_per_iter()
self.train_num_episodes_per_iter.step()
- with (
- autocast(**self.enable_amp)
- if self.enable_amp["enabled"]
- else contextlib.suppress()
+ for experience, done in self._train_sampler.sample(
+ train_num_timesteps_per_iter,
+ train_num_episodes_per_iter,
):
- for experience, done in self._train_sampler.sample(
- train_num_timesteps_per_iter,
- train_num_episodes_per_iter,
- ):
- self._buffer.add(experience, done)
+ self._buffer.add(experience, done)
result = self._train_sampler.stats
self._cumrewards += result["episode_cumreward"]
- if not self._UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH:
+ if not self._OFF_POLICY:
for schedule in self._buffer_dataset.schedules.values():
schedule.step()
train_epoch_result = self._train_epoch()
@@ -1060,7 +1053,7 @@ def _train_step(self) -> "Dict[str, Any]":
schedule.step()
for schedule in self._buffer.schedules.values():
schedule.step()
- if self._UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH:
+ if self._OFF_POLICY:
for schedule in self._buffer_dataset.schedules.values():
schedule.step()
return result
@@ -1113,16 +1106,11 @@ def _eval_step(self) -> "Dict[str, Any]":
eval_num_episodes_per_iter = self.eval_num_episodes_per_iter()
self.eval_num_episodes_per_iter.step()
self._eval_sampler.reset()
- with (
- autocast(**self.enable_amp)
- if self.enable_amp["enabled"]
- else contextlib.suppress()
+ for _ in self._eval_sampler.sample(
+ eval_num_timesteps_per_iter,
+ eval_num_episodes_per_iter,
):
- for _ in self._eval_sampler.sample(
- eval_num_timesteps_per_iter,
- eval_num_episodes_per_iter,
- ):
- pass
+ pass
for schedule in self._eval_agent.schedules.values():
schedule.step()
return self._eval_sampler.stats
@@ -1353,7 +1341,7 @@ def __init__(
Default to ``{}``.
placement_strategy:
The placement strategy
- (see https://docs.ray.io/en/latest/ray-core/placement-group.html).
+ (see https://docs.ray.io/en/releases-1.13.0/ray-core/placement-group.html for Ray 1.13.0).
backend:
The backend for distributed execution
(see https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group).
diff --git a/actorch/algorithms/awr.py b/actorch/algorithms/awr.py
index f13d6db..1e887c3 100644
--- a/actorch/algorithms/awr.py
+++ b/actorch/algorithms/awr.py
@@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================
-"""Advantage-Weighted Regression."""
+"""Advantage-Weighted Regression (AWR)."""
import contextlib
from typing import Any, Callable, Dict, Optional, Tuple, Union
@@ -31,6 +31,7 @@
from actorch.agents import Agent
from actorch.algorithms.a2c import A2C, DistributedDataParallelA2C, Loss, LRScheduler
from actorch.algorithms.algorithm import RefOrFutureRef, Tunable
+from actorch.algorithms.utils import normalize_
from actorch.algorithms.value_estimation import lambda_return
from actorch.buffers import Buffer
from actorch.datasets import BufferDataset
@@ -48,7 +49,7 @@
class AWR(A2C):
- """Advantage-Weighted Regression.
+ """Advantage-Weighted Regression (AWR).
References
----------
@@ -59,9 +60,7 @@ class AWR(A2C):
"""
- _UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH = True # override
-
- _RESET_BUFFER = False # override
+ _OFF_POLICY = True # override
# override
class Config(dict):
@@ -277,6 +276,8 @@ def _train_step(self) -> "Dict[str, Any]":
)
result["weight_clip"] = self.weight_clip()
result["temperature"] = self.temperature()
+ result["entropy_coeff"] = result.pop("entropy_coeff", None)
+ result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None)
result["buffer_num_experiences"] = self._buffer.num_experiences
result["buffer_num_full_trajectories"] = self._buffer.num_full_trajectories
result.pop("num_return_steps", None)
@@ -311,17 +312,7 @@ def _train_on_batch(
self.trace_decay(),
)
if self.normalize_advantage:
- length = mask.sum(dim=1, keepdim=True)
- advantages_mean = advantages.sum(dim=1, keepdim=True) / length
- advantages -= advantages_mean
- advantages *= mask
- advantages_stddev = (
- ((advantages**2).sum(dim=1, keepdim=True) / length)
- .sqrt()
- .clamp(min=1e-6)
- )
- advantages /= advantages_stddev
- advantages *= mask
+ normalize_(advantages, dim=-1, mask=mask)
# Discard next state value
state_values = state_values[:, :-1]
@@ -352,16 +343,19 @@ def _train_on_batch_policy_network(
raise ValueError(
f"`entropy_coeff` ({entropy_coeff}) must be in the interval [0, inf)"
)
+
weight_clip = self.weight_clip()
if weight_clip <= 0.0:
raise ValueError(
f"`weight_clip` ({weight_clip}) must be in the interval (0, inf)"
)
+
temperature = self.temperature()
if temperature <= 0.0:
raise ValueError(
f"`temperature` ({temperature}) must be in the interval (0, inf)"
)
+
with (
autocast(**self.enable_amp)
if self.enable_amp["enabled"]
@@ -392,7 +386,7 @@ def _train_on_batch_policy_network(
class DistributedDataParallelAWR(DistributedDataParallelA2C):
- """Distributed data parallel Advantage-Weighted Regression.
+ """Distributed data parallel Advantage-Weighted Regression (AWR).
See Also
--------
diff --git a/actorch/algorithms/d3pg.py b/actorch/algorithms/d3pg.py
new file mode 100644
index 0000000..ef201d0
--- /dev/null
+++ b/actorch/algorithms/d3pg.py
@@ -0,0 +1,430 @@
+# ==============================================================================
+# Copyright 2022 Luca Della Libera.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Distributional Deep Deterministic Policy Gradient (D3PG)."""
+
+import contextlib
+import logging
+from typing import Any, Callable, Dict, Optional, Tuple, Union
+
+import torch
+from gymnasium import Env
+from numpy import ndarray
+from torch import Tensor
+from torch.cuda.amp import autocast
+from torch.distributions import Distribution, kl_divergence
+from torch.optim import Optimizer
+from torch.utils.data import DataLoader
+
+from actorch.agents import Agent
+from actorch.algorithms.algorithm import RefOrFutureRef, Tunable
+from actorch.algorithms.ddpg import DDPG, DistributedDataParallelDDPG, Loss, LRScheduler
+from actorch.algorithms.utils import freeze_params, sync_polyak_
+from actorch.algorithms.value_estimation import n_step_return
+from actorch.buffers import Buffer, ProportionalBuffer
+from actorch.distributions import Finite
+from actorch.envs import BatchedEnv
+from actorch.models import Model
+from actorch.networks import (
+ DistributionParametrization,
+ Network,
+ NormalizingFlow,
+ Processor,
+)
+from actorch.samplers import Sampler
+from actorch.schedules import Schedule
+
+
+__all__ = [
+ "D3PG",
+ "DistributedDataParallelD3PG",
+]
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class D3PG(DDPG):
+ """Distributional Deep Deterministic Policy Gradient (D3PG).
+
+ References
+ ----------
+ .. [1] G. Barth-Maron, M. W. Hoffman, D. Budden, W. Dabney, D. Horgan, D. TB,
+ A. Muldal, N. Heess, and T. Lillicrap.
+ "Distributed Distributional Deterministic Policy Gradients".
+ In: ICLR. 2018.
+ URL: https://arxiv.org/abs/1804.08617
+
+ """
+
+ # override
+ class Config(dict):
+ """Keyword arguments expected in the configuration received by `setup`."""
+
+ def __init__(
+ self,
+ train_env_builder: "Tunable[RefOrFutureRef[Callable[..., Union[Env, BatchedEnv]]]]",
+ train_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ train_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None,
+ train_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ train_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None,
+ train_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ train_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None,
+ train_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None,
+ eval_freq: "Tunable[RefOrFutureRef[Optional[int]]]" = 1,
+ eval_env_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Union[Env, BatchedEnv]]]]]" = None,
+ eval_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ eval_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None,
+ eval_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ eval_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None,
+ eval_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ eval_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None,
+ eval_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None,
+ policy_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None,
+ policy_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None,
+ policy_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ policy_network_postprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None,
+ policy_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None,
+ policy_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ policy_network_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None,
+ policy_network_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ value_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None,
+ value_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None,
+ value_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ value_network_distribution_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Distribution]]]]" = None,
+ value_network_distribution_parametrization: "Tunable[RefOrFutureRef[Optional[DistributionParametrization]]]" = None,
+ value_network_distribution_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ value_network_normalizing_flow: "Tunable[RefOrFutureRef[Optional[NormalizingFlow]]]" = None,
+ value_network_loss_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Loss]]]]" = None,
+ value_network_loss_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ value_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None,
+ value_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ value_network_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None,
+ value_network_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ buffer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Buffer]]]]" = None,
+ buffer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ buffer_checkpoint: "Tunable[RefOrFutureRef[bool]]" = False,
+ dataloader_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., DataLoader]]]]" = None,
+ dataloader_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ discount: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.99,
+ num_return_steps: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 3,
+ num_updates_per_iter: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 1000,
+ batch_size: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 128,
+ max_trajectory_length: "Tunable[RefOrFutureRef[Union[int, float, Schedule]]]" = float( # noqa: B008
+ "inf"
+ ),
+ sync_freq: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 1,
+ polyak_weight: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.001,
+ max_grad_l2_norm: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = float( # noqa: B008
+ "inf"
+ ),
+ cumreward_window_size: "Tunable[RefOrFutureRef[int]]" = 100,
+ seed: "Tunable[RefOrFutureRef[int]]" = 0,
+ enable_amp: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False,
+ enable_reproducibility: "Tunable[RefOrFutureRef[bool]]" = False,
+ enable_anomaly_detection: "Tunable[RefOrFutureRef[bool]]" = False,
+ enable_profiling: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False,
+ log_sys_usage: "Tunable[RefOrFutureRef[bool]]" = False,
+ suppress_warnings: "Tunable[RefOrFutureRef[bool]]" = False,
+ _accept_kwargs: "bool" = False,
+ **kwargs: "Any",
+ ) -> "None":
+ if not _accept_kwargs and kwargs:
+ raise ValueError(f"Unexpected configuration arguments: {list(kwargs)}")
+ super().__init__(
+ train_env_builder=train_env_builder,
+ train_env_config=train_env_config,
+ train_agent_builder=train_agent_builder,
+ train_agent_config=train_agent_config,
+ train_sampler_builder=train_sampler_builder,
+ train_sampler_config=train_sampler_config,
+ train_num_timesteps_per_iter=train_num_timesteps_per_iter,
+ train_num_episodes_per_iter=train_num_episodes_per_iter,
+ eval_freq=eval_freq,
+ eval_env_builder=eval_env_builder,
+ eval_env_config=eval_env_config,
+ eval_agent_builder=eval_agent_builder,
+ eval_agent_config=eval_agent_config,
+ eval_sampler_builder=eval_sampler_builder,
+ eval_sampler_config=eval_sampler_config,
+ eval_num_timesteps_per_iter=eval_num_timesteps_per_iter,
+ eval_num_episodes_per_iter=eval_num_episodes_per_iter,
+ policy_network_preprocessors=policy_network_preprocessors,
+ policy_network_model_builder=policy_network_model_builder,
+ policy_network_model_config=policy_network_model_config,
+ policy_network_postprocessors=policy_network_postprocessors,
+ policy_network_optimizer_builder=policy_network_optimizer_builder,
+ policy_network_optimizer_config=policy_network_optimizer_config,
+ policy_network_optimizer_lr_scheduler_builder=policy_network_optimizer_lr_scheduler_builder,
+ policy_network_optimizer_lr_scheduler_config=policy_network_optimizer_lr_scheduler_config,
+ value_network_preprocessors=value_network_preprocessors,
+ value_network_model_builder=value_network_model_builder,
+ value_network_model_config=value_network_model_config,
+ value_network_distribution_builder=value_network_distribution_builder,
+ value_network_distribution_parametrization=value_network_distribution_parametrization,
+ value_network_distribution_config=value_network_distribution_config,
+ value_network_normalizing_flow=value_network_normalizing_flow,
+ value_network_loss_builder=value_network_loss_builder,
+ value_network_loss_config=value_network_loss_config,
+ value_network_optimizer_builder=value_network_optimizer_builder,
+ value_network_optimizer_config=value_network_optimizer_config,
+ value_network_optimizer_lr_scheduler_builder=value_network_optimizer_lr_scheduler_builder,
+ value_network_optimizer_lr_scheduler_config=value_network_optimizer_lr_scheduler_config,
+ buffer_builder=buffer_builder,
+ buffer_config=buffer_config,
+ buffer_checkpoint=buffer_checkpoint,
+ dataloader_builder=dataloader_builder,
+ dataloader_config=dataloader_config,
+ discount=discount,
+ num_return_steps=num_return_steps,
+ num_updates_per_iter=num_updates_per_iter,
+ batch_size=batch_size,
+ max_trajectory_length=max_trajectory_length,
+ sync_freq=sync_freq,
+ polyak_weight=polyak_weight,
+ max_grad_l2_norm=max_grad_l2_norm,
+ cumreward_window_size=cumreward_window_size,
+ seed=seed,
+ enable_amp=enable_amp,
+ enable_reproducibility=enable_reproducibility,
+ enable_anomaly_detection=enable_anomaly_detection,
+ enable_profiling=enable_profiling,
+ log_sys_usage=log_sys_usage,
+ suppress_warnings=suppress_warnings,
+ **kwargs,
+ )
+
+ # override
+ def setup(self, config: "Dict[str, Any]") -> "None":
+ self.config = D3PG.Config(**self.config)
+ self.config["_accept_kwargs"] = True
+ super().setup(config)
+ self._warn_failed_logging = True
+
+ # override
+ def _build_buffer(self) -> "Buffer":
+ if self.buffer_builder is None:
+ self.buffer_builder = ProportionalBuffer
+ if self.buffer_config is None:
+ self.buffer_config = {
+ "capacity": int(1e5),
+ "prioritization": 1.0,
+ "bias_correction": 0.4,
+ "epsilon": 1e-5,
+ }
+ return super()._build_buffer()
+
+ # override
+ def _build_value_network(self) -> "Network":
+ if self.value_network_distribution_builder is None:
+ self.value_network_distribution_builder = Finite
+ if self.value_network_distribution_parametrization is None:
+ self.value_network_distribution_parametrization = {
+ "logits": (
+ {"logits": (51,)},
+ lambda x: x["logits"],
+ ),
+ }
+ if self.value_network_distribution_config is None:
+ self.value_network_distribution_config = {
+ "atoms": torch.linspace(-10.0, 10.0, 51).to(self._device),
+ "validate_args": False,
+ }
+ if self.value_network_distribution_parametrization is None:
+ self.value_network_distribution_parametrization = {}
+ if self.value_network_distribution_config is None:
+ self.value_network_distribution_config = {}
+
+ if self.value_network_normalizing_flow is not None:
+ self.value_network_normalizing_flows = {
+ "value": self.value_network_normalizing_flow,
+ }
+ else:
+ self.value_network_normalizing_flows: "Dict[str, NormalizingFlow]" = {}
+
+ self.value_network_distribution_builders = {
+ "value": self.value_network_distribution_builder,
+ }
+ self.value_network_distribution_parametrizations = {
+ "value": self.value_network_distribution_parametrization,
+ }
+ self.value_network_distribution_configs = {
+ "value": self.value_network_distribution_config,
+ }
+ return super()._build_value_network()
+
+ # override
+ def _train_on_batch(
+ self,
+ idx: "int",
+ experiences: "Dict[str, Tensor]",
+ is_weight: "Tensor",
+ mask: "Tensor",
+ ) -> "Tuple[Dict[str, Any], Optional[ndarray]]":
+ sync_freq = self.sync_freq()
+ if sync_freq < 1 or not float(sync_freq).is_integer():
+ raise ValueError(
+ f"`sync_freq` ({sync_freq}) "
+ f"must be in the integer interval [1, inf)"
+ )
+ sync_freq = int(sync_freq)
+
+ result = {}
+
+ with (
+ autocast(**self.enable_amp)
+ if self.enable_amp["enabled"]
+ else contextlib.suppress()
+ ):
+ target_actions, _ = self._target_policy_network(
+ experiences["observation"], mask=mask
+ )
+ observations_target_actions = torch.cat(
+ [
+ x[..., None] if x.shape == mask.shape else x
+ for x in [experiences["observation"], target_actions]
+ ],
+ dim=-1,
+ )
+ self._target_value_network(observations_target_actions, mask=mask)
+ target_action_values = self._target_value_network.distribution
+ # Discard next observation
+ experiences["observation"] = experiences["observation"][:, :-1, ...]
+ mask = mask[:, 1:]
+ targets, _ = n_step_return(
+ target_action_values,
+ experiences["reward"],
+ experiences["terminal"],
+ mask,
+ self.discount(),
+ self.num_return_steps(),
+ return_advantage=False,
+ )
+
+ # Compute action values
+ observations_actions = torch.cat(
+ [
+ x[..., None] if x.shape == mask.shape else x
+ for x in [experiences["observation"], experiences["action"]]
+ ],
+ dim=-1,
+ )
+ self._value_network(observations_actions, mask=mask)
+ action_values = self._value_network.distribution
+
+ result["value_network"], priority = self._train_on_batch_value_network(
+ action_values,
+ targets,
+ is_weight,
+ mask,
+ )
+ result["policy_network"] = self._train_on_batch_policy_network(
+ experiences,
+ mask,
+ )
+
+ # Synchronize
+ if idx % sync_freq == 0:
+ sync_polyak_(
+ self._policy_network,
+ self._target_policy_network,
+ self.polyak_weight(),
+ )
+ sync_polyak_(
+ self._value_network,
+ self._target_value_network,
+ self.polyak_weight(),
+ )
+
+ self._grad_scaler.update()
+ return result, priority
+
+ # override
+ def _train_on_batch_policy_network(
+ self,
+ experiences: "Dict[str, Tensor]",
+ mask: "Tensor",
+ ) -> "Dict[str, Any]":
+ with (
+ autocast(**self.enable_amp)
+ if self.enable_amp["enabled"]
+ else contextlib.suppress()
+ ):
+ actions, _ = self._policy_network(experiences["observation"], mask=mask)
+ observations_actions = torch.cat(
+ [
+ x[..., None] if x.shape == mask.shape else x
+ for x in [experiences["observation"], actions]
+ ],
+ dim=-1,
+ )
+ with freeze_params(self._value_network):
+ self._value_network(observations_actions, mask=mask)
+ action_values = self._value_network.distribution
+ try:
+ action_values = action_values.mean
+ except Exception as e:
+ raise RuntimeError(f"Could not compute `action_values.mean`: {e}")
+ action_value = action_values[mask]
+ loss = -action_value.mean()
+ optimize_result = self._optimize_policy_network(loss)
+ result = {"loss": loss.item()}
+ result.update(optimize_result)
+ return result
+
+ # override
+ def _train_on_batch_value_network(
+ self,
+ action_values: "Distribution",
+ targets: "Distribution",
+ is_weight: "Tensor",
+ mask: "Tensor",
+ ) -> "Tuple[Dict[str, Any], Optional[ndarray]]":
+ with (
+ autocast(**self.enable_amp)
+ if self.enable_amp["enabled"]
+ else contextlib.suppress()
+ ):
+ loss = kl_divergence(targets, action_values)[mask]
+ priority = None
+ if self._buffer.is_prioritized:
+ loss *= is_weight[:, None].expand_as(mask)[mask]
+ priority = loss.detach().abs().to("cpu").numpy()
+ loss = loss.mean()
+ optimize_result = self._optimize_value_network(loss)
+ result = {}
+ try:
+ result["action_value"] = action_values.mean[mask].mean().item()
+ result["target"] = targets.mean[mask].mean().item()
+ except Exception as e:
+ if self._warn_failed_logging:
+ _LOGGER.warning(f"Could not log `action_value` and/or `target`: {e}")
+ self._warn_failed_logging = False
+ result["loss"] = loss.item()
+ result.update(optimize_result)
+ return result, priority
+
+
+class DistributedDataParallelD3PG(DistributedDataParallelDDPG):
+ """Distributed data parallel Distributional Deep Deterministic Policy Gradient (D3PG).
+
+ See Also
+ --------
+ actorch.algorithms.d3pg.D3PG
+
+ """
+
+ _ALGORITHM_CLS = D3PG # override
diff --git a/actorch/algorithms/ddpg.py b/actorch/algorithms/ddpg.py
index 79cc4a6..17190cc 100644
--- a/actorch/algorithms/ddpg.py
+++ b/actorch/algorithms/ddpg.py
@@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================
-"""Deep Deterministic Policy Gradient."""
+"""Deep Deterministic Policy Gradient (DDPG)."""
import contextlib
import copy
@@ -33,7 +33,7 @@
from actorch.agents import Agent, GaussianNoiseAgent
from actorch.algorithms.a2c import A2C, DistributedDataParallelA2C, Loss, LRScheduler
from actorch.algorithms.algorithm import RefOrFutureRef, Tunable
-from actorch.algorithms.utils import freeze_params, prepare_model, sync_polyak
+from actorch.algorithms.utils import freeze_params, prepare_model, sync_polyak_
from actorch.algorithms.value_estimation import n_step_return
from actorch.buffers import Buffer
from actorch.datasets import BufferDataset
@@ -53,7 +53,7 @@
class DDPG(A2C):
- """Deep Deterministic Policy Gradient.
+ """Deep Deterministic Policy Gradient (DDPG).
References
----------
@@ -65,9 +65,7 @@ class DDPG(A2C):
"""
- _UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH = True # override
-
- _RESET_BUFFER = False # override
+ _OFF_POLICY = True # override
# override
class Config(dict):
@@ -203,12 +201,8 @@ def setup(self, config: "Dict[str, Any]") -> "None":
self.config = DDPG.Config(**self.config)
self.config["_accept_kwargs"] = True
super().setup(config)
- self._target_policy_network = copy.deepcopy(self._policy_network)
- self._target_policy_network.eval().to(self._device, non_blocking=True)
- self._target_policy_network.requires_grad_(False)
- self._target_value_network = copy.deepcopy(self._value_network)
- self._target_value_network.eval().to(self._device, non_blocking=True)
- self._target_value_network.requires_grad_(False)
+ self._target_policy_network = self._build_target_policy_network()
+ self._target_value_network = self._build_target_value_network()
if not isinstance(self.num_updates_per_iter, Schedule):
self.num_updates_per_iter = ConstantSchedule(self.num_updates_per_iter)
if not isinstance(self.batch_size, Schedule):
@@ -300,6 +294,16 @@ def _build_value_network(self) -> "Network":
}
return super()._build_value_network()
+ def _build_target_policy_network(self) -> "Network":
+ target_policy_network = copy.deepcopy(self._policy_network)
+ target_policy_network.requires_grad_(False)
+ return target_policy_network.eval().to(self._device, non_blocking=True)
+
+ def _build_target_value_network(self) -> "Network":
+ target_value_network = copy.deepcopy(self._value_network)
+ target_value_network.requires_grad_(False)
+ return target_value_network.eval().to(self._device, non_blocking=True)
+
# override
def _train_step(self) -> "Dict[str, Any]":
result = super()._train_step()
@@ -314,6 +318,7 @@ def _train_step(self) -> "Dict[str, Any]":
)
result["sync_freq"] = self.sync_freq()
result["polyak_weight"] = self.polyak_weight()
+ result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None)
result["buffer_num_experiences"] = self._buffer.num_experiences
result["buffer_num_full_trajectories"] = self._buffer.num_full_trajectories
result.pop("entropy_coeff", None)
@@ -365,6 +370,7 @@ def _train_on_batch(
mask,
self.discount(),
self.num_return_steps(),
+ return_advantage=False,
)
# Compute action values
@@ -390,12 +396,12 @@ def _train_on_batch(
# Synchronize
if idx % sync_freq == 0:
- sync_polyak(
+ sync_polyak_(
self._policy_network,
self._target_policy_network,
self.polyak_weight(),
)
- sync_polyak(
+ sync_polyak_(
self._value_network,
self._target_value_network,
self.polyak_weight(),
@@ -425,7 +431,8 @@ def _train_on_batch_policy_network(
)
with freeze_params(self._value_network):
action_values, _ = self._value_network(observations_actions, mask=mask)
- loss = -action_values[mask].mean()
+ action_value = action_values[mask]
+ loss = -action_value.mean()
optimize_result = self._optimize_policy_network(loss)
result = {"loss": loss.item()}
result.update(optimize_result)
@@ -482,7 +489,7 @@ def _get_default_policy_network_distribution_config(
class DistributedDataParallelDDPG(DistributedDataParallelA2C):
- """Distributed data parallel Deep Deterministic Policy Gradient.
+ """Distributed data parallel Deep Deterministic Policy Gradient (DDPG).
See Also
--------
diff --git a/actorch/algorithms/ppo.py b/actorch/algorithms/ppo.py
index 2fd2963..eee2cf4 100644
--- a/actorch/algorithms/ppo.py
+++ b/actorch/algorithms/ppo.py
@@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================
-"""Proximal Policy Optimization."""
+"""Proximal Policy Optimization (PPO)."""
import contextlib
from typing import Any, Callable, Dict, Optional, Tuple, Union
@@ -31,6 +31,7 @@
from actorch.agents import Agent
from actorch.algorithms.a2c import A2C, DistributedDataParallelA2C, Loss, LRScheduler
from actorch.algorithms.algorithm import RefOrFutureRef, Tunable
+from actorch.algorithms.utils import normalize_
from actorch.algorithms.value_estimation import lambda_return
from actorch.envs import BatchedEnv
from actorch.models import Model
@@ -46,7 +47,7 @@
class PPO(A2C):
- """Proximal Policy Optimization.
+ """Proximal Policy Optimization (PPO).
References
----------
@@ -234,6 +235,8 @@ def _train_step(self) -> "Dict[str, Any]":
result["num_epochs"] = self.num_epochs()
result["minibatch_size"] = self.minibatch_size()
result["ratio_clip"] = self.ratio_clip()
+ result["entropy_coeff"] = result.pop("entropy_coeff", None)
+ result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None)
result.pop("num_return_steps", None)
return result
@@ -301,22 +304,8 @@ def _train_on_batch(
else contextlib.suppress()
):
if self.normalize_advantage:
- length_batch = mask_batch.sum(dim=1, keepdim=True)
- advantages_batch_mean = (
- advantages_batch.sum(dim=1, keepdim=True) / length_batch
- )
- advantages_batch -= advantages_batch_mean
- advantages_batch *= mask_batch
- advantages_batch_stddev = (
- (
- (advantages_batch**2).sum(dim=1, keepdim=True)
- / length_batch
- )
- .sqrt()
- .clamp(min=1e-6)
- )
- advantages_batch /= advantages_batch_stddev
- advantages_batch *= mask_batch
+ normalize_(advantages_batch, dim=-1, mask=mask_batch)
+
state_values_batch, _ = self._value_network(
experiences_batch["observation"], mask=mask_batch
)
@@ -387,7 +376,7 @@ def _train_on_batch_policy_network(
class DistributedDataParallelPPO(DistributedDataParallelA2C):
- """Distributed data parallel Proximal Policy Optimization.
+ """Distributed data parallel Proximal Policy Optimization (PPO).
See Also
--------
diff --git a/actorch/algorithms/reinforce.py b/actorch/algorithms/reinforce.py
index 20db9d1..c0625ec 100644
--- a/actorch/algorithms/reinforce.py
+++ b/actorch/algorithms/reinforce.py
@@ -67,9 +67,7 @@ class REINFORCE(Algorithm):
"""
- _UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH = False # override
-
- _RESET_BUFFER = True
+ _OFF_POLICY = False # override
# override
class Config(dict):
@@ -281,7 +279,7 @@ def _build_policy_network_optimizer_lr_scheduler(
# override
def _train_step(self) -> "Dict[str, Any]":
result = super()._train_step()
- if self._RESET_BUFFER:
+ if not self._OFF_POLICY:
self._buffer.reset()
self.discount.step()
self.entropy_coeff.step()
diff --git a/actorch/algorithms/sac.py b/actorch/algorithms/sac.py
new file mode 100644
index 0000000..6999ff8
--- /dev/null
+++ b/actorch/algorithms/sac.py
@@ -0,0 +1,605 @@
+# ==============================================================================
+# Copyright 2022 Luca Della Libera.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Soft Actor-Critic (SAC)."""
+
+import contextlib
+from typing import Any, Callable, Dict, Optional, Tuple, Union
+
+import torch
+from gymnasium import Env, spaces
+from numpy import ndarray
+from torch import Tensor
+from torch.cuda.amp import autocast
+from torch.distributions import Distribution, Normal
+from torch.nn.utils import clip_grad_norm_
+from torch.optim import Optimizer
+from torch.utils.data import DataLoader
+
+from actorch.agents import Agent
+from actorch.algorithms.algorithm import RefOrFutureRef, Tunable
+from actorch.algorithms.td3 import TD3, DistributedDataParallelTD3, Loss, LRScheduler
+from actorch.algorithms.utils import freeze_params, sync_polyak_
+from actorch.algorithms.value_estimation import n_step_return
+from actorch.buffers import Buffer
+from actorch.envs import BatchedEnv
+from actorch.models import Model
+from actorch.networks import DistributionParametrization, NormalizingFlow, Processor
+from actorch.samplers import Sampler
+from actorch.schedules import ConstantSchedule, Schedule
+from actorch.utils import singledispatchmethod
+
+
+__all__ = [
+ "DistributedDataParallelSAC",
+ "SAC",
+]
+
+
+class SAC(TD3):
+ """Soft Actor-Critic (SAC).
+
+ References
+ ----------
+ .. [1] T. Haarnoja, A. Zhou, P. Abbeel, and S. Levine.
+ "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor".
+ In: ICML. 2018, pp. 1861-1870.
+ URL: https://arxiv.org/abs/1801.01290
+ .. [2] T. Haarnoja, A. Zhou, K. Hartikainen, G. Tucker, S. Ha, J. Tan, V. Kumar,
+ H. Zhu, A. Gupta, P. Abbeel, and S. Levine.
+ "Soft Actor-Critic Algorithms and Applications".
+ In: arXiv. 2018.
+ URL: https://arxiv.org/abs/1812.05905
+
+ """
+
+ # override
+ class Config(dict):
+ """Keyword arguments expected in the configuration received by `setup`."""
+
+ def __init__(
+ self,
+ train_env_builder: "Tunable[RefOrFutureRef[Callable[..., Union[Env, BatchedEnv]]]]",
+ train_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ train_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None,
+ train_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ train_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None,
+ train_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ train_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None,
+ train_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None,
+ eval_freq: "Tunable[RefOrFutureRef[Optional[int]]]" = 1,
+ eval_env_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Union[Env, BatchedEnv]]]]]" = None,
+ eval_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ eval_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None,
+ eval_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ eval_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None,
+ eval_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ eval_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None,
+ eval_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None,
+ policy_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None,
+ policy_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None,
+ policy_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ policy_network_distribution_builders: "Tunable[RefOrFutureRef[Optional[Dict[str, Callable[..., Distribution]]]]]" = None,
+ policy_network_distribution_parametrizations: "Tunable[RefOrFutureRef[Optional[Dict[str, DistributionParametrization]]]]" = None,
+ policy_network_distribution_configs: "Tunable[RefOrFutureRef[Optional[Dict[str, Dict[str, Any]]]]]" = None,
+ policy_network_normalizing_flows: "Tunable[RefOrFutureRef[Optional[Dict[str, NormalizingFlow]]]]" = None,
+ policy_network_sample_fn: "Tunable[RefOrFutureRef[Optional[Callable[[Distribution], Tensor]]]]" = None,
+ policy_network_prediction_fn: "Tunable[RefOrFutureRef[Optional[Callable[[Tensor], Tensor]]]]" = None,
+ policy_network_postprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None,
+ policy_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None,
+ policy_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ policy_network_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None,
+ policy_network_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ value_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None,
+ value_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None,
+ value_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ value_network_loss_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Loss]]]]" = None,
+ value_network_loss_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ value_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None,
+ value_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ value_network_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None,
+ value_network_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ temperature_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None,
+ temperature_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ temperature_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None,
+ temperature_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ buffer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Buffer]]]]" = None,
+ buffer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ buffer_checkpoint: "Tunable[RefOrFutureRef[bool]]" = False,
+ dataloader_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., DataLoader]]]]" = None,
+ dataloader_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ discount: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.99,
+ num_return_steps: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 1,
+ num_updates_per_iter: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 1000,
+ batch_size: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 128,
+ max_trajectory_length: "Tunable[RefOrFutureRef[Union[int, float, Schedule]]]" = float( # noqa: B008
+ "inf"
+ ),
+ sync_freq: "Tunable[RefOrFutureRef[Union[int, Schedule]]]" = 1,
+ polyak_weight: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.001,
+ temperature: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.1,
+ max_grad_l2_norm: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = float( # noqa: B008
+ "inf"
+ ),
+ cumreward_window_size: "Tunable[RefOrFutureRef[int]]" = 100,
+ seed: "Tunable[RefOrFutureRef[int]]" = 0,
+ enable_amp: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False,
+ enable_reproducibility: "Tunable[RefOrFutureRef[bool]]" = False,
+ enable_anomaly_detection: "Tunable[RefOrFutureRef[bool]]" = False,
+ enable_profiling: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False,
+ log_sys_usage: "Tunable[RefOrFutureRef[bool]]" = False,
+ suppress_warnings: "Tunable[RefOrFutureRef[bool]]" = False,
+ _accept_kwargs: "bool" = False,
+ **kwargs: "Any",
+ ) -> "None":
+ if not _accept_kwargs and kwargs:
+ raise ValueError(f"Unexpected configuration arguments: {list(kwargs)}")
+ super().__init__(
+ train_env_builder=train_env_builder,
+ train_env_config=train_env_config,
+ train_agent_builder=train_agent_builder,
+ train_agent_config=train_agent_config,
+ train_sampler_builder=train_sampler_builder,
+ train_sampler_config=train_sampler_config,
+ train_num_timesteps_per_iter=train_num_timesteps_per_iter,
+ train_num_episodes_per_iter=train_num_episodes_per_iter,
+ eval_freq=eval_freq,
+ eval_env_builder=eval_env_builder,
+ eval_env_config=eval_env_config,
+ eval_agent_builder=eval_agent_builder,
+ eval_agent_config=eval_agent_config,
+ eval_sampler_builder=eval_sampler_builder,
+ eval_sampler_config=eval_sampler_config,
+ eval_num_timesteps_per_iter=eval_num_timesteps_per_iter,
+ eval_num_episodes_per_iter=eval_num_episodes_per_iter,
+ policy_network_preprocessors=policy_network_preprocessors,
+ policy_network_model_builder=policy_network_model_builder,
+ policy_network_model_config=policy_network_model_config,
+ policy_network_distribution_builders=policy_network_distribution_builders,
+ policy_network_distribution_parametrizations=policy_network_distribution_parametrizations,
+ policy_network_distribution_configs=policy_network_distribution_configs,
+ policy_network_normalizing_flows=policy_network_normalizing_flows,
+ policy_network_sample_fn=policy_network_sample_fn,
+ policy_network_prediction_fn=policy_network_prediction_fn,
+ policy_network_postprocessors=policy_network_postprocessors,
+ policy_network_optimizer_builder=policy_network_optimizer_builder,
+ policy_network_optimizer_config=policy_network_optimizer_config,
+ policy_network_optimizer_lr_scheduler_builder=policy_network_optimizer_lr_scheduler_builder,
+ policy_network_optimizer_lr_scheduler_config=policy_network_optimizer_lr_scheduler_config,
+ value_network_preprocessors=value_network_preprocessors,
+ value_network_model_builder=value_network_model_builder,
+ value_network_model_config=value_network_model_config,
+ value_network_loss_builder=value_network_loss_builder,
+ value_network_loss_config=value_network_loss_config,
+ value_network_optimizer_builder=value_network_optimizer_builder,
+ value_network_optimizer_config=value_network_optimizer_config,
+ value_network_optimizer_lr_scheduler_builder=value_network_optimizer_lr_scheduler_builder,
+ value_network_optimizer_lr_scheduler_config=value_network_optimizer_lr_scheduler_config,
+ temperature_optimizer_builder=temperature_optimizer_builder,
+ temperature_optimizer_config=temperature_optimizer_config,
+ temperature_optimizer_lr_scheduler_builder=temperature_optimizer_lr_scheduler_builder,
+ temperature_optimizer_lr_scheduler_config=temperature_optimizer_lr_scheduler_config,
+ buffer_builder=buffer_builder,
+ buffer_config=buffer_config,
+ buffer_checkpoint=buffer_checkpoint,
+ dataloader_builder=dataloader_builder,
+ dataloader_config=dataloader_config,
+ discount=discount,
+ num_return_steps=num_return_steps,
+ num_updates_per_iter=num_updates_per_iter,
+ batch_size=batch_size,
+ max_trajectory_length=max_trajectory_length,
+ sync_freq=sync_freq,
+ polyak_weight=polyak_weight,
+ temperature=temperature,
+ max_grad_l2_norm=max_grad_l2_norm,
+ cumreward_window_size=cumreward_window_size,
+ seed=seed,
+ enable_amp=enable_amp,
+ enable_reproducibility=enable_reproducibility,
+ enable_anomaly_detection=enable_anomaly_detection,
+ enable_profiling=enable_profiling,
+ log_sys_usage=log_sys_usage,
+ suppress_warnings=suppress_warnings,
+ **kwargs,
+ )
+
+ # override
+ def setup(self, config: "Dict[str, Any]") -> "None":
+ self.config = SAC.Config(**self.config)
+ self.config["_accept_kwargs"] = True
+ super().setup(config)
+ if self.temperature_optimizer_builder is not None:
+ self._log_temperature = torch.zeros(
+ 1, device=self._device, requires_grad=True
+ )
+ self._target_entropy = torch.as_tensor(
+ self._train_env.single_action_space.sample()
+ ).numel()
+ self._temperature_optimizer = self._build_temperature_optimizer()
+ self._temperature_optimizer_lr_scheduler = (
+ self._build_temperature_optimizer_lr_scheduler()
+ )
+ self.temperature = lambda *args: self._log_temperature.exp().item()
+ elif not isinstance(self.temperature, Schedule):
+ self.temperature = ConstantSchedule(self.temperature)
+
+ # override
+ @property
+ def _checkpoint(self) -> "Dict[str, Any]":
+ checkpoint = super()._checkpoint
+ if self.temperature_optimizer_builder is not None:
+ checkpoint["log_temperature"] = self._log_temperature
+ checkpoint[
+ "temperature_optimizer"
+ ] = self._temperature_optimizer.state_dict()
+ if self._temperature_optimizer_lr_scheduler is not None:
+ checkpoint[
+ "temperature_optimizer_lr_scheduler"
+ ] = self._temperature_optimizer_lr_scheduler.state_dict()
+ checkpoint["temperature"] = self.temperature.state_dict()
+ return checkpoint
+
+ # override
+ @_checkpoint.setter
+ def _checkpoint(self, value: "Dict[str, Any]") -> "None":
+ super()._checkpoint = value
+ if "log_temperature" in value:
+ self._log_temperature = value["log_temperature"]
+ if "temperature_optimizer" in value:
+ self._temperature_optimizer.load_state_dict(value["temperature_optimizer"])
+ if "temperature_optimizer_lr_scheduler" in value:
+ self._temperature_optimizer_lr_scheduler.load_state_dict(
+ value["temperature_optimizer_lr_scheduler"]
+ )
+ self.temperature.load_state_dict(value["temperature"])
+
+ def _build_temperature_optimizer(self) -> "Optimizer":
+ if self.temperature_optimizer_config is None:
+ self.temperature_optimizer_config: "Dict[str, Any]" = {}
+ return self.temperature_optimizer_builder(
+ [self._log_temperature],
+ **self.temperature_optimizer_config,
+ )
+
+ def _build_temperature_optimizer_lr_scheduler(
+ self,
+ ) -> "Optional[LRScheduler]":
+ if self.temperature_optimizer_lr_scheduler_builder is None:
+ return
+ if self.temperature_optimizer_lr_scheduler_config is None:
+ self.temperature_optimizer_lr_scheduler_config: "Dict[str, Any]" = {}
+ return self.temperature_optimizer_lr_scheduler_builder(
+ self._temperature_optimizer,
+ **self.temperature_optimizer_lr_scheduler_config,
+ )
+
+ # override
+ def _train_step(self) -> "Dict[str, Any]":
+ result = super()._train_step()
+ if self.temperature_optimizer_builder is None:
+ result["temperature"] = self.temperature()
+ result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None)
+ result["buffer_num_experiences"] = result.pop("buffer_num_experiences", None)
+ result["buffer_num_full_trajectories"] = result.pop(
+ "buffer_num_full_trajectories", None
+ )
+ result.pop("delay", None)
+ result.pop("noise_stddev", None)
+ result.pop("noise_clip", None)
+ return result
+
+ # override
+ def _train_on_batch(
+ self,
+ idx: "int",
+ experiences: "Dict[str, Tensor]",
+ is_weight: "Tensor",
+ mask: "Tensor",
+ ) -> "Tuple[Dict[str, Any], Optional[ndarray]]":
+ sync_freq = self.sync_freq()
+ if sync_freq < 1 or not float(sync_freq).is_integer():
+ raise ValueError(
+ f"`sync_freq` ({sync_freq}) "
+ f"must be in the integer interval [1, inf)"
+ )
+ sync_freq = int(sync_freq)
+
+ temperature = self.temperature()
+ if temperature <= 0.0:
+ raise ValueError(
+ f"`temperature` ({temperature}) must be in the interval (0, inf)"
+ )
+
+ result = {}
+
+ with (
+ autocast(**self.enable_amp)
+ if self.enable_amp["enabled"]
+ else contextlib.suppress()
+ ):
+ self._policy_network(experiences["observation"], mask=mask)
+ policy = self._policy_network.distribution
+ target_actions = policy.rsample()
+ target_log_probs = policy.log_prob(target_actions)
+ observations_target_actions = torch.cat(
+ [
+ x[..., None] if x.shape == mask.shape else x
+ for x in [experiences["observation"], target_actions]
+ ],
+ dim=-1,
+ )
+ with torch.no_grad():
+ target_action_values, _ = self._target_value_network(
+ observations_target_actions, mask=mask
+ )
+ target_twin_action_values, _ = self._target_twin_value_network(
+ observations_target_actions, mask=mask
+ )
+ target_action_values = target_action_values.min(target_twin_action_values)
+ with torch.no_grad():
+ target_action_values -= temperature * target_log_probs
+
+ # Discard next observation
+ experiences["observation"] = experiences["observation"][:, :-1, ...]
+ observations_target_actions = observations_target_actions[:, :-1, ...]
+ target_log_probs = target_log_probs[:, :-1, ...]
+ mask = mask[:, 1:]
+ targets, _ = n_step_return(
+ target_action_values,
+ experiences["reward"],
+ experiences["terminal"],
+ mask,
+ self.discount(),
+ self.num_return_steps(),
+ return_advantage=False,
+ )
+
+ # Compute action values
+ observations_actions = torch.cat(
+ [
+ x[..., None] if x.shape == mask.shape else x
+ for x in [experiences["observation"], experiences["action"]]
+ ],
+ dim=-1,
+ )
+ action_values, _ = self._value_network(observations_actions, mask=mask)
+ twin_action_values, _ = self._twin_value_network(
+ observations_actions, mask=mask
+ )
+
+ result["value_network"], priority = self._train_on_batch_value_network(
+ action_values,
+ targets,
+ is_weight,
+ mask,
+ )
+
+ (
+ result["twin_value_network"],
+ twin_priority,
+ ) = self._train_on_batch_twin_value_network(
+ twin_action_values,
+ targets,
+ is_weight,
+ mask,
+ )
+
+ if priority is not None:
+ priority += twin_priority
+ priority /= 2.0
+
+ result["policy_network"] = self._train_on_batch_policy_network(
+ observations_target_actions,
+ target_log_probs,
+ mask,
+ )
+
+ if self.temperature_optimizer_builder is not None:
+ result["temperature"] = self._train_on_batch_temperature(
+ target_log_probs,
+ mask,
+ )
+
+ # Synchronize
+ if idx % sync_freq == 0:
+ sync_polyak_(
+ self._value_network,
+ self._target_value_network,
+ self.polyak_weight(),
+ )
+ sync_polyak_(
+ self._twin_value_network,
+ self._target_twin_value_network,
+ self.polyak_weight(),
+ )
+
+ self._grad_scaler.update()
+ return result, priority
+
+ # override
+ def _train_on_batch_policy_network(
+ self,
+ observations_actions: "Tensor",
+ log_probs: "Tensor",
+ mask: "Tensor",
+ ) -> "Dict[str, Any]":
+ with (
+ autocast(**self.enable_amp)
+ if self.enable_amp["enabled"]
+ else contextlib.suppress()
+ ):
+ with freeze_params(self._value_network, self._twin_value_network):
+ action_values, _ = self._value_network(observations_actions, mask=mask)
+ twin_action_values, _ = self._twin_value_network(
+ observations_actions, mask=mask
+ )
+ action_values = action_values.min(twin_action_values)
+ log_prob = log_probs[mask]
+ action_value = action_values[mask]
+ loss = self.temperature() * log_prob
+ loss -= action_value
+ loss = loss.mean()
+ optimize_result = self._optimize_policy_network(loss)
+ result = {
+ "log_prob": log_prob.mean().item(),
+ "action_value": action_value.mean().item(),
+ "loss": loss.item(),
+ }
+ result.update(optimize_result)
+ return result
+
+ def _train_on_batch_temperature(
+ self,
+ log_probs: "Tensor",
+ mask: "Tensor",
+ ) -> "Dict[str, Any]":
+ with (
+ autocast(**self.enable_amp)
+ if self.enable_amp["enabled"]
+ else contextlib.suppress()
+ ):
+ with torch.no_grad():
+ log_prob = log_probs[mask]
+ loss = log_prob + self._target_entropy
+ loss *= -self._log_temperature
+ loss = loss.mean()
+ optimize_result = self._optimize_temperature(loss)
+ result = {
+ "temperature": self.temperature(),
+ "log_prob": log_prob.mean().item(),
+ "loss": loss.item(),
+ }
+ result.update(optimize_result)
+ return result
+
+ def _optimize_temperature(self, loss: "Tensor") -> "Dict[str, Any]":
+ max_grad_l2_norm = self.max_grad_l2_norm()
+ if max_grad_l2_norm <= 0.0:
+ raise ValueError(
+ f"`max_grad_l2_norm` ({max_grad_l2_norm}) must be in the interval (0, inf]"
+ )
+ self._temperature_optimizer.zero_grad(set_to_none=True)
+ self._grad_scaler.scale(loss).backward()
+ self._grad_scaler.unscale_(self._temperature_optimizer)
+ grad_l2_norm = clip_grad_norm_(self._log_temperature, max_grad_l2_norm)
+ self._grad_scaler.step(self._temperature_optimizer)
+ result = {
+ "lr": self._temperature_optimizer.param_groups[0]["lr"],
+ "grad_l2_norm": min(grad_l2_norm.item(), max_grad_l2_norm),
+ }
+ if self._temperature_optimizer_lr_scheduler is not None:
+ self._temperature_optimizer_lr_scheduler.step()
+ return result
+
+ @singledispatchmethod(use_weakrefs=False)
+ def _get_default_policy_network_distribution_builder(
+ self,
+ space: "spaces.Space",
+ ) -> "Callable[..., Distribution]":
+ raise NotImplementedError(
+ f"Unsupported space type: "
+ f"`{type(space).__module__}.{type(space).__name__}`. "
+ f"Register a custom space type through decorator "
+ f"`{type(self).__module__}.{type(self).__name__}."
+ f"_get_default_policy_network_distribution_builder.register`"
+ )
+
+ @singledispatchmethod(use_weakrefs=False)
+ def _get_default_policy_network_distribution_parametrization(
+ self,
+ space: "spaces.Space",
+ ) -> "Callable[..., DistributionParametrization]":
+ raise NotImplementedError(
+ f"Unsupported space type: "
+ f"`{type(space).__module__}.{type(space).__name__}`. "
+ f"Register a custom space type through decorator "
+ f"`{type(self).__module__}.{type(self).__name__}."
+ f"_get_default_policy_network_distribution_parametrization.register`"
+ )
+
+ @singledispatchmethod(use_weakrefs=False)
+ def _get_default_policy_network_distribution_config(
+ self,
+ space: "spaces.Space",
+ ) -> "Dict[str, Any]":
+ raise NotImplementedError(
+ f"Unsupported space type: "
+ f"`{type(space).__module__}.{type(space).__name__}`. "
+ f"Register a custom space type through decorator "
+ f"`{type(self).__module__}.{type(self).__name__}."
+ f"_get_default_policy_network_distribution_config.register`"
+ )
+
+
+class DistributedDataParallelSAC(DistributedDataParallelTD3):
+ """Distributed data parallel Soft Actor-Critic (SAC).
+
+ See Also
+ --------
+ actorch.algorithms.sac.SAC
+
+ """
+
+ _ALGORITHM_CLS = SAC # override
+
+
+#####################################################################################################
+# SAC._get_default_policy_network_distribution_builder implementation
+#####################################################################################################
+
+
+@SAC._get_default_policy_network_distribution_builder.register(spaces.Box)
+def _get_default_policy_network_distribution_builder_box(
+ self,
+ space: "spaces.Box",
+) -> "Callable[..., Normal]":
+ return Normal
+
+
+#####################################################################################################
+# SAC._get_default_policy_network_distribution_parametrization implementation
+#####################################################################################################
+
+
+@SAC._get_default_policy_network_distribution_parametrization.register(spaces.Box)
+def _get_default_policy_network_distribution_parametrization_box(
+ self,
+ space: "spaces.Box",
+) -> "DistributionParametrization":
+ return {
+ "loc": (
+ {"loc": space.shape},
+ lambda x: x["loc"],
+ ),
+ "scale": (
+ {"log_scale": space.shape},
+ lambda x: x["log_scale"].exp(),
+ ),
+ }
+
+
+#####################################################################################################
+# SAC._get_default_policy_network_distribution_config implementation
+#####################################################################################################
+
+
+@SAC._get_default_policy_network_distribution_config.register(spaces.Box)
+def _get_default_policy_network_distribution_config_box(
+ self,
+ space: "spaces.Box",
+) -> "Dict[str, Any]":
+ return {"validate_args": False}
diff --git a/actorch/algorithms/td3.py b/actorch/algorithms/td3.py
index a922fb0..b25ec24 100644
--- a/actorch/algorithms/td3.py
+++ b/actorch/algorithms/td3.py
@@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================
-"""Twin Delayed Deep Deterministic Policy Gradient."""
+"""Twin Delayed Deep Deterministic Policy Gradient (TD3)."""
import contextlib
import copy
@@ -33,12 +33,12 @@
from actorch.agents import Agent
from actorch.algorithms.algorithm import RefOrFutureRef, Tunable
from actorch.algorithms.ddpg import DDPG, DistributedDataParallelDDPG, Loss, LRScheduler
-from actorch.algorithms.utils import prepare_model, sync_polyak
+from actorch.algorithms.utils import prepare_model, sync_polyak_
from actorch.algorithms.value_estimation import n_step_return
from actorch.buffers import Buffer
from actorch.envs import BatchedEnv
from actorch.models import Model
-from actorch.networks import Processor
+from actorch.networks import Network, Processor
from actorch.samplers import Sampler
from actorch.schedules import ConstantSchedule, Schedule
@@ -50,13 +50,13 @@
class TD3(DDPG):
- """Twin Delayed Deep Deterministic Policy Gradient.
+ """Twin Delayed Deep Deterministic Policy Gradient (TD3).
References
----------
.. [1] S. Fujimoto, H. van Hoof, and D. Meger.
"Addressing Function Approximation Error in Actor-Critic Methods".
- In: ICML. 2018, pp. 1587–1596.
+ In: ICML. 2018, pp. 1587-1596.
URL: https://arxiv.org/abs/1802.09477
"""
@@ -201,19 +201,13 @@ def setup(self, config: "Dict[str, Any]") -> "None":
self.config = TD3.Config(**self.config)
self.config["_accept_kwargs"] = True
super().setup(config)
- self._twin_value_network = (
- self._build_value_network().train().to(self._device, non_blocking=True)
- )
- self._twin_value_network_loss = (
- self._build_value_network_loss().train().to(self._device, non_blocking=True)
- )
+ self._twin_value_network = self._build_value_network()
+ self._twin_value_network_loss = self._build_value_network_loss()
self._twin_value_network_optimizer = self._build_twin_value_network_optimizer()
self._twin_value_network_optimizer_lr_scheduler = (
self._build_twin_value_network_optimizer_lr_scheduler()
)
- self._target_twin_value_network = copy.deepcopy(self._twin_value_network)
- self._target_twin_value_network.eval().to(self._device, non_blocking=True)
- self._target_twin_value_network.requires_grad_(False)
+ self._target_twin_value_network = self._build_target_twin_value_network()
if not isinstance(self.delay, Schedule):
self.delay = ConstantSchedule(self.delay)
if not isinstance(self.noise_stddev, Schedule):
@@ -297,6 +291,11 @@ def _build_twin_value_network_optimizer_lr_scheduler(
**self.value_network_optimizer_lr_scheduler_config,
)
+ def _build_target_twin_value_network(self) -> "Network":
+ target_twin_value_network = copy.deepcopy(self._twin_value_network)
+ target_twin_value_network.requires_grad_(False)
+ return target_twin_value_network.eval().to(self._device, non_blocking=True)
+
# override
def _train_step(self) -> "Dict[str, Any]":
result = super()._train_step()
@@ -306,9 +305,10 @@ def _train_step(self) -> "Dict[str, Any]":
result["delay"] = self.delay()
result["noise_stddev"] = self.noise_stddev()
result["noise_clip"] = self.noise_clip()
- result["buffer_num_experiences"] = result.pop("buffer_num_experiences")
+ result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None)
+ result["buffer_num_experiences"] = result.pop("buffer_num_experiences", None)
result["buffer_num_full_trajectories"] = result.pop(
- "buffer_num_full_trajectories"
+ "buffer_num_full_trajectories", None
)
return result
@@ -398,6 +398,7 @@ def _train_on_batch(
mask,
self.discount(),
self.num_return_steps(),
+ return_advantage=False,
)
# Compute action values
@@ -441,17 +442,17 @@ def _train_on_batch(
)
# Synchronize
if idx % sync_freq == 0:
- sync_polyak(
+ sync_polyak_(
self._policy_network,
self._target_policy_network,
self.polyak_weight(),
)
- sync_polyak(
+ sync_polyak_(
self._value_network,
self._target_value_network,
self.polyak_weight(),
)
- sync_polyak(
+ sync_polyak_(
self._twin_value_network,
self._target_twin_value_network,
self.polyak_weight(),
@@ -475,10 +476,12 @@ def _train_on_batch_twin_value_network(
action_values = action_values[mask]
target = targets[mask]
loss = self._twin_value_network_loss(action_values, target)
- loss *= is_weight[:, None].expand_as(mask)[mask]
+ priority = None
+ if self._buffer.is_prioritized:
+ loss *= is_weight[:, None].expand_as(mask)[mask]
+ priority = loss.detach().abs().to("cpu").numpy()
loss = loss.mean()
optimize_result = self._optimize_twin_value_network(loss)
- priority = None
result = {
"action_value": action_values.mean().item(),
"target": target.mean().item(),
@@ -510,7 +513,7 @@ def _optimize_twin_value_network(self, loss: "Tensor") -> "Dict[str, Any]":
class DistributedDataParallelTD3(DistributedDataParallelDDPG):
- """Distributed data parallel Twin Delayed Deep Deterministic Policy Gradient.
+ """Distributed data parallel Twin Delayed Deep Deterministic Policy Gradient (TD3).
See Also
--------
diff --git a/actorch/algorithms/trpo.py b/actorch/algorithms/trpo.py
new file mode 100644
index 0000000..c4b749d
--- /dev/null
+++ b/actorch/algorithms/trpo.py
@@ -0,0 +1,394 @@
+# ==============================================================================
+# Copyright 2022 Luca Della Libera.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Trust Region Policy Optimization (TRPO)."""
+
+import contextlib
+from typing import Any, Callable, Dict, Optional, Tuple, Union
+
+import torch
+from gymnasium import Env
+from numpy import ndarray
+from torch import Tensor
+from torch.cuda.amp import autocast
+from torch.distributions import Distribution, kl_divergence
+from torch.nn.utils import clip_grad_norm_
+from torch.optim import Optimizer
+from torch.utils.data import DataLoader
+
+from actorch.agents import Agent
+from actorch.algorithms.a2c import A2C, DistributedDataParallelA2C, Loss, LRScheduler
+from actorch.algorithms.algorithm import RefOrFutureRef, Tunable
+from actorch.algorithms.utils import normalize_
+from actorch.algorithms.value_estimation import lambda_return
+from actorch.envs import BatchedEnv
+from actorch.models import Model
+from actorch.networks import DistributionParametrization, NormalizingFlow, Processor
+from actorch.optimizers import CGBLS
+from actorch.samplers import Sampler
+from actorch.schedules import ConstantSchedule, Schedule
+
+
+__all__ = [
+ "DistributedDataParallelTRPO",
+ "TRPO",
+]
+
+
+class TRPO(A2C):
+ """Trust Region Policy Optimization (TRPO).
+
+ References
+ ----------
+ .. [1] J. Schulman, S. Levine, P. Abbeel, M. Jordan, and P. Moritz.
+ "Trust Region Policy Optimization".
+ In: ICML. 2015, pp. 1889-1897.
+ URL: https://arxiv.org/abs/1502.05477
+
+ """
+
+ # override
+ class Config(dict):
+ """Keyword arguments expected in the configuration received by `setup`."""
+
+ def __init__(
+ self,
+ train_env_builder: "Tunable[RefOrFutureRef[Callable[..., Union[Env, BatchedEnv]]]]",
+ train_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ train_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None,
+ train_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ train_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None,
+ train_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ train_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None,
+ train_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None,
+ eval_freq: "Tunable[RefOrFutureRef[Optional[int]]]" = 1,
+ eval_env_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Union[Env, BatchedEnv]]]]]" = None,
+ eval_env_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ eval_agent_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Agent]]]]" = None,
+ eval_agent_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ eval_sampler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Sampler]]]]" = None,
+ eval_sampler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ eval_num_timesteps_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None,
+ eval_num_episodes_per_iter: "Tunable[RefOrFutureRef[Optional[Union[int, float, Schedule]]]]" = None,
+ policy_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None,
+ policy_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None,
+ policy_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ policy_network_distribution_builders: "Tunable[RefOrFutureRef[Optional[Dict[str, Callable[..., Distribution]]]]]" = None,
+ policy_network_distribution_parametrizations: "Tunable[RefOrFutureRef[Optional[Dict[str, DistributionParametrization]]]]" = None,
+ policy_network_distribution_configs: "Tunable[RefOrFutureRef[Optional[Dict[str, Dict[str, Any]]]]]" = None,
+ policy_network_normalizing_flows: "Tunable[RefOrFutureRef[Optional[Dict[str, NormalizingFlow]]]]" = None,
+ policy_network_sample_fn: "Tunable[RefOrFutureRef[Optional[Callable[[Distribution], Tensor]]]]" = None,
+ policy_network_prediction_fn: "Tunable[RefOrFutureRef[Optional[Callable[[Tensor], Tensor]]]]" = None,
+ policy_network_postprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None,
+ policy_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None,
+ policy_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ value_network_preprocessors: "Tunable[RefOrFutureRef[Optional[Dict[str, Processor]]]]" = None,
+ value_network_model_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Model]]]]" = None,
+ value_network_model_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ value_network_loss_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Loss]]]]" = None,
+ value_network_loss_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ value_network_optimizer_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., Optimizer]]]]" = None,
+ value_network_optimizer_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ value_network_optimizer_lr_scheduler_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., LRScheduler]]]]" = None,
+ value_network_optimizer_lr_scheduler_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ dataloader_builder: "Tunable[RefOrFutureRef[Optional[Callable[..., DataLoader]]]]" = None,
+ dataloader_config: "Tunable[RefOrFutureRef[Optional[Dict[str, Any]]]]" = None,
+ discount: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.99,
+ trace_decay: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.95,
+ normalize_advantage: "Tunable[RefOrFutureRef[bool]]" = False,
+ entropy_coeff: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = 0.01,
+ max_grad_l2_norm: "Tunable[RefOrFutureRef[Union[float, Schedule]]]" = float( # noqa: B008
+ "inf"
+ ),
+ cumreward_window_size: "Tunable[RefOrFutureRef[int]]" = 100,
+ seed: "Tunable[RefOrFutureRef[int]]" = 0,
+ enable_amp: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False,
+ enable_reproducibility: "Tunable[RefOrFutureRef[bool]]" = False,
+ enable_anomaly_detection: "Tunable[RefOrFutureRef[bool]]" = False,
+ enable_profiling: "Tunable[RefOrFutureRef[Union[bool, Dict[str, Any]]]]" = False,
+ log_sys_usage: "Tunable[RefOrFutureRef[bool]]" = False,
+ suppress_warnings: "Tunable[RefOrFutureRef[bool]]" = False,
+ _accept_kwargs: "bool" = False,
+ **kwargs: "Any",
+ ) -> "None":
+ if not _accept_kwargs and kwargs:
+ raise ValueError(f"Unexpected configuration arguments: {list(kwargs)}")
+ super().__init__(
+ train_env_builder=train_env_builder,
+ train_env_config=train_env_config,
+ train_agent_builder=train_agent_builder,
+ train_agent_config=train_agent_config,
+ train_sampler_builder=train_sampler_builder,
+ train_sampler_config=train_sampler_config,
+ train_num_timesteps_per_iter=train_num_timesteps_per_iter,
+ train_num_episodes_per_iter=train_num_episodes_per_iter,
+ eval_freq=eval_freq,
+ eval_env_builder=eval_env_builder,
+ eval_env_config=eval_env_config,
+ eval_agent_builder=eval_agent_builder,
+ eval_agent_config=eval_agent_config,
+ eval_sampler_builder=eval_sampler_builder,
+ eval_sampler_config=eval_sampler_config,
+ eval_num_timesteps_per_iter=eval_num_timesteps_per_iter,
+ eval_num_episodes_per_iter=eval_num_episodes_per_iter,
+ policy_network_preprocessors=policy_network_preprocessors,
+ policy_network_model_builder=policy_network_model_builder,
+ policy_network_model_config=policy_network_model_config,
+ policy_network_distribution_builders=policy_network_distribution_builders,
+ policy_network_distribution_parametrizations=policy_network_distribution_parametrizations,
+ policy_network_distribution_configs=policy_network_distribution_configs,
+ policy_network_normalizing_flows=policy_network_normalizing_flows,
+ policy_network_sample_fn=policy_network_sample_fn,
+ policy_network_prediction_fn=policy_network_prediction_fn,
+ policy_network_postprocessors=policy_network_postprocessors,
+ policy_network_optimizer_builder=policy_network_optimizer_builder,
+ policy_network_optimizer_config=policy_network_optimizer_config,
+ value_network_preprocessors=value_network_preprocessors,
+ value_network_model_builder=value_network_model_builder,
+ value_network_model_config=value_network_model_config,
+ value_network_loss_builder=value_network_loss_builder,
+ value_network_loss_config=value_network_loss_config,
+ value_network_optimizer_builder=value_network_optimizer_builder,
+ value_network_optimizer_config=value_network_optimizer_config,
+ value_network_optimizer_lr_scheduler_builder=value_network_optimizer_lr_scheduler_builder,
+ value_network_optimizer_lr_scheduler_config=value_network_optimizer_lr_scheduler_config,
+ dataloader_builder=dataloader_builder,
+ dataloader_config=dataloader_config,
+ discount=discount,
+ trace_decay=trace_decay,
+ normalize_advantage=normalize_advantage,
+ entropy_coeff=entropy_coeff,
+ max_grad_l2_norm=max_grad_l2_norm,
+ cumreward_window_size=cumreward_window_size,
+ seed=seed,
+ enable_amp=enable_amp,
+ enable_reproducibility=enable_reproducibility,
+ enable_anomaly_detection=enable_anomaly_detection,
+ enable_profiling=enable_profiling,
+ log_sys_usage=log_sys_usage,
+ suppress_warnings=suppress_warnings,
+ **kwargs,
+ )
+
+ # override
+ def setup(self, config: "Dict[str, Any]") -> "None":
+ self.config = TRPO.Config(**self.config)
+ self.config["_accept_kwargs"] = True
+ super().setup(config)
+ if not isinstance(self.trace_decay, Schedule):
+ self.trace_decay = ConstantSchedule(self.trace_decay)
+
+ # override
+ @property
+ def _checkpoint(self) -> "Dict[str, Any]":
+ checkpoint = super()._checkpoint
+ checkpoint["trace_decay"] = self.trace_decay.state_dict()
+ return checkpoint
+
+ # override
+ @_checkpoint.setter
+ def _checkpoint(self, value: "Dict[str, Any]") -> "None":
+ super()._checkpoint = value
+ self.trace_decay.load_state_dict(value["trace_decay"])
+
+ # override
+ def _build_policy_network_optimizer(self) -> "Optimizer":
+ if self.policy_network_optimizer_builder is None:
+ self.policy_network_optimizer_builder = CGBLS
+ if self.policy_network_optimizer_config is None:
+ self.policy_network_optimizer_config = {
+ "max_constraint": 0.01,
+ "num_cg_iters": 10,
+ "max_backtracks": 15,
+ "backtrack_ratio": 0.8,
+ "hvp_reg_coeff": 1e-5,
+ "accept_violation": False,
+ "epsilon": 1e-8,
+ }
+ if self.policy_network_optimizer_config is None:
+ self.policy_network_optimizer_config = {}
+ return self.policy_network_optimizer_builder(
+ self._policy_network.parameters(),
+ **self.policy_network_optimizer_config,
+ )
+
+ # override
+ def _train_step(self) -> "Dict[str, Any]":
+ result = super()._train_step()
+ self.trace_decay.step()
+ result["trace_decay"] = self.trace_decay()
+ result["entropy_coeff"] = result.pop("entropy_coeff", None)
+ result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None)
+ result.pop("num_return_steps", None)
+ return result
+
+ # override
+ def _train_on_batch(
+ self,
+ idx: "int",
+ experiences: "Dict[str, Tensor]",
+ is_weight: "Tensor",
+ mask: "Tensor",
+ ) -> "Tuple[Dict[str, Any], Optional[ndarray]]":
+ result = {}
+
+ with (
+ autocast(**self.enable_amp)
+ if self.enable_amp["enabled"]
+ else contextlib.suppress()
+ ):
+ state_values, _ = self._value_network(experiences["observation"], mask=mask)
+ # Discard next observation
+ experiences["observation"] = experiences["observation"][:, :-1, ...]
+ mask = mask[:, 1:]
+ with torch.no_grad():
+ targets, advantages = lambda_return(
+ state_values,
+ experiences["reward"],
+ experiences["terminal"],
+ mask,
+ self.discount(),
+ self.trace_decay(),
+ )
+ if self.normalize_advantage:
+ normalize_(advantages, dim=-1, mask=mask)
+
+ # Discard next state value
+ state_values = state_values[:, :-1]
+
+ result["value_network"], priority = self._train_on_batch_value_network(
+ state_values,
+ targets,
+ is_weight,
+ mask,
+ )
+ result["policy_network"] = self._train_on_batch_policy_network(
+ experiences,
+ advantages,
+ mask,
+ )
+ self._grad_scaler.update()
+ return result, priority
+
+ # override
+ def _train_on_batch_policy_network(
+ self,
+ experiences: "Dict[str, Tensor]",
+ advantages: "Tensor",
+ mask: "Tensor",
+ ) -> "Dict[str, Any]":
+ entropy_coeff = self.entropy_coeff()
+ if entropy_coeff < 0.0:
+ raise ValueError(
+ f"`entropy_coeff` ({entropy_coeff}) must be in the interval [0, inf)"
+ )
+ with (
+ autocast(**self.enable_amp)
+ if self.enable_amp["enabled"]
+ else contextlib.suppress()
+ ):
+ advantage = advantages[mask]
+ old_log_prob = experiences["log_prob"][mask]
+ with torch.no_grad():
+ self._policy_network(experiences["observation"], mask=mask)
+ old_policy = self._policy_network.distribution
+ policy = loss = log_prob = entropy_bonus = kl_div = None
+
+ def compute_loss() -> "Tensor":
+ nonlocal policy, loss, log_prob, entropy_bonus
+ with (
+ autocast(**self.enable_amp)
+ if self.enable_amp["enabled"]
+ else contextlib.suppress()
+ ):
+ self._policy_network(experiences["observation"], mask=mask)
+ policy = self._policy_network.distribution
+ log_prob = policy.log_prob(experiences["action"])[mask]
+ ratio = (log_prob - old_log_prob).exp()
+ loss = -advantage * ratio
+ entropy_bonus = None
+ if entropy_coeff != 0.0:
+ entropy_bonus = -entropy_coeff * policy.entropy()[mask]
+ loss += entropy_bonus
+ loss = loss.mean()
+ return loss
+
+ def compute_kl_div() -> "Tensor":
+ nonlocal policy, kl_div
+ with (
+ autocast(**self.enable_amp)
+ if self.enable_amp["enabled"]
+ else contextlib.suppress()
+ ):
+ if policy is None:
+ self._policy_network(experiences["observation"], mask=mask)
+ policy = self._policy_network.distribution
+ kl_div = kl_divergence(old_policy, policy)[mask].mean()
+ policy = None
+ return kl_div
+
+ loss = compute_loss()
+ optimize_result = self._optimize_policy_network(
+ loss, compute_loss, compute_kl_div
+ )
+ result = {
+ "advantage": advantage.mean().item(),
+ "log_prob": log_prob.mean().item(),
+ "old_log_prob": old_log_prob.mean().item(),
+ }
+ if kl_div is not None:
+ result["kl_div"] = kl_div.item()
+ result["loss"] = loss.item()
+ if entropy_bonus is not None:
+ result["entropy_bonus"] = entropy_bonus.mean().item()
+ result.update(optimize_result)
+ return result
+
+ # override
+ def _optimize_policy_network(
+ self,
+ loss: "Tensor",
+ loss_fn: "Callable[[], Tensor]",
+ constraint_fn: "Callable[[], Tensor]",
+ ) -> "Dict[str, Any]":
+ max_grad_l2_norm = self.max_grad_l2_norm()
+ if max_grad_l2_norm <= 0.0:
+ raise ValueError(
+ f"`max_grad_l2_norm` ({max_grad_l2_norm}) must be in the interval (0, inf]"
+ )
+ self._policy_network_optimizer.zero_grad(set_to_none=True)
+ self._grad_scaler.scale(loss).backward(retain_graph=True)
+ self._grad_scaler.unscale_(self._policy_network_optimizer)
+ grad_l2_norm = clip_grad_norm_(
+ self._policy_network.parameters(), max_grad_l2_norm
+ )
+ self._grad_scaler.step(self._policy_network_optimizer, loss_fn, constraint_fn)
+ result = {
+ "grad_l2_norm": min(grad_l2_norm.item(), max_grad_l2_norm),
+ }
+ return result
+
+
+class DistributedDataParallelTRPO(DistributedDataParallelA2C):
+ """Distributed data parallel Trust Region Policy Optimization (TRPO).
+
+ See Also
+ --------
+ actorch.algorithms.trpo.TRPO
+
+ """
+
+ _ALGORITHM_CLS = TRPO # override
diff --git a/actorch/algorithms/utils.py b/actorch/algorithms/utils.py
index 916d79b..456ee80 100644
--- a/actorch/algorithms/utils.py
+++ b/actorch/algorithms/utils.py
@@ -23,6 +23,7 @@
import ray.train.torch # Fix missing train.torch attribute
import torch
from ray import train
+from torch import Tensor
from torch import distributed as dist
from torch.nn import Module
from torch.nn.parallel import DataParallel, DistributedDataParallel
@@ -32,11 +33,59 @@
"count_params",
"freeze_params",
"init_mock_train_session",
+ "normalize_",
"prepare_model",
- "sync_polyak",
+ "sync_polyak_",
]
+def normalize_(input, dim: "int" = 0, mask: "Optional[Tensor]" = None) -> "Tensor":
+ """Normalize a tensor along a dimension, by subtracting the mean
+ and then dividing by the standard deviation (in-place).
+
+ Parameters
+ ----------
+ input:
+ The tensor.
+ dim:
+ The dimension.
+ mask:
+ The boolean tensor indicating which elements
+ are valid (True) and which are not (False).
+ Default to ``torch.ones_like(input, dtype=torch.bool)``.
+
+ Returns
+ -------
+ The normalized tensor.
+
+ Examples
+ --------
+ >>> import torch
+ >>>
+ >>> from actorch.algorithms.utils import normalize_
+ >>>
+ >>>
+ >>> input = torch.randn(2,3)
+ >>> mask = torch.rand(2, 3) > 0.5
+ >>> output = normalize_(input, mask=mask)
+
+ """
+ if mask is None:
+ mask = torch.ones_like(input, dtype=torch.bool)
+ else:
+ input[~mask] = 0.0
+ length = mask.sum(dim=dim, keepdim=True)
+ input_mean = input.sum(dim=dim, keepdim=True) / length
+ input -= input_mean
+ input *= mask
+ input_stddev = (
+ ((input**2).sum(dim=dim, keepdim=True) / length).sqrt().clamp(min=1e-6)
+ )
+ input /= input_stddev
+ input *= mask
+ return input
+
+
def count_params(module: "Module") -> "Tuple[int, int]":
"""Return the number of trainable and non-trainable
parameters in `module`.
@@ -51,16 +100,26 @@ def count_params(module: "Module") -> "Tuple[int, int]":
- The number of trainable parameters;
- the number of non-trainable parameters.
+ Examples
+ --------
+ >>> import torch
+ >>>
+ >>> from actorch.algorithms.utils import count_params
+ >>>
+ >>>
+ >>> model = torch.nn.Linear(4, 2)
+ >>> trainable_count, non_trainable_count = count_params(model)
+
"""
- num_trainable_params, num_non_trainable_params = 0, 0
+ trainable_count, non_trainable_count = 0, 0
for param in module.parameters():
if param.requires_grad:
- num_trainable_params += param.numel()
+ trainable_count += param.numel()
else:
- num_non_trainable_params += param.numel()
+ non_trainable_count += param.numel()
for buffer in module.buffers():
- num_non_trainable_params += buffer.numel()
- return num_trainable_params, num_non_trainable_params
+ non_trainable_count += buffer.numel()
+ return trainable_count, non_trainable_count
@contextmanager
@@ -73,6 +132,21 @@ def freeze_params(*modules: "Module") -> "Iterator[None]":
modules:
The modules.
+ Examples
+ --------
+ >>> import torch
+ >>>
+ >>> from actorch.algorithms.utils import freeze_params
+ >>>
+ >>>
+ >>> policy_model = torch.nn.Linear(4, 2)
+ >>> value_model = torch.nn.Linear(2, 1)
+ >>> input = torch.randn(3, 4)
+ >>> with freeze_params(value_model):
+ ... action = policy_model(input)
+ ... loss = -value_model(action).mean()
+ >>> loss.backward()
+
"""
params = [
param
@@ -89,19 +163,19 @@ def freeze_params(*modules: "Module") -> "Iterator[None]":
param.requires_grad = True
-def sync_polyak(
+def sync_polyak_(
source_module: "Module",
target_module: "Module",
polyak_weight: "float" = 0.001,
-) -> "None":
- """Synchronize `source_module` with `target_module`
- through Polyak averaging.
+) -> "Module":
+ """Synchronize a source module with a target module
+ through Polyak averaging (in-place).
- For each `target_param` in `target_module`,
- for each `source_param` in `source_module`:
- `target_param` =
- (1 - `polyak_weight`) * `target_param`
- + `polyak_weight` * `source_param`.
+ For each `target_parameter` in `target_module`,
+ for each `source_parameter` in `source_module`:
+ `target_parameter` =
+ (1 - `polyak_weight`) * `target_parameter`
+ + `polyak_weight` * `source_parameter`.
Parameters
----------
@@ -112,6 +186,10 @@ def sync_polyak(
polyak_weight:
The Polyak weight.
+ Returns
+ -------
+ The synchronized target module.
+
Raises
------
ValueError
@@ -124,16 +202,27 @@ def sync_polyak(
In: SIAM Journal on Control and Optimization. 1992, pp. 838-855.
URL: https://doi.org/10.1137/0330046
+ Examples
+ --------
+ >>> import torch
+ >>>
+ >>> from actorch.algorithms.utils import sync_polyak_
+ >>>
+ >>>
+ >>> model = torch.nn.Linear(4, 2)
+ >>> target_model = torch.nn.Linear(4, 2)
+ >>> sync_polyak_(model, target_model, polyak_weight=0.001)
+
"""
if polyak_weight < 0.0 or polyak_weight > 1.0:
raise ValueError(
f"`polyak_weight` ({polyak_weight}) must be in the interval [0, 1]"
)
if polyak_weight == 0.0:
- return
+ return target_module
if polyak_weight == 1.0:
target_module.load_state_dict(source_module.state_dict())
- return
+ return target_module
# Update parameters
target_params = target_module.parameters()
source_params = source_module.parameters()
@@ -151,6 +240,7 @@ def sync_polyak(
target_buffer *= (1 - polyak_weight) / polyak_weight
target_buffer += source_buffer
target_buffer *= polyak_weight
+ return target_module
def init_mock_train_session() -> "None":
diff --git a/actorch/algorithms/value_estimation/generalized_estimator.py b/actorch/algorithms/value_estimation/generalized_estimator.py
index ceb15a7..4c6c1b2 100644
--- a/actorch/algorithms/value_estimation/generalized_estimator.py
+++ b/actorch/algorithms/value_estimation/generalized_estimator.py
@@ -32,7 +32,7 @@
]
-def generalized_estimator(
+def generalized_estimator( # noqa: C901
state_values: "Union[Tensor, Distribution]",
rewards: "Tensor",
terminals: "Tensor",
@@ -44,7 +44,8 @@ def generalized_estimator(
discount: "float" = 0.99,
num_return_steps: "int" = 1,
trace_decay: "float" = 1.0,
-) -> "Tuple[Union[Tensor, Distribution], Tensor]":
+ return_advantage: "bool" = True,
+) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]":
"""Compute the (possibly distributional) generalized estimator targets
and the corresponding advantages of a trajectory.
@@ -81,12 +82,14 @@ def generalized_estimator(
The number of return steps (`n` in the literature).
trace_decay:
The trace-decay parameter (`lambda` in the literature).
+ return_advantage:
+ True to additionally return the advantages, False otherwise.
Returns
-------
- The (possibly distributional) generalized estimator targets,
shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``;
- - the corresponding advantages, ``[B, T]``.
+ - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``.
Raises
------
@@ -199,6 +202,8 @@ def generalized_estimator(
f"be equal to the shape of `rewards` ({rewards.shape})"
)
targets, state_values, next_state_values = _compute_targets(*compute_targets_args)
+ if not return_advantage:
+ return targets, None
advantages = _compute_advantages(
targets,
state_values,
@@ -308,7 +313,7 @@ def _compute_distributional_targets(
B, T = rewards.shape
num_return_steps = min(num_return_steps, T)
length = mask.sum(dim=1, keepdim=True)
- idx = torch.arange(1, T + 1).expand(B, T)
+ idx = torch.arange(1, T + 1, device=length.device).expand(B, T)
if state_values.batch_shape == (B, T + 1):
next_state_values = distributional_gather(
state_values,
@@ -316,7 +321,7 @@ def _compute_distributional_targets(
idx.clamp(max=length),
mask * (~terminals),
)
- idx = torch.arange(0, T).expand(B, T)
+ idx = torch.arange(0, T, device=length.device).expand(B, T)
state_values = distributional_gather(
state_values,
1,
@@ -361,7 +366,9 @@ def _compute_distributional_targets(
next_state_value_coeffs *= ~terminals
next_state_value_coeffs *= mask
# Gather
- idx = torch.arange(num_return_steps - 1, T + num_return_steps - 1).expand(B, T)
+ idx = torch.arange(
+ num_return_steps - 1, T + num_return_steps - 1, device=length.device
+ ).expand(B, T)
next_state_values = distributional_gather(
next_state_values,
1,
@@ -371,7 +378,7 @@ def _compute_distributional_targets(
targets = TransformedDistribution(
next_state_values,
AffineTransform(offsets, next_state_value_coeffs),
- next_state_values._validate_args,
+ validate_args=False,
).reduced_dist
return targets, state_values, next_state_values
coeffs = torch.stack(
@@ -403,14 +410,11 @@ def _compute_distributional_targets(
next_state_values, (B, T, num_return_steps)
)
# Transform
- validate_args = (
- state_or_action_values._validate_args or next_state_values._validate_args
- )
targets = TransformedDistribution(
CatDistribution(
[state_or_action_values, next_state_values],
dim=-1,
- validate_args=validate_args,
+ validate_args=False,
),
[
AffineTransform(
@@ -420,7 +424,7 @@ def _compute_distributional_targets(
SumTransform((2,)),
SumTransform((num_return_steps,)),
],
- validate_args=validate_args,
+ validate_args=False,
).reduced_dist
return targets, state_values, next_state_values
@@ -462,9 +466,8 @@ def _compute_advantages(
trace_decay * targets + (1 - trace_decay) * state_values,
[0, 1],
)[:, 1:]
- next_targets[torch.arange(B), length - 1] = next_state_values[
- torch.arange(B), length - 1
- ]
+ batch_idx = torch.arange(B)
+ next_targets[batch_idx, length - 1] = next_state_values[batch_idx, length - 1]
action_values = rewards + discount * next_targets
advantages = advantage_weights * (action_values - state_values)
advantages *= mask
diff --git a/actorch/algorithms/value_estimation/importance_sampling.py b/actorch/algorithms/value_estimation/importance_sampling.py
index 91d1f53..754050d 100644
--- a/actorch/algorithms/value_estimation/importance_sampling.py
+++ b/actorch/algorithms/value_estimation/importance_sampling.py
@@ -40,7 +40,8 @@ def importance_sampling(
log_is_weights: "Tensor",
mask: "Optional[Tensor]" = None,
discount: "float" = 0.99,
-) -> "Tuple[Union[Tensor, Distribution], Tensor]":
+ return_advantage: "bool" = True,
+) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]":
"""Compute the (possibly distributional) importance sampling targets,
a.k.a. IS, and the corresponding advantages of a trajectory.
@@ -73,12 +74,14 @@ def importance_sampling(
Default to ``torch.ones_like(rewards, dtype=torch.bool)``.
discount:
The discount factor (`gamma` in the literature).
+ return_advantage:
+ True to additionally return the advantages, False otherwise.
Returns
-------
- The (possibly distributional) importance sampling targets,
shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``;
- - the corresponding advantages, shape: ``[B, T]``.
+ - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``.
References
----------
@@ -108,4 +111,5 @@ def importance_sampling(
discount=discount,
num_return_steps=rewards.shape[1],
trace_decay=1.0,
+ return_advantage=return_advantage,
)
diff --git a/actorch/algorithms/value_estimation/lambda_return.py b/actorch/algorithms/value_estimation/lambda_return.py
index 8950462..095d7c0 100644
--- a/actorch/algorithms/value_estimation/lambda_return.py
+++ b/actorch/algorithms/value_estimation/lambda_return.py
@@ -37,7 +37,8 @@ def lambda_return(
mask: "Optional[Tensor]" = None,
discount: "float" = 0.99,
trace_decay: "float" = 1.0,
-) -> "Tuple[Union[Tensor, Distribution], Tensor]":
+ return_advantage: "bool" = True,
+) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]":
"""Compute the (possibly distributional) lambda returns, a.k.a. TD(lambda),
and the corresponding advantages, a.k.a. GAE(lambda), of a trajectory.
@@ -63,12 +64,14 @@ def lambda_return(
The discount factor (`gamma` in the literature).
trace_decay:
The trace-decay parameter (`lambda` in the literature).
+ return_advantage:
+ True to additionally return the advantages, False otherwise.
Returns
-------
- The (possibly distributional) lambda returns,
shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``;
- - the corresponding advantages, shape: ``[B, T]``.
+ - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``.
References
----------
@@ -94,4 +97,5 @@ def lambda_return(
max_is_weight_trace=1.0,
max_is_weight_delta=1.0,
max_is_weight_advantage=1.0,
+ return_advantage=return_advantage,
)
diff --git a/actorch/algorithms/value_estimation/monte_carlo_return.py b/actorch/algorithms/value_estimation/monte_carlo_return.py
index 516d9e5..68d4c33 100644
--- a/actorch/algorithms/value_estimation/monte_carlo_return.py
+++ b/actorch/algorithms/value_estimation/monte_carlo_return.py
@@ -33,7 +33,8 @@ def monte_carlo_return(
rewards: "Tensor",
mask: "Optional[Tensor]" = None,
discount: "float" = 0.99,
-) -> "Tuple[Tensor, Tensor]":
+ return_advantage: "bool" = True,
+) -> "Tuple[Tensor, Optional[Tensor]]":
"""Compute the Monte Carlo returns and the corresponding advantages of
a trajectory.
@@ -51,11 +52,14 @@ def monte_carlo_return(
Default to ``torch.ones_like(rewards, dtype=torch.bool)``.
discount:
The discount factor (`gamma` in the literature).
+ return_advantage:
+ True to additionally return the advantages, False otherwise.
Returns
-------
- The Monte Carlo returns, shape: ``[B, T]``;
- - the corresponding advantages, shape: ``[B, T]``.
+ - the corresponding advantages if `return_advantage` is True
+ None otherwise, shape: ``[B, T]``.
References
----------
@@ -72,4 +76,5 @@ def monte_carlo_return(
mask=mask,
discount=discount,
num_return_steps=rewards.shape[1],
+ return_advantage=return_advantage,
)
diff --git a/actorch/algorithms/value_estimation/n_step_return.py b/actorch/algorithms/value_estimation/n_step_return.py
index 1846b43..71b472e 100644
--- a/actorch/algorithms/value_estimation/n_step_return.py
+++ b/actorch/algorithms/value_estimation/n_step_return.py
@@ -37,7 +37,8 @@ def n_step_return(
mask: "Optional[Tensor]" = None,
discount: "float" = 0.99,
num_return_steps: "int" = 1,
-) -> "Tuple[Union[Tensor, Distribution], Tensor]":
+ return_advantage: "bool" = True,
+) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]":
"""Compute the (possibly distributional) n-step returns, a.k.a. TD(n),
and the corresponding advantages of a trajectory.
@@ -63,12 +64,14 @@ def n_step_return(
The discount factor (`gamma` in the literature).
num_return_steps:
The number of return steps (`n` in the literature).
+ return_advantage:
+ True to additionally return the advantages, False otherwise.
Returns
-------
- The (possibly distributional) n-step returns,
shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``;
- - the corresponding advantages, shape: ``[B, T]``.
+ - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``.
References
----------
@@ -90,4 +93,5 @@ def n_step_return(
max_is_weight_trace=1.0,
max_is_weight_delta=1.0,
max_is_weight_advantage=1.0,
+ return_advantage=return_advantage,
)
diff --git a/actorch/algorithms/value_estimation/off_policy_lambda_return.py b/actorch/algorithms/value_estimation/off_policy_lambda_return.py
index 7d7cf0b..c603e94 100644
--- a/actorch/algorithms/value_estimation/off_policy_lambda_return.py
+++ b/actorch/algorithms/value_estimation/off_policy_lambda_return.py
@@ -40,7 +40,8 @@ def off_policy_lambda_return(
mask: "Optional[Tensor]" = None,
discount: "float" = 0.99,
trace_decay: "float" = 1.0,
-) -> "Tuple[Union[Tensor, Distribution], Tensor]":
+ return_advantage: "bool" = True,
+) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]":
"""Compute the (possibly distributional) off-policy lambda returns, a.k.a.
Harutyunyan's et al. Q(lambda), and the corresponding advantages of a
trajectory.
@@ -70,12 +71,14 @@ def off_policy_lambda_return(
The discount factor (`gamma` in the literature).
trace_decay:
The trace-decay parameter (`lambda` in the literature).
+ return_advantage:
+ True to additionally return the advantages, False otherwise.
Returns
-------
- The (possibly distributional) off-policy lambda returns,
shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``;
- - the corresponding advantages, shape: ``[B, T]``.
+ - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``.
References
----------
@@ -100,4 +103,5 @@ def off_policy_lambda_return(
discount=discount,
num_return_steps=rewards.shape[1],
trace_decay=trace_decay,
+ return_advantage=return_advantage,
)
diff --git a/actorch/algorithms/value_estimation/retrace.py b/actorch/algorithms/value_estimation/retrace.py
index b910f45..51a2605 100644
--- a/actorch/algorithms/value_estimation/retrace.py
+++ b/actorch/algorithms/value_estimation/retrace.py
@@ -43,7 +43,8 @@ def retrace(
trace_decay: "float" = 1.0,
max_is_weight_trace: "float" = 1.0,
max_is_weight_advantage: "float" = 1.0,
-) -> "Tuple[Union[Tensor, Distribution], Tensor]":
+ return_advantage: "bool" = True,
+) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]":
"""Compute the (possibly distributional) Retrace targets, a.k.a. Retrace(lambda),
and the corresponding advantages of a trajectory.
@@ -82,12 +83,14 @@ def retrace(
The maximum importance sampling weight for trace computation (`c_bar` in the literature).
max_is_weight_advantage:
The maximum importance sampling weight for advantage computation.
+ return_advantage:
+ True to additionally return the advantages, False otherwise.
Returns
-------
- The (possibly distributional) Retrace targets,
shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``;
- - the corresponding advantages, shape: ``[B, T]``.
+ - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``.
Raises
------
@@ -126,4 +129,5 @@ def retrace(
discount=discount,
num_return_steps=rewards.shape[1],
trace_decay=trace_decay,
+ return_advantage=return_advantage,
)
diff --git a/actorch/algorithms/value_estimation/tree_backup.py b/actorch/algorithms/value_estimation/tree_backup.py
index 4ea1fa1..d8a3bda 100644
--- a/actorch/algorithms/value_estimation/tree_backup.py
+++ b/actorch/algorithms/value_estimation/tree_backup.py
@@ -41,7 +41,8 @@ def tree_backup(
mask: "Optional[Tensor]" = None,
discount: "float" = 0.99,
trace_decay: "float" = 1.0,
-) -> "Tuple[Union[Tensor, Distribution], Tensor]":
+ return_advantage: "bool" = True,
+) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]":
"""Compute the (possibly distributional) tree-backup targets, a.k.a. TB(lambda),
and the corresponding advantages of a trajectory.
@@ -74,12 +75,14 @@ def tree_backup(
The discount factor (`gamma` in the literature).
trace_decay:
The trace-decay parameter (`lambda` in the literature).
+ return_advantage:
+ True to additionally return the advantages, False otherwise.
Returns
-------
- The (possibly distributional) tree-backup targets,
shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``;
- - the corresponding advantages, shape: ``[B, T]``.
+ - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``.
References
----------
@@ -105,4 +108,5 @@ def tree_backup(
discount=discount,
num_return_steps=rewards.shape[1],
trace_decay=trace_decay,
+ return_advantage=return_advantage,
)
diff --git a/actorch/algorithms/value_estimation/vtrace.py b/actorch/algorithms/value_estimation/vtrace.py
index f6471bb..7c54aef 100644
--- a/actorch/algorithms/value_estimation/vtrace.py
+++ b/actorch/algorithms/value_estimation/vtrace.py
@@ -44,7 +44,8 @@ def vtrace(
max_is_weight_trace: "float" = 1.0,
max_is_weight_delta: "float" = 1.0,
max_is_weight_advantage: "float" = 1.0,
-) -> "Tuple[Union[Tensor, Distribution], Tensor]":
+ return_advantage: "bool" = True,
+) -> "Tuple[Union[Tensor, Distribution], Optional[Tensor]]":
"""Compute the (possibly distributional) (leaky) V-trace targets, a.k.a.
V-trace(n), and the corresponding advantages of a trajectory.
@@ -86,12 +87,14 @@ def vtrace(
The maximum importance sampling weight for delta computation (`rho_bar` in the literature).
max_is_weight_advantage:
The maximum importance sampling weight for advantage computation.
+ return_advantage:
+ True to additionally return the advantages, False otherwise.
Returns
-------
- The (possibly distributional) (leaky) V-trace targets
shape (or batch shape if distributional, assuming an empty event shape): ``[B, T]``;
- - the corresponding advantages, shape: ``[B, T]``.
+ - the corresponding advantages if `return_advantage` is True, None otherwise, shape: ``[B, T]``.
Raises
------
@@ -148,4 +151,5 @@ def vtrace(
discount=discount,
num_return_steps=num_return_steps,
trace_decay=trace_decay,
+ return_advantage=return_advantage,
)
diff --git a/actorch/buffers/buffer.py b/actorch/buffers/buffer.py
index 69ff10d..c1c9157 100644
--- a/actorch/buffers/buffer.py
+++ b/actorch/buffers/buffer.py
@@ -35,6 +35,9 @@
class Buffer(ABC, CheckpointableMixin):
"""Replay buffer that stores and samples batched experience trajectories."""
+ is_prioritized = False
+ """Whether a priority is assigned to each trajectory."""
+
_STATE_VARS = ["capacity", "spec", "_sampled_idx"] # override
def __init__(
diff --git a/actorch/buffers/proportional_buffer.py b/actorch/buffers/proportional_buffer.py
index b8286db..b63e7fb 100644
--- a/actorch/buffers/proportional_buffer.py
+++ b/actorch/buffers/proportional_buffer.py
@@ -44,6 +44,8 @@ class ProportionalBuffer(UniformBuffer):
"""
+ is_prioritized = True # override
+
_STATE_VARS = UniformBuffer._STATE_VARS + [
"prioritization",
"bias_correction",
diff --git a/actorch/buffers/rank_based_buffer.py b/actorch/buffers/rank_based_buffer.py
index e54276a..58b994f 100644
--- a/actorch/buffers/rank_based_buffer.py
+++ b/actorch/buffers/rank_based_buffer.py
@@ -43,6 +43,8 @@ class RankBasedBuffer(ProportionalBuffer):
"""
+ is_prioritized = True # override
+
_STATE_VARS = ProportionalBuffer._STATE_VARS # override
_STATE_VARS.remove("epsilon")
diff --git a/actorch/buffers/uniform_buffer.py b/actorch/buffers/uniform_buffer.py
index 0f40362..6ed77de 100644
--- a/actorch/buffers/uniform_buffer.py
+++ b/actorch/buffers/uniform_buffer.py
@@ -55,6 +55,7 @@ def num_experiences(self) -> "int":
# override
@property
def num_full_trajectories(self) -> "int":
+ # print(len(self._full_trajectory_start_idx))
return len(self._full_trajectory_start_idx) if self._num_experiences > 0 else 0
# override
diff --git a/actorch/distributed/distributed_trainable.py b/actorch/distributed/distributed_trainable.py
index ef807f4..0fd013f 100644
--- a/actorch/distributed/distributed_trainable.py
+++ b/actorch/distributed/distributed_trainable.py
@@ -45,7 +45,7 @@ class DistributedTrainable(ABC, TuneDistributedTrainable):
"""Distributed Ray Tune trainable with configurable resource requirements.
Derived classes must implement `step`, `save_checkpoint` and `load_checkpoint`
- (see https://docs.ray.io/en/latest/tune/api_docs/trainable.html#tune-trainable-class-api).
+ (see https://docs.ray.io/en/releases-1.13.0/tune/api_docs/trainable.html#tune-trainable-class-api for Ray 1.13.0).
"""
@@ -67,7 +67,7 @@ def __init__(
Default to ``[{"CPU": 1}]``.
placement_strategy:
The placement strategy
- (see https://docs.ray.io/en/latest/ray-core/placement-group.html).
+ (see https://docs.ray.io/en/releases-1.13.0/ray-core/placement-group.html for Ray 1.13.0).
"""
super().__init__(
@@ -98,7 +98,7 @@ def resource_help(cls, config: "Dict[str, Any]") -> "str":
" Default to ``[{}]``." + "\n"
" placement_strategy:" + "\n"
" The placement strategy." + "\n"
- " (see https://docs.ray.io/en/latest/ray-core/placement-group.html)."
+ f" (see https://docs.ray.io/en/releases-{ray.__version__}/ray-core/placement-group.html)."
)
# override
diff --git a/actorch/distributed/sync_distributed_trainable.py b/actorch/distributed/sync_distributed_trainable.py
index 3b4e805..fa3f231 100644
--- a/actorch/distributed/sync_distributed_trainable.py
+++ b/actorch/distributed/sync_distributed_trainable.py
@@ -100,7 +100,7 @@ def __init__(
Default to ``{}``.
placement_strategy:
The placement strategy
- (see https://docs.ray.io/en/latest/ray-core/placement-group.html).
+ (see https://docs.ray.io/en/releases-1.13.0/ray-core/placement-group.html for Ray 1.13.0).
reduction_mode:
The reduction mode for worker results.
Must be one of the following:
@@ -225,7 +225,7 @@ def resource_help(cls, config: "Dict[str, Any]") -> "str":
" Default to ``{}``." + "\n"
" placement_strategy:" + "\n"
" The placement strategy" + "\n"
- " (see https://docs.ray.io/en/latest/ray-core/placement-group.html)."
+ f" (see https://docs.ray.io/en/releases-{ray.__version__}/ray-core/placement-group.html)."
)
# override
diff --git a/actorch/distributions/finite.py b/actorch/distributions/finite.py
index e48d8e4..16dfca7 100644
--- a/actorch/distributions/finite.py
+++ b/actorch/distributions/finite.py
@@ -48,6 +48,8 @@ class Finite(Distribution):
Examples
--------
+ >>> import torch
+ >>>
>>> from actorch.distributions import Finite
>>>
>>>
diff --git a/actorch/optimizers/cgbls.py b/actorch/optimizers/cgbls.py
index c8d1abe..908c742 100644
--- a/actorch/optimizers/cgbls.py
+++ b/actorch/optimizers/cgbls.py
@@ -300,7 +300,7 @@ def _conjugate_gradient(
References
----------
- .. [1] M.R. Hestenes, and E. Stiefel.
+ .. [1] M.R. Hestenes and E. Stiefel.
"Methods of Conjugate Gradients for Solving Linear Systems".
In: Journal of Research of the National Bureau of Standards. 1952, pp. 409-435.
URL: http://dx.doi.org/10.6028/jres.049.044
@@ -389,7 +389,7 @@ def _backtracking_line_search(
loss, constraint = 0.0, 0.0
for ratio in ratios:
for i, step in enumerate(descent_steps):
- params[i] -= ratio * step
+ params[i].copy_(prev_params[i] - ratio * step)
loss = loss_fn()
constraint = constraint_fn()
@@ -411,6 +411,6 @@ def _backtracking_line_search(
warning_msg = " and".join(warning_msg.rsplit(",", 1))
_LOGGER.warning(warning_msg)
for i, prev_param in enumerate(prev_params):
- params[i] = prev_param
+ params[i].copy_(prev_param)
return loss, constraint
diff --git a/actorch/preconditioners/kfac.py b/actorch/preconditioners/kfac.py
index ee90ce9..98734be 100644
--- a/actorch/preconditioners/kfac.py
+++ b/actorch/preconditioners/kfac.py
@@ -93,8 +93,8 @@ def update_AG(self, decay: "float") -> "None":
)
self._dA, self._dG = A.new_zeros(A.shape[0]), G.new_zeros(G.shape[0])
self._QA, self._QG = A.new_zeros(A.shape), G.new_zeros(G.shape)
- self._update_exp_moving_average(self._A, A, decay)
- self._update_exp_moving_average(self._G, G, decay)
+ self._update_exp_moving_average_(self._A, A, decay)
+ self._update_exp_moving_average_(self._G, G, decay)
def update_eigen_AG(self, epsilon: "float") -> "None":
"""Update eigenvalues and eigenvectors of A and G.
@@ -128,14 +128,15 @@ def get_preconditioned_grad(self, damping: "float") -> "Tensor":
v = self._QG @ v2 @ self._QA.t()
return v
- def _update_exp_moving_average(
+ def _update_exp_moving_average_(
self, current: "Tensor", new: "Tensor", weight: "float"
- ) -> "None":
+ ) -> "Tensor":
if weight == 1.0:
- return
+ return current
current *= weight / (1 - weight)
current += new
current *= 1 - weight
+ return current
@property
@abstractmethod
diff --git a/actorch/version.py b/actorch/version.py
index 76f0825..cf4b122 100644
--- a/actorch/version.py
+++ b/actorch/version.py
@@ -28,7 +28,7 @@
"0" # Minor version to increment in case of backward compatible new functionality
)
-_PATCH = "4" # Patch version to increment in case of backward compatible bug fixes
+_PATCH = "5" # Patch version to increment in case of backward compatible bug fixes
VERSION = f"{_MAJOR}.{_MINOR}.{_PATCH}"
"""The package version."""
diff --git a/docs/_static/images/actorch-overview.png b/docs/_static/images/actorch-overview.png
new file mode 100644
index 0000000..7c65783
Binary files /dev/null and b/docs/_static/images/actorch-overview.png differ
diff --git a/docs/index.rst b/docs/index.rst
index 5281006..e9746dd 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -5,3 +5,6 @@ Welcome to ACTorch!
`actorch` is a deep reinforcement learning framework for fast prototyping based on
`PyTorch `_.
Visit the GitHub repository: https://github.com/lucadellalib/actorch
+
+.. image:: docs/_static/images/actorch-overview.png
+ :width: 600
\ No newline at end of file
diff --git a/examples/A2C-AsyncHyperBand_LunarLander-v2.py b/examples/A2C-AsyncHyperBand_LunarLander-v2.py
index add1b62..c21725a 100644
--- a/examples/A2C-AsyncHyperBand_LunarLander-v2.py
+++ b/examples/A2C-AsyncHyperBand_LunarLander-v2.py
@@ -14,12 +14,13 @@
# limitations under the License.
# ==============================================================================
-"""Train Advantage Actor-Critic on LunarLander-v2. Tune the value network learning
+"""Train Advantage Actor-Critic (A2C) on LunarLander-v2. Tune the value network learning
rate using the asynchronous version of HyperBand (https://arxiv.org/abs/1810.05934).
"""
# Navigate to `/examples`, open a terminal and run:
+# pip install gymnasium[box2d]
# actorch run A2C-AsyncHyperBand_LunarLander-v2.py
import gymnasium as gym
diff --git a/examples/A2C_LunarLander-v2.py b/examples/A2C_LunarLander-v2.py
index d92bfdb..85fb2e6 100644
--- a/examples/A2C_LunarLander-v2.py
+++ b/examples/A2C_LunarLander-v2.py
@@ -14,9 +14,10 @@
# limitations under the License.
# ==============================================================================
-"""Train Advantage Actor-Critic on LunarLander-v2."""
+"""Train Advantage Actor-Critic (A2C) on LunarLander-v2."""
# Navigate to `/examples`, open a terminal and run:
+# pip install gymnasium[box2d]
# actorch run A2C_LunarLander-v2.py
import gymnasium as gym
diff --git a/examples/ACKTR_LunarLander-v2.py b/examples/ACKTR_LunarLander-v2.py
index cdfa571..de0d291 100644
--- a/examples/ACKTR_LunarLander-v2.py
+++ b/examples/ACKTR_LunarLander-v2.py
@@ -14,9 +14,10 @@
# limitations under the License.
# ==============================================================================
-"""Train Actor-Critic Kronecker-Factored Trust Region on LunarLander-v2."""
+"""Train Actor-Critic Kronecker-Factored (ACKTR) Trust Region on LunarLander-v2."""
# Navigate to `/examples`, open a terminal and run:
+# pip install gymnasium[box2d]
# actorch run ACKTR_LunarLander-v2.py
import gymnasium as gym
diff --git a/examples/AWR-NormalizingFlow_Pendulum-v1.py b/examples/AWR-AffineFlow_Pendulum-v1.py
similarity index 95%
rename from examples/AWR-NormalizingFlow_Pendulum-v1.py
rename to examples/AWR-AffineFlow_Pendulum-v1.py
index b63c0df..848db20 100644
--- a/examples/AWR-NormalizingFlow_Pendulum-v1.py
+++ b/examples/AWR-AffineFlow_Pendulum-v1.py
@@ -14,10 +14,11 @@
# limitations under the License.
# ==============================================================================
-"""Train Advantage-Weighted Regression equipped with a normalizing flow policy on Pendulum-v1."""
+"""Train Advantage-Weighted Regression (AWR) equipped with an affine flow policy on Pendulum-v1."""
# Navigate to `/examples`, open a terminal and run:
-# actorch run AWR-NormalizingFlow_Pendulum-v1.py
+# pip install gymnasium[classic_control]
+# actorch run AWR-AffineFlow_Pendulum-v1.py
import gymnasium as gym
import torch
diff --git a/examples/AWR_Pendulum-v1.py b/examples/AWR_Pendulum-v1.py
index 0c42e92..3c89595 100644
--- a/examples/AWR_Pendulum-v1.py
+++ b/examples/AWR_Pendulum-v1.py
@@ -14,9 +14,10 @@
# limitations under the License.
# ==============================================================================
-"""Train Advantage-Weighted Regression on Pendulum-v1."""
+"""Train Advantage-Weighted Regression (AWR) on Pendulum-v1."""
# Navigate to `/examples`, open a terminal and run:
+# pip install gymnasium[classic_control]
# actorch run AWR_Pendulum-v1.py
import gymnasium as gym
diff --git a/examples/D3PG-Finite_LunarLanderContinuous-v2.py b/examples/D3PG-Finite_LunarLanderContinuous-v2.py
new file mode 100644
index 0000000..b2cd13e
--- /dev/null
+++ b/examples/D3PG-Finite_LunarLanderContinuous-v2.py
@@ -0,0 +1,129 @@
+# ==============================================================================
+# Copyright 2022 Luca Della Libera.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Train Distributional Deep Deterministic Policy Gradient (D3PG) on LunarLanderContinuous-v2
+assuming a finite value distribution with 51 atoms (see https://arxiv.org/abs/1804.08617).
+
+"""
+
+# Navigate to `/examples`, open a terminal and run:
+# pip install gymnasium[box2d]
+# actorch run D3PG-Finite_LunarLanderContinuous-v2.py
+
+import gymnasium as gym
+import torch
+from torch import nn
+from torch.optim import Adam
+
+from actorch import *
+
+
+# Define custom model
+class LayerNormFCNet(FCNet):
+ # override
+ def _setup_torso(self, in_shape):
+ super()._setup_torso(in_shape)
+ torso = nn.Sequential()
+ for module in self.torso:
+ torso.append(module)
+ if isinstance(module, nn.Linear):
+ torso.append(
+ nn.LayerNorm(module.out_features, elementwise_affine=False)
+ )
+ self.torso = torso
+
+
+experiment_params = ExperimentParams(
+ run_or_experiment=D3PG,
+ stop={"timesteps_total": int(4e5)},
+ resources_per_trial={"cpu": 1, "gpu": 0},
+ checkpoint_freq=10,
+ checkpoint_at_end=True,
+ log_to_file=True,
+ export_formats=["checkpoint", "model"],
+ config=D3PG.Config(
+ train_env_builder=lambda **config: ParallelBatchedEnv(
+ lambda **kwargs: gym.make("LunarLanderContinuous-v2", **kwargs),
+ config,
+ num_workers=1,
+ ),
+ train_agent_builder=OUNoiseAgent,
+ train_agent_config={
+ "clip_action": True,
+ "device": "cpu",
+ "num_random_timesteps": 1000,
+ "mean": 0.0,
+ "volatility": 0.1,
+ "reversion_speed": 0.15,
+ },
+ train_num_timesteps_per_iter=1024,
+ eval_freq=10,
+ eval_env_config={"render_mode": None},
+ eval_num_episodes_per_iter=10,
+ policy_network_model_builder=LayerNormFCNet,
+ policy_network_model_config={
+ "torso_fc_configs": [
+ {"out_features": 400, "bias": True},
+ {"out_features": 300, "bias": True},
+ ],
+ "head_activation_builder": nn.Tanh,
+ },
+ policy_network_optimizer_builder=Adam,
+ policy_network_optimizer_config={"lr": 1e-3},
+ value_network_model_builder=LayerNormFCNet,
+ value_network_model_config={
+ "torso_fc_configs": [
+ {"out_features": 400, "bias": True},
+ {"out_features": 300, "bias": True},
+ ],
+ },
+ value_network_distribution_builder=Finite,
+ value_network_distribution_parametrization={
+ "logits": (
+ {"logits": (51,)},
+ lambda x: x["logits"],
+ ),
+ },
+ value_network_distribution_config={
+ "atoms": torch.linspace(-10.0, 10.0, 51),
+ },
+ value_network_optimizer_builder=Adam,
+ value_network_optimizer_config={"lr": 1e-3},
+ buffer_config={
+ "capacity": int(1e5),
+ "prioritization": 1.0,
+ "bias_correction": 0.4,
+ "epsilon": 1e-5,
+ },
+ discount=0.995,
+ num_return_steps=5,
+ num_updates_per_iter=LambdaSchedule(
+ lambda iter: (
+ 200 if iter >= 10 else 0
+ ), # Fill buffer with some trajectories before training
+ ),
+ batch_size=128,
+ max_trajectory_length=10,
+ sync_freq=1,
+ polyak_weight=0.001,
+ max_grad_l2_norm=5.0,
+ seed=0,
+ enable_amp=False,
+ enable_reproducibility=True,
+ log_sys_usage=True,
+ suppress_warnings=True,
+ ),
+)
diff --git a/examples/D3PG-Normal_LunarLanderContinuous-v2.py b/examples/D3PG-Normal_LunarLanderContinuous-v2.py
new file mode 100644
index 0000000..d6c7227
--- /dev/null
+++ b/examples/D3PG-Normal_LunarLanderContinuous-v2.py
@@ -0,0 +1,130 @@
+# ==============================================================================
+# Copyright 2022 Luca Della Libera.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Train Distributional Deep Deterministic Policy Gradient (D3PG) on LunarLanderContinuous-v2
+assuming a normal value distribution.
+
+"""
+
+# Navigate to `/examples`, open a terminal and run:
+# pip install gymnasium[box2d]
+# actorch run D3PG-Normal_LunarLanderContinuous-v2.py
+
+import gymnasium as gym
+from torch import nn
+from torch.distributions import Normal
+from torch.optim import Adam
+
+from actorch import *
+
+
+# Define custom model
+class LayerNormFCNet(FCNet):
+ # override
+ def _setup_torso(self, in_shape):
+ super()._setup_torso(in_shape)
+ torso = nn.Sequential()
+ for module in self.torso:
+ torso.append(module)
+ if isinstance(module, nn.Linear):
+ torso.append(
+ nn.LayerNorm(module.out_features, elementwise_affine=False)
+ )
+ self.torso = torso
+
+
+experiment_params = ExperimentParams(
+ run_or_experiment=D3PG,
+ stop={"timesteps_total": int(4e5)},
+ resources_per_trial={"cpu": 1, "gpu": 0},
+ checkpoint_freq=10,
+ checkpoint_at_end=True,
+ log_to_file=True,
+ export_formats=["checkpoint", "model"],
+ config=D3PG.Config(
+ train_env_builder=lambda **config: ParallelBatchedEnv(
+ lambda **kwargs: gym.make("LunarLanderContinuous-v2", **kwargs),
+ config,
+ num_workers=1,
+ ),
+ train_agent_builder=OUNoiseAgent,
+ train_agent_config={
+ "clip_action": True,
+ "device": "cpu",
+ "num_random_timesteps": 1000,
+ "mean": 0.0,
+ "volatility": 0.1,
+ "reversion_speed": 0.15,
+ },
+ train_num_timesteps_per_iter=1024,
+ eval_freq=10,
+ eval_env_config={"render_mode": None},
+ eval_num_episodes_per_iter=10,
+ policy_network_model_builder=LayerNormFCNet,
+ policy_network_model_config={
+ "torso_fc_configs": [
+ {"out_features": 400, "bias": True},
+ {"out_features": 300, "bias": True},
+ ],
+ "head_activation_builder": nn.Tanh,
+ },
+ policy_network_optimizer_builder=Adam,
+ policy_network_optimizer_config={"lr": 1e-3},
+ value_network_model_builder=LayerNormFCNet,
+ value_network_model_config={
+ "torso_fc_configs": [
+ {"out_features": 400, "bias": True},
+ {"out_features": 300, "bias": True},
+ ],
+ },
+ value_network_distribution_builder=Normal,
+ value_network_distribution_parametrization={
+ "loc": (
+ {"loc": ()},
+ lambda x: x["loc"],
+ ),
+ "scale": (
+ {"log_scale": ()},
+ lambda x: x["log_scale"].clamp(-20.0, 2.0).exp(),
+ ),
+ },
+ value_network_optimizer_builder=Adam,
+ value_network_optimizer_config={"lr": 1e-3},
+ buffer_config={
+ "capacity": int(1e5),
+ "prioritization": 1.0,
+ "bias_correction": 0.4,
+ "epsilon": 1e-5,
+ },
+ discount=0.995,
+ num_return_steps=5,
+ num_updates_per_iter=LambdaSchedule(
+ lambda iter: (
+ 200 if iter >= 10 else 0
+ ), # Fill buffer with some trajectories before training
+ ),
+ batch_size=128,
+ max_trajectory_length=10,
+ sync_freq=1,
+ polyak_weight=0.001,
+ max_grad_l2_norm=5.0,
+ seed=0,
+ enable_amp=False,
+ enable_reproducibility=True,
+ log_sys_usage=True,
+ suppress_warnings=True,
+ ),
+)
diff --git a/examples/DDPG_LunarLanderContinuous-v2.py b/examples/DDPG_LunarLanderContinuous-v2.py
index de4912a..c44c4fd 100644
--- a/examples/DDPG_LunarLanderContinuous-v2.py
+++ b/examples/DDPG_LunarLanderContinuous-v2.py
@@ -14,9 +14,10 @@
# limitations under the License.
# ==============================================================================
-"""Train Deep Deterministic Policy Gradient on LunarLanderContinuous-v2."""
+"""Train Deep Deterministic Policy Gradient (DDPG) on LunarLanderContinuous-v2."""
# Navigate to `/examples`, open a terminal and run:
+# pip install gymnasium[box2d]
# actorch run DDPG_LunarLanderContinuous-v2.py
import gymnasium as gym
diff --git a/examples/DistributedDataParallelDDPG_LunarLanderContinuous-v2.py b/examples/DistributedDataParallelDDPG_LunarLanderContinuous-v2.py
index f9319a5..f3a3df7 100644
--- a/examples/DistributedDataParallelDDPG_LunarLanderContinuous-v2.py
+++ b/examples/DistributedDataParallelDDPG_LunarLanderContinuous-v2.py
@@ -14,9 +14,10 @@
# limitations under the License.
# ==============================================================================
-"""Train distributed data parallel Deep Deterministic Policy Gradient on LunarLanderContinuous-v2."""
+"""Train distributed data parallel Deep Deterministic Policy Gradient (DDPG) on LunarLanderContinuous-v2."""
# Navigate to `/examples`, open a terminal and run:
+# pip install gymnasium[box2d]
# actorch run DistributedDataParallelDDPG_LunarLanderContinuous-v2.py
import gymnasium as gym
diff --git a/examples/DistributedDataParallelREINFORCE_CartPole-v1.py b/examples/DistributedDataParallelREINFORCE_CartPole-v1.py
index ff470ab..b07ad36 100644
--- a/examples/DistributedDataParallelREINFORCE_CartPole-v1.py
+++ b/examples/DistributedDataParallelREINFORCE_CartPole-v1.py
@@ -17,6 +17,7 @@
"""Train distributed data parallel REINFORCE on CartPole-v1."""
# Navigate to `/examples`, open a terminal and run:
+# pip install gymnasium[classic_control]
# actorch run DistributedDataParallelREINFORCE_CartPole-v1.py
import gymnasium as gym
diff --git a/examples/PPO-Laplace_Pendulum-v1.py b/examples/PPO-Laplace_Pendulum-v1.py
new file mode 100644
index 0000000..32ad43e
--- /dev/null
+++ b/examples/PPO-Laplace_Pendulum-v1.py
@@ -0,0 +1,94 @@
+# ==============================================================================
+# Copyright 2022 Luca Della Libera.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Train Proximal Policy Optimization (PPO) on Pendulum-v1 assuming a Laplace policy distribution."""
+
+# Navigate to `/examples`, open a terminal and run:
+# pip install gymnasium[classic_control]
+# actorch run PPO-Laplace_Pendulum-v1.py
+
+import gymnasium as gym
+from torch.distributions import Laplace
+from torch.optim import Adam
+
+from actorch import *
+
+
+experiment_params = ExperimentParams(
+ run_or_experiment=PPO,
+ stop={"timesteps_total": int(1e5)},
+ resources_per_trial={"cpu": 1, "gpu": 0},
+ checkpoint_freq=10,
+ checkpoint_at_end=True,
+ log_to_file=True,
+ export_formats=["checkpoint", "model"],
+ config=PPO.Config(
+ train_env_builder=lambda **config: ParallelBatchedEnv(
+ lambda **kwargs: gym.make("Pendulum-v1", **kwargs),
+ config,
+ num_workers=2,
+ ),
+ train_num_timesteps_per_iter=2048,
+ eval_freq=10,
+ eval_env_config={"render_mode": None},
+ eval_num_episodes_per_iter=10,
+ policy_network_model_builder=FCNet,
+ policy_network_model_config={
+ "torso_fc_configs": [
+ {"out_features": 256, "bias": True},
+ {"out_features": 256, "bias": True},
+ ],
+ "independent_heads": ["action/log_scale"],
+ },
+ policy_network_distribution_builders={"action": Laplace},
+ policy_network_distribution_parametrizations={
+ "action": {
+ "loc": (
+ {"loc": (1,)},
+ lambda x: x["loc"],
+ ),
+ "scale": (
+ {"log_scale": (1,)},
+ lambda x: x["log_scale"].exp(),
+ ),
+ },
+ },
+ policy_network_optimizer_builder=Adam,
+ policy_network_optimizer_config={"lr": 5e-5},
+ value_network_model_builder=FCNet,
+ value_network_model_config={
+ "torso_fc_configs": [
+ {"out_features": 256, "bias": True},
+ {"out_features": 256, "bias": True},
+ ],
+ },
+ value_network_optimizer_builder=Adam,
+ value_network_optimizer_config={"lr": 3e-3},
+ discount=0.99,
+ trace_decay=0.95,
+ num_epochs=20,
+ minibatch_size=16,
+ ratio_clip=0.2,
+ normalize_advantage=True,
+ entropy_coeff=0.01,
+ max_grad_l2_norm=0.5,
+ seed=0,
+ enable_amp=False,
+ enable_reproducibility=True,
+ log_sys_usage=True,
+ suppress_warnings=True,
+ ),
+)
diff --git a/examples/PPO_Pendulum-v1.py b/examples/PPO_Pendulum-v1.py
index 0089d72..6cb22ea 100644
--- a/examples/PPO_Pendulum-v1.py
+++ b/examples/PPO_Pendulum-v1.py
@@ -14,9 +14,10 @@
# limitations under the License.
# ==============================================================================
-"""Train Proximal Policy Optimization on Pendulum-v1."""
+"""Train Proximal Policy Optimization (PPO) on Pendulum-v1."""
# Navigate to `/examples`, open a terminal and run:
+# pip install gymnasium[classic_control]
# actorch run PPO_Pendulum-v1.py
import gymnasium as gym
diff --git a/examples/REINFORCE_CartPole-v1.py b/examples/REINFORCE_CartPole-v1.py
index 2ab5c13..342ed36 100644
--- a/examples/REINFORCE_CartPole-v1.py
+++ b/examples/REINFORCE_CartPole-v1.py
@@ -17,6 +17,7 @@
"""Train REINFORCE on CartPole-v1."""
# Navigate to `/examples`, open a terminal and run:
+# pip install gymnasium[classic_control]
# actorch run REINFORCE_CartPole-v1.py
import gymnasium as gym
diff --git a/examples/SAC_BipedalWalker-v3.py b/examples/SAC_BipedalWalker-v3.py
new file mode 100644
index 0000000..55c768e
--- /dev/null
+++ b/examples/SAC_BipedalWalker-v3.py
@@ -0,0 +1,129 @@
+# ==============================================================================
+# Copyright 2022 Luca Della Libera.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Train Soft Actor-Critic (SAC) on BipedalWalker-v3."""
+
+# Navigate to `/examples`, open a terminal and run:
+# pip install gymnasium[box2d]
+# actorch run SAC_BipedalWalker-v3.py
+
+import gymnasium as gym
+import numpy as np
+import torch.nn
+from torch.distributions import Normal, TanhTransform
+from torch.optim import Adam
+
+from actorch import *
+
+
+class Wrapper(gym.Wrapper):
+ def __init__(self, env, action_noise=0.3, action_repeat=3, reward_scale=5):
+ super().__init__(env)
+ self.action_noise = action_noise
+ self.action_repeat = action_repeat
+ self.reward_scale = reward_scale
+
+ def step(self, action):
+ action += self.action_noise * (
+ 1 - 2 * np.random.random(self.action_space.shape)
+ )
+ cumreward = 0.0
+ for _ in range(self.action_repeat):
+ observation, reward, terminated, truncated, info = super().step(action)
+ cumreward += reward
+ if terminated:
+ return observation, 0.0, terminated, truncated, info
+ return observation, self.reward_scale * cumreward, terminated, truncated, info
+
+
+experiment_params = ExperimentParams(
+ run_or_experiment=SAC,
+ stop={"timesteps_total": int(5e5)},
+ resources_per_trial={"cpu": 1, "gpu": 0},
+ checkpoint_freq=10,
+ checkpoint_at_end=True,
+ log_to_file=True,
+ export_formats=["checkpoint", "model"],
+ config=SAC.Config(
+ train_env_builder=lambda **config: ParallelBatchedEnv(
+ lambda **kwargs: Wrapper(
+ gym.wrappers.TimeLimit(
+ gym.make("BipedalWalker-v3", **kwargs),
+ max_episode_steps=200,
+ ),
+ ),
+ config,
+ num_workers=1,
+ ),
+ train_num_episodes_per_iter=1,
+ eval_freq=10,
+ eval_env_config={"render_mode": None},
+ eval_num_episodes_per_iter=10,
+ policy_network_model_builder=FCNet,
+ policy_network_model_config={
+ "torso_fc_configs": [
+ {"out_features": 400, "bias": True},
+ {"out_features": 300, "bias": True},
+ ],
+ "head_activation_builder": torch.nn.Tanh,
+ },
+ policy_network_distribution_builders={
+ "action": lambda loc, scale: TransformedDistribution(
+ Normal(loc, scale),
+ TanhTransform(cache_size=1), # Use tanh normal to enforce action bounds
+ ),
+ },
+ policy_network_distribution_parametrizations={
+ "action": {
+ "loc": (
+ {"loc": (4,)},
+ lambda x: x["loc"],
+ ),
+ "scale": (
+ {"log_scale": (4,)},
+ lambda x: x["log_scale"].clamp(-20.0, 2.0).exp(),
+ ),
+ },
+ },
+ policy_network_sample_fn=lambda d: d.sample(), # `mode` does not exist in closed-form for tanh normal
+ policy_network_optimizer_builder=Adam,
+ policy_network_optimizer_config={"lr": 1e-3},
+ value_network_model_builder=FCNet,
+ value_network_model_config={
+ "torso_fc_configs": [
+ {"out_features": 400, "bias": True},
+ {"out_features": 300, "bias": True},
+ ],
+ },
+ value_network_optimizer_builder=Adam,
+ value_network_optimizer_config={"lr": 1e-3},
+ buffer_config={"capacity": int(3e5)},
+ discount=0.98,
+ num_return_steps=5,
+ num_updates_per_iter=10,
+ batch_size=1024,
+ max_trajectory_length=20,
+ sync_freq=1,
+ polyak_weight=0.001,
+ temperature=0.2,
+ max_grad_l2_norm=1.0,
+ seed=0,
+ enable_amp=False,
+ enable_reproducibility=True,
+ log_sys_usage=True,
+ suppress_warnings=True,
+ ),
+)
diff --git a/examples/SAC_Pendulum-v1.py b/examples/SAC_Pendulum-v1.py
new file mode 100644
index 0000000..b3efe24
--- /dev/null
+++ b/examples/SAC_Pendulum-v1.py
@@ -0,0 +1,82 @@
+# ==============================================================================
+# Copyright 2022 Luca Della Libera.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Train Soft Actor-Critic (SAC) on Pendulum-v1."""
+
+# Navigate to `/examples`, open a terminal and run:
+# pip install gymnasium[classic_control]
+# actorch run SAC_Pendulum-v1.py
+
+import gymnasium as gym
+from torch.optim import Adam
+
+from actorch import *
+
+
+experiment_params = ExperimentParams(
+ run_or_experiment=SAC,
+ stop={"timesteps_total": int(1e5)},
+ resources_per_trial={"cpu": 1, "gpu": 0},
+ checkpoint_freq=10,
+ checkpoint_at_end=True,
+ log_to_file=True,
+ export_formats=["checkpoint", "model"],
+ config=SAC.Config(
+ train_env_builder=lambda **config: ParallelBatchedEnv(
+ lambda **kwargs: gym.make("Pendulum-v1", **kwargs),
+ config,
+ num_workers=2,
+ ),
+ train_num_timesteps_per_iter=500,
+ eval_freq=10,
+ eval_env_config={"render_mode": None},
+ eval_num_episodes_per_iter=10,
+ policy_network_model_builder=FCNet,
+ policy_network_model_config={
+ "torso_fc_configs": [
+ {"out_features": 256, "bias": True},
+ {"out_features": 256, "bias": True},
+ ],
+ },
+ policy_network_optimizer_builder=Adam,
+ policy_network_optimizer_config={"lr": 5e-3},
+ value_network_model_builder=FCNet,
+ value_network_model_config={
+ "torso_fc_configs": [
+ {"out_features": 256, "bias": True},
+ {"out_features": 256, "bias": True},
+ ],
+ },
+ value_network_optimizer_builder=Adam,
+ value_network_optimizer_config={"lr": 5e-3},
+ temperature_optimizer_builder=Adam,
+ temperature_optimizer_config={"lr": 1e-5},
+ buffer_config={"capacity": int(1e5)},
+ discount=0.99,
+ num_return_steps=1,
+ num_updates_per_iter=500,
+ batch_size=32,
+ max_trajectory_length=10,
+ sync_freq=1,
+ polyak_weight=0.001,
+ max_grad_l2_norm=2.0,
+ seed=0,
+ enable_amp=False,
+ enable_reproducibility=True,
+ log_sys_usage=True,
+ suppress_warnings=True,
+ ),
+)
diff --git a/examples/TD3_LunarLanderContinuous-v2.py b/examples/TD3_LunarLanderContinuous-v2.py
index 258e7c1..dee26b4 100644
--- a/examples/TD3_LunarLanderContinuous-v2.py
+++ b/examples/TD3_LunarLanderContinuous-v2.py
@@ -14,9 +14,10 @@
# limitations under the License.
# ==============================================================================
-"""Train Twin Delayed Deep Deterministic Policy Gradient on LunarLanderContinuous-v2."""
+"""Train Twin Delayed Deep Deterministic Policy Gradient (TD3) on LunarLanderContinuous-v2."""
# Navigate to `/examples`, open a terminal and run:
+# pip install gymnasium[box2d]
# actorch run TD3_LunarLanderContinuous-v2.py
import gymnasium as gym
diff --git a/examples/TRPO_Pendulum-v1.py b/examples/TRPO_Pendulum-v1.py
new file mode 100644
index 0000000..df1fbd0
--- /dev/null
+++ b/examples/TRPO_Pendulum-v1.py
@@ -0,0 +1,84 @@
+# ==============================================================================
+# Copyright 2022 Luca Della Libera.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Train Trust Region Policy Optimization (TRPO) on Pendulum-v1."""
+
+# Navigate to `/examples`, open a terminal and run:
+# pip install gymnasium[classic_control]
+# actorch run TRPO_Pendulum-v1.py
+
+import gymnasium as gym
+from torch.optim import Adam
+
+from actorch import *
+
+
+experiment_params = ExperimentParams(
+ run_or_experiment=TRPO,
+ stop={"timesteps_total": int(2e5)},
+ resources_per_trial={"cpu": 1, "gpu": 0},
+ checkpoint_freq=10,
+ checkpoint_at_end=True,
+ log_to_file=True,
+ export_formats=["checkpoint", "model"],
+ config=TRPO.Config(
+ train_env_builder=lambda **config: ParallelBatchedEnv(
+ lambda **kwargs: gym.make("Pendulum-v1", **kwargs),
+ config,
+ num_workers=2,
+ ),
+ train_num_timesteps_per_iter=512,
+ eval_freq=10,
+ eval_env_config={"render_mode": None},
+ eval_num_episodes_per_iter=10,
+ policy_network_model_builder=FCNet,
+ policy_network_model_config={
+ "torso_fc_configs": [
+ {"out_features": 256, "bias": True},
+ {"out_features": 256, "bias": True},
+ ],
+ "independent_heads": ["action/log_scale"],
+ },
+ policy_network_optimizer_config={
+ "max_constraint": 0.01,
+ "num_cg_iters": 10,
+ "max_backtracks": 15,
+ "backtrack_ratio": 0.8,
+ "hvp_reg_coeff": 1e-5,
+ "accept_violation": False,
+ "epsilon": 1e-8,
+ },
+ value_network_model_builder=FCNet,
+ value_network_model_config={
+ "torso_fc_configs": [
+ {"out_features": 256, "bias": True},
+ {"out_features": 256, "bias": True},
+ ],
+ },
+ value_network_optimizer_builder=Adam,
+ value_network_optimizer_config={"lr": 3e-3},
+ discount=0.9,
+ trace_decay=0.95,
+ normalize_advantage=False,
+ entropy_coeff=0.01,
+ max_grad_l2_norm=2.0,
+ seed=0,
+ enable_amp=False,
+ enable_reproducibility=True,
+ log_sys_usage=True,
+ suppress_warnings=True,
+ ),
+)
diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py
index 0b6e578..9d8c59d 100644
--- a/tests/test_algorithms.py
+++ b/tests/test_algorithms.py
@@ -123,6 +123,7 @@ def _cleanup():
algorithms.AWR,
algorithms.PPO,
algorithms.REINFORCE,
+ algorithms.TRPO,
],
)
@pytest.mark.parametrize(
@@ -198,7 +199,9 @@ def test_algorithm_mixed(algorithm_cls, model_cls, space):
@pytest.mark.parametrize(
"algorithm_cls",
[
+ algorithms.D3PG,
algorithms.DDPG,
+ algorithms.SAC,
algorithms.TD3,
],
)