Skip to content

Commit

Permalink
Add TRPO, D3PG and SAC, minor improvements and bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lucadellalib committed Mar 12, 2023
1 parent 758f7e4 commit 2eb6064
Show file tree
Hide file tree
Showing 53 changed files with 2,456 additions and 229 deletions.
52 changes: 36 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,36 @@ Welcome to `actorch`, a deep reinforcement learning framework for fast prototypi
- [REINFORCE](https://people.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf)
- [Advantage Actor-Critic (A2C)](https://arxiv.org/abs/1602.01783)
- [Actor-Critic Kronecker-Factored Trust Region (ACKTR)](https://arxiv.org/abs/1708.05144)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347)
- [Advantage-Weighted Regression (AWR)](https://arxiv.org/abs/1910.00177)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971)
- [Distributional Deep Deterministic Policy Gradient (D3PG)](https://arxiv.org/abs/1804.08617)
- [Twin Delayed Deep Deterministic Policy Gradient (TD3)](https://arxiv.org/abs/1802.09477)
- [Soft Actor-Critic (SAC)](https://arxiv.org/abs/1801.01290)

---------------------------------------------------------------------------------------------------------

## 💡 Key features

- Support for [OpenAI Gymnasium](https://gymnasium.farama.org/) environments
- Support for custom observation/action spaces
- Support for custom multimodal input multimodal output models
- Support for recurrent models (e.g. RNNs, LSTMs, GRUs, etc.)
- Support for custom policy/value distributions
- Support for custom preprocessing/postprocessing pipelines
- Support for custom exploration strategies
- Support for **custom observation/action spaces**
- Support for **custom multimodal input multimodal output models**
- Support for **recurrent models** (e.g. RNNs, LSTMs, GRUs, etc.)
- Support for **custom policy/value distributions**
- Support for **custom preprocessing/postprocessing pipelines**
- Support for **custom exploration strategies**
- Support for [normalizing flows](https://arxiv.org/abs/1906.02771)
- Batched environments (both for training and evaluation)
- Batched trajectory replay
- Batched and distributional value estimation (e.g. batched and distributional [Retrace](https://arxiv.org/abs/1606.02647) and [V-trace](https://arxiv.org/abs/1802.01561))
- Data parallel and distributed data parallel multi-GPU training and evaluation
- Automatic mixed precision training
- Integration with [Ray Tune](https://docs.ray.io/en/releases-1.13.0/tune/index.html) for experiment execution and hyperparameter tuning at any scale
- Effortless experiment definition through Python-based configuration files
- Built-in visualization tool to plot performance metrics
- Modular object-oriented design
- Detailed API documentation
- Batched **trajectory replay**
- Batched and **distributional value estimation** (e.g. batched and distributional [Retrace](https://arxiv.org/abs/1606.02647) and [V-trace](https://arxiv.org/abs/1802.01561))
- Data parallel and distributed data parallel **multi-GPU training and evaluation**
- Automatic **mixed precision training**
- Integration with [Ray Tune](https://docs.ray.io/en/releases-1.13.0/tune/index.html) for experiment execution and **hyperparameter tuning** at any scale
- Effortless experiment definition through **Python-based configuration files**
- Built-in **visualization tool** to plot performance metrics
- Modular **object-oriented** design
- Detailed **API documentation**

---------------------------------------------------------------------------------------------------------

Expand Down Expand Up @@ -161,7 +164,7 @@ experiment_params = ExperimentParams(
enable_amp=False,
enable_reproducibility=True,
log_sys_usage=True,
suppress_warnings=False,
suppress_warnings=True,
),
)
```
Expand Down Expand Up @@ -197,6 +200,8 @@ You can find the generated plots in `plots`.

Congratulations, you ran your first experiment!

See `examples` for additional configuration file examples.

**HINT**: since a configuration file is a regular Python script, you can use all the
features of the language (e.g. inheritance).

Expand All @@ -217,6 +222,21 @@ features of the language (e.g. inheritance).

---------------------------------------------------------------------------------------------------------

## @ Citation

```
@misc{DellaLibera2022ACTorch,
author = {Luca Della Libera},
title = {{ACTorch}: a Deep Reinforcement Learning Framework for Fast Prototyping},
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/lucadellalib/actorch}},
}
```

---------------------------------------------------------------------------------------------------------

## 📧 Contact

[luca.dellalib@gmail.com](mailto:luca.dellalib@gmail.com)
Expand Down
3 changes: 3 additions & 0 deletions actorch/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from actorch.algorithms.acktr import *
from actorch.algorithms.algorithm import *
from actorch.algorithms.awr import *
from actorch.algorithms.d3pg import *
from actorch.algorithms.ddpg import *
from actorch.algorithms.ppo import *
from actorch.algorithms.reinforce import *
from actorch.algorithms.sac import *
from actorch.algorithms.td3 import *
from actorch.algorithms.trpo import *
48 changes: 21 additions & 27 deletions actorch/algorithms/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================

"""Advantage Actor-Critic."""
"""Advantage Actor-Critic (A2C)."""

import contextlib
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
Expand All @@ -38,7 +38,7 @@
DistributedDataParallelREINFORCE,
LRScheduler,
)
from actorch.algorithms.utils import prepare_model
from actorch.algorithms.utils import normalize_, prepare_model
from actorch.algorithms.value_estimation import n_step_return
from actorch.distributions import Deterministic
from actorch.envs import BatchedEnv
Expand All @@ -65,7 +65,7 @@


class A2C(REINFORCE):
"""Advantage Actor-Critic.
"""Advantage Actor-Critic (A2C).
References
----------
Expand Down Expand Up @@ -208,12 +208,8 @@ def setup(self, config: "Dict[str, Any]") -> "None":
self.config = A2C.Config(**self.config)
self.config["_accept_kwargs"] = True
super().setup(config)
self._value_network = (
self._build_value_network().train().to(self._device, non_blocking=True)
)
self._value_network_loss = (
self._build_value_network_loss().train().to(self._device, non_blocking=True)
)
self._value_network = self._build_value_network()
self._value_network_loss = self._build_value_network_loss()
self._value_network_optimizer = self._build_value_network_optimizer()
self._value_network_optimizer_lr_scheduler = (
self._build_value_network_optimizer_lr_scheduler()
Expand Down Expand Up @@ -324,16 +320,20 @@ def _build_value_network(self) -> "Network":
self.value_network_normalizing_flows,
)
self._log_graph(value_network.wrapped_model.model, "value_network_model")
return value_network
return value_network.train().to(self._device, non_blocking=True)

def _build_value_network_loss(self) -> "Loss":
if self.value_network_loss_builder is None:
self.value_network_loss_builder = torch.nn.MSELoss
if self.value_network_loss_config is None:
self.value_network_loss_config: "Dict[str, Any]" = {}
return self.value_network_loss_builder(
reduction="none",
**self.value_network_loss_config,
return (
self.value_network_loss_builder(
reduction="none",
**self.value_network_loss_config,
)
.train()
.to(self._device, non_blocking=True)
)

def _build_value_network_optimizer(self) -> "Optimizer":
Expand Down Expand Up @@ -374,6 +374,8 @@ def _train_step(self) -> "Dict[str, Any]":
result = super()._train_step()
self.num_return_steps.step()
result["num_return_steps"] = self.num_return_steps()
result["entropy_coeff"] = result.pop("entropy_coeff", None)
result["max_grad_l2_norm"] = result.pop("max_grad_l2_norm", None)
return result

# override
Expand Down Expand Up @@ -405,17 +407,7 @@ def _train_on_batch(
self.num_return_steps(),
)
if self.normalize_advantage:
length = mask.sum(dim=1, keepdim=True)
advantages_mean = advantages.sum(dim=1, keepdim=True) / length
advantages -= advantages_mean
advantages *= mask
advantages_stddev = (
((advantages**2).sum(dim=1, keepdim=True) / length)
.sqrt()
.clamp(min=1e-6)
)
advantages /= advantages_stddev
advantages *= mask
normalize_(advantages, dim=-1, mask=mask)

# Discard next state value
state_values = state_values[:, :-1]
Expand Down Expand Up @@ -449,10 +441,12 @@ def _train_on_batch_value_network(
state_value = state_values[mask]
target = targets[mask]
loss = self._value_network_loss(state_value, target)
loss *= is_weight[:, None].expand_as(mask)[mask]
priority = None
if self._buffer.is_prioritized:
loss *= is_weight[:, None].expand_as(mask)[mask]
priority = loss.detach().abs().to("cpu").numpy()
loss = loss.mean()
optimize_result = self._optimize_value_network(loss)
priority = None
result = {
"state_value": state_value.mean().item(),
"target": target.mean().item(),
Expand Down Expand Up @@ -490,7 +484,7 @@ def _get_default_value_network_preprocessor(


class DistributedDataParallelA2C(DistributedDataParallelREINFORCE):
"""Distributed data parallel Advantage Actor-Critic.
"""Distributed data parallel Advantage Actor-Critic (A2C).
See Also
--------
Expand Down
6 changes: 3 additions & 3 deletions actorch/algorithms/acktr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================

"""Actor-Critic Kronecker-Factored Trust Region."""
"""Actor-Critic Kronecker-Factored Trust Region (ACKTR)."""

from typing import Any, Callable, Dict, Optional, Union

Expand Down Expand Up @@ -43,7 +43,7 @@


class ACKTR(A2C):
"""Actor-Critic Kronecker-Factored Trust Region.
"""Actor-Critic Kronecker-Factored Trust Region (ACKTR).
References
----------
Expand Down Expand Up @@ -287,7 +287,7 @@ def _optimize_policy_network(self, loss: "Tensor") -> "Dict[str, Any]":


class DistributedDataParallelACKTR(DistributedDataParallelA2C):
"""Distributed data parallel Actor-Critic Kronecker-Factored Trust Region.
"""Distributed data parallel Actor-Critic Kronecker-Factored Trust Region (ACKTR).
See Also
--------
Expand Down
54 changes: 21 additions & 33 deletions actorch/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from ray.tune.syncer import NodeSyncer
from ray.tune.trial import ExportFormat
from torch import Tensor
from torch.cuda.amp import autocast
from torch.distributions import Bernoulli, Categorical, Distribution, Normal
from torch.profiler import profile, record_function, tensorboard_trace_handler
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -113,7 +112,7 @@ class Algorithm(ABC, Trainable):

_EXPORT_FORMATS = [ExportFormat.CHECKPOINT, ExportFormat.MODEL]

_UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH = True
_OFF_POLICY = True

class Config(dict):
"""Keyword arguments expected in the configuration received by `setup`."""
Expand Down Expand Up @@ -692,12 +691,7 @@ def _build_train_env(self) -> "BatchedEnv":
if self.train_env_config is None:
self.train_env_config = {}

try:
train_env = self.train_env_builder(
**self.train_env_config,
)
except TypeError:
train_env = self.train_env_builder(**self.train_env_config)
train_env = self.train_env_builder(**self.train_env_config)
if not isinstance(train_env, BatchedEnv):
train_env.close()
train_env = SerialBatchedEnv(self.train_env_builder, self.train_env_config)
Expand Down Expand Up @@ -866,7 +860,7 @@ def _build_policy_network(self) -> "PolicyNetwork": # noqa: C901
self.policy_network_postprocessors,
)
self._log_graph(policy_network.wrapped_model.model, "policy_network_model")
return policy_network
return policy_network.train().to(self._device, non_blocking=True)

def _build_train_agent(self) -> "Agent":
if self.train_agent_builder is None:
Expand Down Expand Up @@ -989,14 +983,18 @@ def _build_dataloader(self) -> "DataLoader":
if self.dataloader_builder is None:
self.dataloader_builder = DataLoader
if self.dataloader_config is None:
fork = torch.multiprocessing.get_start_method() == "fork"
use_mp = (
self._OFF_POLICY
and not self._buffer.is_prioritized
and torch.multiprocessing.get_start_method() == "fork"
)
self.dataloader_config = {
"num_workers": 1 if fork else 0,
"num_workers": 1 if use_mp else 0,
"pin_memory": True,
"timeout": 0,
"worker_init_fn": None,
"generator": None,
"prefetch_factor": 1 if fork else 2,
"prefetch_factor": 1 if use_mp else 2,
"pin_memory_device": "",
}
if self.dataloader_config is None:
Expand Down Expand Up @@ -1037,20 +1035,15 @@ def _train_step(self) -> "Dict[str, Any]":
if self.train_num_episodes_per_iter:
train_num_episodes_per_iter = self.train_num_episodes_per_iter()
self.train_num_episodes_per_iter.step()
with (
autocast(**self.enable_amp)
if self.enable_amp["enabled"]
else contextlib.suppress()
for experience, done in self._train_sampler.sample(
train_num_timesteps_per_iter,
train_num_episodes_per_iter,
):
for experience, done in self._train_sampler.sample(
train_num_timesteps_per_iter,
train_num_episodes_per_iter,
):
self._buffer.add(experience, done)
self._buffer.add(experience, done)
result = self._train_sampler.stats
self._cumrewards += result["episode_cumreward"]

if not self._UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH:
if not self._OFF_POLICY:
for schedule in self._buffer_dataset.schedules.values():
schedule.step()
train_epoch_result = self._train_epoch()
Expand All @@ -1060,7 +1053,7 @@ def _train_step(self) -> "Dict[str, Any]":
schedule.step()
for schedule in self._buffer.schedules.values():
schedule.step()
if self._UPDATE_BUFFER_DATASET_SCHEDULES_AFTER_TRAIN_EPOCH:
if self._OFF_POLICY:
for schedule in self._buffer_dataset.schedules.values():
schedule.step()
return result
Expand Down Expand Up @@ -1113,16 +1106,11 @@ def _eval_step(self) -> "Dict[str, Any]":
eval_num_episodes_per_iter = self.eval_num_episodes_per_iter()
self.eval_num_episodes_per_iter.step()
self._eval_sampler.reset()
with (
autocast(**self.enable_amp)
if self.enable_amp["enabled"]
else contextlib.suppress()
for _ in self._eval_sampler.sample(
eval_num_timesteps_per_iter,
eval_num_episodes_per_iter,
):
for _ in self._eval_sampler.sample(
eval_num_timesteps_per_iter,
eval_num_episodes_per_iter,
):
pass
pass
for schedule in self._eval_agent.schedules.values():
schedule.step()
return self._eval_sampler.stats
Expand Down Expand Up @@ -1353,7 +1341,7 @@ def __init__(
Default to ``{}``.
placement_strategy:
The placement strategy
(see https://docs.ray.io/en/latest/ray-core/placement-group.html).
(see https://docs.ray.io/en/releases-1.13.0/ray-core/placement-group.html for Ray 1.13.0).
backend:
The backend for distributed execution
(see https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group).
Expand Down
Loading

0 comments on commit 2eb6064

Please sign in to comment.