forked from N1-314/Intrinsic-Curiosity-Module---PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwrappers_icm_specific.py
33 lines (24 loc) · 1.1 KB
/
wrappers_icm_specific.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# https://github.com/openai/gym/blob/master/gym/wrappers/gray_scale_observation.py
# https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/transform_observation.py
import gym
import numpy as np
class GrayScaleObservation(gym.wrappers.GrayScaleObservation):
""" Convert the image observation from RGB to gray scale.
Example:
>>> env = gym.make("SuperMarioBros-v0")
>>> env.observation_space
Box(0, 255, (240, 256, 3), uint8)
>>> env = GrayScaleObservation(gym.make("SuperMarioBros-v0"))
Box(0, 255, (240, 256), uint8)
>>> env = GrayScaleObservation(gym.make("SuperMarioBros-v0"), keep_dim=True)
Box(0, 255, (240, 256, 1), uint8)
"""
def __init__(self, env: gym.Env, keep_dim: bool=False):
super().__init__(env, keep_dim)
def observation(self, observation):
observation = np.sum(
np.multiply(observation, np.array([0.2125, 0.7154, 0.0721])), axis=-1
).astype(np.uint8)
if self.keep_dim:
observation = np.expand_dims(observation, -1)
return observation