Skip to content

Commit

Permalink
[BugFix] Fix pendulum device
Browse files Browse the repository at this point in the history
ghstack-source-id: bcaf20de6e317d4bda0e1511e0b1e46653a6f352
Pull Request resolved: #2516
  • Loading branch information
vmoens committed Oct 30, 2024
1 parent c851e16 commit 6799a7f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
12 changes: 6 additions & 6 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3405,16 +3405,16 @@ def test_tictactoe_env_single(self):
)
assert r.shape == (5, 100)

def test_pendulum_env(self):
env = PendulumEnv(device=None)
assert env.device is None
env = PendulumEnv(device="cpu")
assert env.device == torch.device("cpu")
@pytest.mark.parametrize("device", [None, *get_default_devices()])
def test_pendulum_env(self, device):
env = PendulumEnv(device=device)
assert env.device == device
check_env_specs(env)

for _ in range(10):
r = env.rollout(10)
assert r.shape == torch.Size((10,))
r = env.rollout(10, tensordict=TensorDict(batch_size=[5]))
r = env.rollout(10, tensordict=TensorDict(batch_size=[5], device=device))
assert r.shape == torch.Size((5, 10))


Expand Down
9 changes: 5 additions & 4 deletions torchrl/envs/custom/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ class PendulumEnv(EnvBase):

def __init__(self, td_params=None, seed=None, device=None):
if td_params is None:
td_params = self.gen_params()
td_params = self.gen_params(device=self.device)

super().__init__(device=device)
self._make_spec(td_params)
Expand Down Expand Up @@ -273,7 +273,7 @@ def _reset(self, tensordict):
# if no ``tensordict`` is passed, we generate a single set of hyperparameters
# Otherwise, we assume that the input ``tensordict`` contains all the relevant
# parameters to get started.
tensordict = self.gen_params(batch_size=batch_size)
tensordict = self.gen_params(batch_size=batch_size, device=self.device)

high_th = torch.tensor(self.DEFAULT_X, device=self.device)
high_thdot = torch.tensor(self.DEFAULT_Y, device=self.device)
Expand Down Expand Up @@ -355,12 +355,12 @@ def make_composite_from_td(td):
return composite

def _set_seed(self, seed: int):
rng = torch.Generator()
rng = torch.Generator(device=self.device)
rng.manual_seed(seed)
self.rng = rng

@staticmethod
def gen_params(g=10.0, batch_size=None) -> TensorDictBase:
def gen_params(g=10.0, batch_size=None, device=None) -> TensorDictBase:
"""Returns a ``tensordict`` containing the physical parameters such as gravitational force and torque or speed limits."""
if batch_size is None:
batch_size = []
Expand All @@ -379,6 +379,7 @@ def gen_params(g=10.0, batch_size=None) -> TensorDictBase:
)
},
[],
device=device,
)
if batch_size:
td = td.expand(batch_size).contiguous()
Expand Down

0 comments on commit 6799a7f

Please sign in to comment.