Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/openai hacks #35

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
data/
*/*/mjkey.txt
**/.DS_STORE
**/*.pyc
Expand Down
7 changes: 5 additions & 2 deletions docs/HER.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
Some notes on the implementation of
[Hindsight Experience Replay](https://arxiv.org/abs/1707.01495).
## Expected Results
If you run the [Fetch example](examples/her/her_td3_gym_fetch_reach.py), then
If you run the [Fetch reach example](examples/her/her_td3_gym_fetch_reach.py), then
you should get results like this:
![Fetch HER results](images/FetchReach-v1_HER-TD3.png)
![Fetch HER Reach results](images/FetchReach-v1_HER-TD3.png)

If you run the [Fetch pick and place example](eaxmples/her/her_td3_gym_fetch_pnp.py), then you should get results like this: ![Fetch HER PNP results](images/FetchPickAndPlace-v1_HER-TD3.png)


If you run the [Sawyer example](examples/her/her_td3_multiworld_sawyer_reach.py)
, then you should get results like this:
Expand Down
Binary file added docs/images/FetchPickAndPlace-v1_HER-TD3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
149 changes: 149 additions & 0 deletions examples/her/her_td3_gym_fetch_pnp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import gym

import rlkit.torch.pytorch_util as ptu
from rlkit.exploration_strategies.base import (
PolicyWrappedWithExplorationStrategy
)
from rlkit.exploration_strategies.gaussian_and_epsilon_strategy import (
GaussianAndEpsilonStrategy
)
from rlkit.torch.her.her import HerTd3
import rlkit.samplers.rollout_functions as rf


from rlkit.torch.networks import FlattenMlp, MlpPolicy, QNormalizedFlattenMlp, CompositeNormalizedMlpPolicy
from rlkit.torch.data_management.normalizer import CompositeNormalizer


def experiment(variant):
try:
import robotics_recorder
except ImportError as e:
print(e)

env = gym.make(variant['env_id'])
es = GaussianAndEpsilonStrategy(
action_space=env.action_space,
max_sigma=.2,
min_sigma=.2, # constant sigma
epsilon=.3,
)
obs_dim = env.observation_space.spaces['observation'].low.size
goal_dim = env.observation_space.spaces['desired_goal'].low.size
action_dim = env.action_space.low.size

shared_normalizer = CompositeNormalizer(obs_dim + goal_dim, action_dim, obs_clip_range=5)

qf1 = QNormalizedFlattenMlp(
input_size=obs_dim + goal_dim + action_dim,
output_size=1,
hidden_sizes=[400, 300],
composite_normalizer=shared_normalizer
)
qf2 = QNormalizedFlattenMlp(
input_size=obs_dim + goal_dim + action_dim,
output_size=1,
hidden_sizes=[400, 300],
composite_normalizer=shared_normalizer
)
import torch
policy = CompositeNormalizedMlpPolicy(
input_size=obs_dim + goal_dim,
output_size=action_dim,
hidden_sizes=[400, 300],
composite_normalizer=shared_normalizer,
output_activation=torch.tanh
)
exploration_policy = PolicyWrappedWithExplorationStrategy(
exploration_strategy=es,
policy=policy,
)

from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer

observation_key = 'observation'
desired_goal_key = 'desired_goal'
achieved_goal_key = desired_goal_key.replace("desired", "achieved")

replay_buffer = ObsDictRelabelingBuffer(
env=env,
observation_key=observation_key,
desired_goal_key=desired_goal_key,
achieved_goal_key=achieved_goal_key,
**variant['replay_buffer_kwargs']
)

algorithm = HerTd3(
her_kwargs=dict(
observation_key='observation',
desired_goal_key='desired_goal'
),
td3_kwargs = dict(
env=env,
qf1=qf1,
qf2=qf2,
policy=policy,
exploration_policy=exploration_policy
),
replay_buffer=replay_buffer,
**variant['algo_kwargs']
)

if variant.get("save_video", True):
rollout_function = rf.create_rollout_function(
rf.multitask_rollout,
max_path_length=algorithm.max_path_length,
observation_key=algorithm.observation_key,
desired_goal_key=algorithm.desired_goal_key,
)
video_func = get_video_save_func(
rollout_function,
env,
algorithm.eval_policy,
variant,
)
algorithm.post_epoch_funcs.append(video_func)

algorithm.to(ptu.device)
algorithm.train()


if __name__ == "__main__":
variant = dict(
algo_kwargs=dict(
num_epochs=5000,
num_steps_per_epoch=1000,
num_steps_per_eval=500,
max_path_length=50,
batch_size=128,
discount=0.98,
save_algorithm=True,
),
replay_buffer_kwargs=dict(
max_size=100000,
fraction_goals_rollout_goals=0.2, # equal to k = 4 in HER paper
fraction_goals_env_goals=0.0,
),
render=False,
env_id="FetchPickAndPlace-v1",
doodad_docker_image="", # Set
gpu_doodad_docker_image="", # Set
save_video=False,
save_video_period=50,
)

from rlkit.launchers.launcher_util import run_experiment

run_experiment(
experiment,
exp_prefix="her_td3_gym_fetch_pnp_test", # Make sure no spaces...
region="us-east-2",
mode='here_no_doodad',
variant=variant,
use_gpu=True, # Note: online normalization is very slow without GPU.
spot_price=.5,
snapshot_mode='gap_and_last',
snapshot_gap=100,
num_exps_per_instance=2
)

14 changes: 11 additions & 3 deletions examples/her/her_td3_gym_fetch_reach.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
PolicyWrappedWithExplorationStrategy
)
from rlkit.exploration_strategies.gaussian_and_epsilon_strategy import (
GaussianAndEpislonStrategy
GaussianAndEpsilonStrategy
)
from rlkit.launchers.launcher_util import setup_logger
from rlkit.torch.her.her import HerTd3
from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy

from rlkit.launchers.launcher_util import run_experiment

def experiment(variant):
env = gym.make('FetchReach-v1')
es = GaussianAndEpislonStrategy(
es = GaussianAndEpsilonStrategy(
action_space=env.action_space,
max_sigma=.2,
min_sigma=.2, # constant sigma
Expand Down Expand Up @@ -91,4 +92,11 @@ def experiment(variant):
),
)
setup_logger('her-td3-fetch-experiment', variant=variant)
experiment(variant)
run_experiment(
experiment,
exp_prefix="rlkit-her_td3_gym_fetch",
mode='local_docker',
variant=variant,
use_gpu=False,
spot_price=.03
)
4 changes: 2 additions & 2 deletions examples/her/her_td3_multiworld_sawyer_reach.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from rlkit.exploration_strategies.base import \
PolicyWrappedWithExplorationStrategy
from rlkit.exploration_strategies.gaussian_and_epsilon_strategy import (
GaussianAndEpislonStrategy
GaussianAndEpsilonStrategy
)
from rlkit.launchers.launcher_util import setup_logger
from rlkit.torch.her.her import HerTd3
Expand All @@ -23,7 +23,7 @@

def experiment(variant):
env = gym.make('SawyerReachXYZEnv-v0')
es = GaussianAndEpislonStrategy(
es = GaussianAndEpsilonStrategy(
action_space=env.action_space,
max_sigma=.2,
min_sigma=.2, # constant sigma
Expand Down
2 changes: 1 addition & 1 deletion examples/rig/pointmass/rig.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,5 @@
exp_prefix='rlkit-pointmass-rig-example',
mode='here_no_doodad',
variant=variant,
# use_gpu=True, # Turn on if you have a GPU
use_gpu=True, # Turn on if you have a GPU
)
2 changes: 1 addition & 1 deletion rlkit/core/rl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def get_extra_data_to_save(self, epoch):
:return:
"""
if self.render:
self.training_env.render(close=True)
self.training_env.close()
data_to_save = dict(
epoch=epoch,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np


class GaussianAndEpislonStrategy(RawExplorationStrategy, Serializable):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this fix!

class GaussianAndEpsilonStrategy(RawExplorationStrategy, Serializable):
"""
With probability epsilon, take a completely random action.
with probability 1-epsilon, add Gaussian noise to the action taken by a
Expand Down
1 change: 1 addition & 0 deletions rlkit/torch/her/her.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
self,
observation_key=None,
desired_goal_key=None,
render=False,
):
self.observation_key = observation_key
self.desired_goal_key = desired_goal_key
Expand Down
64 changes: 63 additions & 1 deletion rlkit/torch/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from rlkit.policies.base import Policy
from rlkit.torch import pytorch_util as ptu
from rlkit.torch.core import PyTorchModule
from rlkit.torch.data_management.normalizer import TorchFixedNormalizer
from rlkit.torch.data_management.normalizer import TorchFixedNormalizer, TorchNormalizer, CompositeNormalizer
from rlkit.torch.modules import LayerNorm


Expand Down Expand Up @@ -89,6 +89,49 @@ def forward(self, *inputs, **kwargs):
return super().forward(flat_inputs, **kwargs)


class QNormalizedFlattenMlp(FlattenMlp):
def __init__(
self,
*args,
composite_normalizer: CompositeNormalizer = None,
**kwargs
):
self.save_init_params(locals())
super().__init__(*args, **kwargs)
assert composite_normalizer is not None
self.composite_normalizer = composite_normalizer

def forward(
self,
observations,
actions,
return_preactivations=False):
obs, _ = self.composite_normalizer.normalize_all(observations, None)
flat_input = torch.cat((obs, actions), dim=1)
return super().forward(flat_input, return_preactivations=return_preactivations)


class VNormalizedFlattenMlp(FlattenMlp):
def __init__(
self,
*args,
composite_normalizer: CompositeNormalizer = None,
**kwargs
):
self.save_init_params(locals())
super().__init__(*args, **kwargs)
assert composite_normalizer is not None
self.composite_normalizer = composite_normalizer

def forward(
self,
observations,
return_preactivations=False):
obs, _ = self.composite_normalizer.normalize_all(observations, None)
flat_input = obs
return super().forward(flat_input, return_preactivations=return_preactivations)


class MlpPolicy(Mlp, Policy):
"""
A simpler interface for creating policies.
Expand Down Expand Up @@ -117,10 +160,29 @@ def get_actions(self, obs):
return self.eval_np(obs)


class CompositeNormalizedMlpPolicy(MlpPolicy):
def __init__(
self,
*args,
composite_normalizer: CompositeNormalizer = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we can just make this a required argument rather than kwarg.

**kwargs
):
assert composite_normalizer is not None
self.save_init_params(locals())
super().__init__(*args, **kwargs)
self.composite_normalizer = composite_normalizer

def forward(self, obs, **kwargs):
if self.composite_normalizer:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check seems a bit redundant given the assert statement in __init__.

obs, _ = self.composite_normalizer.normalize_all(obs, None)
return super().forward(obs, **kwargs)


class TanhMlpPolicy(MlpPolicy):
"""
A helper class since most policies have a tanh output activation.
"""
def __init__(self, *args, **kwargs):
self.save_init_params(locals())
super().__init__(*args, output_activation=torch.tanh, **kwargs)

20 changes: 18 additions & 2 deletions rlkit/torch/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def __init__(
tau=0.005,
qf_criterion=None,
optimizer_class=optim.Adam,

policy_preactivation_loss=True,
policy_preactivation_coefficient=1.0,
clip_q=True,
**kwargs
):
super().__init__(
Expand Down Expand Up @@ -71,6 +73,9 @@ def __init__(
self.policy.parameters(),
lr=policy_learning_rate,
)
self.clip_q = clip_q
self.policy_preactivation_penalty = policy_preactivation_loss
self.policy_preactivation_coefficient = policy_preactivation_coefficient

def _do_training(self):
batch = self.get_batch()
Expand Down Expand Up @@ -99,6 +104,14 @@ def _do_training(self):
target_q1_values = self.target_qf1(next_obs, noisy_next_actions)
target_q2_values = self.target_qf2(next_obs, noisy_next_actions)
target_q_values = torch.min(target_q1_values, target_q2_values)

if self.clip_q:
target_q_values = torch.clamp(
target_q_values,
-1/(1-self.discount),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make this a parameter rather than hard-coding it? It could be something like:

if max_q_value is None:
  max_q_value = -1/(1-self.discount)   # for HER sparse rewards.

in __init__.

0
)

q_target = rewards + (1. - terminals) * self.discount * target_q_values
q_target = q_target.detach()

Expand All @@ -123,9 +136,12 @@ def _do_training(self):

policy_actions = policy_loss = None
if self._n_train_steps_total % self.policy_and_target_update_period == 0:
policy_actions = self.policy(obs)
policy_actions, policy_preactivations = self.policy(obs, return_preactivations=True)
q_output = self.qf1(obs, policy_actions)

policy_loss = - q_output.mean()
if self.policy_preactivation_penalty:
policy_loss += self.policy_preactivation_coefficient * (policy_preactivations ** 2).mean()

self.policy_optimizer.zero_grad()
policy_loss.backward()
Expand Down
Empty file added scripts/__init__.py
Empty file.