diff --git a/test/test_env.py b/test/test_env.py index 04bf18c7c8c..1f95a55c2c7 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3405,16 +3405,16 @@ def test_tictactoe_env_single(self): ) assert r.shape == (5, 100) - def test_pendulum_env(self): - env = PendulumEnv(device=None) - assert env.device is None - env = PendulumEnv(device="cpu") - assert env.device == torch.device("cpu") + @pytest.mark.parametrize("device", [None, *get_default_devices()]) + def test_pendulum_env(self, device): + env = PendulumEnv(device=device) + assert env.device == device check_env_specs(env) + for _ in range(10): r = env.rollout(10) assert r.shape == torch.Size((10,)) - r = env.rollout(10, tensordict=TensorDict(batch_size=[5])) + r = env.rollout(10, tensordict=TensorDict(batch_size=[5], device=device)) assert r.shape == torch.Size((5, 10)) diff --git a/torchrl/envs/custom/pendulum.py b/torchrl/envs/custom/pendulum.py index e2007227127..579faecc3c6 100644 --- a/torchrl/envs/custom/pendulum.py +++ b/torchrl/envs/custom/pendulum.py @@ -220,7 +220,7 @@ class PendulumEnv(EnvBase): def __init__(self, td_params=None, seed=None, device=None): if td_params is None: - td_params = self.gen_params() + td_params = self.gen_params(device=self.device) super().__init__(device=device) self._make_spec(td_params) @@ -273,7 +273,7 @@ def _reset(self, tensordict): # if no ``tensordict`` is passed, we generate a single set of hyperparameters # Otherwise, we assume that the input ``tensordict`` contains all the relevant # parameters to get started. - tensordict = self.gen_params(batch_size=batch_size) + tensordict = self.gen_params(batch_size=batch_size, device=self.device) high_th = torch.tensor(self.DEFAULT_X, device=self.device) high_thdot = torch.tensor(self.DEFAULT_Y, device=self.device) @@ -355,12 +355,12 @@ def make_composite_from_td(td): return composite def _set_seed(self, seed: int): - rng = torch.Generator() + rng = torch.Generator(device=self.device) rng.manual_seed(seed) self.rng = rng @staticmethod - def gen_params(g=10.0, batch_size=None) -> TensorDictBase: + def gen_params(g=10.0, batch_size=None, device=None) -> TensorDictBase: """Returns a ``tensordict`` containing the physical parameters such as gravitational force and torque or speed limits.""" if batch_size is None: batch_size = [] @@ -379,6 +379,7 @@ def gen_params(g=10.0, batch_size=None) -> TensorDictBase: ) }, [], + device=device, ) if batch_size: td = td.expand(batch_size).contiguous()