Skip to content

Commit

Permalink
Merge pull request #913 from b-marks:batched-env-attrs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 602051767
Change-Id: Idbe71f469cd2cc8afd355df546e477ef0d93aac5
  • Loading branch information
copybara-github committed Jan 27, 2024
2 parents 27b851f + 97de036 commit 0113998
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
17 changes: 16 additions & 1 deletion tf_agents/environments/batched_py_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from multiprocessing import dummy as mp_threads
from multiprocessing import pool
# pylint: enable=line-too-long
from typing import Sequence, Optional
from typing import Any, Optional, Sequence

import gin
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
Expand Down Expand Up @@ -182,6 +182,21 @@ def _step(self, actions):
)
return nest_utils.stack_nested_arrays(time_steps)

def seed(self, seed: types.Seed) -> Any:
"""Seeds the environment."""
return self._execute(lambda env: env.seed(seed), self._envs)

def get_state(self) -> Any:
"""Returns the `state` of the environment."""
return self._execute(lambda env: env.get_state(), self._envs)

def set_state(self, state: Sequence[Any]) -> None:
"""Restores the environment to a given `state`."""
self._execute(
lambda env_state: env_state[0].set_state(env_state[1]),
zip(self._envs, state)
)

def render(self, mode="rgb_array") -> Optional[types.NestedArray]:
if self._num_envs == 1:
img = self._envs[0].render(mode)
Expand Down
37 changes: 37 additions & 0 deletions tf_agents/environments/batched_py_environment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,21 @@ class GymWrapperEnvironmentMock(random_py_environment.RandomPyEnvironment):
def __init__(self, *args, **kwargs):
super(GymWrapperEnvironmentMock, self).__init__(*args, **kwargs)
self._info = {}
self._state = {'seed': 0}

def get_info(self):
return self._info

def seed(self, seed):
self._state['seed'] = seed
return super(GymWrapperEnvironmentMock, self).seed(seed)

def get_state(self):
return self._state

def set_state(self, state):
self._state = state

def _step(self, action):
self._info['last_action'] = action
return super(GymWrapperEnvironmentMock, self)._step(action)
Expand Down Expand Up @@ -116,6 +127,32 @@ def test_get_info_gym_env(self, multithreading):
self.assertAllEqual(info['last_action'], action)
gym_env.close()

@parameterized.parameters(*COMMON_PARAMETERS)
def test_seed_gym_env(self, multithreading):
num_envs = 5
gym_env = self._make_batched_mock_gym_py_environment(
multithreading, num_envs=num_envs
)

gym_env.seed(42)

actual_seeds = [state['seed'] for state in gym_env.get_state()]
self.assertEqual(actual_seeds, [42] * num_envs)
gym_env.close()

@parameterized.parameters(*COMMON_PARAMETERS)
def test_state_gym_env(self, multithreading):
num_envs = 5
gym_env = self._make_batched_mock_gym_py_environment(
multithreading, num_envs=num_envs
)
state = [{'value': i * 10} for i in range(num_envs)]

gym_env.set_state(state)

self.assertEqual(gym_env.get_state(), state)
gym_env.close()

@parameterized.parameters(*COMMON_PARAMETERS)
def test_step(self, multithreading):
num_envs = 5
Expand Down

0 comments on commit 0113998

Please sign in to comment.