diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
deleted file mode 100644
index 7cbc19856..000000000
--- a/.github/ISSUE_TEMPLATE/bug_report.md
+++ /dev/null
@@ -1,66 +0,0 @@
----
-name: "\U0001F41B Bug Report"
-about: Submit a bug report to help us improve Stable-Baselines3
-labels: bug
-title: "[Bug] bug title"
----
-
-**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
-Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
-
-
-If your issue is related to a **custom gym environment**, please use the custom gym env template.
-
-### 🐛 Bug
-
-A clear and concise description of what the bug is.
-
-
-### To Reproduce
-
-Steps to reproduce the behavior.
-
-Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
-
-Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks)
-for both code and stack traces.
-
-```python
-from stable_baselines3 import ...
-
-```
-
-```bash
-Traceback (most recent call last): File ...
-
-```
-
-### Expected behavior
-
-A clear and concise description of what you expected to happen.
-
-
-### System Info
-
-Describe the characteristic of your environment:
- * Describe how the library was installed (pip, docker, source, ...)
- * GPU models and configuration
- * Python version
- * PyTorch version
- * Gym version
- * Versions of any other relevant libraries
-
-You can use `sb3.get_system_info()` to print relevant packages info:
-```python
-import stable_baselines3 as sb3
-sb3.get_system_info()
-```
-
-### Additional context
-Add any other context about the problem here.
-
-### Checklist
-
-- [ ] I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo (**required**)
-- [ ] I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/) (**required**)
-- [ ] I have provided a minimal working example to reproduce the bug (**required**)
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
new file mode 100644
index 000000000..8defe9a1e
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -0,0 +1,71 @@
+name: "\U0001F41B Bug Report"
+description: Submit a bug report to help us improve Stable-Baselines3
+title: "[Bug]: bug title"
+labels: ["bug"]
+body:
+ - type: markdown
+ attributes:
+ value: |
+ **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
+ Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
+
+ If your issue is related to a **custom gym environment**, please use the custom gym env template.
+ - type: textarea
+ id: description
+ attributes:
+ label: 🐛 Bug
+ description: A clear and concise description of what the bug is.
+ validations:
+ required: true
+ - type: textarea
+ id: reproduce
+ attributes:
+ label: To Reproduce
+ description: |
+ Steps to reproduce the behavior. Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
+ Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
+ value: |
+ ```python
+ from stable_baselines3 import ...
+
+ ```
+
+ - type: textarea
+ id: traceback
+ attributes:
+ label: Relevant log output / Error message
+ description: Please copy and paste any relevant log output / error message. This will be automatically formatted into code, so no need for backticks.
+ placeholder: "Traceback (most recent call last): File ..."
+ render: shell
+
+ - type: textarea
+ id: system-info
+ attributes:
+ label: System Info
+ description: |
+ Describe the characteristic of your environment:
+ * Describe how the library was installed (pip, docker, source, ...)
+ * GPU models and configuration
+ * Python version
+ * PyTorch version
+ * Gym version
+ * Versions of any other relevant libraries
+
+ You can use `sb3.get_system_info()` to print relevant packages info:
+ ```python
+ import stable_baselines3 as sb3
+ sb3.get_system_info()
+ ```
+ - type: checkboxes
+ id: terms
+ attributes:
+ label: Checklist
+ options:
+ - label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo
+ required: true
+ - label: I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/)
+ required: true
+ - label: I have provided a minimal working example to reproduce the bug
+ required: true
+ - label: I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
+ required: true
diff --git a/.github/ISSUE_TEMPLATE/custom_env.md b/.github/ISSUE_TEMPLATE/custom_env.md
deleted file mode 100644
index 0a12a68bb..000000000
--- a/.github/ISSUE_TEMPLATE/custom_env.md
+++ /dev/null
@@ -1,95 +0,0 @@
----
-name: "\U0001F916 Custom Gym Environment Issue"
-about: How to report an issue when using a custom Gym environment
-labels: question, custom gym env
----
-
-**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
-Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
-
-### 🤖 Custom Gym Environment
-
-**Please check your environment first using**:
-
-```python
-from stable_baselines3.common.env_checker import check_env
-
-env = CustomEnv(arg1, ...)
-# It will check your custom environment and output additional warnings if needed
-check_env(env)
-```
-
-### Describe the bug
-
-A clear and concise description of what the bug is.
-
-### Code example
-
-Please try to provide a minimal example to reproduce the bug.
-For a custom environment, you need to give at least the observation space, action space, `reset()` and `step()` methods
-(see working example below).
-Error messages and stack traces are also helpful.
-
-Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks)
-for both code and stack traces.
-
-```python
-import gym
-import numpy as np
-
-from stable_baselines3 import A2C
-from stable_baselines3.common.env_checker import check_env
-
-
-class CustomEnv(gym.Env):
-
- def __init__(self):
- super(CustomEnv, self).__init__()
- self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
- self.action_space = gym.spaces.Box(low=-1, high=1, shape=(6,))
-
- def reset(self):
- return self.observation_space.sample()
-
- def step(self, action):
- obs = self.observation_space.sample()
- reward = 1.0
- done = False
- info = {}
- return obs, reward, done, info
-
-env = CustomEnv()
-check_env(env)
-
-model = A2C("MlpPolicy", env, verbose=1).learn(1000)
-```
-
-```bash
-Traceback (most recent call last): File ...
-
-```
-
-### System Info
-Describe the characteristic of your environment:
- * Describe how the library was installed (pip, docker, source, ...)
- * GPU models and configuration
- * Python version
- * PyTorch version
- * Gym version
- * Versions of any other relevant libraries
-
-You can use `sb3.get_system_info()` to print relevant packages info:
-```python
-import stable_baselines3 as sb3
-sb3.get_system_info()
-```
-
-### Additional context
-Add any other context about the problem here.
-
-### Checklist
-
-- [ ] I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/) (**required**)
-- [ ] I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo (**required**)
-- [ ] I have checked my env using the env checker (**required**)
-- [ ] I have provided a minimal working example to reproduce the bug (**required**)
diff --git a/.github/ISSUE_TEMPLATE/custom_env.yml b/.github/ISSUE_TEMPLATE/custom_env.yml
new file mode 100644
index 000000000..006fb5a92
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/custom_env.yml
@@ -0,0 +1,108 @@
+name: "\U0001F916 Custom Gym Environment Issue"
+description: How to report an issue when using a custom Gym environment
+labels: ["question", "custom gym env"]
+body:
+ - type: markdown
+ attributes:
+ value: |
+ **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
+ Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
+
+ **Please check your environment first using**:
+ ```python
+ from stable_baselines3.common.env_checker import check_env
+
+ env = CustomEnv(arg1, ...)
+ # It will check your custom environment and output additional warnings if needed
+ check_env(env)
+ ```
+ - type: textarea
+ id: description
+ attributes:
+ label: 🐛 Bug
+ description: A clear and concise description of what the bug is.
+ validations:
+ required: true
+ - type: textarea
+ id: code-example
+ attributes:
+ label: Code example
+ description: |
+ Please try to provide a minimal example to reproduce the bug.
+ For a custom environment, you need to give at least the observation space, action space, `reset()` and `step()` methods (see working example below).
+ Error messages and stack traces are also helpful.
+ Please use the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
+ value: |
+ ```python
+ import gymnasium as gym
+ import numpy as np
+
+ from stable_baselines3 import A2C
+ from stable_baselines3.common.env_checker import check_env
+
+
+ class CustomEnv(gym.Env):
+
+ def __init__(self):
+ super().__init__()
+ self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
+ self.action_space = gym.spaces.Box(low=-1, high=1, shape=(6,))
+
+ def reset(self):
+ return self.observation_space.sample(), {}
+
+ def step(self, action):
+ obs = self.observation_space.sample()
+ reward = 1.0
+ done = False
+ truncated = False
+ info = {}
+ return obs, reward, done, truncated, info
+
+ env = CustomEnv()
+ check_env(env)
+
+ model = A2C("MlpPolicy", env, verbose=1).learn(1000)
+ ```
+
+ - type: textarea
+ id: traceback
+ attributes:
+ label: Relevant log output / Error message
+ description: Please copy and paste any relevant log output / error message. This will be automatically formatted into code, so no need for backticks.
+ placeholder: "Traceback (most recent call last): File ..."
+ render: shell
+
+ - type: textarea
+ id: system-info
+ attributes:
+ label: System Info
+ description: |
+ Describe the characteristic of your environment:
+ * Describe how the library was installed (pip, docker, source, ...)
+ * GPU models and configuration
+ * Python version
+ * PyTorch version
+ * Gym version
+ * Versions of any other relevant libraries
+
+ You can use `sb3.get_system_info()` to print relevant packages info:
+ ```python
+ import stable_baselines3 as sb3
+ sb3.get_system_info()
+ ```
+ - type: checkboxes
+ id: terms
+ attributes:
+ label: Checklist
+ options:
+ - label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo
+ required: true
+ - label: I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/)
+ required: true
+ - label: I have provided a minimal working example to reproduce the bug
+ required: true
+ - label: I have checked my env using the env checker
+ required: true
+ - label: I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
+ required: true
diff --git a/.github/ISSUE_TEMPLATE/documentation.md b/.github/ISSUE_TEMPLATE/documentation.md
deleted file mode 100644
index 59e5da5fb..000000000
--- a/.github/ISSUE_TEMPLATE/documentation.md
+++ /dev/null
@@ -1,21 +0,0 @@
----
-name: "\U0001F4DA Documentation"
-about: Report an issue related to Stable-Baselines3 documentation
-labels: documentation
----
-
-**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
-Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
-
-### 📚 Documentation
-
-A clear and concise description of what should be improved in the documentation.
-
-### Checklist
-
-- [ ] I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/) (**required**)
-- [ ] I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo (**required**)
-
-
-
-
diff --git a/.github/ISSUE_TEMPLATE/documentation.yml b/.github/ISSUE_TEMPLATE/documentation.yml
new file mode 100644
index 000000000..025e2d3b7
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/documentation.yml
@@ -0,0 +1,25 @@
+name: "\U0001F4DA Documentation"
+description: Report an issue related to Stable-Baselines3 documentation
+labels: ["documentation"]
+body:
+ - type: markdown
+ attributes:
+ value: |
+ **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
+ Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
+ - type: textarea
+ id: description
+ attributes:
+ label: 📚 Documentation
+ description: A clear and concise description of what should be improved in the documentation.
+ validations:
+ required: true
+ - type: checkboxes
+ id: terms
+ attributes:
+ label: Checklist
+ options:
+ - label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo
+ required: true
+ - label: I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/)
+ required: true
diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md
deleted file mode 100644
index a65086335..000000000
--- a/.github/ISSUE_TEMPLATE/feature_request.md
+++ /dev/null
@@ -1,39 +0,0 @@
----
-name: "\U0001F680Feature Request"
-about: How to create an issue for requesting a feature
-labels: enhancement
-title: "[Feature Request] request title"
----
-
-**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
-Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
-
-
-### 🚀 Feature
-
-A clear and concise description of the feature proposal.
-
-### Motivation
-
-Please outline the motivation for the proposal.
-Is your feature request related to a problem? e.g.,"I'm always frustrated when [...]".
-If this is related to another GitHub issue, please link here too.
-
-### Pitch
-
-A clear and concise description of what you want to happen.
-
-### Alternatives
-
-A clear and concise description of any alternative solutions or features you've considered, if any.
-
-### Additional context
-
-Add any other context or screenshots about the feature request here.
-
-### Checklist
-
-- [ ] I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo (**required**)
-
-
-
diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml
new file mode 100644
index 000000000..1d598a153
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.yml
@@ -0,0 +1,44 @@
+name: "\U0001F680 Feature Request"
+description: How to create an issue for requesting a feature
+title: "[Feature Request] request title"
+labels: ["enhancement"]
+body:
+ - type: markdown
+ attributes:
+ value: |
+ **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
+ Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
+ - type: textarea
+ id: description
+ attributes:
+ label: 🚀 Feature
+ description: A clear and concise description of the feature proposal.
+ validations:
+ required: true
+ - type: textarea
+ id: motivation
+ attributes:
+ label: Motivation
+ description: Please outline the motivation for the proposal. Is your feature request related to a problem? e.g.,"I'm always frustrated when [...]". If this is related to another GitHub issue, please link here too.
+ - type: textarea
+ id: pitch
+ attributes:
+ label: Pitch
+ description: A clear and concise description of what you want to happen.
+ - type: textarea
+ id: alternatives
+ attributes:
+ label: Alternatives
+ description: A clear and concise description of any alternative solutions or features you've considered, if any.
+ - type: textarea
+ id: additional-context
+ attributes:
+ label: Additional context
+ description: Add any other context or screenshots about the feature request here.
+ - type: checkboxes
+ id: terms
+ attributes:
+ label: Checklist
+ options:
+ - label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo
+ required: true
diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md
deleted file mode 100644
index b3288d889..000000000
--- a/.github/ISSUE_TEMPLATE/question.md
+++ /dev/null
@@ -1,26 +0,0 @@
----
-name: ❓Question
-about: How to ask a question regarding Stable-Baselines3
-labels: question
-title: "[Question] question title"
----
-
-**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
-Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
-
-
-### Question
-
-Your question. This can be e.g. questions regarding confusing or unclear behaviour of functions or a question if X can be done using stable-baselines3. Make sure to check out the documentation first.
-
-### Additional context
-
-Add any other context about the question here.
-
-
-### Checklist
-
-- [ ] I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/) (**required**)
-- [ ] I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo (**required**)
-
-
diff --git a/.github/ISSUE_TEMPLATE/question.yml b/.github/ISSUE_TEMPLATE/question.yml
new file mode 100644
index 000000000..b2fb2f526
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/question.yml
@@ -0,0 +1,30 @@
+name: "❓ Question"
+description: How to ask a question regarding Stable-Baselines3
+title: "[Question] question title"
+labels: ["question"]
+body:
+ - type: markdown
+ attributes:
+ value: |
+ **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
+ Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
+ - type: textarea
+ id: question
+ attributes:
+ label: ❓ Question
+ description: Your question. This can be e.g. questions regarding confusing or unclear behaviour of functions or a question if X can be done using stable-baselines3. Make sure to check out the documentation first.
+ validations:
+ required: true
+ - type: checkboxes
+ id: terms
+ attributes:
+ label: Checklist
+ options:
+ - label: I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo
+ required: true
+ - label: I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/)
+ required: true
+ - label: If code there is, it is minimal and working
+ required: true
+ - label: If code there is, it is formatted using the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
+ required: true
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 45ca8f56a..1e32b4053 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -1,14 +1,17 @@
-image: stablebaselines/stable-baselines3-cpu:1.4.1a0
+image: stablebaselines/stable-baselines3-cpu:1.5.1a6
type-check:
script:
+ - pip install pytype --upgrade
- make type
pytest:
script:
+ - pip install tqdm rich # for progress bar
- python --version
# MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error
- MKL_THREADING_LAYER=GNU make pytest
+ coverage: '/^TOTAL.+?(\d+\%)$/'
doc-build:
script:
diff --git a/Dockerfile b/Dockerfile
index 8dfbbbf4c..96588ef91 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -3,6 +3,9 @@ FROM $PARENT_IMAGE
ARG PYTORCH_DEPS=cpuonly
ARG PYTHON_VERSION=3.7
+# for tzdata
+ENV DEBIAN_FRONTEND="noninteractive" TZ="Europe/Paris"
+
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
@@ -20,7 +23,7 @@ RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest
~/miniconda.sh -b -p /opt/conda && \
rm ~/miniconda.sh && \
/opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include && \
- /opt/conda/bin/conda install -y pytorch $PYTORCH_DEPS -c pytorch && \
+ /opt/conda/bin/conda install -y pytorch=1.11 $PYTORCH_DEPS -c pytorch && \
/opt/conda/bin/conda clean -ya
ENV PATH /opt/conda/bin:$PATH
diff --git a/Makefile b/Makefile
index 9954c7d7b..02851cf45 100644
--- a/Makefile
+++ b/Makefile
@@ -29,7 +29,8 @@ check-codestyle:
commit-checks: format type lint
doc:
- cd docs && make html
+ # Prevent weird error due to protobuf
+ cd docs && PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp make html
spelling:
cd docs && make spelling
diff --git a/README.md b/README.md
index f7275478f..f633cdeaf 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,8 @@
-[![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/master/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master)
+
+![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)
+[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master)
[![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
@@ -77,7 +79,7 @@ Documentation: https://stable-baselines3.readthedocs.io/en/master/guide/rl_zoo.h
We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)
-This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
+This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/)
@@ -115,22 +117,24 @@ Most of the library tries to follow a sklearn-like syntax for the Reinforcement
Here is a quick example of how to train and run PPO on a cartpole environment:
```python
-import gym
+import gymnasium as gym
from stable_baselines3 import PPO
env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, verbose=1)
-model.learn(total_timesteps=10000)
+model.learn(total_timesteps=10_000)
-obs = env.reset()
+vec_env = model.get_env()
+obs = vec_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
- obs, reward, done, info = env.step(action)
- env.render()
- if done:
- obs = env.reset()
+ obs, reward, done, info = vec_env.step(action)
+ vec_env.render()
+ # VecEnv resets automatically
+ # if done:
+ # obs = env.reset()
env.close()
```
@@ -140,7 +144,7 @@ Or just train a model with a one liner if [the environment is registered in Gym]
```python
from stable_baselines3 import PPO
-model = PPO('MlpPolicy', 'CartPole-v1').learn(10000)
+model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
```
Please read the [documentation](https://stable-baselines3.readthedocs.io/) for more examples.
@@ -172,6 +176,7 @@ All the following examples can be executed online using Google colab notebooks:
| HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :x: |
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| QR-DQN[1](#f1) | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
+| RecurrentPPO[1](#f1) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| TQC[1](#f1) | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
@@ -231,7 +236,7 @@ To cite this repository in publications:
## Maintainers
-Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave) and [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli).
+Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec).
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
diff --git a/docs/_static/img/split_graph.png b/docs/_static/img/split_graph.png
new file mode 100644
index 000000000..c966c5565
Binary files /dev/null and b/docs/_static/img/split_graph.png differ
diff --git a/docs/conda_env.yml b/docs/conda_env.yml
index a01d37bce..7b89ba92b 100644
--- a/docs/conda_env.yml
+++ b/docs/conda_env.yml
@@ -4,11 +4,11 @@ channels:
- defaults
dependencies:
- cpuonly=1.0=0
- - pip=21.1
+ - pip=22.1.1
- python=3.7
- - pytorch=1.8.1=py3.7_cpu_0
+ - pytorch=1.11.0=py3.7_cpu_0
- pip:
- - gym>=0.17.2
+ - gym==0.26
- cloudpickle
- opencv-python-headless
- pandas
@@ -16,5 +16,5 @@ dependencies:
- matplotlib
- sphinx_autodoc_typehints
- sphinx>=4.2
- # See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115
- sphinx_rtd_theme>=1.0
+ - sphinx_copybutton
diff --git a/docs/conf.py b/docs/conf.py
index 088f8a067..b44be6f66 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
@@ -25,6 +24,14 @@
except ImportError:
enable_spell_check = False
+# Try to enable copy button
+try:
+ import sphinx_copybutton # noqa: F401
+
+ enable_copy_button = True
+except ImportError:
+ enable_copy_button = False
+
# source code directory, relative to this file, for sphinx-autobuild
sys.path.insert(0, os.path.abspath(".."))
@@ -46,13 +53,13 @@ def __getattr__(cls, name):
# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "../stable_baselines3", "version.txt")
-with open(version_file, "r") as file_handler:
+with open(version_file) as file_handler:
__version__ = file_handler.read().strip()
# -- Project information -----------------------------------------------------
project = "Stable Baselines3"
-copyright = "2020, Stable Baselines3"
+copyright = "2022, Stable Baselines3"
author = "Stable Baselines3 Contributors"
# The short X.Y version
@@ -84,6 +91,9 @@ def __getattr__(cls, name):
if enable_spell_check:
extensions.append("sphinxcontrib.spelling")
+if enable_copy_button:
+ extensions.append("sphinx_copybutton")
+
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
@@ -101,7 +111,7 @@ def __getattr__(cls, name):
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
-language = None
+language = "en"
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst
index 474a04786..55aba3572 100644
--- a/docs/guide/algos.rst
+++ b/docs/guide/algos.rst
@@ -15,6 +15,7 @@ DQN ❌ ✔️ ❌ ❌
HER ✔️ ✔️ ❌ ❌ ❌
PPO ✔️ ✔️ ✔️ ✔️ ✔️
QR-DQN [#f1]_ ❌ ️ ✔️ ❌ ❌ ✔️
+RecurrentPPO [#f1]_ ✔️ ✔️ ✔️ ✔️ ✔️
SAC ✔️ ❌ ❌ ❌ ✔️
TD3 ✔️ ❌ ❌ ❌ ✔️
TQC [#f1]_ ✔️ ❌ ❌ ❌ ✔️
@@ -26,8 +27,8 @@ Maskable PPO [#f1]_ ❌ ✔️ ✔️ ✔
.. [#f1] Implemented in `SB3 Contrib `_
.. note::
- ``Tuple`` observation spaces are not supported by any environment
- however single-level ``Dict`` spaces are (cf. :ref:`Examples `).
+ ``Tuple`` observation spaces are not supported by any environment,
+ however, single-level ``Dict`` spaces are (cf. :ref:`Examples `).
Actions ``gym.spaces``:
diff --git a/docs/guide/callbacks.rst b/docs/guide/callbacks.rst
index 6c7f4ebcd..e08c00367 100644
--- a/docs/guide/callbacks.rst
+++ b/docs/guide/callbacks.rst
@@ -27,7 +27,7 @@ You can find two examples of custom callbacks in the documentation: one for savi
"""
A custom callback that derives from ``BaseCallback``.
- :param verbose: (int) Verbosity level 0: not output 1: info 2: debug
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, verbose=0):
super(CustomCallback, self).__init__(verbose)
@@ -121,7 +121,7 @@ A child callback is for instance :ref:`StopTrainingOnRewardThreshold ` wrapper, you can save the
+corresponding statistics using ``save_vecnormalize`` (``False`` by default).
.. warning::
@@ -168,14 +172,20 @@ and optionally a prefix for the checkpoints (``rl_model`` by default).
.. code-block:: python
- from stable_baselines3 import SAC
- from stable_baselines3.common.callbacks import CheckpointCallback
- # Save a checkpoint every 1000 steps
- checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/',
- name_prefix='rl_model')
+ from stable_baselines3 import SAC
+ from stable_baselines3.common.callbacks import CheckpointCallback
+
+ # Save a checkpoint every 1000 steps
+ checkpoint_callback = CheckpointCallback(
+ save_freq=1000,
+ save_path="./logs/",
+ name_prefix="rl_model",
+ save_replay_buffer=True,
+ save_vecnormalize=True,
+ )
- model = SAC('MlpPolicy', 'Pendulum-v1')
- model.learn(2000, callback=checkpoint_callback)
+ model = SAC("MlpPolicy", "Pendulum-v1")
+ model.learn(2000, callback=checkpoint_callback)
.. _EvalCallback:
@@ -200,21 +210,44 @@ It will save the best model if ``best_model_save_path`` folder is specified and
.. code-block:: python
- import gym
+ import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback
# Separate evaluation env
- eval_env = gym.make('Pendulum-v1')
+ eval_env = gym.make("Pendulum-v1")
# Use deterministic actions for evaluation
- eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/',
- log_path='./logs/', eval_freq=500,
+ eval_callback = EvalCallback(eval_env, best_model_save_path="./logs/",
+ log_path="./logs/", eval_freq=500,
deterministic=True, render=False)
- model = SAC('MlpPolicy', 'Pendulum-v1')
+ model = SAC("MlpPolicy", "Pendulum-v1")
model.learn(5000, callback=eval_callback)
+.. _ProgressBarCallback:
+
+ProgressBarCallback
+^^^^^^^^^^^^^^^^^^^
+
+Display a progress bar with the current progress, elapsed time and estimated remaining time.
+This callback is integrated inside SB3 via the ``progress_bar`` argument of the ``learn()`` method.
+
+.. note::
+
+ This callback requires ``tqdm`` and ``rich`` packages to be installed. This is done automatically when using ``pip install stable-baselines3[extra]``
+
+
+.. code-block:: python
+
+ from stable_baselines3 import PPO
+ from stable_baselines3.common.callbacks import ProgressBarCallback
+
+ model = PPO("MlpPolicy", "Pendulum-v1")
+ # Display progress bar using the progress bar callback
+ # this is equivalent to model.learn(100_000, callback=ProgressBarCallback())
+ model.learn(100_000, progress_bar=True)
+
.. _Callbacklist:
@@ -227,20 +260,20 @@ Alternatively, you can pass directly a list of callbacks to the ``learn()`` meth
.. code-block:: python
- import gym
+ import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback
- checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/')
+ checkpoint_callback = CheckpointCallback(save_freq=1000, save_path="./logs/")
# Separate evaluation env
- eval_env = gym.make('Pendulum-v1')
- eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/best_model',
- log_path='./logs/results', eval_freq=500)
+ eval_env = gym.make("Pendulum-v1")
+ eval_callback = EvalCallback(eval_env, best_model_save_path="./logs/best_model",
+ log_path="./logs/results", eval_freq=500)
# Create the callback list
callback = CallbackList([checkpoint_callback, eval_callback])
- model = SAC('MlpPolicy', 'Pendulum-v1')
+ model = SAC("MlpPolicy", "Pendulum-v1")
# Equivalent to:
# model.learn(5000, callback=[checkpoint_callback, eval_callback])
model.learn(5000, callback=callback)
@@ -257,18 +290,18 @@ It must be used with the :ref:`EvalCallback` and use the event triggered by a ne
.. code-block:: python
- import gym
+ import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
# Separate evaluation env
- eval_env = gym.make('Pendulum-v1')
+ eval_env = gym.make("Pendulum-v1")
# Stop training when the model reaches the reward threshold
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, verbose=1)
- model = SAC('MlpPolicy', 'Pendulum-v1', verbose=1)
+ model = SAC("MlpPolicy", "Pendulum-v1", verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the reward threshold is reached
model.learn(int(1e10), callback=eval_callback)
@@ -289,17 +322,17 @@ An :ref:`EventCallback` that will trigger its child callback every ``n_steps`` t
.. code-block:: python
- import gym
+ import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback, EveryNTimesteps
# this is equivalent to defining CheckpointCallback(save_freq=500)
# checkpoint_callback will be triggered every 500 steps
- checkpoint_on_event = CheckpointCallback(save_freq=1, save_path='./logs/')
+ checkpoint_on_event = CheckpointCallback(save_freq=1, save_path="./logs/")
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
- model = PPO('MlpPolicy', 'Pendulum-v1', verbose=1)
+ model = PPO("MlpPolicy", "Pendulum-v1", verbose=1)
model.learn(int(2e4), callback=event_callback)
@@ -328,7 +361,7 @@ and in total for ``max_episodes * n_envs`` episodes.
# Stops training when the model reaches the maximum number of episodes
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=5, verbose=1)
- model = A2C('MlpPolicy', 'Pendulum-v1', verbose=1)
+ model = A2C("MlpPolicy", "Pendulum-v1", verbose=1)
# Almost infinite number of timesteps, but the training will stop
# early as soon as the max number of episodes is reached
model.learn(int(1e10), callback=callback_max_episodes)
@@ -346,7 +379,7 @@ It must be used with the :ref:`EvalCallback` and use the event triggered after e
.. code-block:: python
- import gym
+ import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement
diff --git a/docs/guide/checking_nan.rst b/docs/guide/checking_nan.rst
index 92f80938b..f3b04a1a1 100644
--- a/docs/guide/checking_nan.rst
+++ b/docs/guide/checking_nan.rst
@@ -51,7 +51,7 @@ which defines for the python process, how it should handle floating point error.
import numpy as np
- np.seterr(all='raise') # define before your code.
+ np.seterr(all="raise") # define before your code.
print("numpy test:")
@@ -66,7 +66,7 @@ but this will also avoid overflow issues on floating point numbers:
import numpy as np
- np.seterr(all='raise') # define before your code.
+ np.seterr(all="raise") # define before your code.
print("numpy overflow test:")
@@ -81,11 +81,11 @@ but will not avoid the propagation issues:
import numpy as np
- np.seterr(all='raise') # define before your code.
+ np.seterr(all="raise") # define before your code.
print("numpy propagation test:")
- a = np.float64('NaN')
+ a = np.float64("NaN")
b = np.float64(1.0)
val = a + b # this will neither warn nor raise anything
print(val)
@@ -100,8 +100,8 @@ It will monitor the actions, observations, and rewards, indicating what action o
.. code-block:: python
- import gym
- from gym import spaces
+ import gymnasium as gym
+ from gymnasium import spaces
import numpy as np
from stable_baselines3 import PPO
@@ -109,7 +109,7 @@ It will monitor the actions, observations, and rewards, indicating what action o
class NanAndInfEnv(gym.Env):
"""Custom Environment that raised NaNs and Infs"""
- metadata = {'render.modes': ['human']}
+ metadata = {"render.modes": ["human"]}
def __init__(self):
super(NanAndInfEnv, self).__init__()
@@ -119,9 +119,9 @@ It will monitor the actions, observations, and rewards, indicating what action o
def step(self, _action):
randf = np.random.rand()
if randf > 0.99:
- obs = float('NaN')
+ obs = float("NaN")
elif randf > 0.98:
- obs = float('inf')
+ obs = float("inf")
else:
obs = randf
return [obs], 0.0, False, {}
@@ -129,7 +129,7 @@ It will monitor the actions, observations, and rewards, indicating what action o
def reset(self):
return [0.0]
- def render(self, mode='human', close=False):
+ def render(self, mode="human", close=False):
pass
# Create environment
@@ -137,7 +137,7 @@ It will monitor the actions, observations, and rewards, indicating what action o
env = VecCheckNan(env, raise_exception=True)
# Instantiate the agent
- model = PPO('MlpPolicy', env)
+ model = PPO("MlpPolicy", env)
# Train the agent
model.learn(total_timesteps=int(2e5)) # this will crash explaining that the invalid value originated from the environment.
diff --git a/docs/guide/custom_env.rst b/docs/guide/custom_env.rst
index 2e2d1f76e..3a7bf9264 100644
--- a/docs/guide/custom_env.rst
+++ b/docs/guide/custom_env.rst
@@ -22,12 +22,12 @@ That is to say, your environment must implement the following methods (and inher
.. code-block:: python
- import gym
- from gym import spaces
+ import gymnasium as gym
+ from gymnasium import spaces
class CustomEnv(gym.Env):
"""Custom Environment that follows gym interface"""
- metadata = {'render.modes': ['human']}
+ metadata = {"render.modes": ["human"]}
def __init__(self, arg1, arg2, ...):
super(CustomEnv, self).__init__()
@@ -45,7 +45,7 @@ That is to say, your environment must implement the following methods (and inher
def reset(self):
...
return observation # reward, done, info can't be included
- def render(self, mode='human'):
+ def render(self, mode="human"):
...
def close (self):
...
@@ -58,7 +58,7 @@ Then you can define and train a RL agent with:
# Instantiate the env
env = CustomEnv(arg1, ...)
# Define and Train the agent
- model = A2C('CnnPolicy', env).learn(total_timesteps=1000)
+ model = A2C("CnnPolicy", env).learn(total_timesteps=1000)
To check that your environment follows the Gym interface that SB3 supports, please use:
diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst
index 1b8f9fb7f..52589342d 100644
--- a/docs/guide/custom_policy.rst
+++ b/docs/guide/custom_policy.rst
@@ -62,7 +62,7 @@ using ``policy_kwargs`` parameter:
.. code-block:: python
- import gym
+ import gymnasium as gym
import torch as th
from stable_baselines3 import PPO
@@ -95,14 +95,15 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
.. note::
By default the feature extractor is shared between the actor and the critic to save computation (when applicable).
- However, this can be changed by defining a custom policy for on-policy algorithms or setting
- ``share_features_extractor=False`` in the ``policy_kwargs`` for off-policy algorithms
- (and when applicable).
+ However, this can be changed by defining a custom policy for on-policy algorithms
+ (see `issue #1066 `_
+ for more information) or setting ``share_features_extractor=False`` in the
+ ``policy_kwargs`` for off-policy algorithms (and when applicable).
.. code-block:: python
- import gym
+ import gymnasium as gym
import torch as th
import torch.nn as nn
@@ -169,7 +170,7 @@ downsampling and "vector" with a single linear layer.
.. code-block:: python
- import gym
+ import gymnasium as gym
import torch as th
from torch import nn
@@ -287,7 +288,7 @@ If your task requires even more granular control over the policy/value architect
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
- import gym
+ import gymnasium as gym
import torch as th
from torch import nn
diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst
index a5b56b249..7ff6033f6 100644
--- a/docs/guide/examples.rst
+++ b/docs/guide/examples.rst
@@ -64,19 +64,19 @@ In the following example, we will train, save and load a DQN model on the Lunar
.. code-block:: python
- import gym
+ import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
# Create environment
- env = gym.make('LunarLander-v2')
+ env = gym.make("LunarLander-v2")
# Instantiate the agent
- model = DQN('MlpPolicy', env, verbose=1)
- # Train the agent
- model.learn(total_timesteps=int(2e5))
+ model = DQN("MlpPolicy", env, verbose=1)
+ # Train the agent and display a progress bar
+ model.learn(total_timesteps=int(2e5), progress_bar=True)
# Save the agent
model.save("dqn_lunar")
del model # delete trained model to demonstrate loading
@@ -94,11 +94,12 @@ In the following example, we will train, save and load a DQN model on the Lunar
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
# Enjoy trained agent
- obs = env.reset()
+ vec_env = model.get_env()
+ obs = vec_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
- obs, rewards, dones, info = env.step(action)
- env.render()
+ obs, rewards, dones, info = vec_env.step(action)
+ vec_env.render()
Multiprocessing: Unleashing the Power of Vectorized Environments
@@ -114,7 +115,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
.. code-block:: python
- import gym
+ import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO
@@ -138,7 +139,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
set_random_seed(seed)
return _init
- if __name__ == '__main__':
+ if __name__ == "__main__":
env_id = "CartPole-v1"
num_cpu = 4 # Number of processes to use
# Create the vectorized environment
@@ -149,7 +150,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
# You can choose between `DummyVecEnv` (usually faster) and `SubprocVecEnv`
# env = make_vec_env(env_id, n_envs=num_cpu, seed=0, vec_env_cls=SubprocVecEnv)
- model = PPO('MlpPolicy', env, verbose=1)
+ model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=25_000)
obs = env.reset()
@@ -172,7 +173,7 @@ Multiprocessing with off-policy algorithms
.. code-block:: python
- import gym
+ import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.env_util import make_vec_env
@@ -182,7 +183,7 @@ Multiprocessing with off-policy algorithms
# We collect 4 transitions per call to `ènv.step()`
# and performs 2 gradient steps per call to `ènv.step()`
# if gradient_steps=-1, then we would do 4 gradients steps per call to `ènv.step()`
- model = SAC('MlpPolicy', env, train_freq=1, gradient_steps=2, verbose=1)
+ model = SAC("MlpPolicy", env, train_freq=1, gradient_steps=2, verbose=1)
model.learn(total_timesteps=10_000)
@@ -228,7 +229,7 @@ If your callback returns False, training is aborted early.
import os
- import gym
+ import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
@@ -248,13 +249,13 @@ If your callback returns False, training is aborted early.
:param check_freq:
:param log_dir: Path to the folder where the model will be saved.
It must contains the file created by the ``Monitor`` wrapper.
- :param verbose: Verbosity level.
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, check_freq: int, log_dir: str, verbose: int = 1):
super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)
self.check_freq = check_freq
self.log_dir = log_dir
- self.save_path = os.path.join(log_dir, 'best_model')
+ self.save_path = os.path.join(log_dir, "best_model")
self.best_mean_reward = -np.inf
def _init_callback(self) -> None:
@@ -266,11 +267,11 @@ If your callback returns False, training is aborted early.
if self.n_calls % self.check_freq == 0:
# Retrieve training reward
- x, y = ts2xy(load_results(self.log_dir), 'timesteps')
+ x, y = ts2xy(load_results(self.log_dir), "timesteps")
if len(x) > 0:
# Mean training reward over the last 100 episodes
mean_reward = np.mean(y[-100:])
- if self.verbose > 0:
+ if self.verbose >= 1:
print(f"Num timesteps: {self.num_timesteps}")
print(f"Best mean reward: {self.best_mean_reward:.2f} - Last mean reward per episode: {mean_reward:.2f}")
@@ -278,7 +279,7 @@ If your callback returns False, training is aborted early.
if mean_reward > self.best_mean_reward:
self.best_mean_reward = mean_reward
# Example for saving best model
- if self.verbose > 0:
+ if self.verbose >= 1:
print(f"Saving new best model to {self.save_path}")
self.model.save(self.save_path)
@@ -289,14 +290,14 @@ If your callback returns False, training is aborted early.
os.makedirs(log_dir, exist_ok=True)
# Create and wrap the environment
- env = gym.make('LunarLanderContinuous-v2')
+ env = gym.make("LunarLanderContinuous-v2")
env = Monitor(env, log_dir)
# Add some action noise for exploration
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
# Because we use parameter noise, we should use a MlpPolicy with layer normalization
- model = TD3('MlpPolicy', env, action_noise=action_noise, verbose=0)
+ model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=0)
# Create the callback: check every 1000 steps
callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir)
# Train the agent
@@ -336,11 +337,11 @@ and multiprocessing for you. To install the Atari environments, run the command
# There already exists an environment generator
# that will make and wrap atari environments correctly.
# Here we are also multi-worker training (n_envs=4 => 4 environments)
- env = make_atari_env('PongNoFrameskip-v4', n_envs=4, seed=0)
+ env = make_atari_env("PongNoFrameskip-v4", n_envs=4, seed=0)
# Frame-stacking with 4 frames
env = VecFrameStack(env, n_stack=4)
- model = A2C('CnnPolicy', env, verbose=1)
+ model = A2C("CnnPolicy", env, verbose=1)
model.learn(total_timesteps=25_000)
obs = env.reset()
@@ -371,7 +372,7 @@ will compute a running average and standard deviation of input features (it can
.. code-block:: python
import os
- import gym
+ import gymnasium as gym
import pybullet_envs
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
@@ -382,7 +383,7 @@ will compute a running average and standard deviation of input features (it can
env = VecNormalize(env, norm_obs=True, norm_reward=True,
clip_obs=10.)
- model = PPO('MlpPolicy', env)
+ model = PPO("MlpPolicy", env)
model.learn(total_timesteps=2000)
# Don't forget to save the VecNormalize statistics when saving the agent
@@ -429,7 +430,7 @@ The parking env is a goal-conditioned continuous control task, in which the vehi
.. code-block:: python
- import gym
+ import gymnasium as gym
import highway_env
import numpy as np
@@ -470,19 +471,19 @@ The parking env is a goal-conditioned continuous control task, in which the vehi
# HER must be loaded with the env
model = SAC.load("her_sac_highway", env=env)
- obs = env.reset()
+ obs, info = env.reset()
# Evaluate the agent
episode_reward = 0
for _ in range(100):
action, _ = model.predict(obs, deterministic=True)
- obs, reward, done, info = env.step(action)
+ obs, reward, done, truncated, info = env.step(action)
env.render()
episode_reward += reward
- if done or info.get("is_success", False):
+ if done or truncated or info.get("is_success", False):
print("Reward:", episode_reward, "Success?", info.get("is_success", False))
episode_reward = 0.0
- obs = env.reset()
+ obs, info = env.reset()
Learning Rate Schedule
@@ -532,19 +533,12 @@ linear and constant schedules.
Advanced Saving and Loading
---------------------------------
-In this example, we show how to use some advanced features of Stable-Baselines3 (SB3):
-how to easily create a test environment to evaluate an agent periodically,
-use a policy independently from a model (and how to save it, load it) and save/load a replay buffer.
+In this example, we show how to use a policy independently from a model (and how to save it, load it) and save/load a replay buffer.
By default, the replay buffer is not saved when calling ``model.save()``, in order to save space on the disk (a replay buffer can be up to several GB when using images).
However, SB3 provides a ``save_replay_buffer()`` and ``load_replay_buffer()`` method to save it separately.
-Stable-Baselines3 automatic creation of an environment for evaluation.
-For that, you only need to specify ``create_eval_env=True`` when passing the Gym ID of the environment while creating the agent.
-Behind the scene, SB3 uses an :ref:`EvalCallback `.
-
-
.. note::
For training model after loading it, we recommend loading the replay buffer to ensure stable learning (for off-policy algorithms).
@@ -562,14 +556,12 @@ Behind the scene, SB3 uses an :ref:`EvalCallback `.
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.sac.policies import MlpPolicy
- # Create the model, the training environment
- # and the test environment (for evaluation)
- model = SAC('MlpPolicy', 'Pendulum-v1', verbose=1,
- learning_rate=1e-3, create_eval_env=True)
+ # Create the model and the training environment
+ model = SAC("MlpPolicy", "Pendulum-v1", verbose=1,
+ learning_rate=1e-3)
- # Evaluate the model every 1000 steps on 5 test episodes
- # and save the evaluation to the "logs/" folder
- model.learn(6000, eval_freq=1000, n_eval_episodes=5, eval_log_path="./logs/")
+ # train the model
+ model.learn(total_timesteps=6000)
# save the model
model.save("sac_pendulum")
@@ -633,7 +625,7 @@ A2C policy gradient updates on the model.
from typing import Dict
- import gym
+ import gymnasium as gym
import numpy as np
import torch as th
@@ -717,7 +709,7 @@ to keep track of the agent progress.
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
# ProcgenEnv is already vectorized
- venv = ProcgenEnv(num_envs=2, env_name='starpilot')
+ venv = ProcgenEnv(num_envs=2, env_name="starpilot")
# To use only part of the observation:
# venv = VecExtractDictObs(venv, "rgb")
@@ -729,6 +721,16 @@ to keep track of the agent progress.
model.learn(10_000)
+SB3 with EnvPool or Isaac Gym
+-----------------------------
+
+Just like Procgen (see above), `EnvPool `_ and `Isaac Gym `_ accelerate the environment by
+already providing a vectorized implementation.
+
+To use SB3 with those tools, you must wrap the env with tool's specific ``VecEnvWrapper`` that will pre-process the data for SB3,
+you can find links to those wrappers in `issue #772 `_.
+
+
Record a Video
--------------
@@ -740,11 +742,11 @@ Record a mp4 video (here using a random agent).
.. code-block:: python
- import gym
+ import gymnasium as gym
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
- env_id = 'CartPole-v1'
- video_folder = 'logs/videos/'
+ env_id = "CartPole-v1"
+ video_folder = "logs/videos/"
video_length = 100
env = DummyVecEnv([lambda: gym.make(env_id)])
@@ -782,11 +784,11 @@ Bonus: Make a GIF of a Trained Agent
images = []
obs = model.env.reset()
- img = model.env.render(mode='rgb_array')
+ img = model.env.render(mode="rgb_array")
for i in range(350):
images.append(img)
action, _ = model.predict(obs)
obs, _, _ ,_ = model.env.step(action)
- img = model.env.render(mode='rgb_array')
+ img = model.env.render(mode="rgb_array")
- imageio.mimsave('lander_a2c.gif', [np.array(img) for i, img in enumerate(images) if i%2 == 0], fps=29)
+ imageio.mimsave("lander_a2c.gif", [np.array(img) for i, img in enumerate(images) if i%2 == 0], fps=29)
diff --git a/docs/guide/export.rst b/docs/guide/export.rst
index b6884c19d..3a2174979 100644
--- a/docs/guide/export.rst
+++ b/docs/guide/export.rst
@@ -46,29 +46,40 @@ For PPO, assuming a shared feature extactor.
.. code-block:: python
+ import torch as th
+
from stable_baselines3 import PPO
- import torch
- class OnnxablePolicy(torch.nn.Module):
- def __init__(self, extractor, action_net, value_net):
- super(OnnxablePolicy, self).__init__()
- self.extractor = extractor
- self.action_net = action_net
- self.value_net = value_net
- def forward(self, observation):
- # NOTE: You may have to process (normalize) observation in the correct
- # way before using this. See `common.preprocessing.preprocess_obs`
- action_hidden, value_hidden = self.extractor(observation)
- return self.action_net(action_hidden), self.value_net(value_hidden)
+ class OnnxablePolicy(th.nn.Module):
+ def __init__(self, extractor, action_net, value_net):
+ super().__init__()
+ self.extractor = extractor
+ self.action_net = action_net
+ self.value_net = value_net
+
+ def forward(self, observation):
+ # NOTE: You may have to process (normalize) observation in the correct
+ # way before using this. See `common.preprocessing.preprocess_obs`
+ action_hidden, value_hidden = self.extractor(observation)
+ return self.action_net(action_hidden), self.value_net(value_hidden)
- # Example: model = PPO("MlpPolicy", "Pendulum-v1")
- model = PPO.load("PathToTrainedModel.zip")
- model.policy.to("cpu")
- onnxable_model = OnnxablePolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net)
- dummy_input = torch.randn(1, observation_size)
- torch.onnx.export(onnxable_model, dummy_input, "my_ppo_model.onnx", opset_version=9)
+ # Example: model = PPO("MlpPolicy", "Pendulum-v1")
+ model = PPO.load("PathToTrainedModel.zip", device="cpu")
+ onnxable_model = OnnxablePolicy(
+ model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net
+ )
+
+ observation_size = model.observation_space.shape
+ dummy_input = th.randn(1, *observation_size)
+ th.onnx.export(
+ onnxable_model,
+ dummy_input,
+ "my_ppo_model.onnx",
+ opset_version=9,
+ input_names=["input"],
+ )
##### Load and test with onnx
@@ -76,48 +87,97 @@ For PPO, assuming a shared feature extactor.
import onnxruntime as ort
import numpy as np
+ onnx_path = "my_ppo_model.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
- observation = np.zeros((1, observation_size)).astype(np.float32)
+ observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
- action, value = ort_sess.run(None, {'input.1': observation})
+ action, value = ort_sess.run(None, {"input": observation})
For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies.
.. code-block:: python
+ import torch as th
+
from stable_baselines3 import SAC
- import torch
- class OnnxablePolicy(torch.nn.Module):
- def __init__(self, actor):
- super(OnnxablePolicy, self).__init__()
- # Removing the flatten layer because it can't be onnxed
- self.actor = torch.nn.Sequential(actor.latent_pi, actor.mu)
+ class OnnxablePolicy(th.nn.Module):
+ def __init__(self, actor: th.nn.Module):
+ super().__init__()
+ # Removing the flatten layer because it can't be onnxed
+ self.actor = th.nn.Sequential(
+ actor.latent_pi,
+ actor.mu,
+ # For gSDE
+ # th.nn.Hardtanh(min_val=-actor.clip_mean, max_val=actor.clip_mean),
+ # Squash the output
+ th.nn.Tanh(),
+ )
+
+ def forward(self, observation: th.Tensor) -> th.Tensor:
+ # NOTE: You may have to process (normalize) observation in the correct
+ # way before using this. See `common.preprocessing.preprocess_obs`
+ return self.actor(observation)
- def forward(self, observation):
- # NOTE: You may have to process (normalize) observation in the correct
- # way before using this. See `common.preprocessing.preprocess_obs`
- return self.actor(observation)
- model = SAC.load("PathToTrainedModel.zip")
+ # Example: model = SAC("MlpPolicy", "Pendulum-v1")
+ model = SAC.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(model.policy.actor)
- dummy_input = torch.randn(1, observation_size)
- onnxable_model.policy.to("cpu")
- torch.onnx.export(onnxable_model, dummy_input, "my_sac_actor.onnx", opset_version=9)
+ observation_size = model.observation_space.shape
+ dummy_input = th.randn(1, *observation_size)
+ th.onnx.export(
+ onnxable_model,
+ dummy_input,
+ "my_sac_actor.onnx",
+ opset_version=9,
+ input_names=["input"],
+ )
+
+ ##### Load and test with onnx
+
+ import onnxruntime as ort
+ import numpy as np
+
+ onnx_path = "my_sac_actor.onnx"
+
+ observation = np.zeros((1, *observation_size)).astype(np.float32)
+ ort_sess = ort.InferenceSession(onnx_path)
+ action = ort_sess.run(None, {"input": observation})
For more discussion around the topic refer to this `issue. `_
-Export to C++
------------------
+Trace/Export to C++
+-------------------
+
+You can use PyTorch JIT to trace and save a trained model that can be re-used in other applications
+(for instance inference code written in C++).
+
+There is a draft PR in the RL Zoo about C++ export: https://github.com/DLR-RM/rl-baselines3-zoo/pull/228
+
+.. code-block:: python
+
+ # See "ONNX export" for imports and OnnxablePolicy
+ jit_path = "sac_traced.pt"
+
+ # Trace and optimize the module
+ traced_module = th.jit.trace(onnxable_model.eval(), dummy_input)
+ frozen_module = th.jit.freeze(traced_module)
+ frozen_module = th.jit.optimize_for_inference(frozen_module)
+ th.jit.save(frozen_module, jit_path)
+
+ ##### Load and test with torch
+
+ import torch as th
-(using PyTorch JIT)
-TODO: help is welcomed!
+ dummy_input = th.randn(1, *observation_size)
+ loaded_module = th.jit.load(jit_path)
+ action_jit = loaded_module(dummy_input)
Export to tensorflowjs / ONNX-JS
diff --git a/docs/guide/imitation.rst b/docs/guide/imitation.rst
index df7895c2e..c4a072611 100644
--- a/docs/guide/imitation.rst
+++ b/docs/guide/imitation.rst
@@ -10,46 +10,10 @@ imitation learning algorithms on top of Stable-Baselines3, including:
- `DAgger `_ with synthetic examples
- `Adversarial Inverse Reinforcement Learning `_ (AIRL)
- `Generative Adversarial Imitation Learning `_ (GAIL)
+ - `Deep RL from Human Preferences `_ (DRLHP)
-
-It also provides `CLI scripts <#cli-quickstart>`_ for training and saving
-demonstrations from RL experts, and for training imitation learners on these demonstrations.
-
-
-Installation
-------------
-
-Installation requires Python 3.7+:
-
-::
-
- pip install imitation
-
-
-CLI Quickstart
----------------------
-
-::
-
- # Train PPO agent on cartpole and collect expert demonstrations
- python -m imitation.scripts.expert_demos with fast cartpole log_dir=quickstart
-
- # Train GAIL from demonstrations
- python -m imitation.scripts.train_adversarial with fast gail cartpole rollout_path=quickstart/rollouts/final.pkl
-
- # Train AIRL from demonstrations
- python -m imitation.scripts.train_adversarial with fast airl cartpole rollout_path=quickstart/rollouts/final.pkl
-
-
-.. note::
-
- You can remove the ``fast`` option to run training to completion. For more CLI options
- and information on reading Tensorboard plots, see the
- `README `_.
-
-
-Python Interface Quickstart
----------------------------
-
-This `example script `_
-uses the Python API to train BC, GAIL, and AIRL models on CartPole data.
+You can install imitation with ``pip install imitation``. The `imitation
+documentation `_ has more details
+on how to use the library, including `a quick start guide
+`_
+for the impatient.
diff --git a/docs/guide/install.rst b/docs/guide/install.rst
index 3b2692787..016949564 100644
--- a/docs/guide/install.rst
+++ b/docs/guide/install.rst
@@ -54,6 +54,17 @@ Bleeding-edge version
pip install git+https://github.com/DLR-RM/stable-baselines3
+.. note::
+
+ If you want to use latest gym version (0.24+), you have to use
+
+ .. code-block:: bash
+
+ pip install git+https://github.com/carlosluis/stable-baselines3@fix_tests
+
+ See `PR #780 `_ for more information.
+
+
Development version
-------------------
diff --git a/docs/guide/integrations.rst b/docs/guide/integrations.rst
index 9007ade51..9e0bbaf73 100644
--- a/docs/guide/integrations.rst
+++ b/docs/guide/integrations.rst
@@ -13,7 +13,7 @@ The full documentation is available here: https://docs.wandb.ai/guides/integrati
.. code-block:: python
- import gym
+ import gymnasium as gym
import wandb
from wandb.integration.sb3 import WandbCallback
@@ -47,21 +47,38 @@ Hugging Face 🤗
===============
The Hugging Face Hub 🤗 is a central place where anyone can share and explore models. It allows you to host your saved models 💾.
-You can see the list of stable-baselines3 saved models here: https://huggingface.co/models?other=stable-baselines3
+You can see the list of stable-baselines3 saved models here: https://huggingface.co/models?library=stable-baselines3
+Most of them are available via the RL Zoo.
Official pre-trained models are saved in the SB3 organization on the hub: https://huggingface.co/sb3
-We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3 here: https://colab.research.google.com/drive/1GI0WpThwRHbl-Fu2RHfczq6dci5GBDVE#scrollTo=q4cz-w9MdO7T
+We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3
+`here `_.
+
Installation
-------------
.. code-block:: bash
- pip install huggingface_hub
pip install huggingface_sb3
+.. note::
+
+ If you use the `RL Zoo `_, pushing/loading models from the hub are already integrated:
+
+ .. code-block:: bash
+
+ # Download model and save it into the logs/ folder
+ python -m rl_zoo3.load_from_hub --algo a2c --env LunarLander-v2 -orga sb3 -f logs/
+ # Test the agent
+ python -m rl_zoo3.enjoy --algo a2c --env LunarLander-v2 -f logs/
+ # Push model, config and hyperparameters to the hub
+ python -m rl_zoo3.push_to_hub --algo a2c --env LunarLander-v2 -f logs/ -orga sb3 -m "Initial commit"
+
+
+
Download a model from the Hub
-----------------------------
You need to copy the repo-id that contains your saved model.
@@ -69,7 +86,7 @@ For instance ``sb3/demo-hf-CartPole-v1``:
.. code-block:: python
- import gym
+ import gymnasium as gym
from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
@@ -80,7 +97,7 @@ For instance ``sb3/demo-hf-CartPole-v1``:
## filename = name of the model zip file from the repository
checkpoint = load_from_hub(
repo_id="sb3/demo-hf-CartPole-v1",
- filename="ppo-CartPole-v1",
+ filename="ppo-CartPole-v1.zip",
)
model = PPO.load(checkpoint)
@@ -91,11 +108,22 @@ For instance ``sb3/demo-hf-CartPole-v1``:
)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
+You need to define two parameters:
+
+- ``repo-id``: the name of the Hugging Face repo you want to download.
+- ``filename``: the file you want to download.
Upload a model to the Hub
-------------------------
+You can easily upload your models using two different functions:
+
+1. ``package_to_hub()``: save the model, evaluate it, generate a model card and record a replay video of your agent before pushing the complete repo to the Hub.
+
+2. ``push_to_hub()``: simply push a file to the Hub.
+
+
First, you need to be logged in to Hugging Face to upload a model:
- If you're using Colab/Jupyter Notebooks:
@@ -106,34 +134,146 @@ First, you need to be logged in to Hugging Face to upload a model:
notebook_login()
-- Otheriwse:
+- Otherwise:
.. code-block:: bash
huggingface-cli login
+
Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a new repo ``sb3/demo-hf-CartPole-v1``
+With ``package_to_hub()``
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
.. code-block:: python
- from huggingface_sb3 import push_to_hub
from stable_baselines3 import PPO
+ from stable_baselines3.common.env_util import make_vec_env
+
+ from huggingface_sb3 import package_to_hub
+
+ # Create the environment
+ env_id = "CartPole-v1"
+ env = make_vec_env(env_id, n_envs=1)
+
+ # Create the evaluation environment
+ eval_env = make_vec_env(env_id, n_envs=1)
+
+ # Instantiate the agent
+ model = PPO("MlpPolicy", env, verbose=1)
+
+ # Train the agent
+ model.learn(total_timesteps=int(5000))
+
+ # This method save, evaluate, generate a model card and record a replay video of your agent before pushing the repo to the hub
+ package_to_hub(model=model,
+ model_name="ppo-CartPole-v1",
+ model_architecture="PPO",
+ env_id=env_id,
+ eval_env=eval_env,
+ repo_id="sb3/demo-hf-CartPole-v1",
+ commit_message="Test commit")
+
+You need to define seven parameters:
+
+- ``model``: your trained model.
+- ``model_architecture``: name of the architecture of your model (DQN, PPO, A2C, SAC…).
+- ``env_id``: name of the environment.
+- ``eval_env``: environment used to evaluate the agent.
+- ``repo-id``: the name of the Hugging Face repo you want to create or update. It’s /.
+- ``commit-message``.
+- ``filename``: the file you want to push to the Hub.
+
+With ``push_to_hub()``
+^^^^^^^^^^^^^^^^^^^^^^
+
+.. code-block:: python
+
+
+ from stable_baselines3 import PPO
+ from stable_baselines3.common.env_util import make_vec_env
- # Define a PPO model with MLP policy network
- model = PPO("MlpPolicy", "CartPole-v1", verbose=1)
+ from huggingface_sb3 import push_to_hub
- # Train it for 10000 timesteps
- model.learn(total_timesteps=10_000)
+ # Create the environment
+ env_id = "CartPole-v1"
+ env = make_vec_env(env_id, n_envs=1)
+
+ # Instantiate the agent
+ model = PPO("MlpPolicy", env, verbose=1)
+
+ # Train the agent
+ model.learn(total_timesteps=int(5000))
# Save the model
model.save("ppo-CartPole-v1")
- # Push this saved model to the hf repo
+ # Push this saved model .zip file to the hf repo
# If this repo does not exists it will be created
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename: the name of the file == "name" inside model.save("ppo-CartPole-v1")
push_to_hub(
- repo_id="sb3/demo-hf-CartPole-v1",
- filename="ppo-CartPole-v1",
- commit_message="Added Cartpole-v1 model trained with PPO",
+ repo_id="sb3/demo-hf-CartPole-v1",
+ filename="ppo-CartPole-v1.zip",
+ commit_message="Added CartPole-v1 model trained with PPO",
+ )
+
+You need to define three parameters:
+
+- ``repo-id``: the name of the Hugging Face repo you want to create or update. It’s /.
+- ``filename``: the file you want to push to the Hub.
+- ``commit-message``.
+
+MLFLow
+======
+
+If you want to use `MLFLow `_ to track your SB3 experiments,
+you can adapt the following code which defines a custom logger output:
+
+.. code-block:: python
+
+ import sys
+ from typing import Any, Dict, Tuple, Union
+
+ import mlflow
+ import numpy as np
+
+ from stable_baselines3 import SAC
+ from stable_baselines3.common.logger import HumanOutputFormat, KVWriter, Logger
+
+
+ class MLflowOutputFormat(KVWriter):
+ """
+ Dumps key/value pairs into MLflow's numeric format.
+ """
+
+ def write(
+ self,
+ key_values: Dict[str, Any],
+ key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
+ step: int = 0,
+ ) -> None:
+
+ for (key, value), (_, excluded) in zip(
+ sorted(key_values.items()), sorted(key_excluded.items())
+ ):
+
+ if excluded is not None and "mlflow" in excluded:
+ continue
+
+ if isinstance(value, np.ScalarType):
+ if not isinstance(value, str):
+ mlflow.log_metric(key, value, step)
+
+
+ loggers = Logger(
+ folder=None,
+ output_formats=[HumanOutputFormat(sys.stdout), MLflowOutputFormat()],
)
+
+ with mlflow.start_run():
+ model = SAC("MlpPolicy", "Pendulum-v1", verbose=2)
+ # Set custom logger
+ model.set_logger(loggers)
+ model.learn(total_timesteps=10000, log_interval=1)
diff --git a/docs/guide/migration.rst b/docs/guide/migration.rst
index 879a5fb4e..571cf4528 100644
--- a/docs/guide/migration.rst
+++ b/docs/guide/migration.rst
@@ -141,7 +141,7 @@ DQN
^^^
Only the vanilla DQN is implemented right now but extensions will follow.
-Default hyperparameters are taken from the nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults.
+Default hyperparameters are taken from the Nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults.
DDPG
^^^^
@@ -206,8 +206,6 @@ New Features (SB3 vs SB2)
- Independent saving/loading/predict for policies
- A2C now supports Generalized Advantage Estimation (GAE) and advantage normalization (both are deactivated by default)
- Generalized State-Dependent Exploration (gSDE) exploration is available for A2C/PPO/SAC. It allows to use RL directly on real robots (cf https://arxiv.org/abs/2005.05719)
-- Proper evaluation (using separate env) is included in the base class (using ``EvalCallback``),
- if you pass the environment as a string, you can pass ``create_eval_env=True`` to the algorithm constructor.
- Better saving/loading: optimizers are now included in the saved parameters and there is two new methods ``save_replay_buffer`` and ``load_replay_buffer`` for the replay buffer when using off-policy algorithms (DQN/DDPG/SAC/TD3)
- You can pass ``optimizer_class`` and ``optimizer_kwargs`` to ``policy_kwargs`` in order to easily
customize optimizers
diff --git a/docs/guide/quickstart.rst b/docs/guide/quickstart.rst
index 064139d25..2d3af54f7 100644
--- a/docs/guide/quickstart.rst
+++ b/docs/guide/quickstart.rst
@@ -10,22 +10,24 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
.. code-block:: python
- import gym
+ import gymnasium as gym
from stable_baselines3 import A2C
- env = gym.make('CartPole-v1')
+ env = gym.make("CartPole-v1")
- model = A2C('MlpPolicy', env, verbose=1)
- model.learn(total_timesteps=10000)
+ model = A2C("MlpPolicy", env, verbose=1)
+ model.learn(total_timesteps=10_000)
- obs = env.reset()
+ vec_env = model.get_env()
+ obs = vec_env.reset()
for i in range(1000):
action, _state = model.predict(obs, deterministic=True)
- obs, reward, done, info = env.step(action)
- env.render()
- if done:
- obs = env.reset()
+ obs, reward, done, info = vec_env.step(action)
+ vec_env.render()
+ # VecEnv resets automatically
+ # if done:
+ # obs = vec_env.reset()
.. note::
@@ -40,4 +42,4 @@ the policy is registered:
from stable_baselines3 import A2C
- model = A2C('MlpPolicy', 'CartPole-v1').learn(10000)
+ model = A2C("MlpPolicy", "CartPole-v1").learn(10000)
diff --git a/docs/guide/rl_tips.rst b/docs/guide/rl_tips.rst
index 031f947c4..aa06a1156 100644
--- a/docs/guide/rl_tips.rst
+++ b/docs/guide/rl_tips.rst
@@ -183,6 +183,16 @@ Some basic advice:
- start with shaped reward (i.e. informative reward) and simplified version of your problem
- debug with random actions to check that your environment works and follows the gym interface:
+Two important things to keep in mind when creating a custom environment is to avoid breaking Markov assumption
+and properly handle termination due to a timeout (maximum number of steps in an episode).
+For instance, if there is some time delay between action and observation (e.g. due to wifi communication), you should give a history of observations
+as input.
+
+Termination due to timeout (max number of steps per episode) needs to be handled separately. You should fill the key in the info dict: ``info["TimeLimit.truncated"] = True``.
+If you are using the gym ``TimeLimit`` wrapper, this will be done automatically.
+You can read `Time Limit in RL `_ or take a look at the `RL Tips and Tricks video `_
+for more details.
+
We provide a helper to check that your environment runs without error:
@@ -241,12 +251,15 @@ We *recommend following those steps to have a working RL algorithm*:
1. Read the original paper several times
2. Read existing implementations (if available)
3. Try to have some "sign of life" on toy problems
-4. Validate the implementation by making it run on harder and harder envs (you can compare results against the RL zoo)
- You usually need to run hyperparameter optimization for that step.
+4. Validate the implementation by making it run on harder and harder envs (you can compare results against the RL zoo).
+ You usually need to run hyperparameter optimization for that step.
-You need to be particularly careful on the shape of the different objects you are manipulating (a broadcast mistake will fail silently cf `issue #75 `_)
+You need to be particularly careful on the shape of the different objects you are manipulating (a broadcast mistake will fail silently cf. `issue #75 `_)
and when to stop the gradient propagation.
+Don't forget to handle termination due to timeout separately (see remark in the custom environment section above),
+you can also take a look at `Issue #284 `_ and `Issue #633 `_.
+
A personal pick (by @araffin) for environments with gradual difficulty in RL with continuous actions:
1. Pendulum (easy to solve)
diff --git a/docs/guide/rl_zoo.rst b/docs/guide/rl_zoo.rst
index 9b255733e..ea1583291 100644
--- a/docs/guide/rl_zoo.rst
+++ b/docs/guide/rl_zoo.rst
@@ -33,6 +33,11 @@ Installation
You can remove the ``--recursive`` option if you don't want to download the trained agents
+.. note::
+
+ If you only need the training/plotting scripts and additional callbacks/wrappers from the RL Zoo, you can also install it via pip: ``pip install rl_zoo3``
+
+
2. Install dependencies
::
@@ -57,7 +62,7 @@ For example (with evaluation and checkpoints):
::
- python train.py --algo ppo2 --env CartPole-v1 --eval-freq 10000 --save-freq 50000
+ python train.py --algo ppo --env CartPole-v1 --eval-freq 10000 --save-freq 50000
Continue training (here, load pretrained agent for Breakout and continue
diff --git a/docs/guide/sb3_contrib.rst b/docs/guide/sb3_contrib.rst
index 1dfa91263..445832c59 100644
--- a/docs/guide/sb3_contrib.rst
+++ b/docs/guide/sb3_contrib.rst
@@ -8,7 +8,7 @@ We implement experimental features in a separate contrib repository:
`SB3-Contrib`_
This allows Stable-Baselines3 (SB3) to maintain a stable and compact core, while still
-providing the latest features, like Truncated Quantile Critics (TQC), Augmented Random Search (ARS), Trust Region Policy Optimization (TRPO) or
+providing the latest features, like RecurrentPPO (PPO LSTM), Truncated Quantile Critics (TQC), Augmented Random Search (ARS), Trust Region Policy Optimization (TRPO) or
Quantile Regression DQN (QR-DQN).
Why create this repository?
@@ -38,9 +38,11 @@ See documentation for the full list of included features.
- `Augmented Random Search (ARS) `_
- `Quantile Regression DQN (QR-DQN)`_
+- `PPO with invalid action masking (Maskable PPO) `_
+- `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) `_
- `Truncated Quantile Critics (TQC)`_
- `Trust Region Policy Optimization (TRPO) `_
-- `PPO with invalid action masking (Maskable PPO) `_
+
**Gym Wrappers**:
diff --git a/docs/guide/tensorboard.rst b/docs/guide/tensorboard.rst
index dc39d8bc1..74f55ff5a 100644
--- a/docs/guide/tensorboard.rst
+++ b/docs/guide/tensorboard.rst
@@ -12,7 +12,7 @@ To use Tensorboard with stable baselines3, you simply need to pass the location
from stable_baselines3 import A2C
- model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
+ model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10_000)
@@ -22,14 +22,24 @@ You can also define custom logging name when training (by default it is the algo
from stable_baselines3 import A2C
- model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
+ model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10_000, tb_log_name="first_run")
# Pass reset_num_timesteps=False to continue the training curve in tensorboard
# By default, it will create a new curve
+ # Keep tb_log_name constant to have continuous curve (see note below)
model.learn(total_timesteps=10_000, tb_log_name="second_run", reset_num_timesteps=False)
model.learn(total_timesteps=10_000, tb_log_name="third_run", reset_num_timesteps=False)
+.. note::
+ If you specify different ``tb_log_name`` in subsequent runs, you will have split graphs, like in the figure below.
+ If you want them to be continuous, you must keep the same ``tb_log_name`` (see `issue #975 `_).
+ And, if you still managed to get your graphs split by other means, just put tensorboard log files into the same folder.
+
+ .. image:: ../_static/img/split_graph.png
+ :width: 330
+ :alt: split_graph
+
Once the learn function is called, you can monitor the RL agent during or after the training, with the following bash command:
.. code-block:: bash
@@ -81,7 +91,7 @@ Here is a simple example on how to log both additional tensor or arbitrary scala
def _on_step(self) -> bool:
# Log scalar value (here a random variable)
value = np.random.random()
- self.logger.record('random_value', value)
+ self.logger.record("random_value", value)
return True
@@ -180,7 +190,7 @@ Here is an example of how to render an episode and log the resulting video to Te
from typing import Any, Dict
- import gym
+ import gymnasium as gym
import torch as th
from stable_baselines3 import A2C
@@ -239,6 +249,55 @@ Here is an example of how to render an episode and log the resulting video to Te
video_recorder = VideoRecorderCallback(gym.make("CartPole-v1"), render_freq=5000)
model.learn(total_timesteps=int(5e4), callback=video_recorder)
+Logging Hyperparameters
+-----------------------
+
+TensorBoard supports logging of hyperparameters in its HPARAMS tab, which helps comparing agents trainings.
+
+.. warning::
+ To display hyperparameters in the HPARAMS section, a ``metric_dict`` must be given (as well as a ``hparam_dict``).
+
+
+Here is an example of how to save hyperparameters in TensorBoard:
+
+.. code-block:: python
+
+ from stable_baselines3 import A2C
+ from stable_baselines3.common.callbacks import BaseCallback
+ from stable_baselines3.common.logger import HParam
+
+
+ class HParamCallback(BaseCallback):
+ def __init__(self):
+ """
+ Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
+ """
+ super().__init__()
+
+ def _on_training_start(self) -> None:
+ hparam_dict = {
+ "algorithm": self.model.__class__.__name__,
+ "learning rate": self.model.learning_rate,
+ "gamma": self.model.gamma,
+ }
+ # define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
+ # Tensorbaord will find & display metrics from the `SCALARS` tab
+ metric_dict = {
+ "rollout/ep_len_mean": 0,
+ "train/value_loss": 0,
+ }
+ self.logger.record(
+ "hparams",
+ HParam(hparam_dict, metric_dict),
+ exclude=("stdout", "log", "json", "csv"),
+ )
+
+ def _on_step(self) -> bool:
+ return True
+
+
+ model = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="runs/", verbose=1)
+ model.learn(total_timesteps=int(5e4), callback=HParamCallback())
Directly Accessing The Summary Writer
-------------------------------------
diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst
index 8f372650b..131509b85 100644
--- a/docs/misc/changelog.rst
+++ b/docs/misc/changelog.rst
@@ -4,20 +4,159 @@ Changelog
==========
-Release 1.5.1a4 (WIP)
+Release 1.7.0a1 (WIP)
+--------------------------
+
+.. warning::
+
+ This version will be the last one supporting ``gym``, we recommend switching to `gymnasium `_.
+ You can find a migration guide here: TODO
+
+
+Breaking Changes:
+^^^^^^^^^^^^^^^^^
+- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
+ please use an ``EvalCallback`` instead
+- Removed deprecated ``sde_net_arch`` parameter
+- Removed ``ret`` attributes in ``VecNormalize``, please use ``returns`` instead
+- Switched minimum Gym version to 0.26 (@carlosluis, @arjun-kg, @tlpss)
+
+
+New Features:
+^^^^^^^^^^^^^
+
+SB3-Contrib
+^^^^^^^^^^^
+
+Bug Fixes:
+^^^^^^^^^^
+- Fix return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde)
+- Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm``
+- Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the ``custom_objects`` workaround
+
+Deprecations:
+^^^^^^^^^^^^^
+
+Others:
+^^^^^^^
+- Used issue forms instead of issue templates
+
+Documentation:
+^^^^^^^^^^^^^^
+- Updated Hugging Face Integration page (@simoninithomas)
+
+Release 1.6.2 (2022-10-10)
+--------------------------
+
+**Progress bar in the learn() method, RL Zoo3 is now a package**
+
+Breaking Changes:
+^^^^^^^^^^^^^^^^^
+
+New Features:
+^^^^^^^^^^^^^
+- Added ``progress_bar`` argument in the ``learn()`` method, displayed using TQDM and rich packages
+- Added progress bar callback
+- The `RL Zoo `_ can now be installed as a package (``pip install rl_zoo3``)
+
+SB3-Contrib
+^^^^^^^^^^^
+
+Bug Fixes:
+^^^^^^^^^^
+- ``self.num_timesteps`` was initialized properly only after the first call to ``on_step()`` for callbacks
+- Set importlib-metadata version to ``~=4.13`` to be compatible with ``gym=0.21``
+
+Deprecations:
+^^^^^^^^^^^^^
+- Added deprecation warning if parameters ``eval_env``, ``eval_freq`` or ``create_eval_env`` are used (see #925) (@tobirohrer)
+
+Others:
+^^^^^^^
+- Fixed type hint of the ``env_id`` parameter in ``make_vec_env`` and ``make_atari_env`` (@AlexPasqua)
+
+Documentation:
+^^^^^^^^^^^^^^
+- Extended docstring of the ``wrapper_class`` parameter in ``make_vec_env`` (@AlexPasqua)
+
+Release 1.6.1 (2022-09-29)
---------------------------
+**Bug fix release**
+
+Breaking Changes:
+^^^^^^^^^^^^^^^^^
+- Switched minimum tensorboard version to 2.9.1
+
+New Features:
+^^^^^^^^^^^^^
+- Support logging hyperparameters to tensorboard (@timothe-chaumont)
+- Added checkpoints for replay buffer and ``VecNormalize`` statistics (@anand-bala)
+- Added option for ``Monitor`` to append to existing file instead of overriding (@sidney-tio)
+- The env checker now raises an error when using dict observation spaces and observation keys don't match observation space keys
+
+SB3-Contrib
+^^^^^^^^^^^
+- Fixed the issue of wrongly passing policy arguments when using ``CnnLstmPolicy`` or ``MultiInputLstmPolicy`` with ``RecurrentPPO`` (@mlodel)
+
+Bug Fixes:
+^^^^^^^^^^
+- Fixed issue where ``PPO`` gives NaN if rollout buffer provides a batch of size 1 (@hughperkins)
+- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
+- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
+- Added multidimensional action space support (@qgallouedec)
+- Fixed missing verbose parameter passing in the ``EvalCallback`` constructor (@burakdmb)
+- Fixed the issue that when updating the target network in DQN, SAC, TD3, the ``running_mean`` and ``running_var`` properties of batch norm layers are not updated (@honglu2875)
+- Fixed incorrect type annotation of the replay_buffer_class argument in ``common.OffPolicyAlgorithm`` initializer, where an instance instead of a class was required (@Rocamonde)
+- Fixed loading saved model with different number of envrionments
+- Removed ``forward()`` abstract method declaration from ``common.policies.BaseModel`` (already defined in ``torch.nn.Module``) to fix type errors in subclasses (@Rocamonde)
+- Fixed the return type of ``.load()`` and ``.learn()`` methods in ``BaseAlgorithm`` so that they now use ``TypeVar`` (@Rocamonde)
+- Fixed an issue where keys with different tags but the same key raised an error in ``common.logger.HumanOutputFormat`` (@Rocamonde and @AdamGleave)
+- Set importlib-metadata version to `~=4.13`
+
+Deprecations:
+^^^^^^^^^^^^^
+
+Others:
+^^^^^^^
+- Fixed ``DictReplayBuffer.next_observations`` typing (@qgallouedec)
+- Added support for ``device="auto"`` in buffers and made it default (@qgallouedec)
+- Updated ``ResultsWriter` (used internally by ``Monitor`` wrapper) to automatically create missing directories when ``filename`` is a path (@dominicgkerr)
+
+Documentation:
+^^^^^^^^^^^^^^
+- Added an example of callback that logs hyperparameters to tensorboard. (@timothe-chaumont)
+- Fixed typo in docstring "nature" -> "Nature" (@Melanol)
+- Added info on split tensorboard logs into (@Melanol)
+- Fixed typo in ppo doc (@francescoluciano)
+- Fixed typo in install doc(@jlp-ue)
+- Clarified and standardized verbosity documentation
+- Added link to a GitHub issue in the custom policy documentation (@AlexPasqua)
+- Update doc on exporting models (fixes and added torch jit)
+- Fixed typos (@Akhilez)
+- Standardized the use of ``"`` for string representation in documentation
+
+Release 1.6.0 (2022-07-11)
+---------------------------
+
+**Recurrent PPO (PPO LSTM), better defaults for learning from pixels with SAC/TD3**
+
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former
``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar)
- SB3 now requires PyTorch >= 1.11
+- Changed the default network architecture when using ``CnnPolicy`` or ``MultiInputPolicy`` with SAC or DDPG/TD3,
+ ``share_features_extractor`` is now set to False by default and the ``net_arch=[256, 256]`` (instead of ``net_arch=[]`` that was before)
New Features:
^^^^^^^^^^^^^
+- ``noop_max`` and ``frame_skip`` are now allowed to be equal to zero when using ``AtariWrapper``
SB3-Contrib
^^^^^^^^^^^
+- Added Recurrent PPO (PPO LSTM). See https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53
+
Bug Fixes:
^^^^^^^^^^
@@ -25,16 +164,32 @@ Bug Fixes:
- Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec)
- Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies)
- Fixed a bug in ``DummyVecEnv``'s and ``SubprocVecEnv``'s seeding function. None value was unchecked (@ScheiklP)
+- Fixed a bug where ``EvalCallback`` would crash when trying to synchronize ``VecNormalize`` stats when observation normalization was disabled
+- Added a check for unbounded actions
+- Fixed issues due to newer version of protobuf (tensorboard) and sphinx
+- Fix exception causes all over the codebase (@cool-RR)
+- Prohibit simultaneous use of optimize_memory_usage and handle_timeout_termination due to a bug (@MWeltevrede)
+- Fixed a bug in ``kl_divergence`` check that would fail when using numpy arrays with MultiCategorical distribution
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
+- Upgraded to Python 3.7+ syntax using ``pyupgrade``
+- Updated docker base image to Ubuntu 20.04 and cuda 11.3
+- Removed redundant double-check for nested observations from ``BaseAlgorithm._wrap_env`` (@TibiGG)
Documentation:
^^^^^^^^^^^^^^
- Added link to gym doc and gym env checker
+- Fix typo in PPO doc (@bcollazo)
+- Added link to PPO ICLR blog post
+- Added remark about breaking Markov assumption and timeout handling
+- Added doc about MLFlow integration via custom logger (@git-thor)
+- Updated Huggingface integration doc
+- Added copy button for code snippets
+- Added doc about EnvPool and Isaac Gym support
Release 1.5.0 (2022-03-25)
@@ -44,7 +199,7 @@ Release 1.5.0 (2022-03-25)
Breaking Changes:
^^^^^^^^^^^^^^^^^
-- Switched minimum Gym version to 0.21.0.
+- Switched minimum Gym version to 0.21.0
New Features:
^^^^^^^^^^^^^
@@ -930,7 +1085,8 @@ Maintainers
-----------
Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_), `Ashley Hill`_ (aka @hill-a),
-`Maximilian Ernestus`_ (aka @ernestum), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_).
+`Maximilian Ernestus`_ (aka @ernestum), `Adam Gleave`_ (`@AdamGleave`_), `Anssi Kanervisto`_ (aka `@Miffyli`_)
+and `Quentin Gallouédec`_ (aka @qgallouedec).
.. _Ashley Hill: https://github.com/hill-a
.. _Antonin Raffin: https://araffin.github.io/
@@ -940,6 +1096,8 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
.. _@AdamGleave: https://github.com/adamgleave
.. _Anssi Kanervisto: https://github.com/Miffyli
.. _@Miffyli: https://github.com/Miffyli
+.. _Quentin Gallouédec: https://gallouedec.com/
+.. _@qgallouedec: https://github.com/qgallouedec
@@ -964,4 +1122,7 @@ And all the contributors:
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
-@Gregwar @ycheng517 @quantitative-technologies
+@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
+@carlosluis @arjun-kg @tlpss
+@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
+@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer
diff --git a/docs/modules/a2c.rst b/docs/modules/a2c.rst
index e871424e7..9325e54a2 100644
--- a/docs/modules/a2c.rst
+++ b/docs/modules/a2c.rst
@@ -53,7 +53,7 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.
.. code-block:: python
- import gym
+ import gymnasium as gym
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
diff --git a/docs/modules/ddpg.rst b/docs/modules/ddpg.rst
index c484a1c93..4ac28ccb3 100644
--- a/docs/modules/ddpg.rst
+++ b/docs/modules/ddpg.rst
@@ -61,7 +61,7 @@ This example is only to demonstrate the use of the library and its functions, an
.. code-block:: python
- import gym
+ import gymnasium as gym
import numpy as np
from stable_baselines3 import DDPG
diff --git a/docs/modules/dqn.rst b/docs/modules/dqn.rst
index ce4385502..d69da499a 100644
--- a/docs/modules/dqn.rst
+++ b/docs/modules/dqn.rst
@@ -56,7 +56,7 @@ This example is only to demonstrate the use of the library and its functions, an
.. code-block:: python
- import gym
+ import gymnasium as gym
from stable_baselines3 import DQN
diff --git a/docs/modules/her.rst b/docs/modules/her.rst
index 82bf745d6..817a991cf 100644
--- a/docs/modules/her.rst
+++ b/docs/modules/her.rst
@@ -19,10 +19,12 @@ It creates "virtual" transitions by relabeling transitions (changing the desired
but a replay buffer class ``HerReplayBuffer`` that must be passed to an off-policy algorithm
when using ``MultiInputPolicy`` (to have Dict observation support).
-
.. warning::
- HER requires the environment to inherits from `gym.GoalEnv `_
+ HER requires the environment to follow the legacy `gym_robotics.GoalEnv interface `_
+ In short, the ``gym.Env`` must have:
+ - a vectorized implementation of ``compute_reward()``
+ - a dictionary observation space with three keys: ``observation``, ``achieved_goal`` and ``desired_goal``
.. warning::
@@ -73,7 +75,7 @@ This example is only to demonstrate the use of the library and its functions, an
env = BitFlippingEnv(n_bits=N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
# Available strategies (cf paper): future, final, episode
- goal_selection_strategy = 'future' # equivalent to GoalSelectionStrategy.FUTURE
+ goal_selection_strategy = "future" # equivalent to GoalSelectionStrategy.FUTURE
# If True the HER transitions will get sampled online
online_sampling = True
@@ -101,7 +103,7 @@ This example is only to demonstrate the use of the library and its functions, an
model.save("./her_bit_env")
# Because it needs access to `env.compute_reward()`
# HER must be loaded with the env
- model = model_class.load('./her_bit_env', env=env)
+ model = model_class.load("./her_bit_env", env=env)
obs = env.reset()
for _ in range(100):
diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst
index 3aab653fa..a822cb436 100644
--- a/docs/modules/ppo.rst
+++ b/docs/modules/ppo.rst
@@ -8,14 +8,14 @@ PPO
The `Proximal Policy Optimization `_ algorithm combines ideas from A2C (having multiple workers)
and TRPO (it uses a trust region to improve the actor).
-The main idea is that after an update, the new policy should be not too far form the old policy.
+The main idea is that after an update, the new policy should be not too far from the old policy.
For that, ppo uses clipping to avoid too large update.
.. note::
PPO contains several modifications from the original algorithm not documented
- by OpenAI: advantages are normalized and value function can be also clipped .
+ by OpenAI: advantages are normalized and value function can be also clipped.
Notes
@@ -25,11 +25,22 @@ Notes
- Clear explanation of PPO on Arxiv Insights channel: https://www.youtube.com/watch?v=5P7I-xPq8u8
- OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/
- Spinning Up guide: https://spinningup.openai.com/en/latest/algorithms/ppo.html
+- 37 implementation details blog: https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/
Can I use?
----------
+.. note::
+
+ A recurrent version of PPO is available in our contrib repo: https://sb3-contrib.readthedocs.io/en/master/modules/ppo_recurrent.html
+
+ However we advise users to start with simple frame-stacking as a simpler, faster
+ and usually competitive alternative, more info in our report: https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity--VmlldzoxOTI4NjE4
+ See also `Procgen paper appendix Fig 11. `_.
+ In practice, you can stack multiple observations using ``VecFrameStack``.
+
+
- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:
@@ -54,7 +65,7 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
.. code-block:: python
- import gym
+ import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
diff --git a/docs/modules/sac.rst b/docs/modules/sac.rst
index e7f9057d5..0e9bb3f64 100644
--- a/docs/modules/sac.rst
+++ b/docs/modules/sac.rst
@@ -68,7 +68,7 @@ This example is only to demonstrate the use of the library and its functions, an
.. code-block:: python
- import gym
+ import gymnasium as gym
import numpy as np
from stable_baselines3 import SAC
diff --git a/docs/modules/td3.rst b/docs/modules/td3.rst
index d039ae71c..7c17e644d 100644
--- a/docs/modules/td3.rst
+++ b/docs/modules/td3.rst
@@ -61,7 +61,7 @@ This example is only to demonstrate the use of the library and its functions, an
.. code-block:: python
- import gym
+ import gymnasium as gym
import numpy as np
from stable_baselines3 import TD3
diff --git a/scripts/build_docker.sh b/scripts/build_docker.sh
index 13ac86b17..3f0d5ae7c 100755
--- a/scripts/build_docker.sh
+++ b/scripts/build_docker.sh
@@ -1,14 +1,14 @@
#!/bin/bash
-CPU_PARENT=ubuntu:18.04
-GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04
+CPU_PARENT=ubuntu:20.04
+GPU_PARENT=nvidia/cuda:11.3.1-base-ubuntu20.04
TAG=stablebaselines/stable-baselines3
VERSION=$(cat ./stable_baselines3/version.txt)
if [[ ${USE_GPU} == "True" ]]; then
PARENT=${GPU_PARENT}
- PYTORCH_DEPS="cudatoolkit=10.1"
+ PYTORCH_DEPS="cudatoolkit=11.3"
else
PARENT=${CPU_PARENT}
PYTORCH_DEPS="cpuonly"
diff --git a/setup.cfg b/setup.cfg
index 5bc66c20c..ed46c73b6 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,6 +1,9 @@
[metadata]
# This includes the license file in the wheel.
-license_file = LICENSE
+license_files = LICENSE
+project_urls =
+ Code = https://github.com/DLR-RM/stable-baselines3
+ Documentation = https://stable-baselines3.readthedocs.io/
[tool:pytest]
# Deterministic ordering for tests; useful for pytest-xdist.
@@ -10,11 +13,7 @@ filterwarnings =
# Tensorboard warnings
ignore::DeprecationWarning:tensorboard
# Gym warnings
- ignore:Parameters to load are deprecated.:DeprecationWarning
- ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
ignore::UserWarning:gym
- ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning
- ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning
markers =
expensive: marks tests as expensive (deselect with '-m "not expensive"')
diff --git a/setup.py b/setup.py
index 3664bbc59..7d9ac62db 100644
--- a/setup.py
+++ b/setup.py
@@ -2,7 +2,7 @@
from setuptools import find_packages, setup
-with open(os.path.join("stable_baselines3", "version.txt"), "r") as file_handler:
+with open(os.path.join("stable_baselines3", "version.txt")) as file_handler:
__version__ = file_handler.read().strip()
@@ -39,30 +39,33 @@
Here is a quick example of how to train and run PPO on a cartpole environment:
```python
-import gym
+import gymnasium as gym
from stable_baselines3 import PPO
-env = gym.make('CartPole-v1')
+env = gym.make("CartPole-v1")
-model = PPO('MlpPolicy', env, verbose=1)
-model.learn(total_timesteps=10000)
+model = PPO("MlpPolicy", env, verbose=1)
+model.learn(total_timesteps=10_000)
-obs = env.reset()
+vec_env = model.get_env()
+obs = vec_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
- obs, reward, done, info = env.step(action)
- env.render()
- if done:
- obs = env.reset()
+ obs, reward, done, info = vec_env.step(action)
+ vec_env.render()
+ # VecEnv resets automatically
+ # if done:
+ # obs = vec_env.reset()
+
```
-Or just train a model with a one liner if [the environment is registered in Gym](https://github.com/openai/gym/wiki/Environments) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
+Or just train a model with a one liner if [the environment is registered in Gym](https://www.gymlibrary.ml/content/environment_creation/) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
```python
from stable_baselines3 import PPO
-model = PPO('MlpPolicy', 'CartPole-v1').learn(10000)
+model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
```
""" # noqa:E501
@@ -73,7 +76,7 @@
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
- "gym==0.21", # Fixed version due to breaking changes in 0.22
+ "gymnasium==0.26.3",
"numpy",
"torch>=1.11",
# For saving models
@@ -82,6 +85,8 @@
"pandas",
# Plotting learning curves
"matplotlib",
+ # gym and flake8 not compatible with importlib-metadata>5.0
+ "importlib-metadata~=4.13",
],
extras_require={
"tests": [
@@ -100,8 +105,6 @@
"isort>=5.0",
# Reformat
"black",
- # For toy text Gym envs
- "scipy>=1.4.1",
],
"docs": [
"sphinx",
@@ -111,18 +114,24 @@
"sphinxcontrib.spelling",
# Type hints support
"sphinx-autodoc-typehints",
+ # Copy button for code snippets
+ "sphinx_copybutton",
],
"extra": [
# For render
"opencv-python",
+ "pygame",
# For atari games,
- "ale-py~=0.7.4",
+ "ale-py~=0.8.0",
"autorom[accept-rom-license]~=0.4.2",
"pillow",
# Tensorboard support
- "tensorboard>=2.2.0",
+ "tensorboard>=2.9.1",
# Checking memory taken by replay buffer
"psutil",
+ # For progress bar callback
+ "tqdm",
+ "rich",
],
},
description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py
index 4e31c5b3b..d73f5f095 100644
--- a/stable_baselines3/__init__.py
+++ b/stable_baselines3/__init__.py
@@ -11,7 +11,7 @@
# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
-with open(version_file, "r") as file_handler:
+with open(version_file) as file_handler:
__version__ = file_handler.read().strip()
diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py
index eeeb670c3..95fca770b 100644
--- a/stable_baselines3/a2c/a2c.py
+++ b/stable_baselines3/a2c/a2c.py
@@ -1,7 +1,7 @@
-from typing import Any, Dict, Optional, Type, Union
+from typing import Any, Dict, Optional, Type, TypeVar, Union
import torch as th
-from gym import spaces
+from gymnasium import spaces
from torch.nn import functional as F
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
@@ -9,6 +9,8 @@
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance
+A2CSelf = TypeVar("A2CSelf", bound="A2C")
+
class A2C(OnPolicyAlgorithm):
"""
@@ -41,10 +43,9 @@ class A2C(OnPolicyAlgorithm):
Default: -1 (only sample at the beginning of the rollout)
:param normalize_advantage: Whether to normalize or not the advantage
:param tensorboard_log: the log location for tensorboard (if None, no logging)
- :param create_eval_env: Whether to create a second environment that will be
- used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: additional arguments to be passed to the policy on creation
- :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
+ debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
@@ -74,7 +75,6 @@ def __init__(
sde_sample_freq: int = -1,
normalize_advantage: bool = False,
tensorboard_log: Optional[str] = None,
- create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
@@ -82,7 +82,7 @@ def __init__(
_init_setup_model: bool = True,
):
- super(A2C, self).__init__(
+ super().__init__(
policy,
env,
learning_rate=learning_rate,
@@ -98,7 +98,6 @@ def __init__(
policy_kwargs=policy_kwargs,
verbose=verbose,
device=device,
- create_eval_env=create_eval_env,
seed=seed,
_init_setup_model=False,
supported_action_spaces=(
@@ -182,26 +181,20 @@ def train(self) -> None:
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
def learn(
- self,
+ self: A2CSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 100,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
tb_log_name: str = "A2C",
- eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
- ) -> "A2C":
+ progress_bar: bool = False,
+ ) -> A2CSelf:
- return super(A2C, self).learn(
+ return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
- eval_env=eval_env,
- eval_freq=eval_freq,
- n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name,
- eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
+ progress_bar=progress_bar,
)
diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py
index 832ad9f23..bf89102da 100644
--- a/stable_baselines3/common/atari_wrappers.py
+++ b/stable_baselines3/common/atari_wrappers.py
@@ -1,6 +1,8 @@
-import gym
+from typing import Dict, Tuple
+
+import gymnasium as gym
import numpy as np
-from gym import spaces
+from gymnasium import spaces
try:
import cv2 # pytype:disable=import-error
@@ -9,7 +11,7 @@
except ImportError:
cv2 = None
-from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
+from stable_baselines3.common.type_aliases import Gym26ResetReturn, Gym26StepReturn
class NoopResetEnv(gym.Wrapper):
@@ -28,19 +30,20 @@ def __init__(self, env: gym.Env, noop_max: int = 30):
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
- def reset(self, **kwargs) -> np.ndarray:
+ def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]:
self.env.reset(**kwargs)
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
- noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
+ noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
assert noops > 0
obs = np.zeros(0)
+ info = {}
for _ in range(noops):
- obs, _, done, _ = self.env.step(self.noop_action)
- if done:
- obs = self.env.reset(**kwargs)
- return obs
+ obs, _, terminated, truncated, info = self.env.step(self.noop_action)
+ if terminated or truncated:
+ obs, info = self.env.reset(**kwargs)
+ return obs, info
class FireResetEnv(gym.Wrapper):
@@ -55,15 +58,15 @@ def __init__(self, env: gym.Env):
assert env.unwrapped.get_action_meanings()[1] == "FIRE"
assert len(env.unwrapped.get_action_meanings()) >= 3
- def reset(self, **kwargs) -> np.ndarray:
+ def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]:
self.env.reset(**kwargs)
- obs, _, done, _ = self.env.step(1)
- if done:
+ obs, _, terminated, truncated, _ = self.env.step(1)
+ if terminated or truncated:
self.env.reset(**kwargs)
- obs, _, done, _ = self.env.step(2)
- if done:
+ obs, _, terminated, truncated, _ = self.env.step(2)
+ if terminated or truncated:
self.env.reset(**kwargs)
- return obs
+ return obs, {}
class EpisodicLifeEnv(gym.Wrapper):
@@ -79,21 +82,21 @@ def __init__(self, env: gym.Env):
self.lives = 0
self.was_real_done = True
- def step(self, action: int) -> GymStepReturn:
- obs, reward, done, info = self.env.step(action)
- self.was_real_done = done
+ def step(self, action: int) -> Gym26StepReturn:
+ obs, reward, terminated, truncated, info = self.env.step(action)
+ self.was_real_done = terminated or truncated
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if 0 < lives < self.lives:
- # for Qbert sometimes we stay in lives == 0 condtion for a few frames
+ # for Qbert sometimes we stay in lives == 0 condition for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
- done = True
+ terminated = True
self.lives = lives
- return obs, reward, done, info
+ return obs, reward, terminated, truncated, info
- def reset(self, **kwargs) -> np.ndarray:
+ def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]:
"""
Calls the Gym environment reset, only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
@@ -103,12 +106,12 @@ def reset(self, **kwargs) -> np.ndarray:
:return: the first observation of the environment
"""
if self.was_real_done:
- obs = self.env.reset(**kwargs)
+ obs, info = self.env.reset(**kwargs)
else:
# no-op step to advance from terminal/lost life state
- obs, _, _, _ = self.env.step(0)
+ obs, _, _, _, info = self.env.step(0)
self.lives = self.env.unwrapped.ale.lives()
- return obs
+ return obs, info
class MaxAndSkipEnv(gym.Wrapper):
@@ -125,7 +128,7 @@ def __init__(self, env: gym.Env, skip: int = 4):
self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=env.observation_space.dtype)
self._skip = skip
- def step(self, action: int) -> GymStepReturn:
+ def step(self, action: int) -> Gym26StepReturn:
"""
Step the environment with the given action
Repeat action, sum reward, and max over last observations.
@@ -134,9 +137,10 @@ def step(self, action: int) -> GymStepReturn:
:return: observation, reward, done, information
"""
total_reward = 0.0
- done = None
+ terminated = truncated = False
for i in range(self._skip):
- obs, reward, done, info = self.env.step(action)
+ obs, reward, terminated, truncated, info = self.env.step(action)
+ done = terminated or truncated
if i == self._skip - 2:
self._obs_buffer[0] = obs
if i == self._skip - 1:
@@ -148,9 +152,9 @@ def step(self, action: int) -> GymStepReturn:
# doesn't matter
max_frame = self._obs_buffer.max(axis=0)
- return max_frame, total_reward, done, info
+ return max_frame, total_reward, terminated, truncated, info
- def reset(self, **kwargs) -> GymObs:
+ def reset(self, **kwargs) -> Gym26ResetReturn:
return self.env.reset(**kwargs)
@@ -235,8 +239,10 @@ def __init__(
terminal_on_life_loss: bool = True,
clip_reward: bool = True,
):
- env = NoopResetEnv(env, noop_max=noop_max)
- env = MaxAndSkipEnv(env, skip=frame_skip)
+ if noop_max > 0:
+ env = NoopResetEnv(env, noop_max=noop_max)
+ if frame_skip > 0:
+ env = MaxAndSkipEnv(env, skip=frame_skip)
if terminal_on_life_loss:
env = EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings():
@@ -245,4 +251,4 @@ def __init__(
if clip_reward:
env = ClipRewardEnv(env)
- super(AtariWrapper, self).__init__(env)
+ super().__init__(env)
diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py
index 14570be30..149ff1cf9 100644
--- a/stable_baselines3/common/base_class.py
+++ b/stable_baselines3/common/base_class.py
@@ -5,14 +5,14 @@
import time
from abc import ABC, abstractmethod
from collections import deque
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
-import gym
+import gymnasium as gym
import numpy as np
import torch as th
from stable_baselines3.common import utils
-from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
+from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
from stable_baselines3.common.env_util import is_wrapped
from stable_baselines3.common.logger import Logger
from stable_baselines3.common.monitor import Monitor
@@ -23,6 +23,7 @@
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import (
check_for_correct_spaces,
+ compat_gym_seed,
get_device,
get_schedule_fn,
get_system_info,
@@ -43,7 +44,7 @@ def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymE
"""If env is a string, make the environment; otherwise, return env.
:param env: The environment to learn from.
- :param verbose: logging verbosity
+ :param verbose: Verbosity level: 0 for no output, 1 for indicating if envrironment is created
:return A Gym (vector) environment.
"""
if isinstance(env, str):
@@ -53,25 +54,27 @@ def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymE
return env
+BaseAlgorithmSelf = TypeVar("BaseAlgorithmSelf", bound="BaseAlgorithm")
+
+
class BaseAlgorithm(ABC):
"""
The base of RL algorithms
- :param policy: Policy object
+ :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param learning_rate: learning rate for the optimizer,
it can be a function of the current progress remaining (from 1 to 0)
:param policy_kwargs: Additional arguments to be passed to the policy on creation
:param tensorboard_log: the log location for tensorboard (if None, no logging)
- :param verbose: The verbosity level: 0 none, 1 training information, 2 debug
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
+ debug messages
:param device: Device on which the code should run.
By default, it will try to use a Cuda compatible device and fallback to cpu
if it is not possible.
:param support_multi_env: Whether the algorithm supports training
with multiple environments (as in A2C)
- :param create_eval_env: Whether to create a second environment that will be
- used for evaluating the agent periodically. (Only available when passing string for the environment)
:param monitor_wrapper: When creating an environment, whether to wrap it
or not in a Monitor wrapper.
:param seed: Seed for the pseudo random generators
@@ -87,7 +90,7 @@ class BaseAlgorithm(ABC):
def __init__(
self,
- policy: Type[BasePolicy],
+ policy: Union[str, Type[BasePolicy]],
env: Union[GymEnv, str, None],
learning_rate: Union[float, Schedule],
policy_kwargs: Optional[Dict[str, Any]] = None,
@@ -95,7 +98,6 @@ def __init__(
verbose: int = 0,
device: Union[th.device, str] = "auto",
support_multi_env: bool = False,
- create_eval_env: bool = False,
monitor_wrapper: bool = True,
seed: Optional[int] = None,
use_sde: bool = False,
@@ -108,7 +110,7 @@ def __init__(
self.policy_class = policy
self.device = get_device(device)
- if verbose > 0:
+ if verbose >= 1:
print(f"Using {self.device} device")
self.env = None # type: Optional[GymEnv]
@@ -124,7 +126,6 @@ def __init__(
self._total_timesteps = 0
# Used for computing fps, it is updated at each call of learn()
self._num_timesteps_at_start = 0
- self.eval_env = None
self.seed = seed
self.action_noise = None # type: Optional[ActionNoise]
self.start_time = None
@@ -155,10 +156,6 @@ def __init__(
# Create and wrap the env if needed
if env is not None:
- if isinstance(env, str):
- if create_eval_env:
- self.eval_env = maybe_make_env(env, self.verbose)
-
env = maybe_make_env(env, self.verbose)
env = self._wrap_env(env, self.verbose, monitor_wrapper)
@@ -185,6 +182,11 @@ def __init__(
if self.use_sde and not isinstance(self.action_space, gym.spaces.Box):
raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")
+ if isinstance(self.action_space, gym.spaces.Box):
+ assert np.all(
+ np.isfinite(np.array([self.action_space.low, self.action_space.high]))
+ ), "Continuous action space must have a finite lower and upper bound"
+
@staticmethod
def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> VecEnv:
""" "
@@ -193,7 +195,7 @@ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> Ve
or to re-order the image channels.
:param env:
- :param verbose:
+ :param verbose: Verbosity level: 0 for no output, 1 for indicating wrappers used
:param monitor_wrapper: Whether to wrap the env in a ``Monitor`` when possible.
:return: The wrapped environment.
"""
@@ -209,11 +211,6 @@ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> Ve
# Make sure that dict-spaces are not nested (not supported)
check_for_nested_spaces(env.observation_space)
- if isinstance(env.observation_space, gym.spaces.Dict):
- for space in env.observation_space.spaces.values():
- if isinstance(space, gym.spaces.Dict):
- raise ValueError("Nested observation spaces are not supported (Dict spaces inside Dict space).")
-
if not is_vecenv_wrapped(env, VecTransposeImage):
wrap_with_vectranspose = False
if isinstance(env.observation_space, gym.spaces.Dict):
@@ -259,21 +256,6 @@ def logger(self) -> Logger:
"""Getter for the logger object."""
return self._logger
- def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]:
- """
- Return the environment that will be used for evaluation.
-
- :param eval_env:)
- :return:
- """
- if eval_env is None:
- eval_env = self.eval_env
-
- if eval_env is not None:
- eval_env = self._wrap_env(eval_env, self.verbose)
- assert eval_env.num_envs == 1
- return eval_env
-
def _setup_lr_schedule(self) -> None:
"""Transform to callable if needed."""
self.lr_schedule = get_schedule_fn(self.learning_rate)
@@ -316,7 +298,6 @@ def _excluded_save_params(self) -> List[str]:
"policy",
"device",
"env",
- "eval_env",
"replay_buffer",
"rollout_buffer",
"_vec_normalize_env",
@@ -363,17 +344,11 @@ def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
def _init_callback(
self,
callback: MaybeCallback,
- eval_env: Optional[VecEnv] = None,
- eval_freq: int = 10000,
- n_eval_episodes: int = 5,
- log_path: Optional[str] = None,
+ progress_bar: bool = False,
) -> BaseCallback:
"""
:param callback: Callback(s) called at every step with state of the algorithm.
- :param eval_freq: How many steps between evaluations; if None, do not evaluate.
- :param n_eval_episodes: How many episodes to play per evaluation
- :param n_eval_episodes: Number of episodes to rollout during evaluation.
- :param log_path: Path to a folder where the evaluations will be saved
+ :param progress_bar: Display a progress bar using tqdm and rich.
:return: A hybrid callback calling `callback` and performing evaluation.
"""
# Convert a list of callbacks into a callback
@@ -384,16 +359,9 @@ def _init_callback(
if not isinstance(callback, BaseCallback):
callback = ConvertCallback(callback)
- # Create eval callback in charge of the evaluation
- if eval_env is not None:
- eval_callback = EvalCallback(
- eval_env,
- best_model_save_path=log_path,
- log_path=log_path,
- eval_freq=eval_freq,
- n_eval_episodes=n_eval_episodes,
- )
- callback = CallbackList([callback, eval_callback])
+ # Add progress bar callback
+ if progress_bar:
+ callback = CallbackList([callback, ProgressBarCallback()])
callback.init_callback(self)
return callback
@@ -401,28 +369,22 @@ def _init_callback(
def _setup_learn(
self,
total_timesteps: int,
- eval_env: Optional[GymEnv],
callback: MaybeCallback = None,
- eval_freq: int = 10000,
- n_eval_episodes: int = 5,
- log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
+ progress_bar: bool = False,
) -> Tuple[int, BaseCallback]:
"""
Initialize different variables needed for training.
:param total_timesteps: The total number of samples (env steps) to train on
- :param eval_env: Environment to use for evaluation.
:param callback: Callback(s) called at every step with state of the algorithm.
- :param eval_freq: How many steps between evaluations
- :param n_eval_episodes: How many episodes to play per evaluation
- :param log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: the name of the run for tensorboard log
- :return:
+ :param progress_bar: Display a progress bar using tqdm and rich.
+ :return: Total timesteps and callback(s)
"""
- self.start_time = time.time()
+ self.start_time = time.time_ns()
if self.ep_info_buffer is None or reset_num_timesteps:
# Initialize buffers if they don't exist, or reinitialize if resetting counters
@@ -449,17 +411,12 @@ def _setup_learn(
if self._vec_normalize_env is not None:
self._last_original_obs = self._vec_normalize_env.get_original_obs()
- if eval_env is not None and self.seed is not None:
- eval_env.seed(self.seed)
-
- eval_env = self._get_eval_env(eval_env)
-
# Configure logger's outputs if no logger was passed
if not self._custom_logger:
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
# Create eval callback if needed
- callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path)
+ callback = self._init_callback(callback, progress_bar)
return total_timesteps, callback
@@ -514,6 +471,11 @@ def set_env(self, env: GymEnv, force_reset: bool = True) -> None:
# if it is not a VecEnv, make it a VecEnv
# and do other transformations (dict obs, image transpose) if needed
env = self._wrap_env(env, self.verbose)
+ assert env.num_envs == self.n_envs, (
+ "The number of environments to be set is different from the number of environments in the model: "
+ f"({env.num_envs} != {self.n_envs}), whereas `set_env` requires them to be the same. To load a model with "
+ f"a different number of environments, you must use `{self.__class__.__name__}.load(path, env)` instead"
+ )
# Check that the observation spaces match
check_for_correct_spaces(env, self.observation_space, self.action_space)
# Update VecNormalize object
@@ -530,17 +492,14 @@ def set_env(self, env: GymEnv, force_reset: bool = True) -> None:
@abstractmethod
def learn(
- self,
+ self: BaseAlgorithmSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 100,
tb_log_name: str = "run",
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
- eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
- ) -> "BaseAlgorithm":
+ progress_bar: bool = False,
+ ) -> BaseAlgorithmSelf:
"""
Return a trained model.
@@ -548,11 +507,8 @@ def learn(
:param callback: callback(s) called at every step with state of the algorithm.
:param log_interval: The number of timesteps before logging.
:param tb_log_name: the name of the run for TensorBoard logging
- :param eval_env: Environment that will be used to evaluate the agent
- :param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
- :param n_eval_episodes: Number of episode to evaluate the agent
- :param eval_log_path: Path to a folder where the evaluations will be saved
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
+ :param progress_bar: Display a progress bar using tqdm and rich.
:return: the trained model
"""
@@ -589,10 +545,9 @@ def set_random_seed(self, seed: Optional[int] = None) -> None:
return
set_random_seed(seed, using_cuda=self.device.type == th.device("cuda").type)
self.action_space.seed(seed)
+ # self.env is always a VecEnv
if self.env is not None:
self.env.seed(seed)
- if self.eval_env is not None:
- self.eval_env.seed(seed)
def set_parameters(
self,
@@ -628,11 +583,11 @@ def set_parameters(
attr = None
try:
attr = recursive_getattr(self, name)
- except Exception:
+ except Exception as e:
# What errors recursive_getattr could throw? KeyError, but
# possible something else too (e.g. if key is an int?).
# Catch anything for now.
- raise ValueError(f"Key {name} is an invalid object name.")
+ raise ValueError(f"Key {name} is an invalid object name.") from e
if isinstance(attr, th.optim.Optimizer):
# Optimizers do not support "strict" keyword...
@@ -664,7 +619,7 @@ def set_parameters(
@classmethod
def load(
- cls,
+ cls: Type[BaseAlgorithmSelf],
path: Union[str, pathlib.Path, io.BufferedIOBase],
env: Optional[GymEnv] = None,
device: Union[th.device, str] = "auto",
@@ -672,7 +627,7 @@ def load(
print_system_info: bool = False,
force_reset: bool = True,
**kwargs,
- ) -> "BaseAlgorithm":
+ ) -> BaseAlgorithmSelf:
"""
Load the model from a zip-file.
Warning: ``load`` re-creates the model from scratch, it does not update it in-place!
@@ -702,7 +657,10 @@ def load(
get_system_info()
data, params, pytorch_variables = load_from_zip_file(
- path, device=device, custom_objects=custom_objects, print_system_info=print_system_info
+ path,
+ device=device,
+ custom_objects=custom_objects,
+ print_system_info=print_system_info,
)
# Remove stored device information and replace with ours
@@ -728,6 +686,9 @@ def load(
# See issue https://github.com/DLR-RM/stable-baselines3/issues/597
if force_reset and data is not None:
data["_last_obs"] = None
+ # `n_envs` must be updated. See issue https://github.com/DLR-RM/stable-baselines3/issues/1018
+ if data is not None:
+ data["n_envs"] = env.num_envs
else:
# Use stored env, if one exists. If not, continue as is (can be used for predict)
if "env" in data:
diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py
index bba2272a5..5442a0358 100644
--- a/stable_baselines3/common/buffers.py
+++ b/stable_baselines3/common/buffers.py
@@ -4,7 +4,7 @@
import numpy as np
import torch as th
-from gym import spaces
+from gymnasium import spaces
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
from stable_baselines3.common.type_aliases import (
@@ -13,6 +13,7 @@
ReplayBufferSamples,
RolloutBufferSamples,
)
+from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize
try:
@@ -39,10 +40,10 @@ def __init__(
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
- device: Union[th.device, str] = "cpu",
+ device: Union[th.device, str] = "auto",
n_envs: int = 1,
):
- super(BaseBuffer, self).__init__()
+ super().__init__()
self.buffer_size = buffer_size
self.observation_space = observation_space
self.action_space = action_space
@@ -51,7 +52,7 @@ def __init__(
self.action_dim = get_action_dim(action_space)
self.pos = 0
self.full = False
- self.device = device
+ self.device = get_device(device)
self.n_envs = n_envs
@staticmethod
@@ -157,13 +158,14 @@ class ReplayBuffer(BaseBuffer):
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
- :param device:
+ :param device: PyTorch device
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
of the replay buffer which reduces by almost a factor two the memory used,
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
+ Cannot be used in combination with handle_timeout_termination.
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
separately and treat the task as infinite horizon task.
https://github.com/DLR-RM/stable-baselines3/issues/284
@@ -174,12 +176,12 @@ def __init__(
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
- device: Union[th.device, str] = "cpu",
+ device: Union[th.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
):
- super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
+ super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
# Adjust buffer size
self.buffer_size = max(buffer_size // n_envs, 1)
@@ -188,6 +190,13 @@ def __init__(
if psutil is not None:
mem_available = psutil.virtual_memory().available
+ # there is a bug if both optimize_memory_usage and handle_timeout_termination are true
+ # see https://github.com/DLR-RM/stable-baselines3/issues/934
+ if optimize_memory_usage and handle_timeout_termination:
+ raise ValueError(
+ "ReplayBuffer does not support optimize_memory_usage = True "
+ "and handle_timeout_termination = True simultaneously."
+ )
self.optimize_memory_usage = optimize_memory_usage
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
@@ -239,8 +248,7 @@ def add(
next_obs = next_obs.reshape((self.n_envs,) + self.obs_shape)
# Same, for actions
- if isinstance(self.action_space, spaces.Discrete):
- action = action.reshape((self.n_envs, self.action_dim))
+ action = action.reshape((self.n_envs, self.action_dim))
# Copy to avoid modification by reference
self.observations[self.pos] = np.array(obs).copy()
@@ -321,7 +329,7 @@ class RolloutBuffer(BaseBuffer):
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
- :param device:
+ :param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
@@ -333,13 +341,13 @@ def __init__(
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
- device: Union[th.device, str] = "cpu",
+ device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
- super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
+ super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda
self.gamma = gamma
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
@@ -358,7 +366,7 @@ def reset(self) -> None:
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False
- super(RolloutBuffer, self).reset()
+ super().reset()
def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
"""
@@ -425,6 +433,9 @@ def add(
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs,) + self.obs_shape)
+ # Same reshape, for actions
+ action = action.reshape((self.n_envs, self.action_dim))
+
self.observations[self.pos] = np.array(obs).copy()
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
@@ -483,7 +494,7 @@ class DictReplayBuffer(ReplayBuffer):
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
- :param device:
+ :param device: PyTorch device
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)
@@ -497,7 +508,7 @@ def __init__(
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
- device: Union[th.device, str] = "cpu",
+ device: Union[th.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
@@ -563,7 +574,7 @@ def add(
reward: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
- ) -> None:
+ ) -> None: # pytype: disable=signature-mismatch
# Copy to avoid modification by reference
for key in self.observations.keys():
# Reshape needed when using multiple envs with discrete observations
@@ -578,8 +589,7 @@ def add(
self.next_observations[key][self.pos] = np.array(next_obs[key]).copy()
# Same reshape, for actions
- if isinstance(self.action_space, spaces.Discrete):
- action = action.reshape((self.n_envs, self.action_dim))
+ action = action.reshape((self.n_envs, self.action_dim))
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
@@ -649,7 +659,7 @@ class DictRolloutBuffer(RolloutBuffer):
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
- :param device:
+ :param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to Monte-Carlo advantage estimate when set to 1.
:param gamma: Discount factor
@@ -661,7 +671,7 @@ def __init__(
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
- device: Union[th.device, str] = "cpu",
+ device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
@@ -701,7 +711,7 @@ def add(
episode_start: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
- ) -> None:
+ ) -> None: # pytype: disable=signature-mismatch
"""
:param obs: Observation
:param action: Action
diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py
index 27ce5e639..11f431d3c 100644
--- a/stable_baselines3/common/callbacks.py
+++ b/stable_baselines3/common/callbacks.py
@@ -3,9 +3,20 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
-import gym
+import gymnasium as gym
import numpy as np
+try:
+ from tqdm import TqdmExperimentalWarning
+
+ # Remove experimental warning
+ warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
+ from tqdm.rich import tqdm
+except ImportError:
+ # Rich not installed, we only throw an error
+ # if the progress bar is used
+ tqdm = None
+
from stable_baselines3.common import base_class # pytype: disable=pyi-error
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
@@ -15,11 +26,11 @@ class BaseCallback(ABC):
"""
Base class for callback.
- :param verbose:
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, verbose: int = 0):
- super(BaseCallback, self).__init__()
+ super().__init__()
# The RL model
self.model = None # type: Optional[base_class.BaseAlgorithm]
# An alias for self.model.get_env(), the environment used for training
@@ -54,6 +65,8 @@ def on_training_start(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -
# Those are reference and will be updated automatically
self.locals = locals_
self.globals = globals_
+ # Update num_timesteps in case training was done before
+ self.num_timesteps = self.model.num_timesteps
self._on_training_start()
def _on_training_start(self) -> None:
@@ -82,7 +95,6 @@ def on_step(self) -> bool:
:return: If the callback returns False, training is aborted early.
"""
self.n_calls += 1
- # timesteps start at zero
self.num_timesteps = self.model.num_timesteps
return self._on_step()
@@ -123,18 +135,18 @@ class EventCallback(BaseCallback):
:param callback: Callback that will be called
when an event is triggered.
- :param verbose:
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
- super(EventCallback, self).__init__(verbose=verbose)
+ super().__init__(verbose=verbose)
self.callback = callback
# Give access to the parent
if callback is not None:
self.callback.parent = self
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
- super(EventCallback, self).init_callback(model)
+ super().init_callback(model)
if self.callback is not None:
self.callback.init_callback(self.model)
@@ -169,7 +181,7 @@ class CallbackList(BaseCallback):
"""
def __init__(self, callbacks: List[BaseCallback]):
- super(CallbackList, self).__init__()
+ super().__init__()
assert isinstance(callbacks, list)
self.callbacks = callbacks
@@ -214,6 +226,10 @@ class CheckpointCallback(BaseCallback):
"""
Callback for saving a model every ``save_freq`` calls
to ``env.step()``.
+ By default, it only saves model checkpoints,
+ you need to pass ``save_replay_buffer=True``,
+ and ``save_vecnormalize=True`` to also save replay buffer checkpoints
+ and normalization statistics checkpoints.
.. warning::
@@ -221,29 +237,67 @@ class CheckpointCallback(BaseCallback):
will effectively correspond to ``n_envs`` steps.
To account for that, you can use ``save_freq = max(save_freq // n_envs, 1)``
- :param save_freq:
+ :param save_freq: Save checkpoints every ``save_freq`` call of the callback.
:param save_path: Path to the folder where the model will be saved.
:param name_prefix: Common prefix to the saved models
- :param verbose:
+ :param save_replay_buffer: Save the model replay buffer
+ :param save_vecnormalize: Save the ``VecNormalize`` statistics
+ :param verbose: Verbosity level: 0 for no output, 2 for indicating when saving model checkpoint
"""
- def __init__(self, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0):
- super(CheckpointCallback, self).__init__(verbose)
+ def __init__(
+ self,
+ save_freq: int,
+ save_path: str,
+ name_prefix: str = "rl_model",
+ save_replay_buffer: bool = False,
+ save_vecnormalize: bool = False,
+ verbose: int = 0,
+ ):
+ super().__init__(verbose)
self.save_freq = save_freq
self.save_path = save_path
self.name_prefix = name_prefix
+ self.save_replay_buffer = save_replay_buffer
+ self.save_vecnormalize = save_vecnormalize
def _init_callback(self) -> None:
# Create folder if needed
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)
+ def _checkpoint_path(self, checkpoint_type: str = "", extension: str = "") -> str:
+ """
+ Helper to get checkpoint path for each type of checkpoint.
+
+ :param checkpoint_type: empty for the model, "replay_buffer_"
+ or "vecnormalize_" for the other checkpoints.
+ :param extension: Checkpoint file extension (zip for model, pkl for others)
+ :return: Path to the checkpoint
+ """
+ return os.path.join(self.save_path, f"{self.name_prefix}_{checkpoint_type}{self.num_timesteps}_steps.{extension}")
+
def _on_step(self) -> bool:
if self.n_calls % self.save_freq == 0:
- path = os.path.join(self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps")
- self.model.save(path)
- if self.verbose > 1:
- print(f"Saving model checkpoint to {path}")
+ model_path = self._checkpoint_path(extension="zip")
+ self.model.save(model_path)
+ if self.verbose >= 2:
+ print(f"Saving model checkpoint to {model_path}")
+
+ if self.save_replay_buffer and hasattr(self.model, "replay_buffer") and self.model.replay_buffer is not None:
+ # If model has a replay buffer, save it too
+ replay_buffer_path = self._checkpoint_path("replay_buffer_", extension="pkl")
+ self.model.save_replay_buffer(replay_buffer_path)
+ if self.verbose > 1:
+ print(f"Saving model replay buffer checkpoint to {replay_buffer_path}")
+
+ if self.save_vecnormalize and self.model.get_vec_normalize_env() is not None:
+ # Save the VecNormalize statistics
+ vec_normalize_path = self._checkpoint_path("vecnormalize_", extension="pkl")
+ self.model.get_vec_normalize_env().save(vec_normalize_path)
+ if self.verbose >= 2:
+ print(f"Saving model VecNormalize to {vec_normalize_path}")
+
return True
@@ -252,11 +306,11 @@ class ConvertCallback(BaseCallback):
Convert functional callback (old-style) to object.
:param callback:
- :param verbose:
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0):
- super(ConvertCallback, self).__init__(verbose)
+ super().__init__(verbose)
self.callback = callback
def _on_step(self) -> bool:
@@ -288,7 +342,7 @@ class EvalCallback(EventCallback):
:param deterministic: Whether the evaluation should
use a stochastic or deterministic actions.
:param render: Whether to render or not the environment during evaluation
- :param verbose:
+ :param verbose: Verbosity level: 0 for no output, 1 for indicating information about evaluation results
:param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
wrapped with a Monitor wrapper)
"""
@@ -307,7 +361,7 @@ def __init__(
verbose: int = 1,
warn: bool = True,
):
- super(EvalCallback, self).__init__(callback_after_eval, verbose=verbose)
+ super().__init__(callback_after_eval, verbose=verbose)
self.callback_on_new_best = callback_on_new_best
if self.callback_on_new_best is not None:
@@ -380,12 +434,12 @@ def _on_step(self) -> bool:
if self.model.get_vec_normalize_env() is not None:
try:
sync_envs_normalization(self.training_env, self.eval_env)
- except AttributeError:
+ except AttributeError as e:
raise AssertionError(
"Training and eval env are not wrapped the same way, "
"see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
"and warning above."
- )
+ ) from e
# Reset success rate buffer
self._is_success_buffer = []
@@ -424,7 +478,7 @@ def _on_step(self) -> bool:
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
self.last_mean_reward = mean_reward
- if self.verbose > 0:
+ if self.verbose >= 1:
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
# Add to current Logger
@@ -433,7 +487,7 @@ def _on_step(self) -> bool:
if len(self._is_success_buffer) > 0:
success_rate = np.mean(self._is_success_buffer)
- if self.verbose > 0:
+ if self.verbose >= 1:
print(f"Success rate: {100 * success_rate:.2f}%")
self.logger.record("eval/success_rate", success_rate)
@@ -442,7 +496,7 @@ def _on_step(self) -> bool:
self.logger.dump(self.num_timesteps)
if mean_reward > self.best_mean_reward:
- if self.verbose > 0:
+ if self.verbose >= 1:
print("New best mean reward!")
if self.best_model_save_path is not None:
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
@@ -476,18 +530,19 @@ class StopTrainingOnRewardThreshold(BaseCallback):
:param reward_threshold: Minimum expected reward per episode
to stop training.
- :param verbose:
+ :param verbose: Verbosity level: 0 for no output, 1 for indicating when training ended because episodic reward
+ threshold reached
"""
def __init__(self, reward_threshold: float, verbose: int = 0):
- super(StopTrainingOnRewardThreshold, self).__init__(verbose=verbose)
+ super().__init__(verbose=verbose)
self.reward_threshold = reward_threshold
def _on_step(self) -> bool:
assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used " "with an ``EvalCallback``"
# Convert np.bool_ to bool, otherwise callback() is False won't work
continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)
- if self.verbose > 0 and not continue_training:
+ if self.verbose >= 1 and not continue_training:
print(
f"Stopping training because the mean reward {self.parent.best_mean_reward:.2f} "
f" is above the threshold {self.reward_threshold}"
@@ -505,7 +560,7 @@ class EveryNTimesteps(EventCallback):
"""
def __init__(self, n_steps: int, callback: BaseCallback):
- super(EveryNTimesteps, self).__init__(callback)
+ super().__init__(callback)
self.n_steps = n_steps
self.last_time_trigger = 0
@@ -524,11 +579,12 @@ class StopTrainingOnMaxEpisodes(BaseCallback):
and in total for ``max_episodes * n_envs`` episodes.
:param max_episodes: Maximum number of episodes to stop training.
- :param verbose: Select whether to print information about when training ended by reaching ``max_episodes``
+ :param verbose: Verbosity level: 0 for no output, 1 for indicating information about when training ended by
+ reaching ``max_episodes``
"""
def __init__(self, max_episodes: int, verbose: int = 0):
- super(StopTrainingOnMaxEpisodes, self).__init__(verbose=verbose)
+ super().__init__(verbose=verbose)
self.max_episodes = max_episodes
self._total_max_episodes = max_episodes
self.n_episodes = 0
@@ -544,7 +600,7 @@ def _on_step(self) -> bool:
continue_training = self.n_episodes < self._total_max_episodes
- if self.verbose > 0 and not continue_training:
+ if self.verbose >= 1 and not continue_training:
mean_episodes_per_env = self.n_episodes / self.training_env.num_envs
mean_ep_str = (
f"with an average of {mean_episodes_per_env:.2f} episodes per env" if self.training_env.num_envs > 1 else ""
@@ -569,11 +625,11 @@ class StopTrainingOnNoModelImprovement(BaseCallback):
:param max_no_improvement_evals: Maximum number of consecutive evaluations without a new best model.
:param min_evals: Number of evaluations before start to count evaluations without improvements.
- :param verbose: Verbosity of the output (set to 1 for info messages)
+ :param verbose: Verbosity level: 0 for no output, 1 for indicating when training ended because no new best model
"""
def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0):
- super(StopTrainingOnNoModelImprovement, self).__init__(verbose=verbose)
+ super().__init__(verbose=verbose)
self.max_no_improvement_evals = max_no_improvement_evals
self.min_evals = min_evals
self.last_best_mean_reward = -np.inf
@@ -594,9 +650,40 @@ def _on_step(self) -> bool:
self.last_best_mean_reward = self.parent.best_mean_reward
- if self.verbose > 0 and not continue_training:
+ if self.verbose >= 1 and not continue_training:
print(
f"Stopping training because there was no new best model in the last {self.no_improvement_evals:d} evaluations"
)
return continue_training
+
+
+class ProgressBarCallback(BaseCallback):
+ """
+ Display a progress bar when training SB3 agent
+ using tqdm and rich packages.
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ if tqdm is None:
+ raise ImportError(
+ "You must install tqdm and rich in order to use the progress bar callback. "
+ "It is included if you install stable-baselines with the extra packages: "
+ "`pip install stable-baselines3[extra]`"
+ )
+ self.pbar = None
+
+ def _on_training_start(self) -> None:
+ # Initialize progress bar
+ # Remove timesteps that were done in previous training sessions
+ self.pbar = tqdm(total=self.locals["total_timesteps"] - self.model.num_timesteps)
+
+ def _on_step(self) -> bool:
+ # Update progress bar, we do num_envs steps per call to `env.step()`
+ self.pbar.update(self.training_env.num_envs)
+ return True
+
+ def _on_training_end(self) -> None:
+ # Close progress bar
+ self.pbar.close()
diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py
index 1c0e54a88..19d132ffb 100644
--- a/stable_baselines3/common/distributions.py
+++ b/stable_baselines3/common/distributions.py
@@ -3,9 +3,10 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
-import gym
+import gymnasium as gym
+import numpy as np
import torch as th
-from gym import spaces
+from gymnasium import spaces
from torch import nn
from torch.distributions import Bernoulli, Categorical, Normal
@@ -16,7 +17,7 @@ class Distribution(ABC):
"""Abstract base class for distributions."""
def __init__(self):
- super(Distribution, self).__init__()
+ super().__init__()
self.distribution = None
@abstractmethod
@@ -120,7 +121,7 @@ class DiagGaussianDistribution(Distribution):
"""
def __init__(self, action_dim: int):
- super(DiagGaussianDistribution, self).__init__()
+ super().__init__()
self.action_dim = action_dim
self.mean_actions = None
self.log_std = None
@@ -201,13 +202,13 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
"""
def __init__(self, action_dim: int, epsilon: float = 1e-6):
- super(SquashedDiagGaussianDistribution, self).__init__(action_dim)
+ super().__init__(action_dim)
# Avoid NaN (prevents division by zero or log of zero)
self.epsilon = epsilon
self.gaussian_actions = None
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "SquashedDiagGaussianDistribution":
- super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std)
+ super().proba_distribution(mean_actions, log_std)
return self
def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
@@ -219,7 +220,7 @@ def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = N
gaussian_actions = TanhBijector.inverse(actions)
# Log likelihood for a Gaussian distribution
- log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_actions)
+ log_prob = super().log_prob(gaussian_actions)
# Squash correction (from original SAC implementation)
# this comes from the fact that tanh is bijective and differentiable
log_prob -= th.sum(th.log(1 - actions**2 + self.epsilon), dim=1)
@@ -254,7 +255,7 @@ class CategoricalDistribution(Distribution):
"""
def __init__(self, action_dim: int):
- super(CategoricalDistribution, self).__init__()
+ super().__init__()
self.action_dim = action_dim
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
@@ -305,7 +306,7 @@ class MultiCategoricalDistribution(Distribution):
"""
def __init__(self, action_dims: List[int]):
- super(MultiCategoricalDistribution, self).__init__()
+ super().__init__()
self.action_dims = action_dims
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
@@ -360,7 +361,7 @@ class BernoulliDistribution(Distribution):
"""
def __init__(self, action_dims: int):
- super(BernoulliDistribution, self).__init__()
+ super().__init__()
self.action_dims = action_dims
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
@@ -433,7 +434,7 @@ def __init__(
learn_features: bool = False,
epsilon: float = 1e-6,
):
- super(StateDependentNoiseDistribution, self).__init__()
+ super().__init__()
self.action_dim = action_dim
self.latent_sde_dim = None
self.mean_actions = None
@@ -577,10 +578,10 @@ def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
return th.mm(latent_sde, self.exploration_mat)
# Use batch matrix multiplication for efficient computation
# (batch_size, n_features) -> (batch_size, 1, n_features)
- latent_sde = latent_sde.unsqueeze(1)
+ latent_sde = latent_sde.unsqueeze(dim=1)
# (batch_size, 1, n_actions)
noise = th.bmm(latent_sde, self.exploration_matrices)
- return noise.squeeze(1)
+ return noise.squeeze(dim=1)
def actions_from_params(
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, deterministic: bool = False
@@ -597,7 +598,7 @@ def log_prob_from_params(
return actions, log_prob
-class TanhBijector(object):
+class TanhBijector:
"""
Bijective transformation of a probability distribution
using a squashing function (tanh)
@@ -607,7 +608,7 @@ class TanhBijector(object):
"""
def __init__(self, epsilon: float = 1e-6):
- super(TanhBijector, self).__init__()
+ super().__init__()
self.epsilon = epsilon
@staticmethod
@@ -657,7 +658,6 @@ def make_proba_distribution(
dist_kwargs = {}
if isinstance(action_space, spaces.Box):
- assert len(action_space.shape) == 1, "Error: the action space must be a vector"
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
return cls(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
@@ -688,7 +688,7 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor
# MultiCategoricalDistribution is not a PyTorch Distribution subclass
# so we need to implement it ourselves!
if isinstance(dist_pred, MultiCategoricalDistribution):
- assert dist_pred.action_dims == dist_true.action_dims, "Error: distributions must have the same input space"
+ assert np.allclose(dist_pred.action_dims, dist_true.action_dims), "Error: distributions must have the same input space"
return th.stack(
[th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)],
dim=1,
diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py
index c4e566991..71f89270e 100644
--- a/stable_baselines3/common/env_checker.py
+++ b/stable_baselines3/common/env_checker.py
@@ -1,9 +1,9 @@
import warnings
-from typing import Union
+from typing import Any, Dict, Union
-import gym
+import gymnasium as gym
import numpy as np
-from gym import spaces
+from gymnasium import spaces
from stable_baselines3.common.preprocessing import is_image_space_channels_first
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
@@ -93,7 +93,65 @@ def _check_nan(env: gym.Env) -> None:
_, _, _, _ = vec_env.step(action)
-def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spaces.Space, method_name: str) -> None:
+def _is_goal_env(env: gym.Env) -> bool:
+ """
+ Check if the env uses the convention for goal-conditioned envs (previously, the gym.GoalEnv interface)
+ """
+ return hasattr(env, "compute_reward")
+
+
+def _check_goal_env_obs(obs: dict, observation_space: spaces.Space, method_name: str) -> None:
+ """
+ Check that an environment implementing the `compute_rewards()` method
+ (previously known as GoalEnv in gym) contains three elements,
+ namely `observation`, `desired_goal`, and `achieved_goal`.
+ """
+ assert len(observation_space.spaces) == 3, (
+ "A goal conditioned env must contain 3 observation keys: `observation`, `desired_goal`, and `achieved_goal`."
+ f"The current observation contains {len(observation_space.spaces)} keys: {list(observation_space.spaces.keys())}"
+ )
+
+ for key in ["observation", "achieved_goal", "desired_goal"]:
+ if key not in observation_space.spaces:
+ raise AssertionError(
+ f"The observation returned by the `{method_name}()` method of a goal-conditioned env requires the '{key}' "
+ "key to be part of the observation dictionary. "
+ f"Current keys are {list(observation_space.spaces.keys())}"
+ )
+
+
+def _check_goal_env_compute_reward(
+ obs: Dict[str, Union[np.ndarray, int]],
+ env: gym.Env,
+ reward: float,
+ info: Dict[str, Any],
+):
+ """
+ Check that reward is computed with `compute_reward`
+ and that the implementation is vectorized.
+ """
+ achieved_goal, desired_goal = obs["achieved_goal"], obs["desired_goal"]
+ assert reward == env.compute_reward(
+ achieved_goal, desired_goal, info
+ ), "The reward was not computed with `compute_reward()`"
+
+ achieved_goal, desired_goal = np.array(achieved_goal), np.array(desired_goal)
+ batch_achieved_goals = np.array([achieved_goal, achieved_goal])
+ batch_desired_goals = np.array([desired_goal, desired_goal])
+ if isinstance(achieved_goal, int) or len(achieved_goal.shape) == 0:
+ batch_achieved_goals = batch_achieved_goals.reshape(2, 1)
+ batch_desired_goals = batch_desired_goals.reshape(2, 1)
+ batch_infos = np.array([info, info])
+ rewards = env.compute_reward(batch_achieved_goals, batch_desired_goals, batch_infos)
+ assert rewards.shape == (2,), f"Unexpected shape for vectorized computation of reward: {rewards.shape} != (2,)"
+ assert rewards[0] == reward, f"Vectorized computation of reward differs from single computation: {rewards[0]} != {reward}"
+
+
+def _check_obs(
+ obs: Union[tuple, dict, np.ndarray, int],
+ observation_space: spaces.Space,
+ method_name: str,
+) -> None:
"""
Check that the observation returned by the environment
correspond to the declared one.
@@ -139,15 +197,28 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
Check the returned values by the env when calling `.reset()` or `.step()` methods.
"""
# because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
- obs = env.reset()
-
- if isinstance(observation_space, spaces.Dict):
+ reset_returns = env.reset()
+ assert isinstance(reset_returns, tuple), "`reset()` must return a tuple (obs, info)"
+ assert len(reset_returns) == 2, f"`reset()` must return a tuple of size 2 (obs, info), not {len(reset_returns)}"
+ obs, info = reset_returns
+ assert isinstance(info, dict), "The second element of the tuple return by `reset()` must be a dictionary"
+
+ if _is_goal_env(env):
+ _check_goal_env_obs(obs, observation_space, "reset")
+ elif isinstance(observation_space, spaces.Dict):
assert isinstance(obs, dict), "The observation returned by `reset()` must be a dictionary"
+
+ if not obs.keys() == observation_space.spaces.keys():
+ raise AssertionError(
+ "The observation keys returned by `reset()` must match the observation "
+ f"space keys: {obs.keys()} != {observation_space.spaces.keys()}"
+ )
+
for key in observation_space.spaces.keys():
try:
_check_obs(obs[key], observation_space.spaces[key], "reset")
except AssertionError as e:
- raise AssertionError(f"Error while checking key={key}: " + str(e))
+ raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
else:
_check_obs(obs, observation_space, "reset")
@@ -155,36 +226,48 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
action = action_space.sample()
data = env.step(action)
- assert len(data) == 4, "The `step()` method must return four values: obs, reward, done, info"
+ assert len(data) == 5, "The `step()` method must return four values: obs, reward, terminated, truncated, info"
# Unpack
- obs, reward, done, info = data
+ obs, reward, terminated, truncated, info = data
- if isinstance(observation_space, spaces.Dict):
+ if _is_goal_env(env):
+ _check_goal_env_obs(obs, observation_space, "step")
+ _check_goal_env_compute_reward(obs, env, reward, info)
+ elif isinstance(observation_space, spaces.Dict):
assert isinstance(obs, dict), "The observation returned by `step()` must be a dictionary"
+
+ if not obs.keys() == observation_space.spaces.keys():
+ raise AssertionError(
+ "The observation keys returned by `step()` must match the observation "
+ f"space keys: {obs.keys()} != {observation_space.spaces.keys()}"
+ )
+
for key in observation_space.spaces.keys():
try:
_check_obs(obs[key], observation_space.spaces[key], "step")
except AssertionError as e:
- raise AssertionError(f"Error while checking key={key}: " + str(e))
+ raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
else:
_check_obs(obs, observation_space, "step")
# We also allow int because the reward will be cast to float
assert isinstance(reward, (float, int)), "The reward returned by `step()` must be a float"
- assert isinstance(done, bool), "The `done` signal must be a boolean"
+ assert isinstance(terminated, bool), "The `terminated` signal must be a boolean"
+ assert isinstance(truncated, bool), "The `truncated` signal must be a boolean"
assert isinstance(info, dict), "The `info` returned by `step()` must be a python dictionary"
- if isinstance(env, gym.GoalEnv):
- # For a GoalEnv, the keys are checked at reset
+ # Goal conditioned env
+ if hasattr(env, "compute_reward"):
assert reward == env.compute_reward(obs["achieved_goal"], obs["desired_goal"], info)
def _check_spaces(env: gym.Env) -> None:
"""
- Check that the observation and action spaces are defined
- and inherit from gym.spaces.Space.
+ Check that the observation and action spaces are defined and inherit from gym.spaces.Space. For
+ envs that follow the goal-conditioned standard (previously, the gym.GoalEnv interface) we check
+ the observation space is gym.spaces.Dict
"""
# Helper to link to the code, because gym has no proper documentation
gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/"
@@ -195,6 +278,11 @@ def _check_spaces(env: gym.Env) -> None:
assert isinstance(env.observation_space, spaces.Space), "The observation space must inherit from gym.spaces" + gym_spaces
assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gym.spaces" + gym_spaces
+ if _is_goal_env(env):
+ assert isinstance(
+ env.observation_space, spaces.Dict
+ ), "Goal conditioned envs (previously gym.GoalEnv) require the observation space to be gym.spaces.Dict"
+
# Check render cannot be covered by CI
def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> None: # pragma: no cover
@@ -274,6 +362,11 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
"cf https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html"
)
+ if isinstance(action_space, spaces.Box):
+ assert np.all(
+ np.isfinite(np.array([action_space.low, action_space.high]))
+ ), "Continuous action space must have a finite lower and upper bound"
+
if isinstance(action_space, spaces.Box) and action_space.dtype != np.dtype(np.float32):
warnings.warn(
f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors."
diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py
index 520c50a5f..a981d5fe4 100644
--- a/stable_baselines3/common/env_util.py
+++ b/stable_baselines3/common/env_util.py
@@ -1,10 +1,11 @@
import os
from typing import Any, Callable, Dict, Optional, Type, Union
-import gym
+import gymnasium as gym
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.monitor import Monitor
+from stable_baselines3.common.utils import compat_gym_seed
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
@@ -36,7 +37,7 @@ def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool:
def make_vec_env(
- env_id: Union[str, Type[gym.Env]],
+ env_id: Union[str, Callable[..., gym.Env]],
n_envs: int = 1,
seed: Optional[int] = None,
start_index: int = 0,
@@ -53,7 +54,7 @@ def make_vec_env(
By default it uses a ``DummyVecEnv`` which is usually faster
than a ``SubprocVecEnv``.
- :param env_id: the environment ID or the environment class
+ :param env_id: either the env ID, the env class or a callable returning an env
:param n_envs: the number of environments you wish to have in parallel
:param seed: the initial seed for the random number generator
:param start_index: start rank index
@@ -62,6 +63,9 @@ def make_vec_env(
in a Monitor wrapper to provide additional information about training.
:param wrapper_class: Additional wrapper to use on the environment.
This can also be a function with single argument that wraps the environment in many things.
+ Note: the wrapper specified by this parameter will be applied after the ``Monitor`` wrapper.
+ if some cases (e.g. with TimeLimit wrapper) this can lead to undesired behavior.
+ See here for more details: https://github.com/DLR-RM/stable-baselines3/issues/894
:param env_kwargs: Optional keyword argument to pass to the env constructor
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
@@ -81,7 +85,7 @@ def _init():
else:
env = env_id(**env_kwargs)
if seed is not None:
- env.seed(seed + rank)
+ compat_gym_seed(env, seed=seed + rank)
env.action_space.seed(seed + rank)
# Wrap the env in a Monitor wrapper
# to have additional training information
@@ -106,7 +110,7 @@ def _init():
def make_atari_env(
- env_id: Union[str, Type[gym.Env]],
+ env_id: Union[str, Callable[..., gym.Env]],
n_envs: int = 1,
seed: Optional[int] = None,
start_index: int = 0,
@@ -121,7 +125,7 @@ def make_atari_env(
Create a wrapped, monitored VecEnv for Atari.
It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games.
- :param env_id: the environment ID or the environment class
+ :param env_id: either the env ID, the env class or a callable returning an env
:param n_envs: the number of environments you wish to have in parallel
:param seed: the initial seed for the random number generator
:param start_index: start rank index
diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py
index c5d713aa2..ea9d723da 100644
--- a/stable_baselines3/common/envs/bit_flipping_env.py
+++ b/stable_baselines3/common/envs/bit_flipping_env.py
@@ -1,14 +1,14 @@
from collections import OrderedDict
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
-from gym import GoalEnv, spaces
+from gymnasium import Env, spaces
from gym.envs.registration import EnvSpec
-from stable_baselines3.common.type_aliases import GymStepReturn
+from stable_baselines3.common.type_aliases import Gym26StepReturn
-class BitFlippingEnv(GoalEnv):
+class BitFlippingEnv(Env):
"""
Simple bit flipping env, useful to test HER.
The goal is to flip all the bits to get a vector of ones.
@@ -25,7 +25,7 @@ class BitFlippingEnv(GoalEnv):
:param channel_first: Whether to use channel-first or last image.
"""
- spec = EnvSpec("BitFlippingEnv-v0")
+ spec = EnvSpec("BitFlippingEnv-v0", "no-entry-point")
def __init__(
self,
@@ -36,7 +36,7 @@ def __init__(
image_obs_space: bool = False,
channel_first: bool = True,
):
- super(BitFlippingEnv, self).__init__()
+ super().__init__()
# Shape of the observation when using image space
self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1)
# The achieved goal is determined by the current state
@@ -115,7 +115,7 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]:
if self.discrete_obs_space:
# The internal state is the binary representation of the
# observed one
- return int(sum([state[i] * 2**i for i in range(len(state))]))
+ return int(sum(state[i] * 2**i for i in range(len(state))))
if self.image_obs_space:
size = np.prod(self.image_shape)
@@ -135,7 +135,7 @@ def convert_to_bit_vector(self, state: Union[int, np.ndarray], batch_size: int)
if isinstance(state, int):
state = np.array(state).reshape(batch_size, -1)
# Convert to binary representation
- state = (((state[:, :] & (1 << np.arange(len(self.state))))) > 0).astype(int)
+ state = ((state[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int)
elif self.image_obs_space:
state = state.reshape(batch_size, -1)[:, : len(self.state)] / 255
else:
@@ -157,12 +157,20 @@ def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]:
]
)
- def reset(self) -> Dict[str, Union[int, np.ndarray]]:
+ def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Union[int, np.ndarray]], Dict]:
+ if seed is not None:
+ self.obs_space.seed(seed)
self.current_step = 0
self.state = self.obs_space.sample()
- return self._get_obs()
+ return self._get_obs(), {}
- def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
+ def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn:
+ """
+ Step into the env.
+
+ :param action:
+ :return:
+ """
if self.continuous:
self.state[action > 0] = 1 - self.state[action > 0]
else:
@@ -173,8 +181,9 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
self.current_step += 1
# Episode terminate when we reached the goal or the max number of steps
info = {"is_success": done}
- done = done or self.current_step >= self.max_steps
- return obs, reward, done, info
+ truncated = self.current_step >= self.max_steps
+ done = done or truncated
+ return obs, reward, done, truncated, info
def compute_reward(
self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[Dict[str, Any]]
diff --git a/stable_baselines3/common/envs/identity_env.py b/stable_baselines3/common/envs/identity_env.py
index 8f6ccd2dc..03fc1dbd2 100644
--- a/stable_baselines3/common/envs/identity_env.py
+++ b/stable_baselines3/common/envs/identity_env.py
@@ -1,10 +1,10 @@
-from typing import Optional, Union
+from typing import Dict, Optional, Tuple, Union
import numpy as np
-from gym import Env, Space
+from gymnasium import Env, Space
from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete
-from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
+from stable_baselines3.common.type_aliases import Gym26ResetReturn, Gym26StepReturn
class IdentityEnv(Env):
@@ -32,18 +32,20 @@ def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_
self.num_resets = -1 # Becomes 0 after __init__ exits.
self.reset()
- def reset(self) -> GymObs:
+ def reset(self, seed: Optional[int] = None) -> Gym26ResetReturn:
+ if seed is not None:
+ super().reset(seed=seed)
self.current_step = 0
self.num_resets += 1
self._choose_next_state()
- return self.state
+ return self.state, {}
- def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
+ def step(self, action: Union[int, np.ndarray]) -> Gym26StepReturn:
reward = self._get_reward(action)
self._choose_next_state()
self.current_step += 1
- done = self.current_step >= self.ep_length
- return self.state, reward, done, {}
+ done = truncated = self.current_step >= self.ep_length
+ return self.state, reward, done, truncated, {}
def _choose_next_state(self) -> None:
self.state = self.action_space.sample()
@@ -69,12 +71,12 @@ def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_l
super().__init__(ep_length=ep_length, space=space)
self.eps = eps
- def step(self, action: np.ndarray) -> GymStepReturn:
+ def step(self, action: np.ndarray) -> Gym26StepReturn:
reward = self._get_reward(action)
self._choose_next_state()
self.current_step += 1
- done = self.current_step >= self.ep_length
- return self.state, reward, done, {}
+ done = truncated = self.current_step >= self.ep_length
+ return self.state, reward, done, truncated, {}
def _get_reward(self, action: np.ndarray) -> float:
return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0
@@ -136,15 +138,17 @@ def __init__(
self.ep_length = 10
self.current_step = 0
- def reset(self) -> np.ndarray:
+ def reset(self, seed: Optional[int] = None) -> Tuple[np.ndarray, Dict]:
+ if seed is not None:
+ super().reset(seed=seed)
self.current_step = 0
- return self.observation_space.sample()
+ return self.observation_space.sample(), {}
- def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
+ def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn:
reward = 0.0
self.current_step += 1
- done = self.current_step >= self.ep_length
- return self.observation_space.sample(), reward, done, {}
+ done = truncated = self.current_step >= self.ep_length
+ return self.observation_space.sample(), reward, done, truncated, {}
def render(self, mode: str = "human") -> None:
pass
diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py
index 177a64166..f4a275f06 100644
--- a/stable_baselines3/common/envs/multi_input_envs.py
+++ b/stable_baselines3/common/envs/multi_input_envs.py
@@ -1,9 +1,9 @@
-from typing import Dict, Union
+from typing import Dict, Optional, Tuple, Union
-import gym
+import gymnasium as gym
import numpy as np
-from stable_baselines3.common.type_aliases import GymStepReturn
+from stable_baselines3.common.type_aliases import Gym26StepReturn
class SimpleMultiObsEnv(gym.Env):
@@ -42,7 +42,7 @@ def __init__(
discrete_actions: bool = True,
channel_last: bool = True,
):
- super(SimpleMultiObsEnv, self).__init__()
+ super().__init__()
self.vector_size = 5
if channel_last:
@@ -120,7 +120,7 @@ def init_possible_transitions(self) -> None:
self.right_possible = [0, 1, 2, 12, 13, 14]
self.up_possible = [4, 8, 12, 7, 11, 15]
- def step(self, action: Union[int, float, np.ndarray]) -> GymStepReturn:
+ def step(self, action: Union[int, float, np.ndarray]) -> Gym26StepReturn:
"""
Run one timestep of the environment's dynamics. When end of
episode is reached, you are responsible for calling `reset()`
@@ -152,11 +152,12 @@ def step(self, action: Union[int, float, np.ndarray]) -> GymStepReturn:
got_to_end = self.state == self.max_state
reward = 1 if got_to_end else reward
- done = self.count > self.max_count or got_to_end
+ truncated = self.count > self.max_count
+ done = got_to_end or truncated
self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}"
- return self.get_state_mapping(), reward, done, {"got_to_end": got_to_end}
+ return self.get_state_mapping(), reward, done, truncated, {"got_to_end": got_to_end}
def render(self, mode: str = "human") -> None:
"""
@@ -166,15 +167,18 @@ def render(self, mode: str = "human") -> None:
"""
print(self.log)
- def reset(self) -> Dict[str, np.ndarray]:
+ def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, np.ndarray], Dict]:
"""
Resets the environment state and step count and returns reset observation.
+ :param seed:
:return: observation dict {'vec': ..., 'img': ...}
"""
+ if seed is not None:
+ super().reset(seed=seed)
self.count = 0
if not self.random_start:
self.state = 0
else:
self.state = np.random.randint(0, self.max_state)
- return self.state_mapping[self.state]
+ return self.state_mapping[self.state], {}
diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py
index e3f14d3f8..33592b44a 100644
--- a/stable_baselines3/common/evaluation.py
+++ b/stable_baselines3/common/evaluation.py
@@ -1,7 +1,7 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
-import gym
+import gymnasium as gym
import numpy as np
from stable_baselines3.common import base_class
diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py
index 6493a3e0d..51b6f6b55 100644
--- a/stable_baselines3/common/logger.py
+++ b/stable_baselines3/common/logger.py
@@ -14,9 +14,15 @@
try:
from torch.utils.tensorboard import SummaryWriter
+ from torch.utils.tensorboard.summary import hparams
except ImportError:
SummaryWriter = None
+try:
+ from tqdm import tqdm
+except ImportError:
+ tqdm = None
+
DEBUG = 10
INFO = 20
WARN = 30
@@ -24,7 +30,7 @@
DISABLED = 50
-class Video(object):
+class Video:
"""
Video data class storing the video frames and the frame per seconds
@@ -37,7 +43,7 @@ def __init__(self, frames: th.Tensor, fps: Union[float, int]):
self.fps = fps
-class Figure(object):
+class Figure:
"""
Figure data class storing a matplotlib figure and whether to close the figure after logging it
@@ -50,7 +56,7 @@ def __init__(self, figure: plt.figure, close: bool):
self.close = close
-class Image(object):
+class Image:
"""
Image data class storing an image and data format
@@ -65,6 +71,22 @@ def __init__(self, image: Union[th.Tensor, np.ndarray, str], dataformats: str):
self.dataformats = dataformats
+class HParam:
+ """
+ Hyperparameter data class storing hyperparameters and metrics in dictionnaries
+
+ :param hparam_dict: key-value pairs of hyperparameters to log
+ :param metric_dict: key-value pairs of metrics to log
+ A non-empty metrics dict is required to display hyperparameters in the corresponding Tensorboard section.
+ """
+
+ def __init__(self, hparam_dict: Dict[str, Union[bool, str, float, int, None]], metric_dict: Dict[str, Union[float, int]]):
+ self.hparam_dict = hparam_dict
+ if not metric_dict:
+ raise Exception("`metric_dict` must not be empty to display hyperparameters to the HPARAMS tensorboard tab.")
+ self.metric_dict = metric_dict
+
+
class FormatUnsupportedError(NotImplementedError):
"""
Custom error to display informative message when
@@ -80,13 +102,13 @@ def __init__(self, unsupported_formats: Sequence[str], value_description: str):
format_str = f"formats {', '.join(unsupported_formats)} are"
else:
format_str = f"format {unsupported_formats[0]} is"
- super(FormatUnsupportedError, self).__init__(
+ super().__init__(
f"The {format_str} not supported for the {value_description} value logged.\n"
f"You can exclude formats via the `exclude` parameter of the logger's `record` function."
)
-class KVWriter(object):
+class KVWriter:
"""
Key Value writer
"""
@@ -108,7 +130,7 @@ def close(self) -> None:
raise NotImplementedError
-class SeqWriter(object):
+class SeqWriter:
"""
sequence writer
"""
@@ -164,6 +186,9 @@ def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None:
elif isinstance(value, Image):
raise FormatUnsupportedError(["stdout", "log"], "image")
+ elif isinstance(value, HParam):
+ raise FormatUnsupportedError(["stdout", "log"], "hparam")
+
elif isinstance(value, float):
# Align left
value_str = f"{value:<8.3g}"
@@ -172,35 +197,41 @@ def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None:
if key.find("/") > 0: # Find tag and add it to the dict
tag = key[: key.find("/") + 1]
- key2str[self._truncate(tag)] = ""
+ key2str[(tag, self._truncate(tag))] = ""
# Remove tag from key
if tag is not None and tag in key:
key = str(" " + key[len(tag) :])
truncated_key = self._truncate(key)
- if truncated_key in key2str:
+ if (tag, truncated_key) in key2str:
raise ValueError(
f"Key '{key}' truncated to '{truncated_key}' that already exists. Consider increasing `max_length`."
)
- key2str[truncated_key] = self._truncate(value_str)
+ key2str[(tag, truncated_key)] = self._truncate(value_str)
# Find max widths
if len(key2str) == 0:
warnings.warn("Tried to write empty key-value dict")
return
else:
- key_width = max(map(len, key2str.keys()))
+ tagless_keys = map(lambda x: x[1], key2str.keys())
+ key_width = max(map(len, tagless_keys))
val_width = max(map(len, key2str.values()))
# Write out the data
dashes = "-" * (key_width + val_width + 7)
lines = [dashes]
- for key, value in key2str.items():
+ for (_, key), value in key2str.items():
key_space = " " * (key_width - len(key))
val_space = " " * (val_width - len(value))
lines.append(f"| {key}{key_space} | {value}{val_space} |")
lines.append(dashes)
- self.file.write("\n".join(lines) + "\n")
+
+ if tqdm is not None and hasattr(self.file, "name") and self.file.name == "":
+ # Do not mess up with progress bar
+ tqdm.write("\n".join(lines) + "\n", file=sys.stdout, end="")
+ else:
+ self.file.write("\n".join(lines) + "\n")
# Flush the output to the file
self.file.flush()
@@ -246,12 +277,13 @@ def is_excluded(key: str) -> bool:
class JSONOutputFormat(KVWriter):
- def __init__(self, filename: str):
- """
- log to a file, in the JSON format
+ """
+ Log to a file, in the JSON format
- :param filename: the file to write the log to
- """
+ :param filename: the file to write the log to
+ """
+
+ def __init__(self, filename: str):
self.file = open(filename, "wt")
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
@@ -262,6 +294,8 @@ def cast_to_json_serializable(value: Any):
raise FormatUnsupportedError(["json"], "figure")
if isinstance(value, Image):
raise FormatUnsupportedError(["json"], "image")
+ if isinstance(value, HParam):
+ raise FormatUnsupportedError(["json"], "hparam")
if hasattr(value, "dtype"):
if value.shape == () or len(value) == 1:
# if value is a dimensionless numpy array or of length 1, serialize as a float
@@ -287,13 +321,13 @@ def close(self) -> None:
class CSVOutputFormat(KVWriter):
- def __init__(self, filename: str):
- """
- log to a file, in a CSV format
+ """
+ Log to a file, in a CSV format
- :param filename: the file to write the log to
- """
+ :param filename: the file to write the log to
+ """
+ def __init__(self, filename: str):
self.file = open(filename, "w+t")
self.keys = []
self.separator = ","
@@ -331,6 +365,9 @@ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, T
elif isinstance(value, Image):
raise FormatUnsupportedError(["csv"], "image")
+ elif isinstance(value, HParam):
+ raise FormatUnsupportedError(["csv"], "hparam")
+
elif isinstance(value, str):
# escape quotechars by prepending them with another quotechar
value = value.replace(self.quotechar, self.quotechar + self.quotechar)
@@ -351,12 +388,13 @@ def close(self) -> None:
class TensorBoardOutputFormat(KVWriter):
- def __init__(self, folder: str):
- """
- Dumps key/value pairs into TensorBoard's numeric format.
+ """
+ Dumps key/value pairs into TensorBoard's numeric format.
- :param folder: the folder to write the log to
- """
+ :param folder: the folder to write the log to
+ """
+
+ def __init__(self, folder: str):
assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so"
self.writer = SummaryWriter(log_dir=folder)
@@ -386,6 +424,13 @@ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, T
if isinstance(value, Image):
self.writer.add_image(key, value.image, step, dataformats=value.dataformats)
+ if isinstance(value, HParam):
+ # we don't use `self.writer.add_hparams` to have control over the log_dir
+ experiment, session_start_info, session_end_info = hparams(value.hparam_dict, metric_dict=value.metric_dict)
+ self.writer.file_writer.add_summary(experiment)
+ self.writer.file_writer.add_summary(session_start_info)
+ self.writer.file_writer.add_summary(session_end_info)
+
# Flush the output to the file
self.writer.flush()
@@ -427,7 +472,7 @@ def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWr
# ================================================================
-class Logger(object):
+class Logger:
"""
The logger class.
@@ -623,7 +668,7 @@ def read_json(filename: str) -> pandas.DataFrame:
:return: the data in the json
"""
data = []
- with open(filename, "rt") as file_handler:
+ with open(filename) as file_handler:
for line in file_handler:
data.append(json.loads(line))
return pandas.DataFrame(data)
diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py
index 04cda2242..83204f68d 100644
--- a/stable_baselines3/common/monitor.py
+++ b/stable_baselines3/common/monitor.py
@@ -7,11 +7,11 @@
from glob import glob
from typing import Dict, List, Optional, Tuple, Union
-import gym
+import gymnasium as gym
import numpy as np
import pandas
-from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
+from stable_baselines3.common.type_aliases import Gym26ResetReturn, Gym26StepReturn
class Monitor(gym.Wrapper):
@@ -24,6 +24,8 @@ class Monitor(gym.Wrapper):
:param reset_keywords: extra keywords for the reset call,
if extra parameters are needed at reset
:param info_keywords: extra information to log, from the information return of env.step()
+ :param override_existing: appends to file if ``filename`` exists, otherwise
+ override existing files (default)
"""
EXT = "monitor.csv"
@@ -35,14 +37,16 @@ def __init__(
allow_early_resets: bool = True,
reset_keywords: Tuple[str, ...] = (),
info_keywords: Tuple[str, ...] = (),
+ override_existing: bool = True,
):
- super(Monitor, self).__init__(env=env)
+ super().__init__(env=env)
self.t_start = time.time()
if filename is not None:
self.results_writer = ResultsWriter(
filename,
header={"t_start": self.t_start, "env_id": env.spec and env.spec.id},
extra_keys=reset_keywords + info_keywords,
+ override_existing=override_existing,
)
else:
self.results_writer = None
@@ -57,7 +61,7 @@ def __init__(
self.total_steps = 0
self.current_reset_info = {} # extra info about the current episode, that was passed in during reset()
- def reset(self, **kwargs) -> GymObs:
+ def reset(self, **kwargs) -> Gym26ResetReturn:
"""
Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True
@@ -78,7 +82,7 @@ def reset(self, **kwargs) -> GymObs:
self.current_reset_info[key] = value
return self.env.reset(**kwargs)
- def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
+ def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn:
"""
Step the environment with the given action
@@ -87,9 +91,9 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
"""
if self.needs_reset:
raise RuntimeError("Tried to step environment that needs reset")
- observation, reward, done, info = self.env.step(action)
+ observation, reward, done, truncated, info = self.env.step(action)
self.rewards.append(reward)
- if done:
+ if done or truncated:
self.needs_reset = True
ep_rew = sum(self.rewards)
ep_len = len(self.rewards)
@@ -104,13 +108,13 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
self.results_writer.write_row(ep_info)
info["episode"] = ep_info
self.total_steps += 1
- return observation, reward, done, info
+ return observation, reward, done, truncated, info
def close(self) -> None:
"""
Closes the environment
"""
- super(Monitor, self).close()
+ super().close()
if self.results_writer is not None:
self.results_writer.close()
@@ -159,10 +163,13 @@ class ResultsWriter:
"""
A result writer that saves the data from the `Monitor` class
- :param filename: the location to save a log file, can be None for no log
+ :param filename: the location to save a log file. When it does not end in
+ the string ``"monitor.csv"``, this suffix will be appended to it
:param header: the header dictionary object of the saved csv
- :param reset_keywords: the extra information to log, typically is composed of
+ :param extra_keys: the extra information to log, typically is composed of
``reset_keywords`` and ``info_keywords``
+ :param override_existing: appends to file if ``filename`` exists, otherwise
+ override existing files (default)
"""
def __init__(
@@ -170,6 +177,7 @@ def __init__(
filename: str = "",
header: Optional[Dict[str, Union[float, str]]] = None,
extra_keys: Tuple[str, ...] = (),
+ override_existing: bool = True,
):
if header is None:
header = {}
@@ -178,11 +186,18 @@ def __init__(
filename = os.path.join(filename, Monitor.EXT)
else:
filename = filename + "." + Monitor.EXT
+ filename = os.path.realpath(filename)
+ # Create (if any) missing filename directories
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ # Append mode when not overridding existing file
+ mode = "w" if override_existing else "a"
# Prevent newline issue on Windows, see GH issue #692
- self.file_handler = open(filename, "wt", newline="\n")
- self.file_handler.write("#%s\n" % json.dumps(header))
+ self.file_handler = open(filename, f"{mode}t", newline="\n")
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + extra_keys)
- self.logger.writeheader()
+ if override_existing:
+ self.file_handler.write(f"#{json.dumps(header)}\n")
+ self.logger.writeheader()
+
self.file_handler.flush()
def write_row(self, epinfo: Dict[str, Union[float, int]]) -> None:
@@ -224,7 +239,7 @@ def load_results(path: str) -> pandas.DataFrame:
raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}")
data_frames, headers = [], []
for file_name in monitor_files:
- with open(file_name, "rt") as file_handler:
+ with open(file_name) as file_handler:
first_line = file_handler.readline()
assert first_line[0] == "#"
header = json.loads(first_line[1:])
diff --git a/stable_baselines3/common/noise.py b/stable_baselines3/common/noise.py
index b1db6f4f2..baa72e9a7 100644
--- a/stable_baselines3/common/noise.py
+++ b/stable_baselines3/common/noise.py
@@ -11,7 +11,7 @@ class ActionNoise(ABC):
"""
def __init__(self):
- super(ActionNoise, self).__init__()
+ super().__init__()
def reset(self) -> None:
"""
@@ -35,7 +35,7 @@ class NormalActionNoise(ActionNoise):
def __init__(self, mean: np.ndarray, sigma: np.ndarray):
self._mu = mean
self._sigma = sigma
- super(NormalActionNoise, self).__init__()
+ super().__init__()
def __call__(self) -> np.ndarray:
return np.random.normal(self._mu, self._sigma)
@@ -72,7 +72,7 @@ def __init__(
self.initial_noise = initial_noise
self.noise_prev = np.zeros_like(self._mu)
self.reset()
- super(OrnsteinUhlenbeckActionNoise, self).__init__()
+ super().__init__()
def __call__(self) -> np.ndarray:
noise = (
@@ -105,8 +105,8 @@ def __init__(self, base_noise: ActionNoise, n_envs: int):
try:
self.n_envs = int(n_envs)
assert self.n_envs > 0
- except (TypeError, AssertionError):
- raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0")
+ except (TypeError, AssertionError) as e:
+ raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") from e
self.base_noise = base_noise
self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)]
diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py
index 5905deec3..f263cb241 100644
--- a/stable_baselines3/common/off_policy_algorithm.py
+++ b/stable_baselines3/common/off_policy_algorithm.py
@@ -1,11 +1,12 @@
import io
import pathlib
+import sys
import time
import warnings
from copy import deepcopy
-from typing import Any, Dict, List, Optional, Tuple, Type, Union
+from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
-import gym
+import gymnasium as gym
import numpy as np
import torch as th
@@ -20,12 +21,14 @@
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
+OffPolicyAlgorithmSelf = TypeVar("OffPolicyAlgorithmSelf", bound="OffPolicyAlgorithm")
+
class OffPolicyAlgorithm(BaseAlgorithm):
"""
The base for Off-Policy algorithms (ex: SAC/TD3)
- :param policy: Policy object
+ :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param learning_rate: learning rate for the optimizer,
@@ -50,14 +53,13 @@ class OffPolicyAlgorithm(BaseAlgorithm):
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param policy_kwargs: Additional arguments to be passed to the policy on creation
:param tensorboard_log: the log location for tensorboard (if None, no logging)
- :param verbose: The verbosity level: 0 none, 1 training information, 2 debug
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
+ debug messages
:param device: Device on which the code should run.
By default, it will try to use a Cuda compatible device and fallback to cpu
if it is not possible.
:param support_multi_env: Whether the algorithm supports training
with multiple environments (as in A2C)
- :param create_eval_env: Whether to create a second environment that will be
- used for evaluating the agent periodically. (Only available when passing string for the environment)
:param monitor_wrapper: When creating an environment, whether to wrap it
or not in a Monitor wrapper.
:param seed: Seed for the pseudo random generators
@@ -73,7 +75,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
def __init__(
self,
- policy: Type[BasePolicy],
+ policy: Union[str, Type[BasePolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule],
buffer_size: int = 1_000_000, # 1e6
@@ -84,7 +86,7 @@ def __init__(
train_freq: Union[int, Tuple[int, str]] = (1, "step"),
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
- replay_buffer_class: Optional[ReplayBuffer] = None,
+ replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
@@ -92,7 +94,6 @@ def __init__(
verbose: int = 0,
device: Union[th.device, str] = "auto",
support_multi_env: bool = False,
- create_eval_env: bool = False,
monitor_wrapper: bool = True,
seed: Optional[int] = None,
use_sde: bool = False,
@@ -102,7 +103,7 @@ def __init__(
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
):
- super(OffPolicyAlgorithm, self).__init__(
+ super().__init__(
policy=policy,
env=env,
learning_rate=learning_rate,
@@ -111,7 +112,6 @@ def __init__(
verbose=verbose,
device=device,
support_multi_env=support_multi_env,
- create_eval_env=create_eval_env,
monitor_wrapper=monitor_wrapper,
seed=seed,
use_sde=use_sde,
@@ -157,8 +157,10 @@ def _convert_train_freq(self) -> None:
try:
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
- except ValueError:
- raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!")
+ except ValueError as e:
+ raise ValueError(
+ f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!"
+ ) from e
if not isinstance(train_freq[0], int):
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")
@@ -263,13 +265,10 @@ def load_replay_buffer(
def _setup_learn(
self,
total_timesteps: int,
- eval_env: Optional[GymEnv],
callback: MaybeCallback = None,
- eval_freq: int = 10000,
- n_eval_episodes: int = 5,
- log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
+ progress_bar: bool = False,
) -> Tuple[int, BaseCallback]:
"""
cf `BaseAlgorithm`.
@@ -305,37 +304,28 @@ def _setup_learn(
return super()._setup_learn(
total_timesteps,
- eval_env,
callback,
- eval_freq,
- n_eval_episodes,
- log_path,
reset_num_timesteps,
tb_log_name,
+ progress_bar,
)
def learn(
- self,
+ self: OffPolicyAlgorithmSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
tb_log_name: str = "run",
- eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
- ) -> "OffPolicyAlgorithm":
+ progress_bar: bool = False,
+ ) -> OffPolicyAlgorithmSelf:
total_timesteps, callback = self._setup_learn(
total_timesteps,
- eval_env,
callback,
- eval_freq,
- n_eval_episodes,
- eval_log_path,
reset_num_timesteps,
tb_log_name,
+ progress_bar,
)
callback.on_training_start(locals(), globals())
@@ -425,8 +415,8 @@ def _dump_logs(self) -> None:
"""
Write log.
"""
- time_elapsed = time.time() - self.start_time
- fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8))
+ time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
+ fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
@@ -612,7 +602,6 @@ def collect_rollouts(
# Log training infos
if log_interval is not None and self._episode_num % log_interval == 0:
self._dump_logs()
-
callback.on_rollout_end()
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)
diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py
index 281758c0b..7ae2c53c0 100644
--- a/stable_baselines3/common/on_policy_algorithm.py
+++ b/stable_baselines3/common/on_policy_algorithm.py
@@ -1,7 +1,8 @@
+import sys
import time
-from typing import Any, Dict, List, Optional, Tuple, Type, Union
+from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
-import gym
+import gymnasium as gym
import numpy as np
import torch as th
@@ -13,6 +14,8 @@
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
+OnPolicyAlgorithmSelf = TypeVar("OnPolicyAlgorithmSelf", bound="OnPolicyAlgorithm")
+
class OnPolicyAlgorithm(BaseAlgorithm):
"""
@@ -35,12 +38,11 @@ class OnPolicyAlgorithm(BaseAlgorithm):
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param tensorboard_log: the log location for tensorboard (if None, no logging)
- :param create_eval_env: Whether to create a second environment that will be
- used for evaluating the agent periodically. (Only available when passing string for the environment)
:param monitor_wrapper: When creating an environment, whether to wrap it
or not in a Monitor wrapper.
:param policy_kwargs: additional arguments to be passed to the policy on creation
- :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
+ debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
@@ -62,7 +64,6 @@ def __init__(
use_sde: bool,
sde_sample_freq: int,
tensorboard_log: Optional[str] = None,
- create_eval_env: bool = False,
monitor_wrapper: bool = True,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
@@ -72,7 +73,7 @@ def __init__(
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
):
- super(OnPolicyAlgorithm, self).__init__(
+ super().__init__(
policy=policy,
env=env,
learning_rate=learning_rate,
@@ -81,7 +82,6 @@ def __init__(
device=device,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
- create_eval_env=create_eval_env,
support_multi_env=True,
seed=seed,
tensorboard_log=tensorboard_log,
@@ -139,7 +139,7 @@ def collect_rollouts(
:param callback: Callback that will be called at each step
(and at the beginning and end of the rollout)
:param rollout_buffer: Buffer to fill with rollouts
- :param n_steps: Number of experiences to collect per environment
+ :param n_rollout_steps: Number of experiences to collect per environment
:return: True if function returned with at least `n_rollout_steps`
collected, False if callback terminated rollout prematurely.
"""
@@ -223,21 +223,22 @@ def train(self) -> None:
raise NotImplementedError
def learn(
- self,
+ self: OnPolicyAlgorithmSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
tb_log_name: str = "OnPolicyAlgorithm",
- eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
- ) -> "OnPolicyAlgorithm":
+ progress_bar: bool = False,
+ ) -> OnPolicyAlgorithmSelf:
iteration = 0
total_timesteps, callback = self._setup_learn(
- total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
+ total_timesteps,
+ callback,
+ reset_num_timesteps,
+ tb_log_name,
+ progress_bar,
)
callback.on_training_start(locals(), globals())
@@ -254,13 +255,14 @@ def learn(
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
- fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
+ time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
+ fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
- self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
+ self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)
diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py
index c322dc6f1..e802f48c2 100644
--- a/stable_baselines3/common/policies.py
+++ b/stable_baselines3/common/policies.py
@@ -2,12 +2,11 @@
import collections
import copy
-import warnings
from abc import ABC, abstractmethod
from functools import partial
-from typing import Any, Dict, List, Optional, Tuple, Type, Union
+from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
-import gym
+import gymnasium as gym
import numpy as np
import torch as th
from torch import nn
@@ -33,8 +32,10 @@
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
+BaseModelSelf = TypeVar("BaseModelSelf", bound="BaseModel")
-class BaseModel(nn.Module, ABC):
+
+class BaseModel(nn.Module):
"""
The base model object: makes predictions in response to observations.
@@ -67,7 +68,7 @@ def __init__(
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
- super(BaseModel, self).__init__()
+ super().__init__()
if optimizer_kwargs is None:
optimizer_kwargs = {}
@@ -87,10 +88,6 @@ def __init__(
self.features_extractor_class = features_extractor_class
self.features_extractor_kwargs = features_extractor_kwargs
- @abstractmethod
- def forward(self, *args, **kwargs):
- pass
-
def _update_features_extractor(
self,
net_kwargs: Dict[str, Any],
@@ -162,7 +159,7 @@ def save(self, path: str) -> None:
th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
@classmethod
- def load(cls, path: str, device: Union[th.device, str] = "auto") -> "BaseModel":
+ def load(cls: Type[BaseModelSelf], path: str, device: Union[th.device, str] = "auto") -> BaseModelSelf:
"""
Load model from path.
@@ -173,14 +170,6 @@ def load(cls, path: str, device: Union[th.device, str] = "auto") -> "BaseModel":
device = get_device(device)
saved_variables = th.load(path, map_location=device)
- # Allow to load policy saved with older version of SB3
- if "sde_net_arch" in saved_variables["data"]:
- warnings.warn(
- "sde_net_arch is deprecated, please downgrade to SB3 v1.2.0 if you need such parameter.",
- DeprecationWarning,
- )
- del saved_variables["data"]["sde_net_arch"]
-
# Create policy object
model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable
# Load weights
@@ -255,7 +244,7 @@ def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -
return observation, vectorized_env
-class BasePolicy(BaseModel):
+class BasePolicy(BaseModel, ABC):
"""The base policy object.
Parameters are mostly the same as `BaseModel`; additions are documented below.
@@ -267,7 +256,7 @@ class BasePolicy(BaseModel):
"""
def __init__(self, *args, squash_output: bool = False, **kwargs):
- super(BasePolicy, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
self._squash_output = squash_output
@staticmethod
@@ -336,8 +325,8 @@ def predict(
with th.no_grad():
actions = self._predict(observation, deterministic=deterministic)
- # Convert to numpy
- actions = actions.cpu().numpy()
+ # Convert to numpy, and reshape to the original action shape
+ actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape)
if isinstance(self.action_space, gym.spaces.Box):
if self.squash_output:
@@ -350,7 +339,7 @@ def predict(
# Remove batch dimension if needed
if not vectorized_env:
- actions = actions[0]
+ actions = actions.squeeze(axis=0)
return actions, state
@@ -391,9 +380,6 @@ class ActorCriticPolicy(BasePolicy):
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
- :param sde_net_arch: Network architecture for extracting features
- when using gSDE. If None, the latent features from the policy will be used.
- Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
@@ -421,7 +407,6 @@ def __init__(
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
- sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
@@ -437,7 +422,7 @@ def __init__(
if optimizer_class == th.optim.Adam:
optimizer_kwargs["eps"] = 1e-5
- super(ActorCriticPolicy, self).__init__(
+ super().__init__(
observation_space,
action_space,
features_extractor_class,
@@ -473,9 +458,6 @@ def __init__(
"learn_features": False,
}
- if sde_net_arch is not None:
- warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
-
self.use_sde = use_sde
self.dist_kwargs = dist_kwargs
@@ -592,6 +574,7 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tenso
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
+ actions = actions.reshape((-1,) + self.action_space.shape)
return actions, values, log_prob
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
@@ -629,7 +612,7 @@ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Te
"""
return self.get_distribution(observation).get_actions(deterministic=deterministic)
- def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
+ def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
"""
Evaluate actions according to the current policy,
given the observations.
@@ -645,7 +628,8 @@ def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tenso
distribution = self._get_action_dist_from_latent(latent_pi)
log_prob = distribution.log_prob(actions)
values = self.value_net(latent_vf)
- return values, log_prob, distribution.entropy()
+ entropy = distribution.entropy()
+ return values, log_prob, entropy
def get_distribution(self, obs: th.Tensor) -> Distribution:
"""
@@ -685,9 +669,6 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
- :param sde_net_arch: Network architecture for extracting features
- when using gSDE. If None, the latent features from the policy will be used.
- Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
@@ -715,7 +696,6 @@ def __init__(
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
- sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
@@ -724,7 +704,7 @@ def __init__(
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
- super(ActorCriticCnnPolicy, self).__init__(
+ super().__init__(
observation_space,
action_space,
lr_schedule,
@@ -734,7 +714,6 @@ def __init__(
use_sde,
log_std_init,
full_std,
- sde_net_arch,
use_expln,
squash_output,
features_extractor_class,
@@ -760,9 +739,6 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
- :param sde_net_arch: Network architecture for extracting features
- when using gSDE. If None, the latent features from the policy will be used.
- Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
@@ -790,7 +766,6 @@ def __init__(
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
- sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
@@ -799,7 +774,7 @@ def __init__(
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
- super(MultiInputActorCriticPolicy, self).__init__(
+ super().__init__(
observation_space,
action_space,
lr_schedule,
@@ -809,7 +784,6 @@ def __init__(
use_sde,
log_std_init,
full_std,
- sde_net_arch,
use_expln,
squash_output,
features_extractor_class,
diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py
index 01422aafc..c8fcbecd3 100644
--- a/stable_baselines3/common/preprocessing.py
+++ b/stable_baselines3/common/preprocessing.py
@@ -3,7 +3,7 @@
import numpy as np
import torch as th
-from gym import spaces
+from gymnasium import spaces
from torch.nn import functional as F
diff --git a/stable_baselines3/common/running_mean_std.py b/stable_baselines3/common/running_mean_std.py
index fb3ae8bd5..b48f9223c 100644
--- a/stable_baselines3/common/running_mean_std.py
+++ b/stable_baselines3/common/running_mean_std.py
@@ -3,7 +3,7 @@
import numpy as np
-class RunningMeanStd(object):
+class RunningMeanStd:
def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
"""
Calulates the running mean and std of a data stream
diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py
index e0b104f12..facc55ac8 100644
--- a/stable_baselines3/common/save_util.py
+++ b/stable_baselines3/common/save_util.py
@@ -162,13 +162,15 @@ def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = No
try:
base64_object = base64.b64decode(serialization.encode())
deserialized_object = cloudpickle.loads(base64_object)
- except (RuntimeError, TypeError):
+ except (RuntimeError, TypeError, AttributeError) as e:
warnings.warn(
f"Could not deserialize object {data_key}. "
- + "Consider using `custom_objects` argument to replace "
- + "this object."
+ "Consider using `custom_objects` argument to replace "
+ "this object.\n"
+ f"Exception: {e}"
)
- return_data[data_key] = deserialized_object
+ else:
+ return_data[data_key] = deserialized_object
else:
# Read as it is
return_data[data_key] = data_item
@@ -186,14 +188,14 @@ def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verb
If the provided path is a string or a pathlib.Path, it ensures that it exists. If the mode is "read"
it checks that it exists, if it doesn't exist it attempts to read path.suffix if a suffix is provided.
If the mode is "write" and the path does not exist, it creates all the parent folders. If the path
- points to a folder, it changes the path to path_2. If the path already exists and verbose == 2,
+ points to a folder, it changes the path to path_2. If the path already exists and verbose >= 2,
it raises a warning.
:param path: the path to open.
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
path actually exists. If path is a io.BufferedIOBase the path exists.
:param mode: how to open the file. "w"|"write" for writing, "r"|"read" for reading.
- :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
@@ -206,8 +208,8 @@ def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verb
mode = mode.lower()
try:
mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode]
- except KeyError:
- raise ValueError("Expected mode to be either 'w' or 'r'.")
+ except KeyError as e:
+ raise ValueError("Expected mode to be either 'w' or 'r'.") from e
if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable():
e1 = "writable" if "w" == mode else "readable"
raise ValueError(f"Expected a {e1} file.")
@@ -223,7 +225,7 @@ def open_path_str(path: str, mode: str, verbose: int = 0, suffix: Optional[str]
:param path: the path to open. If mode is "w" then it ensures that the path exists
by creating the necessary folders and renaming path if it points to a folder.
:param mode: how to open the file. "w" for writing, "r" for reading.
- :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
@@ -242,7 +244,7 @@ def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: O
ensures that the path exists by creating the necessary folders and
renaming path if it points to a folder.
:param mode: how to open the file. "w" for writing, "r" for reading.
- :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
+ :param verbose: Verbosity level: 0 for no output, 2 for indicating if path without suffix is not found when mode is "r"
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
@@ -257,7 +259,7 @@ def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: O
except FileNotFoundError as error:
if suffix is not None and suffix != "":
newpath = pathlib.Path(f"{path}.{suffix}")
- if verbose == 2:
+ if verbose >= 2:
warnings.warn(f"Path '{path}' not found. Attempting {newpath}.")
path, suffix = newpath, None
else:
@@ -266,7 +268,7 @@ def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: O
try:
if path.suffix == "" and suffix is not None and suffix != "":
path = pathlib.Path(f"{path}.{suffix}")
- if path.exists() and path.is_file() and verbose == 2:
+ if path.exists() and path.is_file() and verbose >= 2:
warnings.warn(f"Path '{path}' exists, will overwrite it.")
path = path.open("wb")
except IsADirectoryError:
@@ -300,7 +302,7 @@ def save_to_zip_file(
:param params: Model parameters being stored expected to contain an entry for every
state_dict with its name and the state_dict.
:param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
- :param verbose: Verbosity level, 0 means only warnings, 2 means debug information
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
save_path = open_path(save_path, "w", verbose=0, suffix="zip")
# data/params can be None, so do not
@@ -336,7 +338,7 @@ def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, ver
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
path actually exists. If path is a io.BufferedIOBase the path exists.
:param obj: The object to save.
- :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
with open_path(path, "w", verbose=verbose, suffix="pkl") as file_handler:
# Use protocol>=4 to support saving replay buffers >= 4Gb
@@ -352,7 +354,7 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: in
:param path: the path to open.
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
path actually exists. If path is a io.BufferedIOBase the path exists.
- :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
with open_path(path, "r", verbose=verbose, suffix="pkl") as file_handler:
return pickle.load(file_handler)
@@ -379,7 +381,7 @@ def load_from_zip_file(
``keras.models.load_model``. Useful when you have an object in
file that can not be deserialized.
:param device: Device on which the code should run.
- :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
:param print_system_info: Whether to print or not the system info
about the saved model.
:return: Class parameters, model state_dicts (aka "params", dict of state_dict)
@@ -441,7 +443,7 @@ def load_from_zip_file(
# State dicts. Store into params dictionary
# with same name as in .zip file (without .pth)
params[os.path.splitext(file_path)[0]] = th_object
- except zipfile.BadZipFile:
+ except zipfile.BadZipFile as e:
# load_path wasn't a zip file
- raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
+ raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e
return data, params, pytorch_variables
diff --git a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py
index ba70a5f63..377b7f604 100644
--- a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py
+++ b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py
@@ -54,21 +54,21 @@ def __init__(
centered: bool = False,
):
if not 0.0 <= lr:
- raise ValueError("Invalid learning rate: {}".format(lr))
+ raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
- raise ValueError("Invalid epsilon value: {}".format(eps))
+ raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= momentum:
- raise ValueError("Invalid momentum value: {}".format(momentum))
+ raise ValueError(f"Invalid momentum value: {momentum}")
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= alpha:
- raise ValueError("Invalid alpha value: {}".format(alpha))
+ raise ValueError(f"Invalid alpha value: {alpha}")
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
- super(RMSpropTFLike, self).__init__(params, defaults)
+ super().__init__(params, defaults)
def __setstate__(self, state: Dict[str, Any]) -> None:
- super(RMSpropTFLike, self).__setstate__(state)
+ super().__setstate__(state)
for group in self.param_groups:
group.setdefault("momentum", 0)
group.setdefault("centered", False)
diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py
index 589d12eb1..a93ef6b67 100644
--- a/stable_baselines3/common/torch_layers.py
+++ b/stable_baselines3/common/torch_layers.py
@@ -1,7 +1,7 @@
from itertools import zip_longest
from typing import Dict, List, Tuple, Type, Union
-import gym
+import gymnasium as gym
import torch as th
from torch import nn
@@ -19,7 +19,7 @@ class BaseFeaturesExtractor(nn.Module):
"""
def __init__(self, observation_space: gym.Space, features_dim: int = 0):
- super(BaseFeaturesExtractor, self).__init__()
+ super().__init__()
assert features_dim > 0
self._observation_space = observation_space
self._features_dim = features_dim
@@ -41,7 +41,7 @@ class FlattenExtractor(BaseFeaturesExtractor):
"""
def __init__(self, observation_space: gym.Space):
- super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space))
+ super().__init__(observation_space, get_flattened_obs_dim(observation_space))
self.flatten = nn.Flatten()
def forward(self, observations: th.Tensor) -> th.Tensor:
@@ -50,7 +50,7 @@ def forward(self, observations: th.Tensor) -> th.Tensor:
class NatureCNN(BaseFeaturesExtractor):
"""
- CNN from DQN nature paper:
+ CNN from DQN Nature paper:
Mnih, Volodymyr, et al.
"Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529-533.
@@ -61,7 +61,7 @@ class NatureCNN(BaseFeaturesExtractor):
"""
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
- super(NatureCNN, self).__init__(observation_space, features_dim)
+ super().__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
assert is_image_space(observation_space, check_channels=False), (
@@ -169,7 +169,7 @@ def __init__(
activation_fn: Type[nn.Module],
device: Union[th.device, str] = "auto",
):
- super(MlpExtractor, self).__init__()
+ super().__init__()
device = get_device(device)
shared_net, policy_net, value_net = [], [], []
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
@@ -250,7 +250,7 @@ class CombinedExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256):
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
- super(CombinedExtractor, self).__init__(observation_space, features_dim=1)
+ super().__init__(observation_space, features_dim=1)
extractors = {}
diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py
index 7e69d398e..48c07d54e 100644
--- a/stable_baselines3/common/type_aliases.py
+++ b/stable_baselines3/common/type_aliases.py
@@ -3,7 +3,7 @@
from enum import Enum
from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union
-import gym
+import gymnasium as gym
import numpy as np
import torch as th
@@ -11,7 +11,9 @@
GymEnv = Union[gym.Env, vec_env.VecEnv]
GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
+Gym26ResetReturn = Tuple[GymObs, Dict]
GymStepReturn = Tuple[GymObs, float, bool, Dict]
+Gym26StepReturn = Tuple[GymObs, float, bool, bool, Dict]
TensorDict = Dict[Union[str, int], th.Tensor]
OptimizerStateDict = Dict[str, Any]
MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback]
@@ -50,7 +52,7 @@ class ReplayBufferSamples(NamedTuple):
class DictReplayBufferSamples(ReplayBufferSamples):
observations: TensorDict
actions: th.Tensor
- next_observations: th.Tensor
+ next_observations: TensorDict
dones: th.Tensor
rewards: th.Tensor
diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py
index 94cd65827..70abc2579 100644
--- a/stable_baselines3/common/utils.py
+++ b/stable_baselines3/common/utils.py
@@ -3,10 +3,11 @@
import platform
import random
from collections import deque
+from inspect import signature
from itertools import zip_longest
-from typing import Dict, Iterable, Optional, Tuple, Union
+from typing import Dict, Iterable, List, Optional, Tuple, Union
-import gym
+import gymnasium as gym
import numpy as np
import torch as th
@@ -67,8 +68,8 @@ def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) ->
Update the learning rate for a given optimizer.
Useful when doing linear schedule.
- :param optimizer:
- :param learning_rate:
+ :param optimizer: Pytorch optimizer
+ :param learning_rate: New learning rate value
"""
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate
@@ -79,8 +80,8 @@ def get_schedule_fn(value_schedule: Union[Schedule, float, int]) -> Schedule:
Transform (if needed) learning rate and clip range (for PPO)
to callable.
- :param value_schedule:
- :return:
+ :param value_schedule: Constant value of schedule function
+ :return: Schedule function (can return constant value)
"""
# If the passed schedule is a float
# create a constant function
@@ -104,7 +105,7 @@ def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
:params end_fraction: fraction of ``progress_remaining``
where end is reached e.g 0.1 then end is reached after 10%
of the complete training process.
- :return:
+ :return: Linear schedule function.
"""
def func(progress_remaining: float) -> float:
@@ -121,8 +122,8 @@ def constant_fn(val: float) -> Schedule:
Create a function that returns a constant
It is useful for learning rate schedule (to avoid code duplication)
- :param val:
- :return:
+ :param val: constant value
+ :return: Constant schedule function.
"""
def func(_):
@@ -139,7 +140,7 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device:
By default, it tries to use the gpu.
:param device: One for 'auto', 'cuda', 'cpu'
- :return:
+ :return: Supported Pytorch device
"""
# Cuda by default
if device == "auto":
@@ -182,7 +183,7 @@ def configure_logger(
"""
Configure the logger's outputs.
- :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
+ :param verbose: Verbosity level: 0 for no output, 1 for the standard output to be part of the logger outputs
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param tb_log_name: tensorboard log
:param reset_num_timesteps: Whether the ``num_timesteps`` attribute is reset or not.
@@ -386,12 +387,25 @@ def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
Compute the mean of an array if there is at least one element.
For empty array, return NaN. It is used for logging only.
- :param arr:
+ :param arr: Numpy array or list of values
:return:
"""
return np.nan if len(arr) == 0 else np.mean(arr)
+def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> List[th.Tensor]:
+ """
+ Extract parameters from the state dict of ``model``
+ if the name contains one of the strings in ``included_names``.
+
+ :param model: the model where the parameters come from.
+ :param included_names: substrings of names to include.
+ :return: List of parameters values (Pytorch tensors)
+ that matches the queried names.
+ """
+ return [param for name, param in model.state_dict().items() if any([key in name for key in included_names])]
+
+
def zip_strict(*iterables: Iterable) -> Iterable:
r"""
``zip()`` function but enforces that iterables are of equal length.
@@ -411,8 +425,8 @@ def zip_strict(*iterables: Iterable) -> Iterable:
def polyak_update(
- params: Iterable[th.nn.Parameter],
- target_params: Iterable[th.nn.Parameter],
+ params: Iterable[th.Tensor],
+ target_params: Iterable[th.Tensor],
tau: float,
) -> None:
"""
@@ -506,3 +520,18 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
if print_info:
print(env_info_str)
return env_info, env_info_str
+
+
+def compat_gym_seed(env: GymEnv, seed: int) -> None:
+ """
+ Compatibility helper to seed Gym envs.
+
+ :param env: The Gym environment.
+ :param seed: The seed for the pseudo random generator
+ """
+ if isinstance(env, gym.Env) and "seed" in signature(env.unwrapped.reset).parameters:
+ # gym >= 0.23.1
+ env.reset(seed=seed)
+ else:
+ # VecEnv and backward compatibility
+ env.seed(seed)
diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py
index 37ebc364d..3880fbd53 100644
--- a/stable_baselines3/common/vec_env/__init__.py
+++ b/stable_baselines3/common/vec_env/__init__.py
@@ -66,7 +66,9 @@ def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
env_tmp, eval_env_tmp = env, eval_env
while isinstance(env_tmp, VecEnvWrapper):
if isinstance(env_tmp, VecNormalize):
- eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
+ # Only synchronize if observation normalization exists
+ if hasattr(env_tmp, "obs_rms"):
+ eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
env_tmp = env_tmp.venv
eval_env_tmp = eval_env_tmp.venv
diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py
index d3e624af9..584159d13 100644
--- a/stable_baselines3/common/vec_env/base_vec_env.py
+++ b/stable_baselines3/common/vec_env/base_vec_env.py
@@ -4,7 +4,7 @@
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union
import cloudpickle
-import gym
+import gymnasium as gym
import numpy as np
# Define type aliases here to avoid circular import
@@ -59,6 +59,7 @@ def __init__(self, num_envs: int, observation_space: gym.spaces.Space, action_sp
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
+ self.reset_infos = [{} for _ in range(num_envs)] # store info returned by the reset method
@abstractmethod
def reset(self) -> VecEnvObs:
@@ -305,7 +306,7 @@ def __getattr__(self, name: str) -> Any:
own_class = f"{type(self).__module__}.{type(self).__name__}"
error_str = (
f"Error: Recursive attribute lookup for {name} from {own_class} is "
- "ambiguous and hides attribute from {blocked_class}"
+ f"ambiguous and hides attribute from {blocked_class}"
)
raise AttributeError(error_str)
diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py
index c0efc8caf..0af11f4a2 100644
--- a/stable_baselines3/common/vec_env/dummy_vec_env.py
+++ b/stable_baselines3/common/vec_env/dummy_vec_env.py
@@ -2,7 +2,7 @@
from copy import deepcopy
from typing import Any, Callable, List, Optional, Sequence, Type, Union
-import gym
+import gymnasium as gym
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
@@ -39,28 +39,36 @@ def step_async(self, actions: np.ndarray) -> None:
self.actions = actions
def step_wait(self) -> VecEnvStepReturn:
+ # Avoid circular imports
for env_idx in range(self.num_envs):
- obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step(
+ obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step(
self.actions[env_idx]
)
+ # convert to SB3 VecEnv api
+ self.buf_dones[env_idx] = terminated or truncated
+ self.buf_infos[env_idx]["TimeLimit.truncated"] = truncated
+
if self.buf_dones[env_idx]:
# save final observation where user can get it, then reset
self.buf_infos[env_idx]["terminal_observation"] = obs
- obs = self.envs[env_idx].reset()
+ obs, self.reset_infos[env_idx] = self.envs[env_idx].reset()
self._save_obs(env_idx, obs)
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
+ # Avoid circular import
+ from stable_baselines3.common.utils import compat_gym_seed
+
if seed is None:
seed = np.random.randint(0, 2**32 - 1)
seeds = []
for idx, env in enumerate(self.envs):
- seeds.append(env.seed(seed + idx))
+ seeds.append(compat_gym_seed(env, seed=seed + idx))
return seeds
def reset(self) -> VecEnvObs:
for env_idx in range(self.num_envs):
- obs = self.envs[env_idx].reset()
+ obs, self.reset_infos[env_idx] = self.envs[env_idx].reset()
self._save_obs(env_idx, obs)
return self._obs_from_buf()
diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py
index affd7756e..c5dc3d67f 100644
--- a/stable_baselines3/common/vec_env/stacked_observations.py
+++ b/stable_baselines3/common/vec_env/stacked_observations.py
@@ -2,12 +2,12 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
-from gym import spaces
+from gymnasium import spaces
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
-class StackedObservations(object):
+class StackedObservations:
"""
Frame stacking wrapper for data.
@@ -199,7 +199,7 @@ def stack_observation_space(self, observation_space: spaces.Dict) -> spaces.Dict
spaces_dict[key] = spaces.Box(low=low, high=high, dtype=subspace.dtype)
return spaces.Dict(spaces=spaces_dict)
- def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
+ def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: # pytype: disable=signature-mismatch
"""
Resets the stacked observations, adds the reset observation to the stack, and returns the stack
@@ -219,7 +219,7 @@ def update(
observations: Dict[str, np.ndarray],
dones: np.ndarray,
infos: List[Dict[str, Any]],
- ) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]:
+ ) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]: # pytype: disable=signature-mismatch
"""
Adds the observations to the stack and uses the dones to update the infos.
diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py
index 04f5d0c58..155857294 100644
--- a/stable_baselines3/common/vec_env/subproc_vec_env.py
+++ b/stable_baselines3/common/vec_env/subproc_vec_env.py
@@ -2,7 +2,7 @@
from collections import OrderedDict
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
-import gym
+import gymnasium as gym
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import (
@@ -14,29 +14,36 @@
)
-def _worker(
- remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper
+def _worker( # noqa: C901
+ remote: mp.connection.Connection,
+ parent_remote: mp.connection.Connection,
+ env_fn_wrapper: CloudpickleWrapper,
) -> None:
# Import here to avoid a circular import
from stable_baselines3.common.env_util import is_wrapped
+ from stable_baselines3.common.utils import compat_gym_seed
parent_remote.close()
env = env_fn_wrapper.var()
+ reset_info = {}
while True:
try:
cmd, data = remote.recv()
if cmd == "step":
- observation, reward, done, info = env.step(data)
+ observation, reward, terminated, truncated, info = env.step(data)
+ # convert to SB3 VecEnv api
+ done = terminated or truncated
+ info["TimeLimit.truncated"] = truncated
if done:
# save final observation where user can get it, then reset
info["terminal_observation"] = observation
- observation = env.reset()
- remote.send((observation, reward, done, info))
+ observation, reset_info = env.reset()
+ remote.send((observation, reward, done, info, reset_info))
elif cmd == "seed":
- remote.send(env.seed(data))
+ remote.send(compat_gym_seed(env, seed=data))
elif cmd == "reset":
- observation = env.reset()
- remote.send(observation)
+ observation, reset_info = env.reset()
+ remote.send((observation, reset_info))
elif cmd == "render":
remote.send(env.render(data))
elif cmd == "close":
@@ -119,7 +126,7 @@ def step_async(self, actions: np.ndarray) -> None:
def step_wait(self) -> VecEnvStepReturn:
results = [remote.recv() for remote in self.remotes]
self.waiting = False
- obs, rews, dones, infos = zip(*results)
+ obs, rews, dones, infos, self.reset_infos = zip(*results)
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
@@ -132,7 +139,8 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
def reset(self) -> VecEnvObs:
for remote in self.remotes:
remote.send(("reset", None))
- obs = [remote.recv() for remote in self.remotes]
+ results = [remote.recv() for remote in self.remotes]
+ obs, self.reset_infos = zip(*results)
return _flatten_obs(obs, self.observation_space)
def close(self) -> None:
@@ -217,6 +225,6 @@ def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: gym.space
elif isinstance(space, gym.spaces.Tuple):
assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
obs_len = len(space.spaces)
- return tuple((np.stack([o[i] for o in obs]) for i in range(obs_len)))
+ return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len))
else:
return np.stack(obs)
diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py
index 859f1ec95..cb047ee0d 100644
--- a/stable_baselines3/common/vec_env/util.py
+++ b/stable_baselines3/common/vec_env/util.py
@@ -4,7 +4,7 @@
from collections import OrderedDict
from typing import Any, Dict, List, Tuple
-import gym
+import gymnasium as gym
import numpy as np
from stable_baselines3.common.preprocessing import check_for_nested_spaces
@@ -37,7 +37,7 @@ def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) ->
return obs_dict
elif isinstance(obs_space, gym.spaces.Tuple):
assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space"
- return tuple((obs_dict[i] for i in range(len(obs_space.spaces))))
+ return tuple(obs_dict[i] for i in range(len(obs_space.spaces)))
else:
assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space"
return obs_dict[None]
diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py
index e06d5125e..0c5c03145 100644
--- a/stable_baselines3/common/vec_env/vec_frame_stack.py
+++ b/stable_baselines3/common/vec_env/vec_frame_stack.py
@@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
-from gym import spaces
+from gymnasium import spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
@@ -55,8 +55,7 @@ def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""
Reset all environments
"""
- observation = self.venv.reset() # pytype:disable=annotation-type-mismatch
-
+ observation = self.venv.reset()
observation = self.stackedobs.reset(observation)
return observation
diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py
index f3ee588ab..079ff4e21 100644
--- a/stable_baselines3/common/vec_env/vec_normalize.py
+++ b/stable_baselines3/common/vec_env/vec_normalize.py
@@ -1,9 +1,8 @@
import pickle
-import warnings
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
-import gym
+import gymnasium as gym
import numpy as np
from stable_baselines3.common import utils
@@ -289,8 +288,3 @@ def save(self, save_path: str) -> None:
"""
with open(save_path, "wb") as file_handler:
pickle.dump(self, file_handler)
-
- @property
- def ret(self) -> np.ndarray:
- warnings.warn("`VecNormalize` `ret` attribute is deprecated. Please use `returns` instead.", DeprecationWarning)
- return self.returns
diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py
index e6f728bec..beb603961 100644
--- a/stable_baselines3/common/vec_env/vec_transpose.py
+++ b/stable_baselines3/common/vec_env/vec_transpose.py
@@ -2,7 +2,7 @@
from typing import Dict, Union
import numpy as np
-from gym import spaces
+from gymnasium import spaces
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
@@ -26,7 +26,7 @@ def __init__(self, venv: VecEnv, skip: bool = False):
self.skip = skip
# Do nothing
if skip:
- super(VecTransposeImage, self).__init__(venv)
+ super().__init__(venv)
return
if isinstance(venv.observation_space, spaces.dict.Dict):
@@ -39,7 +39,7 @@ def __init__(self, venv: VecEnv, skip: bool = False):
observation_space.spaces[key] = self.transpose_space(space, key)
else:
observation_space = self.transpose_space(venv.observation_space)
- super(VecTransposeImage, self).__init__(venv, observation_space=observation_space)
+ super().__init__(venv, observation_space=observation_space)
@staticmethod
def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box:
diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py
index 70d74ebe4..90c4aa491 100644
--- a/stable_baselines3/common/vec_env/vec_video_recorder.py
+++ b/stable_baselines3/common/vec_env/vec_video_recorder.py
@@ -1,7 +1,7 @@
import os
from typing import Callable
-from gym.wrappers.monitoring import video_recorder
+from gymnasium.wrappers.monitoring import video_recorder
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py
index 14293ca3b..993a8c22a 100644
--- a/stable_baselines3/ddpg/ddpg.py
+++ b/stable_baselines3/ddpg/ddpg.py
@@ -1,14 +1,15 @@
-from typing import Any, Dict, Optional, Tuple, Type, Union
+from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
import torch as th
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
-from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.td3.policies import TD3Policy
from stable_baselines3.td3.td3 import TD3
+DDPGSelf = TypeVar("DDPGSelf", bound="DDPG")
+
class DDPG(TD3):
"""
@@ -43,10 +44,9 @@ class DDPG(TD3):
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
- :param create_eval_env: Whether to create a second environment that will be
- used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: additional arguments to be passed to the policy on creation
- :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
+ debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
@@ -66,11 +66,10 @@ def __init__(
train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
gradient_steps: int = -1,
action_noise: Optional[ActionNoise] = None,
- replay_buffer_class: Optional[ReplayBuffer] = None,
+ replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
tensorboard_log: Optional[str] = None,
- create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
@@ -78,7 +77,7 @@ def __init__(
_init_setup_model: bool = True,
):
- super(DDPG, self).__init__(
+ super().__init__(
policy=policy,
env=env,
learning_rate=learning_rate,
@@ -96,7 +95,6 @@ def __init__(
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
- create_eval_env=create_eval_env,
seed=seed,
optimize_memory_usage=optimize_memory_usage,
# Remove all tricks from TD3 to obtain DDPG:
@@ -115,26 +113,20 @@ def __init__(
self._setup_model()
def learn(
- self,
+ self: DDPGSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
tb_log_name: str = "DDPG",
- eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
- ) -> OffPolicyAlgorithm:
+ progress_bar: bool = False,
+ ) -> DDPGSelf:
- return super(DDPG, self).learn(
+ return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
- eval_env=eval_env,
- eval_freq=eval_freq,
- n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name,
- eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
+ progress_bar=progress_bar,
)
diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py
index ed6073b25..c89cf4dec 100644
--- a/stable_baselines3/dqn/dqn.py
+++ b/stable_baselines3/dqn/dqn.py
@@ -1,7 +1,7 @@
import warnings
-from typing import Any, Dict, List, Optional, Tuple, Type, Union
+from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
-import gym
+import gymnasium as gym
import numpy as np
import torch as th
from torch.nn import functional as F
@@ -11,16 +11,18 @@
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
-from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
+from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy
+DQNSelf = TypeVar("DQNSelf", bound="DQN")
+
class DQN(OffPolicyAlgorithm):
"""
Deep Q-Network (DQN)
Paper: https://arxiv.org/abs/1312.5602, https://www.nature.com/articles/nature14236
- Default hyperparameters are taken from the nature paper,
+ Default hyperparameters are taken from the Nature paper,
except for the optimizer and learning rate that were taken from Stable Baselines defaults.
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
@@ -50,10 +52,9 @@ class DQN(OffPolicyAlgorithm):
:param exploration_final_eps: final value of random action probability
:param max_grad_norm: The maximum value for the gradient clipping
:param tensorboard_log: the log location for tensorboard (if None, no logging)
- :param create_eval_env: Whether to create a second environment that will be
- used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: additional arguments to be passed to the policy on creation
- :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
+ debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
@@ -78,7 +79,7 @@ def __init__(
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 4,
gradient_steps: int = 1,
- replay_buffer_class: Optional[ReplayBuffer] = None,
+ replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
target_update_interval: int = 10000,
@@ -87,7 +88,6 @@ def __init__(
exploration_final_eps: float = 0.05,
max_grad_norm: float = 10,
tensorboard_log: Optional[str] = None,
- create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
@@ -95,7 +95,7 @@ def __init__(
_init_setup_model: bool = True,
):
- super(DQN, self).__init__(
+ super().__init__(
policy,
env,
learning_rate,
@@ -113,7 +113,6 @@ def __init__(
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
- create_eval_env=create_eval_env,
seed=seed,
sde_support=False,
optimize_memory_usage=optimize_memory_usage,
@@ -138,8 +137,11 @@ def __init__(
self._setup_model()
def _setup_model(self) -> None:
- super(DQN, self)._setup_model()
+ super()._setup_model()
self._create_aliases()
+ # Copy running stats, see GH issue #996
+ self.batch_norm_stats = get_parameters_by_name(self.q_net, ["running_"])
+ self.batch_norm_stats_target = get_parameters_by_name(self.q_net_target, ["running_"])
self.exploration_schedule = get_linear_fn(
self.exploration_initial_eps,
self.exploration_final_eps,
@@ -170,6 +172,8 @@ def _on_step(self) -> None:
self._n_calls += 1
if self._n_calls % self.target_update_interval == 0:
polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau)
+ # Copy running stats, see GH issue #996
+ polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
self.logger.record("rollout/exploration_rate", self.exploration_rate)
@@ -249,32 +253,26 @@ def predict(
return action, state
def learn(
- self,
+ self: DQNSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
tb_log_name: str = "DQN",
- eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
- ) -> OffPolicyAlgorithm:
+ progress_bar: bool = False,
+ ) -> DQNSelf:
- return super(DQN, self).learn(
+ return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
- eval_env=eval_env,
- eval_freq=eval_freq,
- n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name,
- eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
+ progress_bar=progress_bar,
)
def _excluded_save_params(self) -> List[str]:
- return super(DQN, self)._excluded_save_params() + ["q_net", "q_net_target"]
+ return super()._excluded_save_params() + ["q_net", "q_net_target"]
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]
diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py
index ea00b5cb5..1e7a0f27d 100644
--- a/stable_baselines3/dqn/policies.py
+++ b/stable_baselines3/dqn/policies.py
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Optional, Type
-import gym
+import gymnasium as gym
import torch as th
from torch import nn
@@ -37,7 +37,7 @@ def __init__(
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
):
- super(QNetwork, self).__init__(
+ super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
@@ -118,7 +118,7 @@ def __init__(
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
- super(DQNPolicy, self).__init__(
+ super().__init__(
observation_space,
action_space,
features_extractor_class,
@@ -239,7 +239,7 @@ def __init__(
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
- super(CnnPolicy, self).__init__(
+ super().__init__(
observation_space,
action_space,
lr_schedule,
@@ -284,7 +284,7 @@ def __init__(
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
- super(MultiInputPolicy, self).__init__(
+ super().__init__(
observation_space,
action_space,
lr_schedule,
diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py
index f61a78641..151843632 100644
--- a/stable_baselines3/her/her_replay_buffer.py
+++ b/stable_baselines3/her/her_replay_buffer.py
@@ -28,13 +28,13 @@ def get_time_limit(env: VecEnv, current_max_episode_length: Optional[int]) -> in
if current_max_episode_length is None:
raise AttributeError
# if not available check if a valid value was passed as an argument
- except AttributeError:
+ except AttributeError as e:
raise ValueError(
"The max episode length could not be inferred.\n"
"You must specify a `max_episode_steps` when registering the environment,\n"
"use a `gym.wrappers.TimeLimit` wrapper "
"or pass `max_episode_length` to the model constructor"
- )
+ ) from e
return current_max_episode_length
@@ -73,7 +73,7 @@ def __init__(
self,
env: VecEnv,
buffer_size: int,
- device: Union[th.device, str] = "cpu",
+ device: Union[th.device, str] = "auto",
replay_buffer: Optional[DictReplayBuffer] = None,
max_episode_length: Optional[int] = None,
n_sampled_goal: int = 4,
@@ -82,7 +82,7 @@ def __init__(
handle_timeout_termination: bool = True,
):
- super(HerReplayBuffer, self).__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs)
+ super().__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs)
# convert goal_selection_strategy into GoalSelectionStrategy if string
if isinstance(goal_selection_strategy, str):
@@ -226,7 +226,7 @@ def _sample_offline(
maybe_vec_env=None,
online_sampling=False,
n_sampled_goal=n_sampled_goal,
- )
+ ) # pytype: disable=bad-return-type
def sample_goals(
self,
diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py
index 0d05b4c45..3e75b5cde 100644
--- a/stable_baselines3/ppo/ppo.py
+++ b/stable_baselines3/ppo/ppo.py
@@ -1,9 +1,9 @@
import warnings
-from typing import Any, Dict, Optional, Type, Union
+from typing import Any, Dict, Optional, Type, TypeVar, Union
import numpy as np
import torch as th
-from gym import spaces
+from gymnasium import spaces
from torch.nn import functional as F
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
@@ -11,6 +11,8 @@
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
+PPOSelf = TypeVar("PPOSelf", bound="PPO")
+
class PPO(OnPolicyAlgorithm):
"""
@@ -19,7 +21,7 @@ class PPO(OnPolicyAlgorithm):
Paper: https://arxiv.org/abs/1707.06347
Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/)
https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
- and Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)
+ Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)
Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html
@@ -55,10 +57,9 @@ class PPO(OnPolicyAlgorithm):
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
By default, there is no limit on the kl div.
:param tensorboard_log: the log location for tensorboard (if None, no logging)
- :param create_eval_env: Whether to create a second environment that will be
- used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: additional arguments to be passed to the policy on creation
- :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
+ debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
@@ -91,7 +92,6 @@ def __init__(
sde_sample_freq: int = -1,
target_kl: Optional[float] = None,
tensorboard_log: Optional[str] = None,
- create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
@@ -99,7 +99,7 @@ def __init__(
_init_setup_model: bool = True,
):
- super(PPO, self).__init__(
+ super().__init__(
policy,
env,
learning_rate=learning_rate,
@@ -115,7 +115,6 @@ def __init__(
policy_kwargs=policy_kwargs,
verbose=verbose,
device=device,
- create_eval_env=create_eval_env,
seed=seed,
_init_setup_model=False,
supported_action_spaces=(
@@ -137,8 +136,8 @@ def __init__(
# Check that `n_steps * n_envs > 1` to avoid NaN
# when doing advantage normalization
buffer_size = self.env.num_envs * self.n_steps
- assert (
- buffer_size > 1
+ assert buffer_size > 1 or (
+ not normalize_advantage
), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}"
# Check that the rollout buffer size is a multiple of the mini-batch size
untruncated_batches = buffer_size // batch_size
@@ -162,7 +161,7 @@ def __init__(
self._setup_model()
def _setup_model(self) -> None:
- super(PPO, self)._setup_model()
+ super()._setup_model()
# Initialize schedules for policy/value clipping
self.clip_range = get_schedule_fn(self.clip_range)
@@ -210,7 +209,8 @@ def train(self) -> None:
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
- if self.normalize_advantage:
+ # Normalization does not make sense if mini batchsize == 1, see GH issue #325
+ if self.normalize_advantage and len(advantages) > 1:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
@@ -230,7 +230,7 @@ def train(self) -> None:
# No clipping
values_pred = values
else:
- # Clip the different between old and new value
+ # Clip the difference between old and new value
# NOTE: this depends on the reward scaling
values_pred = rollout_data.old_values + th.clamp(
values - rollout_data.old_values, -clip_range_vf, clip_range_vf
@@ -295,26 +295,20 @@ def train(self) -> None:
self.logger.record("train/clip_range_vf", clip_range_vf)
def learn(
- self,
+ self: PPOSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
tb_log_name: str = "PPO",
- eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
- ) -> "PPO":
+ progress_bar: bool = False,
+ ) -> PPOSelf:
- return super(PPO, self).learn(
+ return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
- eval_env=eval_env,
- eval_freq=eval_freq,
- n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name,
- eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
+ progress_bar=progress_bar,
)
diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py
index cb6a61c11..2f45e1446 100644
--- a/stable_baselines3/sac/policies.py
+++ b/stable_baselines3/sac/policies.py
@@ -1,7 +1,6 @@
-import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, Union
-import gym
+import gymnasium as gym
import torch as th
from torch import nn
@@ -38,9 +37,6 @@ class Actor(BasePolicy):
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE.
- :param sde_net_arch: Network architecture for extracting features
- when using gSDE. If None, the latent features from the policy will be used.
- Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
@@ -60,12 +56,11 @@ def __init__(
use_sde: bool = False,
log_std_init: float = -3,
full_std: bool = True,
- sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
clip_mean: float = 2.0,
normalize_images: bool = True,
):
- super(Actor, self).__init__(
+ super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
@@ -80,14 +75,10 @@ def __init__(
self.features_dim = features_dim
self.activation_fn = activation_fn
self.log_std_init = log_std_init
- self.sde_net_arch = sde_net_arch
self.use_expln = use_expln
self.full_std = full_std
self.clip_mean = clip_mean
- if sde_net_arch is not None:
- warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
-
action_dim = get_action_dim(self.action_space)
latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
self.latent_pi = nn.Sequential(*latent_pi_net)
@@ -196,9 +187,6 @@ class SACPolicy(BasePolicy):
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
- :param sde_net_arch: Network architecture for extracting features
- when using gSDE. If None, the latent features from the policy will be used.
- Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
@@ -226,7 +214,6 @@ def __init__(
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
- sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
@@ -235,9 +222,9 @@ def __init__(
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
- share_features_extractor: bool = True,
+ share_features_extractor: bool = False,
):
- super(SACPolicy, self).__init__(
+ super().__init__(
observation_space,
action_space,
features_extractor_class,
@@ -248,10 +235,7 @@ def __init__(
)
if net_arch is None:
- if features_extractor_class == NatureCNN:
- net_arch = []
- else:
- net_arch = [256, 256]
+ net_arch = [256, 256]
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
@@ -266,9 +250,6 @@ def __init__(
}
self.actor_kwargs = self.net_args.copy()
- if sde_net_arch is not None:
- warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
-
sde_kwargs = {
"use_sde": use_sde,
"log_std_init": log_std_init,
@@ -385,9 +366,6 @@ class CnnPolicy(SACPolicy):
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
- :param sde_net_arch: Network architecture for extracting features
- when using gSDE. If None, the latent features from the policy will be used.
- Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
@@ -413,7 +391,6 @@ def __init__(
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
- sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
@@ -422,9 +399,9 @@ def __init__(
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
- share_features_extractor: bool = True,
+ share_features_extractor: bool = False,
):
- super(CnnPolicy, self).__init__(
+ super().__init__(
observation_space,
action_space,
lr_schedule,
@@ -432,7 +409,6 @@ def __init__(
activation_fn,
use_sde,
log_std_init,
- sde_net_arch,
use_expln,
clip_mean,
features_extractor_class,
@@ -456,9 +432,6 @@ class MultiInputPolicy(SACPolicy):
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
- :param sde_net_arch: Network architecture for extracting features
- when using gSDE. If None, the latent features from the policy will be used.
- Pass an empty list to use the states as features.
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
@@ -484,7 +457,6 @@ def __init__(
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
- sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
@@ -493,9 +465,9 @@ def __init__(
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
- share_features_extractor: bool = True,
+ share_features_extractor: bool = False,
):
- super(MultiInputPolicy, self).__init__(
+ super().__init__(
observation_space,
action_space,
lr_schedule,
@@ -503,7 +475,6 @@ def __init__(
activation_fn,
use_sde,
log_std_init,
- sde_net_arch,
use_expln,
clip_mean,
features_extractor_class,
diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py
index 3703b730b..365927ee9 100644
--- a/stable_baselines3/sac/sac.py
+++ b/stable_baselines3/sac/sac.py
@@ -1,6 +1,6 @@
-from typing import Any, Dict, List, Optional, Tuple, Type, Union
+from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
-import gym
+import gymnasium as gym
import numpy as np
import torch as th
from torch.nn import functional as F
@@ -10,9 +10,11 @@
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
-from stable_baselines3.common.utils import polyak_update
+from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
+SACSelf = TypeVar("SACSelf", bound="SAC")
+
class SAC(OffPolicyAlgorithm):
"""
@@ -63,10 +65,9 @@ class SAC(OffPolicyAlgorithm):
Default: -1 (only sample at the beginning of the rollout)
:param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
during the warm up phase (before learning starts)
- :param create_eval_env: Whether to create a second environment that will be
- used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: additional arguments to be passed to the policy on creation
- :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
+ debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
@@ -92,7 +93,7 @@ def __init__(
train_freq: Union[int, Tuple[int, str]] = 1,
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
- replay_buffer_class: Optional[ReplayBuffer] = None,
+ replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
ent_coef: Union[str, float] = "auto",
@@ -102,7 +103,6 @@ def __init__(
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
tensorboard_log: Optional[str] = None,
- create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
@@ -110,7 +110,7 @@ def __init__(
_init_setup_model: bool = True,
):
- super(SAC, self).__init__(
+ super().__init__(
policy,
env,
learning_rate,
@@ -128,7 +128,6 @@ def __init__(
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
- create_eval_env=create_eval_env,
seed=seed,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
@@ -150,8 +149,11 @@ def __init__(
self._setup_model()
def _setup_model(self) -> None:
- super(SAC, self)._setup_model()
+ super()._setup_model()
self._create_aliases()
+ # Running mean and running var
+ self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
+ self.batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"])
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
# automatically set target entropy if needed
@@ -248,7 +250,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
current_q_values = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss
- critic_loss = 0.5 * sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values])
+ critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
critic_losses.append(critic_loss.item())
# Optimize the critic
@@ -258,7 +260,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
# Compute actor loss
# Alternative: actor_loss = th.mean(log_prob - qf1_pi)
- # Mean over all critic networks
+ # Min over all critic networks
q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1)
min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
@@ -272,6 +274,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
# Update target networks
if gradient_step % self.target_update_interval == 0:
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
+ # Copy running stats, see GH issue #996
+ polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
self._n_updates += gradient_steps
@@ -283,32 +287,26 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
def learn(
- self,
+ self: SACSelf,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
tb_log_name: str = "SAC",
- eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
- ) -> OffPolicyAlgorithm:
+ progress_bar: bool = False,
+ ) -> SACSelf:
- return super(SAC, self).learn(
+ return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
- eval_env=eval_env,
- eval_freq=eval_freq,
- n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name,
- eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
+ progress_bar=progress_bar,
)
def _excluded_save_params(self) -> List[str]:
- return super(SAC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]
+ return super()._excluded_save_params() + ["actor", "critic", "critic_target"]
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py
index ce91a0f91..6b554851f 100644
--- a/stable_baselines3/td3/policies.py
+++ b/stable_baselines3/td3/policies.py
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Optional, Type, Union
-import gym
+import gymnasium as gym
import torch as th
from torch import nn
@@ -42,7 +42,7 @@ def __init__(
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
):
- super(Actor, self).__init__(
+ super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
@@ -119,9 +119,9 @@ def __init__(
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
- share_features_extractor: bool = True,
+ share_features_extractor: bool = False,
):
- super(TD3Policy, self).__init__(
+ super().__init__(
observation_space,
action_space,
features_extractor_class,
@@ -134,7 +134,7 @@ def __init__(
# Default network architecture, from the original paper
if net_arch is None:
if features_extractor_class == NatureCNN:
- net_arch = []
+ net_arch = [256, 256]
else:
net_arch = [400, 300]
@@ -281,9 +281,9 @@ def __init__(
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
- share_features_extractor: bool = True,
+ share_features_extractor: bool = False,
):
- super(CnnPolicy, self).__init__(
+ super().__init__(
observation_space,
action_space,
lr_schedule,
@@ -335,9 +335,9 @@ def __init__(
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
- share_features_extractor: bool = True,
+ share_features_extractor: bool = False,
):
- super(MultiInputPolicy, self).__init__(
+ super().__init__(
observation_space,
action_space,
lr_schedule,
diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py
index d31720b67..a37df473c 100644
--- a/stable_baselines3/td3/td3.py
+++ b/stable_baselines3/td3/td3.py
@@ -1,6 +1,6 @@
-from typing import Any, Dict, List, Optional, Tuple, Type, Union
+from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
-import gym
+import gymnasium as gym
import numpy as np
import torch as th
from torch.nn import functional as F
@@ -10,9 +10,11 @@
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
-from stable_baselines3.common.utils import polyak_update
+from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy
+TD3Self = TypeVar("TD3Self", bound="TD3")
+
class TD3(OffPolicyAlgorithm):
"""
@@ -51,10 +53,9 @@ class TD3(OffPolicyAlgorithm):
:param target_policy_noise: Standard deviation of Gaussian noise added to target policy
(smoothing noise)
:param target_noise_clip: Limit for absolute value of target policy smoothing noise.
- :param create_eval_env: Whether to create a second environment that will be
- used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: additional arguments to be passed to the policy on creation
- :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
+ :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
+ debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
@@ -80,14 +81,13 @@ def __init__(
train_freq: Union[int, Tuple[int, str]] = (1, "episode"),
gradient_steps: int = -1,
action_noise: Optional[ActionNoise] = None,
- replay_buffer_class: Optional[ReplayBuffer] = None,
+ replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
policy_delay: int = 2,
target_policy_noise: float = 0.2,
target_noise_clip: float = 0.5,
tensorboard_log: Optional[str] = None,
- create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
@@ -95,7 +95,7 @@ def __init__(
_init_setup_model: bool = True,
):
- super(TD3, self).__init__(
+ super().__init__(
policy,
env,
learning_rate,
@@ -113,7 +113,6 @@ def __init__(
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
- create_eval_env=create_eval_env,
seed=seed,
sde_support=False,
optimize_memory_usage=optimize_memory_usage,
@@ -129,8 +128,13 @@ def __init__(
self._setup_model()
def _setup_model(self) -> None:
- super(TD3, self)._setup_model()
+ super()._setup_model()
self._create_aliases()
+ # Running mean and running var
+ self.actor_batch_norm_stats = get_parameters_by_name(self.actor, ["running_"])
+ self.critic_batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
+ self.actor_batch_norm_stats_target = get_parameters_by_name(self.actor_target, ["running_"])
+ self.critic_batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"])
def _create_aliases(self) -> None:
self.actor = self.policy.actor
@@ -146,7 +150,6 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
self._update_learning_rate([self.actor.optimizer, self.critic.optimizer])
actor_losses, critic_losses = [], []
-
for _ in range(gradient_steps):
self._n_updates += 1
@@ -168,7 +171,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
current_q_values = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss
- critic_loss = sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values])
+ critic_loss = sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
critic_losses.append(critic_loss.item())
# Optimize the critics
@@ -189,6 +192,9 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
polyak_update(self.actor.parameters(), self.actor_target.parameters(), self.tau)
+ # Copy running stats, see GH issue #996
+ polyak_update(self.critic_batch_norm_stats, self.critic_batch_norm_stats_target, 1.0)
+ polyak_update(self.actor_batch_norm_stats, self.actor_batch_norm_stats_target, 1.0)
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
if len(actor_losses) > 0:
@@ -196,32 +202,26 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
self.logger.record("train/critic_loss", np.mean(critic_losses))
def learn(
- self,
+ self: TD3Self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
tb_log_name: str = "TD3",
- eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
- ) -> OffPolicyAlgorithm:
+ progress_bar: bool = False,
+ ) -> TD3Self:
- return super(TD3, self).learn(
+ return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
- eval_env=eval_env,
- eval_freq=eval_freq,
- n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name,
- eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
+ progress_bar=progress_bar,
)
def _excluded_save_params(self) -> List[str]:
- return super(TD3, self)._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"]
+ return super()._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"]
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt
index d6a9f8c61..35a785a76 100644
--- a/stable_baselines3/version.txt
+++ b/stable_baselines3/version.txt
@@ -1 +1 @@
-1.5.1a4
+2.0.0a0
diff --git a/tests/test_buffers.py b/tests/test_buffers.py
index 45c5e6aa3..bdf320c7a 100644
--- a/tests/test_buffers.py
+++ b/tests/test_buffers.py
@@ -1,12 +1,13 @@
-import gym
+import gymnasium as gym
import numpy as np
import pytest
import torch as th
-from gym import spaces
+from gymnasium import spaces
-from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
+from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples
+from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize
@@ -26,15 +27,15 @@ def __init__(self):
def reset(self):
self._t = 0
obs = self._observations[0]
- return obs
+ return obs, {}
def step(self, action):
self._t += 1
index = self._t % len(self._observations)
obs = self._observations[index]
- done = self._t >= self._ep_length
+ done = truncated = self._t >= self._ep_length
reward = self._rewards[index]
- return obs, reward, done, {}
+ return obs, reward, done, truncated, {}
class DummyDictEnv(gym.Env):
@@ -54,15 +55,15 @@ def __init__(self):
def reset(self):
self._t = 0
obs = {key: self._observations[0] for key in self.observation_space.spaces.keys()}
- return obs
+ return obs, {}
def step(self, action):
self._t += 1
index = self._t % len(self._observations)
obs = {key: self._observations[index] for key in self.observation_space.spaces.keys()}
- done = self._t >= self._ep_length
+ done = truncated = self._t >= self._ep_length
reward = self._rewards[index]
- return obs, reward, done, {}
+ return obs, reward, done, truncated, {}
@pytest.mark.parametrize("replay_buffer_cls", [ReplayBuffer, DictReplayBuffer])
@@ -71,7 +72,7 @@ def test_replay_buffer_normalization(replay_buffer_cls):
env = make_vec_env(env)
env = VecNormalize(env)
- buffer = replay_buffer_cls(100, env.observation_space, env.action_space)
+ buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device="cpu")
# Interract and store transitions
env.reset()
@@ -94,3 +95,47 @@ def test_replay_buffer_normalization(replay_buffer_cls):
assert th.allclose(observations.mean(0), th.zeros(1), atol=1)
# Test reward normalization
assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1)
+
+
+@pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer])
+@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"])
+def test_device_buffer(replay_buffer_cls, device):
+ if device == "cuda" and not th.cuda.is_available():
+ pytest.skip("CUDA not available")
+
+ env = {
+ RolloutBuffer: DummyEnv,
+ DictRolloutBuffer: DummyDictEnv,
+ ReplayBuffer: DummyEnv,
+ DictReplayBuffer: DummyDictEnv,
+ }[replay_buffer_cls]
+ env = make_vec_env(env)
+
+ buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device)
+
+ # Interract and store transitions
+ obs = env.reset()
+ for _ in range(100):
+ action = env.action_space.sample()
+ next_obs, reward, done, info = env.step(action)
+ if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
+ episode_start, values, log_prob = np.zeros(1), th.zeros(1), th.ones(1)
+ buffer.add(obs, action, reward, episode_start, values, log_prob)
+ else:
+ buffer.add(obs, next_obs, action, reward, done, info)
+ obs = next_obs
+
+ # Get data from the buffer
+ if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
+ data = buffer.get(50)
+ elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]:
+ data = buffer.sample(50)
+
+ # Check that all data are on the desired device
+ desired_device = get_device(device).type
+ for value in list(data):
+ if isinstance(value, dict):
+ for key in value.keys():
+ assert value[key].device.type == desired_device
+ elif isinstance(value, th.Tensor):
+ assert value.device.type == desired_device
diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py
index 6576f7dc3..1177bc87d 100644
--- a/tests/test_callbacks.py
+++ b/tests/test_callbacks.py
@@ -1,7 +1,7 @@
import os
import shutil
-import gym
+import gymnasium as gym
import numpy as np
import pytest
@@ -71,8 +71,8 @@ def test_callbacks(tmp_path, model_class):
assert event_callback.n_calls == model.num_timesteps
model.learn(500, callback=None)
- # Transform callback into a callback list automatically
- model.learn(500, callback=[checkpoint_callback, eval_callback])
+ # Transform callback into a callback list automatically and use progress bar
+ model.learn(500, callback=[checkpoint_callback, eval_callback], progress_bar=True)
# Automatic wrapping, old way of doing callbacks
model.learn(500, callback=lambda _locals, _globals: True)
@@ -102,7 +102,7 @@ def test_callbacks(tmp_path, model_class):
def select_env(model_class) -> str:
if model_class is DQN:
- return "CartPole-v0"
+ return "CartPole-v1"
else:
return "Pendulum-v1"
@@ -203,3 +203,29 @@ def test_eval_friendly_error():
with pytest.warns(Warning):
with pytest.raises(AssertionError):
model.learn(100, callback=eval_callback)
+
+
+def test_checkpoint_additional_info(tmp_path):
+ # tests if the replay buffer and the VecNormalize stats are saved with every checkpoint
+ dummy_vec_env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
+ env = VecNormalize(dummy_vec_env)
+
+ checkpoint_dir = tmp_path / "checkpoints"
+ checkpoint_callback = CheckpointCallback(
+ save_freq=200,
+ save_path=checkpoint_dir,
+ save_replay_buffer=True,
+ save_vecnormalize=True,
+ verbose=2,
+ )
+
+ model = DQN("MlpPolicy", env, learning_starts=100, buffer_size=500, seed=0)
+ model.learn(200, callback=checkpoint_callback)
+
+ assert os.path.exists(checkpoint_dir / "rl_model_200_steps.zip")
+ assert os.path.exists(checkpoint_dir / "rl_model_replay_buffer_200_steps.pkl")
+ assert os.path.exists(checkpoint_dir / "rl_model_vecnormalize_200_steps.pkl")
+ # Check that checkpoints can be properly loaded
+ model = DQN.load(checkpoint_dir / "rl_model_200_steps.zip")
+ model.load_replay_buffer(checkpoint_dir / "rl_model_replay_buffer_200_steps.pkl")
+ VecNormalize.load(checkpoint_dir / "rl_model_vecnormalize_200_steps.pkl", dummy_vec_env)
diff --git a/tests/test_cnn.py b/tests/test_cnn.py
index 03f089db9..a230fa445 100644
--- a/tests/test_cnn.py
+++ b/tests/test_cnn.py
@@ -4,7 +4,7 @@
import numpy as np
import pytest
import torch as th
-from gym import spaces
+from gymnasium import spaces
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.envs import FakeImageEnv
@@ -35,7 +35,7 @@ def test_cnn(tmp_path, model_class):
# FakeImageEnv is channel last by default and should be wrapped
assert is_vecenv_wrapped(model.get_env(), VecTransposeImage)
- obs = env.reset()
+ obs, _ = env.reset()
# Test stochastic predict with channel last input
if model_class == DQN:
@@ -238,7 +238,7 @@ def test_channel_first_env(tmp_path):
assert not is_vecenv_wrapped(model.get_env(), VecTransposeImage)
- obs = env.reset()
+ obs, _ = env.reset()
action, _ = model.predict(obs, deterministic=True)
diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py
index 93b13b40e..4aea04c17 100644
--- a/tests/test_dict_env.py
+++ b/tests/test_dict_env.py
@@ -1,7 +1,9 @@
-import gym
+from typing import Optional
+
+import gymnasium as gym
import numpy as np
import pytest
-from gym import spaces
+from gymnasium import spaces
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.env_util import make_vec_env
@@ -66,14 +68,16 @@ def seed(self, seed=None):
def step(self, action):
reward = 0.0
- done = False
- return self.observation_space.sample(), reward, done, {}
+ done = truncated = False
+ return self.observation_space.sample(), reward, done, truncated, {}
def compute_reward(self, achieved_goal, desired_goal, info):
return np.zeros((len(achieved_goal),))
- def reset(self):
- return self.observation_space.sample()
+ def reset(self, seed: Optional[int] = None):
+ if seed is not None:
+ self.observation_space.seed(seed)
+ return self.observation_space.sample(), {}
def render(self, mode="human"):
pass
@@ -105,7 +109,7 @@ def test_consistency(model_class):
dict_env = gym.wrappers.TimeLimit(dict_env, 100)
env = gym.wrappers.FlattenObservation(dict_env)
dict_env.seed(10)
- obs = dict_env.reset()
+ obs, _ = dict_env.reset()
kwargs = {}
n_steps = 256
diff --git a/tests/test_distributions.py b/tests/test_distributions.py
index 3652b1850..d041ca528 100644
--- a/tests/test_distributions.py
+++ b/tests/test_distributions.py
@@ -1,7 +1,7 @@
from copy import deepcopy
from typing import Tuple
-import gym
+import gymnasium as gym
import numpy as np
import pytest
import torch as th
@@ -77,6 +77,8 @@ def test_get_distribution(dummy_model_distribution_obs_and_actions):
distribution = model.policy.get_distribution(observations)
log_prob_2 = distribution.log_prob(actions)
entropy_2 = distribution.entropy()
+ assert entropy_1 is not None
+ assert entropy_2 is not None
assert th.allclose(log_prob_1, log_prob_2)
assert th.allclose(entropy_1, entropy_2)
@@ -163,7 +165,9 @@ def test_categorical(dist, CAT_ACTIONS):
BernoulliDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)),
CategoricalDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)),
DiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)),
- MultiCategoricalDistribution([N_ACTIONS, N_ACTIONS]).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS]))),
+ MultiCategoricalDistribution(np.array([N_ACTIONS, N_ACTIONS])).proba_distribution(
+ th.rand(1, sum([N_ACTIONS, N_ACTIONS]))
+ ),
SquashedDiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)),
StateDependentNoiseDistribution(N_ACTIONS).proba_distribution(
th.rand(N_ACTIONS), th.rand([N_ACTIONS, N_ACTIONS]), th.rand([N_ACTIONS, N_ACTIONS])
diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py
index 0b0a82d8f..5a20a91d2 100644
--- a/tests/test_env_checker.py
+++ b/tests/test_env_checker.py
@@ -1,4 +1,4 @@
-import gym
+import gymnasium as gym
import numpy as np
import pytest
from gym.spaces import Box, Dict, Discrete
@@ -14,11 +14,12 @@ def step(self, action):
observation = np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype)
reward = 1
done = True
+ truncated = False
info = {}
- return observation, reward, done, info
+ return observation, reward, done, truncated, info
def reset(self):
- return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype)
+ return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype), {}
def render(self, mode="human"):
pass
diff --git a/tests/test_envs.py b/tests/test_envs.py
index 671e2a5e6..0f4a8efff 100644
--- a/tests/test_envs.py
+++ b/tests/test_envs.py
@@ -1,10 +1,10 @@
import types
import warnings
-import gym
+import gymnasium as gym
import numpy as np
import pytest
-from gym import spaces
+from gymnasium import spaces
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.envs import (
@@ -28,7 +28,7 @@
]
-@pytest.mark.parametrize("env_id", ["CartPole-v0", "Pendulum-v1"])
+@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
def test_env(env_id):
"""
Check that environmnent integrated in Gym pass the test.
@@ -75,6 +75,17 @@ def test_bit_flipping(kwargs):
# No warnings for custom envs
assert len(record) == 0
+ # Remove a key, must throw an error
+ obs_space = env.observation_space.spaces["observation"]
+ del env.observation_space.spaces["observation"]
+ with pytest.raises(AssertionError):
+ check_env(env)
+
+ # Rename a key, must throw an error
+ env.observation_space.spaces["obs"] = obs_space
+ with pytest.raises(AssertionError):
+ check_env(env)
+
def test_high_dimension_action_space():
"""
@@ -87,7 +98,7 @@ def test_high_dimension_action_space():
# Patch to avoid error
def patched_step(_action):
- return env.observation_space.sample(), 0.0, False, {}
+ return env.observation_space.sample(), 0.0, False, False, {}
env.step = patched_step
check_env(env)
@@ -116,10 +127,10 @@ def test_non_default_spaces(new_obs_space):
env = FakeImageEnv()
env.observation_space = new_obs_space
# Patch methods to avoid errors
- env.reset = new_obs_space.sample
+ env.reset = lambda: (new_obs_space.sample(), {})
def patched_step(_action):
- return new_obs_space.sample(), 0.0, False, {}
+ return new_obs_space.sample(), 0.0, False, False, {}
env.step = patched_step
with pytest.warns(UserWarning):
@@ -141,6 +152,8 @@ def patched_step(_action):
spaces.Box(low=1, high=-1, shape=(2,), dtype=np.float32),
# Same boundaries
spaces.Box(low=1, high=1, shape=(2,), dtype=np.float32),
+ # Unbounded action space
+ spaces.Box(low=-np.inf, high=1, shape=(2,), dtype=np.float32),
# Almost good, except for one dim
spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32),
],
@@ -153,11 +166,23 @@ def test_non_default_action_spaces(new_action_space):
# No warnings for custom envs
assert len(record) == 0
+
# Change the action space
env.action_space = new_action_space
- with pytest.warns(UserWarning):
- check_env(env)
+ low, high = new_action_space.low[0], new_action_space.high[0]
+ # Unbounded action space throws an error,
+ # the rest only warning
+ if not np.all(np.isfinite(env.action_space.low)):
+ with pytest.raises(AssertionError), pytest.warns(UserWarning):
+ check_env(env)
+ # numpy >= 1.21 raises a ValueError
+ elif int(np.__version__.split(".")[1]) >= 21 and (low > high):
+ with pytest.raises(ValueError), pytest.warns(UserWarning):
+ check_env(env)
+ else:
+ with pytest.warns(UserWarning):
+ check_env(env)
def check_reset_assert_error(env, new_reset_return):
@@ -168,7 +193,7 @@ def check_reset_assert_error(env, new_reset_return):
"""
def wrong_reset():
- return new_reset_return
+ return new_reset_return, {}
# Patch the reset method with a wrong one
env.reset = wrong_reset
@@ -186,14 +211,27 @@ def test_common_failures_reset():
# The observation is not a numpy array
check_reset_assert_error(env, 1)
+ # Return only obs (gym < 0.26)
+ env.reset = env.observation_space.sample
+ with pytest.raises(AssertionError):
+ check_env(env)
+
# Return not only the observation
check_reset_assert_error(env, (env.observation_space.sample(), False))
env = SimpleMultiObsEnv()
- obs = env.reset()
+
+ # Observation keys and observation space keys must match
+ wrong_obs = env.observation_space.sample()
+ wrong_obs.pop("img")
+ check_reset_assert_error(env, wrong_obs)
+ wrong_obs = {**env.observation_space.sample(), "extra_key": None}
+ check_reset_assert_error(env, wrong_obs)
+
+ obs, _ = env.reset()
def wrong_reset(self):
- return {"img": obs["img"], "vec": obs["img"]}
+ return {"img": obs["img"], "vec": obs["img"]}, {}
env.reset = types.MethodType(wrong_reset, env)
with pytest.raises(AssertionError) as excinfo:
@@ -226,25 +264,38 @@ def test_common_failures_step():
env = IdentityEnvBox()
# Wrong shape for the observation
- check_step_assert_error(env, (np.ones((4,)), 1.0, False, {}))
+ check_step_assert_error(env, (np.ones((4,)), 1.0, False, False, {}))
# Obs is not a numpy array
- check_step_assert_error(env, (1, 1.0, False, {}))
+ check_step_assert_error(env, (1, 1.0, False, False, {}))
# Return a wrong reward
- check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, {}))
+ check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, False, {}))
# Info dict is not returned
- check_step_assert_error(env, (env.observation_space.sample(), 0.0, False))
+ check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, False))
+
+ # Truncated is not returned (gym < 0.26)
+ check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, {}))
# Done is not a boolean
- check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, {}))
- check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, {}))
+ check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, False, {}))
+ check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, False, {}))
+ # Truncated is not a boolean
+ check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, 1.0, {}))
env = SimpleMultiObsEnv()
- obs = env.reset()
+
+ # Observation keys and observation space keys must match
+ wrong_obs = env.observation_space.sample()
+ wrong_obs.pop("img")
+ check_step_assert_error(env, (wrong_obs, 0.0, False, False, {}))
+ wrong_obs = {**env.observation_space.sample(), "extra_key": None}
+ check_step_assert_error(env, (wrong_obs, 0.0, False, False, {}))
+
+ obs, _ = env.reset()
def wrong_step(self, action):
- return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, {}
+ return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, False, {}
env.step = types.MethodType(wrong_step, env)
with pytest.raises(AssertionError) as excinfo:
diff --git a/tests/test_gae.py b/tests/test_gae.py
index 54e03b8b1..be2d0c3e3 100644
--- a/tests/test_gae.py
+++ b/tests/test_gae.py
@@ -1,4 +1,6 @@
-import gym
+from typing import Optional
+
+import gymnasium as gym
import numpy as np
import pytest
import torch as th
@@ -10,7 +12,7 @@
class CustomEnv(gym.Env):
def __init__(self, max_steps=8):
- super(CustomEnv, self).__init__()
+ super().__init__()
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.max_steps = max_steps
@@ -19,20 +21,26 @@ def __init__(self, max_steps=8):
def seed(self, seed):
self.observation_space.seed(seed)
- def reset(self):
+ def reset(self, seed: Optional[int] = None):
+ if seed is not None:
+ self.observation_space.seed(seed)
self.n_steps = 0
- return self.observation_space.sample()
+ return self.observation_space.sample(), {}
def step(self, action):
self.n_steps += 1
- done = False
+ done = truncated = False
reward = 0.0
if self.n_steps >= self.max_steps:
reward = 1.0
done = True
+ # To simplify GAE computation checks,
+ # we do not consider truncation here.
+ # Truncations are checked in InfiniteHorizonEnv
+ truncated = False
- return self.observation_space.sample(), reward, done, {}
+ return self.observation_space.sample(), reward, done, truncated, {}
class InfiniteHorizonEnv(gym.Env):
@@ -43,18 +51,21 @@ def __init__(self, n_states=4):
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.current_state = 0
- def reset(self):
+ def reset(self, seed: Optional[int] = None):
+ if seed is not None:
+ super().reset(seed=seed)
+
self.current_state = 0
- return self.current_state
+ return self.current_state, {}
def step(self, action):
self.current_state = (self.current_state + 1) % self.n_states
- return self.current_state, 1.0, False, {}
+ return self.current_state, 1.0, False, False, {}
class CheckGAECallback(BaseCallback):
def __init__(self):
- super(CheckGAECallback, self).__init__(verbose=0)
+ super().__init__(verbose=0)
def _on_rollout_end(self):
buffer = self.model.rollout_buffer
@@ -99,7 +110,7 @@ class CustomPolicy(ActorCriticPolicy):
"""Custom Policy with a constant value function"""
def __init__(self, *args, **kwargs):
- super(CustomPolicy, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
self.constant_value = 0.0
def forward(self, obs, deterministic=False):
diff --git a/tests/test_her.py b/tests/test_her.py
index 0f6d75f6f..2e385d51e 100644
--- a/tests/test_her.py
+++ b/tests/test_her.py
@@ -3,7 +3,7 @@
import warnings
from copy import deepcopy
-import gym
+import gymnasium as gym
import numpy as np
import pytest
import torch as th
@@ -143,7 +143,7 @@ def test_save_load(tmp_path, model_class, use_sde, online_sampling):
model.learn(total_timesteps=150)
- obs = env.reset()
+ obs, _ = env.reset()
observations = {key: [] for key in obs.keys()}
for _ in range(10):
@@ -156,7 +156,7 @@ def test_save_load(tmp_path, model_class, use_sde, online_sampling):
params = deepcopy(model.policy.state_dict())
# Modify all parameters to be random values
- random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
+ random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
# Update model parameters with the new random values
model.policy.load_state_dict(random_params)
@@ -237,7 +237,7 @@ def test_save_load_replay_buffer(tmp_path, recwarn, online_sampling, truncate_la
train_freq=4,
buffer_size=int(2e4),
policy_kwargs=dict(net_arch=[64]),
- seed=1,
+ seed=0,
)
model.learn(200)
if online_sampling:
diff --git a/tests/test_identity.py b/tests/test_identity.py
index f5bbc4946..d4b584016 100644
--- a/tests/test_identity.py
+++ b/tests/test_identity.py
@@ -15,21 +15,17 @@
def test_discrete(model_class, env):
env_ = DummyVecEnv([lambda: env])
kwargs = {}
- n_steps = 3000
+ n_steps = 2500
if model_class == DQN:
kwargs = dict(learning_starts=0)
- n_steps = 4000
# DQN only support discrete actions
if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
return
- elif model_class == A2C:
- # slightly higher budget
- n_steps = 3500
- model = model_class("MlpPolicy", env_, gamma=0.4, seed=1, **kwargs).learn(n_steps)
+ model = model_class("MlpPolicy", env_, gamma=0.4, seed=3, **kwargs).learn(n_steps)
evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False)
- obs = env.reset()
+ obs, _ = env.reset()
assert np.shape(model.predict(obs)[0]) == np.shape(obs)
@@ -38,14 +34,17 @@ def test_discrete(model_class, env):
def test_continuous(model_class):
env = IdentityEnvBox(eps=0.5)
- n_steps = {A2C: 3500, PPO: 3000, SAC: 700, TD3: 500, DDPG: 500}[model_class]
+ n_steps = {A2C: 2000, PPO: 2000, SAC: 400, TD3: 400, DDPG: 400}[model_class]
kwargs = dict(policy_kwargs=dict(net_arch=[64, 64]), seed=0, gamma=0.95)
+
if model_class in [TD3]:
n_actions = 1
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
kwargs["action_noise"] = action_noise
+ elif model_class == PPO:
+ kwargs = dict(n_steps=512, n_epochs=5)
- model = model_class("MlpPolicy", env, **kwargs).learn(n_steps)
+ model = model_class("MlpPolicy", env, learning_rate=1e-3, **kwargs).learn(n_steps)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)
diff --git a/tests/test_logger.py b/tests/test_logger.py
index 6fe536f93..849a6fd8b 100644
--- a/tests/test_logger.py
+++ b/tests/test_logger.py
@@ -1,8 +1,10 @@
import os
+import sys
import time
from typing import Sequence
+from unittest import mock
-import gym
+import gymnasium as gym
import numpy as np
import pytest
import torch as th
@@ -16,6 +18,7 @@
CSVOutputFormat,
Figure,
FormatUnsupportedError,
+ HParam,
HumanOutputFormat,
Image,
Logger,
@@ -295,6 +298,19 @@ def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_
writer.close()
+@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
+def test_report_hparam_to_unsupported_format_raises_error(tmp_path, unsupported_format):
+ writer = make_output_format(unsupported_format, tmp_path)
+
+ with pytest.raises(FormatUnsupportedError) as exec_info:
+ hparam_dict = {"learning rate": np.random.random()}
+ metric_dict = {"train/value_loss": 0}
+ hparam = HParam(hparam_dict=hparam_dict, metric_dict=metric_dict)
+ writer.write({"hparam": hparam}, key_excluded={"hparam": ()})
+ assert unsupported_format in str(exec_info.value)
+ writer.close()
+
+
def test_key_length(tmp_path):
writer = make_output_format("stdout", tmp_path)
assert writer.max_length == 36
@@ -338,12 +354,12 @@ def __init__(self, delay: float = 0.01):
self.action_space = gym.spaces.Discrete(2)
def reset(self):
- return self.observation_space.sample()
+ return self.observation_space.sample(), {}
def step(self, action):
time.sleep(self.delay)
obs = self.observation_space.sample()
- return obs, 0.0, True, {}
+ return obs, 0.0, True, False, {}
class InMemoryLogger(Logger):
@@ -381,3 +397,24 @@ def test_fps_logger(tmp_path, algo):
# third time, FPS should be the same
model.learn(100, log_interval=1, reset_num_timesteps=False)
assert max_fps / 10 <= logger.name_to_value["time/fps"] <= max_fps
+
+
+@pytest.mark.parametrize("algo", [A2C, DQN])
+def test_fps_no_div_zero(algo):
+ """Set time to constant and train algorithm to check no division by zero error.
+
+ Time can appear to be constant during short runs on platforms with low-precision
+ timers. We should avoid division by zero errors e.g. when computing FPS in
+ this situation."""
+ with mock.patch("time.time", lambda: 42.0):
+ with mock.patch("time.time_ns", lambda: 42.0):
+ model = algo("MlpPolicy", "CartPole-v1")
+ model.learn(total_timesteps=100)
+
+
+def test_human_output_format_no_crash_on_same_keys_different_tags():
+ o = HumanOutputFormat(sys.stdout, max_length=60)
+ o.write(
+ {"key1/foo": "value1", "key1/bar": "value2", "key2/bizz": "value3", "key2/foo": "value4"},
+ {"key1/foo": None, "key2/bizz": None, "key1/bar": None, "key2/foo": None},
+ )
diff --git a/tests/test_monitor.py b/tests/test_monitor.py
index d3d041b4d..d847a926b 100644
--- a/tests/test_monitor.py
+++ b/tests/test_monitor.py
@@ -2,7 +2,7 @@
import os
import uuid
-import gym
+import gymnasium as gym
import pandas
from stable_baselines3.common.monitor import Monitor, get_monitor_files, load_results
@@ -13,8 +13,8 @@ def test_monitor(tmp_path):
Test the monitor wrapper
"""
env = gym.make("CartPole-v1")
- env.seed(0)
- monitor_file = os.path.join(str(tmp_path), "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
+ env.reset(seed=0)
+ monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
monitor_env = Monitor(env, monitor_file)
monitor_env.reset()
total_steps = 1000
@@ -22,10 +22,10 @@ def test_monitor(tmp_path):
ep_lengths = []
ep_len, ep_reward = 0, 0
for _ in range(total_steps):
- _, reward, done, _ = monitor_env.step(monitor_env.action_space.sample())
+ _, reward, done, truncated, _ = monitor_env.step(monitor_env.action_space.sample())
ep_len += 1
ep_reward += reward
- if done:
+ if done or truncated:
ep_rewards.append(ep_reward)
ep_lengths.append(ep_len)
monitor_env.reset()
@@ -37,7 +37,7 @@ def test_monitor(tmp_path):
assert sum(monitor_env.get_episode_rewards()) == sum(ep_rewards)
_ = monitor_env.get_episode_times()
- with open(monitor_file, "rt") as file_handler:
+ with open(monitor_file) as file_handler:
first_line = file_handler.readline()
assert first_line.startswith("#")
metadata = json.loads(first_line[1:])
@@ -48,6 +48,15 @@ def test_monitor(tmp_path):
assert set(last_logline.keys()) == {"l", "t", "r"}, "Incorrect keys in monitor logline"
os.remove(monitor_file)
+ # Check missing filename directories are created
+ monitor_dir = os.path.join(str(tmp_path), "missing-dir")
+ monitor_file = os.path.join(monitor_dir, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
+ assert os.path.exists(monitor_dir) is False
+ _ = Monitor(env, monitor_file)
+ assert os.path.exists(monitor_dir) is True
+ os.remove(monitor_file)
+ os.rmdir(monitor_dir)
+
def test_monitor_load_results(tmp_path):
"""
@@ -55,8 +64,8 @@ def test_monitor_load_results(tmp_path):
"""
tmp_path = str(tmp_path)
env1 = gym.make("CartPole-v1")
- env1.seed(0)
- monitor_file1 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
+ env1.reset(seed=0)
+ monitor_file1 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
monitor_env1 = Monitor(env1, monitor_file1)
monitor_files = get_monitor_files(tmp_path)
@@ -66,8 +75,8 @@ def test_monitor_load_results(tmp_path):
monitor_env1.reset()
episode_count1 = 0
for _ in range(1000):
- _, _, done, _ = monitor_env1.step(monitor_env1.action_space.sample())
- if done:
+ _, _, done, truncated, _ = monitor_env1.step(monitor_env1.action_space.sample())
+ if done or truncated:
episode_count1 += 1
monitor_env1.reset()
@@ -75,21 +84,24 @@ def test_monitor_load_results(tmp_path):
assert results_size1 == episode_count1
env2 = gym.make("CartPole-v1")
- env2.seed(0)
- monitor_file2 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
+ env2.reset(seed=0)
+ monitor_file2 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
monitor_env2 = Monitor(env2, monitor_file2)
monitor_files = get_monitor_files(tmp_path)
assert len(monitor_files) == 2
assert monitor_file1 in monitor_files
assert monitor_file2 in monitor_files
- monitor_env2.reset()
episode_count2 = 0
- for _ in range(1000):
- _, _, done, _ = monitor_env2.step(monitor_env2.action_space.sample())
- if done:
- episode_count2 += 1
- monitor_env2.reset()
+ for _ in range(2):
+ # Test appending to existing file
+ monitor_env2 = Monitor(env2, monitor_file2, override_existing=False)
+ monitor_env2.reset()
+ for _ in range(1000):
+ _, _, done, truncated, _ = monitor_env2.step(monitor_env2.action_space.sample())
+ if done or truncated:
+ episode_count2 += 1
+ monitor_env2.reset()
results_size2 = len(load_results(os.path.join(tmp_path)).index)
diff --git a/tests/test_predict.py b/tests/test_predict.py
index 853f4d11d..75787d7a0 100644
--- a/tests/test_predict.py
+++ b/tests/test_predict.py
@@ -1,4 +1,4 @@
-import gym
+import gymnasium as gym
import numpy as np
import pytest
import torch as th
@@ -29,25 +29,23 @@ def __init__(self):
self.action_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32)
def reset(self):
- return self.observation_space.sample()
+ return self.observation_space.sample(), {}
def step(self, action):
- return self.observation_space.sample(), 0.0, np.random.rand() > 0.5, {}
+ return self.observation_space.sample(), 0.0, np.random.rand() > 0.5, False, {}
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_auto_wrap(model_class):
- # test auto wrapping of env into a VecEnv
-
+ """Test auto wrapping of env into a VecEnv."""
# Use different environment for DQN
if model_class is DQN:
- env_name = "CartPole-v0"
+ env_name = "CartPole-v1"
else:
env_name = "Pendulum-v1"
env = gym.make(env_name)
- eval_env = gym.make(env_name)
model = model_class("MlpPolicy", env)
- model.learn(100, eval_env=eval_env)
+ model.learn(100)
@pytest.mark.parametrize("model_class", MODEL_LIST)
@@ -71,13 +69,15 @@ def test_predict(model_class, env_id, device):
env = gym.make(env_id)
vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)])
- obs = env.reset()
+ obs, _ = env.reset()
action, _ = model.predict(obs)
+ assert isinstance(action, np.ndarray)
assert action.shape == env.action_space.shape
assert env.action_space.contains(action)
vec_env_obs = vec_env.reset()
action, _ = model.predict(vec_env_obs)
+ assert isinstance(action, np.ndarray)
assert action.shape[0] == vec_env_obs.shape[0]
# Special case for DQN to check the epsilon greedy exploration
@@ -95,7 +95,7 @@ def test_dqn_epsilon_greedy():
env = IdentityEnv(2)
model = DQN("MlpPolicy", env)
model.exploration_rate = 1.0
- obs = env.reset()
+ obs, _ = env.reset()
# is vectorized should not crash with discrete obs
action, _ = model.predict(obs, deterministic=False)
assert env.action_space.contains(action)
@@ -106,5 +106,5 @@ def test_subclassed_space_env(model_class):
env = CustomSubClassedSpaceEnv()
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[32]))
model.learn(300)
- obs = env.reset()
+ obs, _ = env.reset()
env.step(model.predict(obs))
diff --git a/tests/test_run.py b/tests/test_run.py
index e4e8a2e45..2071f2bb1 100644
--- a/tests/test_run.py
+++ b/tests/test_run.py
@@ -1,4 +1,4 @@
-import gym
+import gymnasium as gym
import numpy as np
import pytest
@@ -10,7 +10,10 @@
@pytest.mark.parametrize("model_class", [TD3, DDPG])
-@pytest.mark.parametrize("action_noise", [normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))])
+@pytest.mark.parametrize(
+ "action_noise",
+ [normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))],
+)
def test_deterministic_pg(model_class, action_noise):
"""
Test for DDPG and variants (TD3).
@@ -21,17 +24,16 @@ def test_deterministic_pg(model_class, action_noise):
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100,
verbose=1,
- create_eval_env=True,
buffer_size=250,
action_noise=action_noise,
)
- model.learn(total_timesteps=300, eval_freq=250)
+ model.learn(total_timesteps=200)
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
def test_a2c(env_id):
- model = A2C("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
- model.learn(total_timesteps=1000, eval_freq=500)
+ model = A2C("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1)
+ model.learn(total_timesteps=64)
@pytest.mark.parametrize("model_class", [A2C, PPO])
@@ -53,7 +55,6 @@ def test_ppo(env_id, clip_range_vf):
seed=0,
policy_kwargs=dict(net_arch=[16]),
verbose=1,
- create_eval_env=True,
clip_range_vf=clip_range_vf,
)
else:
@@ -64,10 +65,10 @@ def test_ppo(env_id, clip_range_vf):
seed=0,
policy_kwargs=dict(net_arch=[16]),
verbose=1,
- create_eval_env=True,
clip_range_vf=clip_range_vf,
+ n_epochs=2,
)
- model.learn(total_timesteps=1000, eval_freq=500)
+ model.learn(total_timesteps=1000)
@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
@@ -78,12 +79,11 @@ def test_sac(ent_coef):
policy_kwargs=dict(net_arch=[64, 64]),
learning_starts=100,
verbose=1,
- create_eval_env=True,
buffer_size=250,
ent_coef=ent_coef,
action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)),
)
- model.learn(total_timesteps=300, eval_freq=250)
+ model.learn(total_timesteps=200)
@pytest.mark.parametrize("n_critics", [1, 3])
@@ -97,7 +97,7 @@ def test_n_critics(n_critics):
buffer_size=10000,
verbose=1,
)
- model.learn(total_timesteps=300)
+ model.learn(total_timesteps=200)
def test_dqn():
@@ -109,9 +109,8 @@ def test_dqn():
buffer_size=500,
learning_rate=3e-4,
verbose=1,
- create_eval_env=True,
)
- model.learn(total_timesteps=500, eval_freq=250)
+ model.learn(total_timesteps=200)
@pytest.mark.parametrize("train_freq", [4, (4, "step"), (1, "episode")])
@@ -210,3 +209,29 @@ def test_warn_dqn_multi_env():
buffer_size=100,
target_update_interval=1,
)
+
+
+def test_ppo_warnings():
+ """Test that PPO warns and errors correctly on
+ problematic rollout buffer sizes"""
+
+ # Only 1 step: advantage normalization will return NaN
+ with pytest.raises(AssertionError):
+ PPO("MlpPolicy", "Pendulum-v1", n_steps=1)
+
+ # batch_size of 1 is allowed when normalize_advantage=False
+ model = PPO("MlpPolicy", "Pendulum-v1", n_steps=1, batch_size=1, normalize_advantage=False)
+ model.learn(4)
+
+ # Truncated mini-batch
+ # Batch size 1 yields NaN with normalized advantage because
+ # torch.std(some_length_1_tensor) == NaN
+ # advantage normalization is automatically deactivated
+ # in that case
+ with pytest.warns(UserWarning, match="there will be a truncated mini-batch of size 1"):
+ model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, batch_size=63, verbose=1)
+ model.learn(64)
+
+ loss = model.logger.name_to_value["train/loss"]
+ assert loss > 0
+ assert not np.isnan(loss) # check not nan (since nan does not equal nan)
diff --git a/tests/test_save_load.py b/tests/test_save_load.py
index 452e6fbdc..aa32be361 100644
--- a/tests/test_save_load.py
+++ b/tests/test_save_load.py
@@ -1,11 +1,14 @@
+import base64
import io
+import json
import os
import pathlib
import warnings
+import zipfile
from collections import OrderedDict
from copy import deepcopy
-import gym
+import gymnasium as gym
import numpy as np
import pytest
import torch as th
@@ -27,7 +30,7 @@ def select_env(model_class: BaseAlgorithm) -> gym.Env:
if model_class == DQN:
return IdentityEnv(10)
else:
- return IdentityEnvBox(10)
+ return IdentityEnvBox(-10, 10)
@pytest.mark.parametrize("model_class", MODEL_LIST)
@@ -64,7 +67,7 @@ def test_save_load(tmp_path, model_class):
model.set_parameters(invalid_object_params, exact_match=False)
# Test that exact_match catches when something was missed.
- missing_object_params = dict((k, v) for k, v in list(original_params.items())[:-1])
+ missing_object_params = {k: v for k, v in list(original_params.items())[:-1]}
with pytest.raises(ValueError):
model.set_parameters(missing_object_params, exact_match=True)
@@ -174,6 +177,7 @@ def test_set_env(tmp_path, model_class):
env = DummyVecEnv([lambda: select_env(model_class)])
env2 = DummyVecEnv([lambda: select_env(model_class)])
env3 = select_env(model_class)
+ env4 = DummyVecEnv([lambda: select_env(model_class) for _ in range(2)])
kwargs = {}
if model_class in {DQN, DDPG, SAC, TD3}:
@@ -199,6 +203,10 @@ def test_set_env(tmp_path, model_class):
# learn again
model.learn(total_timesteps=64)
+ # num_env must be the same
+ with pytest.raises(AssertionError):
+ model.set_env(env4)
+
# Keep the same env, disable reset
model.set_env(model.get_env(), force_reset=False)
assert model._last_obs is not None
@@ -223,6 +231,11 @@ def test_set_env(tmp_path, model_class):
model.learn(total_timesteps=64, reset_num_timesteps=False)
assert model.num_timesteps == 3 * 64
+ del model
+ # Load the model with a different number of environments
+ model = model_class.load(tmp_path / "test_save.zip", env=env4)
+ model.learn(total_timesteps=64)
+
# Clear saved file
os.remove(tmp_path / "test_save.zip")
@@ -375,6 +388,9 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
select_env(model_class),
buffer_size=100,
optimize_memory_usage=optimize_memory_usage,
+ # we cannot use optimize_memory_usage and handle_timeout_termination
+ # at the same time
+ replay_buffer_kwargs={"handle_timeout_termination": not optimize_memory_usage},
policy_kwargs=dict(net_arch=[64]),
learning_starts=10,
)
@@ -446,7 +462,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str, use_sde):
params = deepcopy(policy.state_dict())
# Modify all parameters to be random values
- random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
+ random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
# Update model parameters with the new random values
policy.load_state_dict(random_params)
@@ -537,7 +553,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str):
params = deepcopy(q_net.state_dict())
# Modify all parameters to be random values
- random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
+ random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
# Update model parameters with the new random values
q_net.load_state_dict(random_params)
@@ -677,3 +693,33 @@ def test_save_load_large_model(tmp_path):
# clear file from os
os.remove(tmp_path / "test_save.zip")
+
+
+def test_load_invalid_object(tmp_path):
+ # See GH Issue #1122 for an example
+ # of invalid object loading
+ path = str(tmp_path / "ppo_pendulum.zip")
+ PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda _: 1.0).save(path)
+
+ with zipfile.ZipFile(path, mode="r") as archive:
+ json_data = json.loads(archive.read("data").decode())
+
+ # Intentionally corrupt the data
+ serialization = json_data["learning_rate"][":serialized:"]
+ base64_object = base64.b64decode(serialization.encode())
+ new_bytes = base64_object.replace(b"CodeType", b"CodeTyps")
+ base64_encoded = base64.b64encode(new_bytes).decode()
+ json_data["learning_rate"][":serialized:"] = base64_encoded
+ serialized_data = json.dumps(json_data, indent=4)
+
+ with open(tmp_path / "data", "w") as f:
+ f.write(serialized_data)
+ # Replace with the corrupted file
+ # probably doesn't work on windows
+ os.system(f"cd {tmp_path}; zip ppo_pendulum.zip data")
+ with pytest.warns(UserWarning, match=r"custom_objects"):
+ PPO.load(path)
+ # Load with custom object, no warnings
+ with warnings.catch_warnings(record=True) as record:
+ PPO.load(path, custom_objects=dict(learning_rate=lambda _: 1.0))
+ assert len(record) == 0
diff --git a/tests/test_sde.py b/tests/test_sde.py
index 0a650a57c..4fc16fc39 100644
--- a/tests/test_sde.py
+++ b/tests/test_sde.py
@@ -68,12 +68,11 @@ def test_state_dependent_noise(model_class, use_expln):
"Pendulum-v1",
use_sde=True,
seed=None,
- create_eval_env=True,
verbose=1,
policy_kwargs=dict(log_std_init=-2, use_expln=use_expln, net_arch=[64]),
**kwargs,
)
- model.learn(total_timesteps=255, eval_freq=250)
+ model.learn(total_timesteps=255)
model.policy.reset_noise()
if model_class == SAC:
model.policy.actor.get_std()
diff --git a/tests/test_spaces.py b/tests/test_spaces.py
index 54994b2b5..c8bc443ed 100644
--- a/tests/test_spaces.py
+++ b/tests/test_spaces.py
@@ -1,4 +1,6 @@
-import gym
+from typing import Optional
+
+import gymnasium as gym
import numpy as np
import pytest
@@ -9,28 +11,45 @@
class DummyMultiDiscreteSpace(gym.Env):
def __init__(self, nvec):
- super(DummyMultiDiscreteSpace, self).__init__()
+ super().__init__()
self.observation_space = gym.spaces.MultiDiscrete(nvec)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
- def reset(self):
- return self.observation_space.sample()
+ def reset(self, seed: Optional[int] = None):
+ if seed is not None:
+ super().reset(seed=seed)
+ return self.observation_space.sample(), {}
def step(self, action):
- return self.observation_space.sample(), 0.0, False, {}
+ return self.observation_space.sample(), 0.0, False, False, {}
class DummyMultiBinary(gym.Env):
def __init__(self, n):
- super(DummyMultiBinary, self).__init__()
+ super().__init__()
self.observation_space = gym.spaces.MultiBinary(n)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
+ def reset(self, seed: Optional[int] = None):
+ if seed is not None:
+ super().reset(seed=seed)
+ return self.observation_space.sample(), {}
+
+ def step(self, action):
+ return self.observation_space.sample(), 0.0, False, False, {}
+
+
+class DummyMultidimensionalAction(gym.Env):
+ def __init__(self):
+ super().__init__()
+ self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
+ self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)
+
def reset(self):
- return self.observation_space.sample()
+ return self.observation_space.sample(), {}
def step(self, action):
- return self.observation_space.sample(), 0.0, False, {}
+ return self.observation_space.sample(), 0.0, False, False, {}
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
@@ -53,22 +72,39 @@ def test_identity_spaces(model_class, env):
@pytest.mark.parametrize("model_class", [A2C, DDPG, DQN, PPO, SAC, TD3])
-@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
+@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1", DummyMultidimensionalAction()])
def test_action_spaces(model_class, env):
+ kwargs = {}
if model_class in [SAC, DDPG, TD3]:
- supported_action_space = env == "Pendulum-v1"
+ supported_action_space = env == "Pendulum-v1" or isinstance(env, DummyMultidimensionalAction)
+ kwargs["learning_starts"] = 2
+ kwargs["train_freq"] = 32
elif model_class == DQN:
supported_action_space = env == "CartPole-v1"
elif model_class in [A2C, PPO]:
supported_action_space = True
+ kwargs["n_steps"] = 64
if supported_action_space:
- model_class("MlpPolicy", env)
+ model = model_class("MlpPolicy", env, **kwargs)
+ if isinstance(env, DummyMultidimensionalAction):
+ model.learn(64)
else:
with pytest.raises(AssertionError):
model_class("MlpPolicy", env)
+def test_sde_multi_dim():
+ SAC(
+ "MlpPolicy",
+ DummyMultidimensionalAction(),
+ learning_starts=10,
+ use_sde=True,
+ sde_sample_freq=2,
+ use_sde_at_warmup=True,
+ ).learn(20)
+
+
@pytest.mark.parametrize("model_class", [A2C, PPO, DQN])
@pytest.mark.parametrize("env", ["Taxi-v3"])
def test_discrete_obs_space(model_class, env):
diff --git a/tests/test_tensorboard.py b/tests/test_tensorboard.py
index 6dccf4105..8aa864df8 100644
--- a/tests/test_tensorboard.py
+++ b/tests/test_tensorboard.py
@@ -3,6 +3,8 @@
import pytest
from stable_baselines3 import A2C, PPO, SAC, TD3
+from stable_baselines3.common.callbacks import BaseCallback
+from stable_baselines3.common.logger import HParam
from stable_baselines3.common.utils import get_latest_run_id
MODEL_DICT = {
@@ -15,6 +17,34 @@
N_STEPS = 100
+class HParamCallback(BaseCallback):
+ def __init__(self):
+ """
+ Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
+ """
+ super().__init__()
+
+ def _on_training_start(self) -> None:
+ hparam_dict = {
+ "algorithm": self.model.__class__.__name__,
+ "learning rate": self.model.learning_rate,
+ "gamma": self.model.gamma,
+ }
+ # define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
+ # Tensorbaord will find & display metrics from the `SCALARS` tab
+ metric_dict = {
+ "rollout/ep_len_mean": 0,
+ }
+ self.logger.record(
+ "hparams",
+ HParam(hparam_dict, metric_dict),
+ exclude=("stdout", "log", "json", "csv"),
+ )
+
+ def _on_step(self) -> bool:
+ return True
+
+
@pytest.mark.parametrize("model_name", MODEL_DICT.keys())
def test_tensorboard(tmp_path, model_name):
# Skip if no tensorboard installed
@@ -22,8 +52,13 @@ def test_tensorboard(tmp_path, model_name):
logname = model_name.upper()
algo, env_id = MODEL_DICT[model_name]
- model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path)
- model.learn(N_STEPS)
+ kwargs = {}
+ if model_name == "ppo":
+ kwargs["n_steps"] = 64
+ elif model_name in {"sac", "td3"}:
+ kwargs["train_freq"] = 2
+ model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path, **kwargs)
+ model.learn(N_STEPS, callback=HParamCallback())
model.learn(N_STEPS, reset_num_timesteps=False)
assert os.path.isdir(tmp_path / str(logname + "_1"))
diff --git a/tests/test_train_eval_mode.py b/tests/test_train_eval_mode.py
index 1ea2efe67..dcbda74e1 100644
--- a/tests/test_train_eval_mode.py
+++ b/tests/test_train_eval_mode.py
@@ -1,6 +1,6 @@
from typing import Union
-import gym
+import gymnasium as gym
import numpy as np
import pytest
import torch as th
@@ -28,7 +28,7 @@ class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor):
"""
def __init__(self, observation_space: gym.Space):
- super(FlattenBatchNormDropoutExtractor, self).__init__(
+ super().__init__(
observation_space,
get_flattened_obs_dim(observation_space),
)
@@ -143,7 +143,8 @@ def test_dqn_train_with_batch_norm():
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0,
seed=1,
- tau=0, # do not clone the target
+ tau=0.0, # do not clone the target
+ target_update_interval=100, # Copy the stats to the target
)
(
@@ -154,6 +155,9 @@ def test_dqn_train_with_batch_norm():
) = clone_dqn_batch_norm_stats(model)
model.learn(total_timesteps=200)
+ # Force stats copy
+ model.target_update_interval = 1
+ model._on_step()
(
q_net_bias_after,
@@ -165,8 +169,12 @@ def test_dqn_train_with_batch_norm():
assert ~th.isclose(q_net_bias_before, q_net_bias_after).all()
assert ~th.isclose(q_net_running_mean_before, q_net_running_mean_after).all()
+ # No weight update
+ assert th.isclose(q_net_bias_before, q_net_target_bias_after).all()
assert th.isclose(q_net_target_bias_before, q_net_target_bias_after).all()
- assert th.isclose(q_net_target_running_mean_before, q_net_target_running_mean_after).all()
+ # Running stat should be copied even when tau=0
+ assert th.isclose(q_net_running_mean_before, q_net_target_running_mean_before).all()
+ assert th.isclose(q_net_running_mean_after, q_net_target_running_mean_after).all()
def test_td3_train_with_batch_norm():
@@ -210,10 +218,12 @@ def test_td3_train_with_batch_norm():
assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()
assert th.isclose(actor_target_bias_before, actor_target_bias_after).all()
- assert th.isclose(actor_target_running_mean_before, actor_target_running_mean_after).all()
+ # Running stat should be copied even when tau=0
+ assert th.isclose(actor_running_mean_after, actor_target_running_mean_after).all()
assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
- assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()
+ # Running stat should be copied even when tau=0
+ assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all()
def test_sac_train_with_batch_norm():
@@ -250,10 +260,12 @@ def test_sac_train_with_batch_norm():
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()
assert ~th.isclose(critic_bias_before, critic_bias_after).all()
- assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()
+ # Running stat should be copied even when tau=0
+ assert th.isclose(critic_running_mean_before, critic_target_running_mean_before).all()
assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
- assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()
+ # Running stat should be copied even when tau=0
+ assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all()
@pytest.mark.parametrize("model_class", [A2C, PPO])
@@ -320,7 +332,7 @@ def test_a2c_ppo_collect_rollouts_with_batch_norm(model_class, env_id):
bias_before, running_mean_before = clone_on_policy_batch_norm(model)
- total_timesteps, callback = model._setup_learn(total_timesteps=2 * 64, eval_env=model.get_env())
+ total_timesteps, callback = model._setup_learn(total_timesteps=2 * 64)
for _ in range(2):
model.collect_rollouts(model.get_env(), callback, model.rollout_buffer, n_rollout_steps=model.n_steps)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index b07bbe931..7989368d5 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,20 +1,26 @@
import os
import shutil
-import gym
+import gymnasium as gym
import numpy as np
import pytest
import torch as th
-from gym import spaces
+from gymnasium import spaces
import stable_baselines3 as sb3
-from stable_baselines3 import A2C, PPO
+from stable_baselines3 import A2C
from stable_baselines3.common.atari_wrappers import ClipRewardEnv, MaxAndSkipEnv
from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
-from stable_baselines3.common.utils import get_system_info, is_vectorized_observation, polyak_update, zip_strict
+from stable_baselines3.common.utils import (
+ get_parameters_by_name,
+ get_system_info,
+ is_vectorized_observation,
+ polyak_update,
+ zip_strict,
+)
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
@@ -180,22 +186,23 @@ class AlwaysDoneWrapper(gym.Wrapper):
# Pretends that environment only has single step for each
# episode.
def __init__(self, env):
- super(AlwaysDoneWrapper, self).__init__(env)
+ super().__init__(env)
self.last_obs = None
self.needs_reset = True
def step(self, action):
- obs, reward, done, info = self.env.step(action)
- self.needs_reset = done
+ obs, reward, done, truncated, info = self.env.step(action)
+ self.needs_reset = done or truncated
self.last_obs = obs
- return obs, reward, True, info
+ return obs, reward, True, truncated, info
def reset(self, **kwargs):
+ info = {}
if self.needs_reset:
- obs = self.env.reset(**kwargs)
+ obs, info = self.env.reset(**kwargs)
self.last_obs = obs
self.needs_reset = False
- return self.last_obs
+ return self.last_obs, info
@pytest.mark.parametrize("n_envs", [1, 2, 5, 7])
@@ -229,7 +236,7 @@ def test_evaluate_policy_monitors(vec_env_class):
# Also test VecEnvs
n_eval_episodes = 3
n_envs = 2
- env_id = "CartPole-v0"
+ env_id = "CartPole-v1"
model = A2C("MlpPolicy", env_id, seed=0)
def make_eval_env(with_monitor, wrapper_class=gym.Wrapper):
@@ -322,6 +329,22 @@ def test_vec_noise():
assert len(vec.noises) == num_envs
+def test_get_parameters_by_name():
+ model = th.nn.Sequential(th.nn.Linear(5, 5), th.nn.BatchNorm1d(5))
+ # Initialize stats
+ model(th.ones(3, 5))
+ included_names = ["weight", "bias", "running_"]
+ # 2 x weight, 2 x bias, 1 x running_mean, 1 x running_var; Ignore num_batches_tracked.
+ parameters = get_parameters_by_name(model, included_names)
+ assert len(parameters) == 6
+ assert th.allclose(parameters[4], model[1].running_mean)
+ assert th.allclose(parameters[5], model[1].running_var)
+ parameters = get_parameters_by_name(model, ["running_"])
+ assert len(parameters) == 2
+ assert th.allclose(parameters[0], model[1].running_mean)
+ assert th.allclose(parameters[1], model[1].running_var)
+
+
def test_polyak():
param1, param2 = th.nn.Parameter(th.ones((5, 5))), th.nn.Parameter(th.zeros((5, 5)))
target1, target2 = th.nn.Parameter(th.ones((5, 5))), th.nn.Parameter(th.zeros((5, 5)))
@@ -366,19 +389,6 @@ def test_is_wrapped():
assert unwrap_wrapper(env, Monitor) == monitor_env
-def test_ppo_warnings():
- """Test that PPO warns and errors correctly on
- problematic rollour buffer sizes"""
-
- # Only 1 step: advantage normalization will return NaN
- with pytest.raises(AssertionError):
- PPO("MlpPolicy", "Pendulum-v1", n_steps=1)
-
- # Truncated mini-batch
- with pytest.warns(UserWarning):
- PPO("MlpPolicy", "Pendulum-v1", n_steps=6, batch_size=8)
-
-
def test_get_system_info():
info, info_str = get_system_info(print_info=True)
assert info["Stable-Baselines3"] == str(sb3.__version__)
diff --git a/tests/test_vec_check_nan.py b/tests/test_vec_check_nan.py
index 265da2ed9..d80b44c8d 100644
--- a/tests/test_vec_check_nan.py
+++ b/tests/test_vec_check_nan.py
@@ -1,7 +1,7 @@
-import gym
+import gymnasium as gym
import numpy as np
import pytest
-from gym import spaces
+from gymnasium import spaces
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
@@ -12,7 +12,7 @@ class NanAndInfEnv(gym.Env):
metadata = {"render.modes": ["human"]}
def __init__(self):
- super(NanAndInfEnv, self).__init__()
+ super().__init__()
self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
@@ -24,11 +24,11 @@ def step(action):
obs = float("inf")
else:
obs = 0
- return [obs], 0.0, False, {}
+ return [obs], 0.0, False, False, {}
@staticmethod
def reset():
- return [0.0]
+ return [0.0], {}
def render(self, mode="human", close=False):
pass
diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py
index 93ea348b1..9572a3216 100644
--- a/tests/test_vec_envs.py
+++ b/tests/test_vec_envs.py
@@ -2,8 +2,9 @@
import functools
import itertools
import multiprocessing
+from typing import Optional
-import gym
+import gymnasium as gym
import numpy as np
import pytest
@@ -25,17 +26,19 @@ def __init__(self, space):
self.current_step = 0
self.ep_length = 4
- def reset(self):
+ def reset(self, seed: Optional[int] = None):
+ if seed is not None:
+ self.seed(seed)
self.current_step = 0
self._choose_next_state()
- return self.state
+ return self.state, {}
def step(self, action):
reward = float(np.random.rand())
self._choose_next_state()
self.current_step += 1
- done = self.current_step >= self.ep_length
- return self.state, reward, done, {}
+ done = truncated = self.current_step >= self.ep_length
+ return self.state, reward, done, truncated, {}
def _choose_next_state(self):
self.state = self.observation_space.sample()
@@ -144,13 +147,13 @@ def __init__(self, max_steps):
def reset(self):
self.current_step = 0
- return np.array([self.current_step], dtype="int")
+ return np.array([self.current_step], dtype="int"), {}
def step(self, action):
prev_step = self.current_step
self.current_step += 1
- done = self.current_step >= self.max_steps
- return np.array([prev_step], dtype="int"), 0.0, done, {}
+ done = truncated = self.current_step >= self.max_steps
+ return np.array([prev_step], dtype="int"), 0.0, done, truncated, {}
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
@@ -444,6 +447,23 @@ def make_monitored_env():
assert vec_env.env_is_wrapped(Monitor) == [False, True]
+@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
+def test_backward_compat_seed(vec_env_class):
+ def make_env():
+ env = CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
+ # Patch reset function to remove seed param
+ env.reset = lambda: (env.observation_space.sample(), {})
+ env.seed = env.observation_space.seed
+ return env
+
+ vec_env = vec_env_class([make_env for _ in range(N_ENVS)])
+ vec_env.seed(3)
+ obs = vec_env.reset()
+ vec_env.seed(3)
+ new_obs = vec_env.reset()
+ assert np.allclose(new_obs, obs)
+
+
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
def test_vec_seeding(vec_env_class):
def make_env():
diff --git a/tests/test_vec_extract_dict_obs.py b/tests/test_vec_extract_dict_obs.py
index 15074425e..0039cf6e1 100644
--- a/tests/test_vec_extract_dict_obs.py
+++ b/tests/test_vec_extract_dict_obs.py
@@ -1,5 +1,5 @@
import numpy as np
-from gym import spaces
+from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py
index 974202b31..f5a91d06e 100644
--- a/tests/test_vec_monitor.py
+++ b/tests/test_vec_monitor.py
@@ -3,7 +3,7 @@
import os
import uuid
-import gym
+import gymnasium as gym
import pandas
import pytest
@@ -36,7 +36,7 @@ def test_vec_monitor(tmp_path):
monitor_env.close()
- with open(monitor_file, "rt") as file_handler:
+ with open(monitor_file) as file_handler:
first_line = file_handler.readline()
assert first_line.startswith("#")
metadata = json.loads(first_line[1:])
@@ -66,7 +66,7 @@ def test_vec_monitor_info_keywords(tmp_path):
monitor_env.close()
- with open(monitor_file, "rt") as f:
+ with open(monitor_file) as f:
reader = csv.reader(f)
for i, line in enumerate(reader):
if i == 0 or i == 1:
@@ -132,15 +132,15 @@ def test_vec_monitor_ppo(recwarn):
"""
Test the `VecMonitor` with PPO
"""
- env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
- env.seed(0)
+ env = DummyVecEnv([lambda: gym.make("CartPole-v1", disable_env_checker=True)])
+ env.seed(seed=0)
monitor_env = VecMonitor(env)
model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu")
model.learn(total_timesteps=250)
# No warnings because using `VecMonitor`
evaluate_policy(model, monitor_env)
- assert len(recwarn) == 0
+ assert len(recwarn) == 0, f"{[str(warning) for warning in recwarn]}"
def test_vec_monitor_warn():
diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py
index 07ad77f22..c2d314c4c 100644
--- a/tests/test_vec_normalize.py
+++ b/tests/test_vec_normalize.py
@@ -1,10 +1,10 @@
import operator
-import warnings
+from typing import Optional
-import gym
+import gymnasium as gym
import numpy as np
import pytest
-from gym import spaces
+from gymnasium import spaces
from stable_baselines3 import SAC, TD3, HerReplayBuffer
from stable_baselines3.common.monitor import Monitor
@@ -34,20 +34,23 @@ def step(self, action):
self.t += 1
index = (self.t + self.return_reward_idx) % len(self.returned_rewards)
returned_value = self.returned_rewards[index]
- return np.array([returned_value]), returned_value, self.t == len(self.returned_rewards), {}
+ done = truncated = self.t == len(self.returned_rewards)
+ return np.array([returned_value]), returned_value, done, truncated, {}
- def reset(self):
+ def reset(self, seed: Optional[int] = None):
+ if seed is not None:
+ super().reset(seed=seed)
self.t = 0
- return np.array([self.returned_rewards[self.return_reward_idx]])
+ return np.array([self.returned_rewards[self.return_reward_idx]]), {}
-class DummyDictEnv(gym.GoalEnv):
+class DummyDictEnv(gym.Env):
"""
Dummy gym goal env for testing purposes
"""
def __init__(self):
- super(DummyDictEnv, self).__init__()
+ super().__init__()
self.observation_space = spaces.Dict(
{
"observation": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32),
@@ -57,14 +60,16 @@ def __init__(self):
)
self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)
- def reset(self):
- return self.observation_space.sample()
+ def reset(self, seed: Optional[int] = None):
+ if seed is not None:
+ super().reset(seed=seed)
+ return self.observation_space.sample(), {}
def step(self, action):
obs = self.observation_space.sample()
reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"], {})
done = np.random.rand() > 0.8
- return obs, reward, done, {}
+ return obs, reward, done, False, {}
def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, _info) -> np.float32:
distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
@@ -87,13 +92,15 @@ def __init__(self):
)
self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)
- def reset(self):
- return self.observation_space.sample()
+ def reset(self, seed: Optional[int] = None):
+ if seed is not None:
+ super().reset(seed=seed)
+ return self.observation_space.sample(), {}
def step(self, action):
obs = self.observation_space.sample()
done = np.random.rand() > 0.8
- return obs, 0.0, done, {}
+ return obs, 0.0, done, False, {}
def allclose(obs_1, obs_2):
@@ -118,15 +125,6 @@ def make_dict_env():
return Monitor(DummyDictEnv())
-def test_deprecation():
- venv = DummyVecEnv([lambda: gym.make("CartPole-v1")])
- venv = VecNormalize(venv)
- with warnings.catch_warnings(record=True) as record:
- assert np.allclose(venv.ret, venv.returns)
- # Deprecation warning when using .ret
- assert len(record) == 1
-
-
def check_rms_equal(rmsa, rmsb):
if isinstance(rmsa, dict):
for key in rmsa.keys():
@@ -380,19 +378,18 @@ def test_offpolicy_normalization(model_class, online_sampling):
assert model.get_vec_normalize_env() is eval_env
model.learn(total_timesteps=10)
model.set_env(env)
-
- model.learn(total_timesteps=150, eval_env=eval_env, eval_freq=75)
+ model.learn(total_timesteps=150)
# Check getter
assert isinstance(model.get_vec_normalize_env(), VecNormalize)
@pytest.mark.parametrize("make_env", [make_env, make_dict_env])
def test_sync_vec_normalize(make_env):
- env = DummyVecEnv([make_env])
+ original_env = DummyVecEnv([make_env])
- assert unwrap_vec_normalize(env) is None
+ assert unwrap_vec_normalize(original_env) is None
- env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
+ env = VecNormalize(original_env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
assert isinstance(unwrap_vec_normalize(env), VecNormalize)
@@ -433,6 +430,17 @@ def test_sync_vec_normalize(make_env):
assert allclose(obs, eval_env.normalize_obs(original_obs))
assert allclose(env.normalize_reward(dummy_rewards), eval_env.normalize_reward(dummy_rewards))
+ # Check synchronization when only reward is normalized
+ env = VecNormalize(original_env, norm_obs=False, norm_reward=True, clip_reward=100.0)
+ eval_env = DummyVecEnv([make_env])
+ eval_env = VecNormalize(eval_env, training=False, norm_obs=False, norm_reward=False)
+ env.reset()
+ env.step([env.action_space.sample()])
+ assert not np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean)
+ sync_envs_normalization(env, eval_env)
+ assert np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean)
+ assert np.allclose(env.ret_rms.var, eval_env.ret_rms.var)
+
def test_discrete_obs():
with pytest.raises(ValueError, match=".*only supports.*"):