forked from real-stanford/diffusion_policy
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathenv_robosuite.py
375 lines (326 loc) · 13.9 KB
/
env_robosuite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
"""
This file contains the robosuite environment wrapper that is used
to provide a standardized environment API for training policies and interacting
with metadata present in datasets.
"""
import json
import numpy as np
from copy import deepcopy
import mujoco_py
import robosuite
from robosuite.utils.mjcf_utils import postprocess_model_xml
import robomimic.utils.obs_utils as ObsUtils
import robomimic.envs.env_base as EB
class EnvRobosuite(EB.EnvBase):
"""Wrapper class for robosuite environments (https://github.com/ARISE-Initiative/robosuite)"""
def __init__(
self,
env_name,
render=False,
render_offscreen=False,
use_image_obs=False,
postprocess_visual_obs=True,
**kwargs,
):
"""
Args:
env_name (str): name of environment. Only needs to be provided if making a different
environment from the one in @env_meta.
render (bool): if True, environment supports on-screen rendering
render_offscreen (bool): if True, environment supports off-screen rendering. This
is forced to be True if @env_meta["use_images"] is True.
use_image_obs (bool): if True, environment is expected to render rgb image observations
on every env.step call. Set this to False for efficiency reasons, if image
observations are not required.
postprocess_visual_obs (bool): if True, postprocess image observations
to prepare for learning. This should only be False when extracting observations
for saving to a dataset (to save space on RGB images for example).
"""
self.postprocess_visual_obs = postprocess_visual_obs
# robosuite version check
self._is_v1 = (robosuite.__version__.split(".")[0] == "1")
if self._is_v1:
assert (int(robosuite.__version__.split(".")[1]) >= 2), "only support robosuite v0.3 and v1.2+"
kwargs = deepcopy(kwargs)
# update kwargs based on passed arguments
update_kwargs = dict(
has_renderer=render,
has_offscreen_renderer=(render_offscreen or use_image_obs),
ignore_done=True,
use_object_obs=True,
use_camera_obs=use_image_obs,
camera_depths=False,
)
kwargs.update(update_kwargs)
if self._is_v1:
if kwargs["has_offscreen_renderer"]:
# ensure that we select the correct GPU device for rendering by testing for EGL rendering
# NOTE: this package should be installed from this link (https://github.com/StanfordVL/egl_probe)
'''
import egl_probe
valid_gpu_devices = egl_probe.get_available_devices()
if len(valid_gpu_devices) > 0:
print('Valid gpu devices: ', valid_gpu_devices)
'''
import os
dev_id = os.environ.get('EGL_DEVICE_ID', '0')
if dev_id.isdigit():
idx = int(dev_id)
else:
idx = int(dev_id.split(',')[0])
kwargs["render_gpu_device_id"] = idx # valid_gpu_devices[idx % len(valid_gpu_devices)]
else:
# make sure gripper visualization is turned off (we almost always want this for learning)
kwargs["gripper_visualization"] = False
del kwargs["camera_depths"]
kwargs["camera_depth"] = False # rename kwarg
self._env_name = env_name
self._init_kwargs = deepcopy(kwargs)
self.env = robosuite.make(self._env_name, **kwargs)
if self._is_v1:
# Make sure joint position observations and eef vel observations are active
for ob_name in self.env.observation_names:
if ("joint_pos" in ob_name) or ("eef_vel" in ob_name):
self.env.modify_observable(observable_name=ob_name, attribute="active", modifier=True)
def step(self, action):
"""
Step in the environment with an action.
Args:
action (np.array): action to take
Returns:
observation (dict): new observation dictionary
reward (float): reward for this step
done (bool): whether the task is done
info (dict): extra information
"""
obs, r, done, info = self.env.step(action)
obs = self.get_observation(obs)
return obs, r, self.is_done(), info
def reset(self):
"""
Reset environment.
Returns:
observation (dict): initial observation dictionary.
"""
di = self.env.reset()
return self.get_observation(di)
def reset_to(self, state):
"""
Reset to a specific simulator state.
Args:
state (dict): current simulator state that contains one or more of:
- states (np.ndarray): initial state of the mujoco environment
- model (str): mujoco scene xml
Returns:
observation (dict): observation dictionary after setting the simulator state (only
if "states" is in @state)
"""
should_ret = False
if "model" in state:
self.reset()
xml = postprocess_model_xml(state["model"])
self.env.reset_from_xml_string(xml)
self.env.sim.reset()
if not self._is_v1:
# hide teleop visualization after restoring from model
self.env.sim.model.site_rgba[self.env.eef_site_id] = np.array([0., 0., 0., 0.])
self.env.sim.model.site_rgba[self.env.eef_cylinder_id] = np.array([0., 0., 0., 0.])
if "states" in state:
self.env.sim.set_state_from_flattened(state["states"])
self.env.sim.forward()
should_ret = True
if "goal" in state:
self.set_goal(**state["goal"])
if should_ret:
# only return obs if we've done a forward call - otherwise the observations will be garbage
return self.get_observation()
return None
def render(self, mode="human", height=None, width=None, camera_name="agentview"):
"""
Render from simulation to either an on-screen window or off-screen to RGB array.
Args:
mode (str): pass "human" for on-screen rendering or "rgb_array" for off-screen rendering
height (int): height of image to render - only used if mode is "rgb_array"
width (int): width of image to render - only used if mode is "rgb_array"
camera_name (str): camera name to use for rendering
"""
if mode == "human":
cam_id = self.env.sim.model.camera_name2id(camera_name)
self.env.viewer.set_camera(cam_id)
return self.env.render()
elif mode == "rgb_array":
return self.env.sim.render(height=height, width=width, camera_name=camera_name)[::-1]
else:
raise NotImplementedError("mode={} is not implemented".format(mode))
def get_observation(self, di=None):
"""
Get current environment observation dictionary.
Args:
di (dict): current raw observation dictionary from robosuite to wrap and provide
as a dictionary. If not provided, will be queried from robosuite.
"""
if di is None:
di = self.env._get_observations(force_update=True) if self._is_v1 else self.env._get_observation()
ret = {}
for k in di:
if (k in ObsUtils.OBS_KEYS_TO_MODALITIES) and ObsUtils.key_is_obs_modality(key=k, obs_modality="rgb"):
ret[k] = di[k][::-1]
if self.postprocess_visual_obs:
ret[k] = ObsUtils.process_obs(obs=ret[k], obs_key=k)
# "object" key contains object information
ret["object"] = np.array(di["object-state"])
if self._is_v1:
for robot in self.env.robots:
# add all robot-arm-specific observations. Note the (k not in ret) check
# ensures that we don't accidentally add robot wrist images a second time
pf = robot.robot_model.naming_prefix
for k in di:
if k.startswith(pf) and (k not in ret) and (not k.endswith("proprio-state")):
ret[k] = np.array(di[k])
else:
# minimal proprioception for older versions of robosuite
ret["proprio"] = np.array(di["robot-state"])
ret["eef_pos"] = np.array(di["eef_pos"])
ret["eef_quat"] = np.array(di["eef_quat"])
ret["gripper_qpos"] = np.array(di["gripper_qpos"])
return ret
def get_state(self):
"""
Get current environment simulator state as a dictionary. Should be compatible with @reset_to.
"""
xml = self.env.sim.model.get_xml() # model xml file
state = np.array(self.env.sim.get_state().flatten()) # simulator state
return dict(model=xml, states=state)
def get_reward(self):
"""
Get current reward.
"""
return self.env.reward()
def get_goal(self):
"""
Get goal observation. Not all environments support this.
"""
return self.get_observation(self.env._get_goal())
def set_goal(self, **kwargs):
"""
Set goal observation with external specification. Not all environments support this.
"""
return self.env.set_goal(**kwargs)
def is_done(self):
"""
Check if the task is done (not necessarily successful).
"""
# Robosuite envs always rollout to fixed horizon.
return False
def is_success(self):
"""
Check if the task condition(s) is reached. Should return a dictionary
{ str: bool } with at least a "task" key for the overall task success,
and additional optional keys corresponding to other task criteria.
"""
succ = self.env._check_success()
if isinstance(succ, dict):
assert "task" in succ
return succ
return { "task" : succ }
@property
def action_dimension(self):
"""
Returns dimension of actions (int).
"""
return self.env.action_spec[0].shape[0]
@property
def name(self):
"""
Returns name of environment name (str).
"""
return self._env_name
@property
def type(self):
"""
Returns environment type (int) for this kind of environment.
This helps identify this env class.
"""
return EB.EnvType.ROBOSUITE_TYPE
def serialize(self):
"""
Save all information needed to re-instantiate this environment in a dictionary.
This is the same as @env_meta - environment metadata stored in hdf5 datasets,
and used in utils/env_utils.py.
"""
return dict(env_name=self.name, type=self.type, env_kwargs=deepcopy(self._init_kwargs))
@classmethod
def create_for_data_processing(
cls,
env_name,
camera_names,
camera_height,
camera_width,
reward_shaping,
**kwargs,
):
"""
Create environment for processing datasets, which includes extracting
observations, labeling dense / sparse rewards, and annotating dones in
transitions.
Args:
env_name (str): name of environment
camera_names (list of str): list of camera names that correspond to image observations
camera_height (int): camera height for all cameras
camera_width (int): camera width for all cameras
reward_shaping (bool): if True, use shaped environment rewards, else use sparse task completion rewards
"""
is_v1 = (robosuite.__version__.split(".")[0] == "1")
has_camera = (len(camera_names) > 0)
new_kwargs = {
"reward_shaping": reward_shaping,
}
if has_camera:
if is_v1:
new_kwargs["camera_names"] = list(camera_names)
new_kwargs["camera_heights"] = camera_height
new_kwargs["camera_widths"] = camera_width
else:
assert len(camera_names) == 1
if has_camera:
new_kwargs["camera_name"] = camera_names[0]
new_kwargs["camera_height"] = camera_height
new_kwargs["camera_width"] = camera_width
kwargs.update(new_kwargs)
# also initialize obs utils so it knows which modalities are image modalities
image_modalities = list(camera_names)
if is_v1:
image_modalities = ["{}_image".format(cn) for cn in camera_names]
elif has_camera:
# v0.3 only had support for one image, and it was named "rgb"
assert len(image_modalities) == 1
image_modalities = ["rgb"]
obs_modality_specs = {
"obs": {
"low_dim": [], # technically unused, so we don't have to specify all of them
"rgb": image_modalities,
}
}
ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs)
# note that @postprocess_visual_obs is False since this env's images will be written to a dataset
return cls(
env_name=env_name,
render=False,
render_offscreen=has_camera,
use_image_obs=has_camera,
postprocess_visual_obs=False,
**kwargs,
)
@property
def rollout_exceptions(self):
"""
Return tuple of exceptions to except when doing rollouts. This is useful to ensure
that the entire training run doesn't crash because of a bad policy that causes unstable
simulation computations.
"""
return (mujoco_py.builder.MujocoException)
def __repr__(self):
"""
Pretty-print env description.
"""
return self.name + "\n" + json.dumps(self._init_kwargs, sort_keys=True, indent=4)