Skip to content

Commit

Permalink
Merge pull request #1 from carlosluis/fix_tests
Browse files Browse the repository at this point in the history
Fix tests
  • Loading branch information
tlpss authored Oct 2, 2022
2 parents 18b29a6 + 6ed3079 commit 0851440
Show file tree
Hide file tree
Showing 63 changed files with 718 additions and 342 deletions.
7 changes: 4 additions & 3 deletions .github/ISSUE_TEMPLATE/custom_env.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,20 @@ from stable_baselines3.common.env_checker import check_env
class CustomEnv(gym.Env):

def __init__(self):
super(CustomEnv, self).__init__()
super().__init__()
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(6,))

def reset(self):
return self.observation_space.sample()
return self.observation_space.sample(), {}

def step(self, action):
obs = self.observation_space.sample()
reward = 1.0
done = False
truncated = False
info = {}
return obs, reward, done, info
return obs, reward, done, truncated, info

env = CustomEnv()
check_env(env)
Expand Down
2 changes: 1 addition & 1 deletion .gitlab-ci.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
image: stablebaselines/stable-baselines3-cpu:1.4.1a0
image: stablebaselines/stable-baselines3-cpu:1.5.1a6

type-check:
script:
Expand Down
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ FROM $PARENT_IMAGE
ARG PYTORCH_DEPS=cpuonly
ARG PYTHON_VERSION=3.7

# for tzdata
ENV DEBIAN_FRONTEND="noninteractive" TZ="Europe/Paris"

RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
Expand All @@ -20,7 +23,7 @@ RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest
~/miniconda.sh -b -p /opt/conda && \
rm ~/miniconda.sh && \
/opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include && \
/opt/conda/bin/conda install -y pytorch $PYTORCH_DEPS -c pytorch && \
/opt/conda/bin/conda install -y pytorch=1.11 $PYTORCH_DEPS -c pytorch && \
/opt/conda/bin/conda clean -ya
ENV PATH /opt/conda/bin:$PATH

Expand Down
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ check-codestyle:
commit-checks: format type lint

doc:
cd docs && make html
# Prevent weird error due to protobuf
cd docs && PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp make html

spelling:
cd docs && make spelling
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)

obs = env.reset()
obs, info = env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
obs, reward, done, truncated, info = env.step(action)
env.render()
if done:
if done or truncated:
obs = env.reset()

env.close()
Expand Down
6 changes: 3 additions & 3 deletions docs/conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ channels:
- defaults
dependencies:
- cpuonly=1.0=0
- pip=21.1
- pip=22.1.1
- python=3.7
- pytorch=1.11=py3.7_cpu_0
- pytorch=1.11.0=py3.7_cpu_0
- pip:
- gym==0.21
- gym==0.26
- cloudpickle
- opencv-python-headless
- pandas
Expand Down
7 changes: 4 additions & 3 deletions docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,10 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
.. note::

By default the feature extractor is shared between the actor and the critic to save computation (when applicable).
However, this can be changed by defining a custom policy for on-policy algorithms or setting
``share_features_extractor=False`` in the ``policy_kwargs`` for off-policy algorithms
(and when applicable).
However, this can be changed by defining a custom policy for on-policy algorithms
(see `issue #1066 <https://github.com/DLR-RM/stable-baselines3/issues/1066#issuecomment-1246866844>`_
for more information) or setting ``share_features_extractor=False`` in the
``policy_kwargs`` for off-policy algorithms (and when applicable).


.. code-block:: python
Expand Down
15 changes: 8 additions & 7 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ In the following example, we will train, save and load a DQN model on the Lunar
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
# Enjoy trained agent
obs = env.reset()
vec_env = model.get_env()
obs = vec_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, rewards, dones, info = env.step(action)
env.render()
obs, rewards, dones, info = vec_env.step(action)
vec_env.render()
Multiprocessing: Unleashing the Power of Vectorized Environments
Expand Down Expand Up @@ -470,19 +471,19 @@ The parking env is a goal-conditioned continuous control task, in which the vehi
# HER must be loaded with the env
model = SAC.load("her_sac_highway", env=env)
obs = env.reset()
obs, info = env.reset()
# Evaluate the agent
episode_reward = 0
for _ in range(100):
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
obs, reward, done, truncated, info = env.step(action)
env.render()
episode_reward += reward
if done or info.get("is_success", False):
if done or truncated or info.get("is_success", False):
print("Reward:", episode_reward, "Success?", info.get("is_success", False))
episode_reward = 0.0
obs = env.reset()
obs, info = env.reset()
Learning Rate Schedule
Expand Down
136 changes: 98 additions & 38 deletions docs/guide/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,78 +46,138 @@ For PPO, assuming a shared feature extactor.

.. code-block:: python
import torch as th
from stable_baselines3 import PPO
import torch
class OnnxablePolicy(torch.nn.Module):
def __init__(self, extractor, action_net, value_net):
super(OnnxablePolicy, self).__init__()
self.extractor = extractor
self.action_net = action_net
self.value_net = value_net
def forward(self, observation):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
action_hidden, value_hidden = self.extractor(observation)
return self.action_net(action_hidden), self.value_net(value_hidden)
class OnnxablePolicy(th.nn.Module):
def __init__(self, extractor, action_net, value_net):
super().__init__()
self.extractor = extractor
self.action_net = action_net
self.value_net = value_net
def forward(self, observation):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
action_hidden, value_hidden = self.extractor(observation)
return self.action_net(action_hidden), self.value_net(value_hidden)
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
model = PPO.load("PathToTrainedModel.zip")
model.policy.to("cpu")
onnxable_model = OnnxablePolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net)
dummy_input = torch.randn(1, observation_size)
torch.onnx.export(onnxable_model, dummy_input, "my_ppo_model.onnx", opset_version=9)
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
model = PPO.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(
model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net
)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnxable_model,
dummy_input,
"my_ppo_model.onnx",
opset_version=9,
input_names=["input"],
)
##### Load and test with onnx
import onnx
import onnxruntime as ort
import numpy as np
onnx_path = "my_ppo_model.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
observation = np.zeros((1, observation_size)).astype(np.float32)
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action, value = ort_sess.run(None, {'input.1': observation})
action, value = ort_sess.run(None, {"input": observation})
For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies.

.. code-block:: python
import torch as th
from stable_baselines3 import SAC
import torch
class OnnxablePolicy(torch.nn.Module):
def __init__(self, actor):
super(OnnxablePolicy, self).__init__()
# Removing the flatten layer because it can't be onnxed
self.actor = torch.nn.Sequential(actor.latent_pi, actor.mu)
class OnnxablePolicy(th.nn.Module):
def __init__(self, actor: th.nn.Module):
super().__init__()
# Removing the flatten layer because it can't be onnxed
self.actor = th.nn.Sequential(
actor.latent_pi,
actor.mu,
# For gSDE
# th.nn.Hardtanh(min_val=-actor.clip_mean, max_val=actor.clip_mean),
# Squash the output
th.nn.Tanh(),
)
def forward(self, observation: th.Tensor) -> th.Tensor:
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
return self.actor(observation)
def forward(self, observation):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
return self.actor(observation)
model = SAC.load("PathToTrainedModel.zip")
# Example: model = SAC("MlpPolicy", "Pendulum-v1")
model = SAC.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(model.policy.actor)
dummy_input = torch.randn(1, observation_size)
onnxable_model.policy.to("cpu")
torch.onnx.export(onnxable_model, dummy_input, "my_sac_actor.onnx", opset_version=9)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnxable_model,
dummy_input,
"my_sac_actor.onnx",
opset_version=9,
input_names=["input"],
)
##### Load and test with onnx
import onnxruntime as ort
import numpy as np
onnx_path = "my_sac_actor.onnx"
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action = ort_sess.run(None, {"input": observation})
For more discussion around the topic refer to this `issue. <https://github.com/DLR-RM/stable-baselines3/issues/383>`_

Export to C++
-----------------
Trace/Export to C++
-------------------

You can use PyTorch JIT to trace and save a trained model that can be re-used in other applications
(for instance inference code written in C++).

There is a draft PR in the RL Zoo about C++ export: https://github.com/DLR-RM/rl-baselines3-zoo/pull/228

.. code-block:: python
# See "ONNX export" for imports and OnnxablePolicy
jit_path = "sac_traced.pt"
# Trace and optimize the module
traced_module = th.jit.trace(onnxable_model.eval(), dummy_input)
frozen_module = th.jit.freeze(traced_module)
frozen_module = th.jit.optimize_for_inference(frozen_module)
th.jit.save(frozen_module, jit_path)
##### Load and test with torch
import torch as th
(using PyTorch JIT)
TODO: help is welcomed!
dummy_input = th.randn(1, *observation_size)
loaded_module = th.jit.load(jit_path)
action_jit = loaded_module(dummy_input)
Export to tensorflowjs / ONNX-JS
Expand Down
20 changes: 13 additions & 7 deletions docs/guide/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,24 @@ Here is a quick example of how to train and run A2C on a CartPole environment:
from stable_baselines3 import A2C
env = gym.make('CartPole-v1')
env = gym.make("CartPole-v1")
model = A2C('MlpPolicy', env, verbose=1)
model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000)
obs = env.reset()
# Note: Gym 0.26+ reset() returns a tuple
# where SB3 VecEnv only return an observation
obs, info = env.reset()
for i in range(1000):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
# Note: Gym 0.26+ step() returns an additional boolean
# "truncated" where SB3 store truncation information
# in info["TimeLimit.truncated"]
obs, reward, done, truncated, info = env.step(action)
env.render()
if done:
obs = env.reset()
# Note: reset is automated in SB3 VecEnv
if done or truncated:
obs, info = env.reset()
.. note::

Expand All @@ -40,4 +46,4 @@ the policy is registered:
from stable_baselines3 import A2C
model = A2C('MlpPolicy', 'CartPole-v1').learn(10000)
model = A2C("MlpPolicy", "CartPole-v1").learn(10000)
Loading

0 comments on commit 0851440

Please sign in to comment.