Skip to content

Commit

Permalink
Add gym_kwargs option to suite_atari
Browse files Browse the repository at this point in the history
  • Loading branch information
msmith93 committed Oct 28, 2023
1 parent 7b67210 commit 441df4b
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tf_agents/environments/suite_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import division
from __future__ import print_function

from typing import Dict, Optional, Sequence, Text
from typing import Dict, Optional, Sequence, Text, Any

import ale_py # pylint: disable=unused-import
import gin
Expand Down Expand Up @@ -84,13 +84,15 @@ def load(
] = DEFAULT_ATARI_GYM_WRAPPERS,
env_wrappers: Sequence[types.PyEnvWrapper] = (),
spec_dtype_map: Optional[Dict[gym.Space, np.dtype]] = None,
gym_kwargs: Optional[Dict[str, Any]] = None,
) -> py_environment.PyEnvironment:
"""Loads the selected environment and wraps it with the specified wrappers."""
if spec_dtype_map is None:
spec_dtype_map = {gym.spaces.Box: np.uint8}

gym_kwargs = gym_kwargs if gym_kwargs else {}
gym_spec = gym.spec(environment_name)
gym_env = gym_spec.make()
gym_env = gym_spec.make(**gym_kwargs)

if max_episode_steps is None and gym_spec.max_episode_steps is not None:
max_episode_steps = gym_spec.max_episode_steps
Expand Down

0 comments on commit 441df4b

Please sign in to comment.