Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
kurtamohler committed Dec 17, 2024
1 parent b3d0833 commit f164641
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 8 deletions.
106 changes: 106 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
from torchrl.envs import (
CatFrames,
CatTensors,
ChessEnv,
DoubleToFloat,
EnvBase,
EnvCreator,
Expand Down Expand Up @@ -3380,6 +3381,111 @@ def test_partial_rest(self, batched):
assert s["next", "string"] == ["6", "6"]


# fen strings for board positions generated with:
# https://lichess.org/editor
@pytest.mark.parametrize("stateful", [False, True])
class TestChessEnv:
def test_env(self, stateful):
env = ChessEnv(stateful=stateful)
check_env_specs(env)

def test_rollout(self, stateful):
env = ChessEnv(stateful=stateful)
env.rollout(5000)

def test_reset_white_to_move(self, stateful):
env = ChessEnv(stateful=stateful)
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1"
td = env.reset(TensorDict({"fen": fen}))
assert td["fen"] == fen
assert td["turn"] == env.lib.WHITE
assert not td["done"]

def test_reset_black_to_move(self, stateful):
env = ChessEnv(stateful=stateful)
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 b - - 0 1"
td = env.reset(TensorDict({"fen": fen}))
assert td["fen"] == fen
assert td["turn"] == env.lib.BLACK
assert not td["done"]

def test_reset_done_error(self, stateful):
env = ChessEnv(stateful=stateful)
fen = "1R3k2/2R5/8/8/8/8/2K5/8 b - - 0 1"
with pytest.raises(ValueError) as e_info:
env.reset(TensorDict({"fen": fen}))

assert "Cannot reset to a fen that is a gameover state" in str(e_info)

@pytest.mark.parametrize("reset_without_fen", [False, True])
@pytest.mark.parametrize(
"endstate", ["white win", "black win", "stalemate", "50 move", "insufficient"]
)
def test_reward(self, stateful, reset_without_fen, endstate):
if stateful and reset_without_fen:
pytest.skip("reset_without_fen is only used for stateless env")

env = ChessEnv(stateful=stateful)

if endstate == "white win":
fen = "5k2/2R5/8/8/8/1R6/2K5/8 w - - 0 1"
expected_turn = env.lib.WHITE
move = "Rb8#"
expected_reward = 1
expected_done = True

elif endstate == "black win":
fen = "5k2/6r1/8/8/8/8/7r/1K6 b - - 0 1"
expected_turn = env.lib.BLACK
move = "Rg1#"
expected_reward = -1
expected_done = True

elif endstate == "stalemate":
fen = "5k2/6r1/8/8/8/8/7r/K7 b - - 0 1"
expected_turn = env.lib.BLACK
move = "Rb7"
expected_reward = 0
expected_done = True

elif endstate == "insufficient":
fen = "5k2/8/8/8/3r4/2K5/8/8 w - - 0 1"
expected_turn = env.lib.WHITE
move = "Kxd4"
expected_reward = 0
expected_done = True

elif endstate == "50 move":
fen = "5k2/8/1R6/8/6r1/2K5/8/8 b - - 99 123"
expected_turn = env.lib.BLACK
move = "Kf7"
expected_reward = 0
expected_done = True

elif endstate == "not_done":
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
expected_turn = env.lib.WHITE
move = "e4"
expected_reward = 0
expected_done = False

else:
raise RuntimeError(f"endstate not supported: {endstate}")

if reset_without_fen:
td = TensorDict({"fen": fen})
else:
td = env.reset(TensorDict({"fen": fen}))
assert td["turn"] == expected_turn

moves = env.get_legal_moves(None if stateful else td)
td["action"] = moves.index(move)
td = env.step(td)["next"]
assert td["done"] == expected_done
assert td["reward"] == expected_reward
assert td["turn"] == (not expected_turn)


class TestCustomEnvs:
def test_tictactoe_env(self):
torch.manual_seed(0)
Expand Down
59 changes: 51 additions & 8 deletions torchrl/envs/custom/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,13 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None):
self._set_action_space(tensordict)
return super().rand_action(tensordict)

def _is_done(self, board):
return board.is_game_over() | board.is_fifty_moves()

def _reset(self, tensordict=None):
fen = None
if tensordict is not None:
fen = self._get_fen(tensordict)
fen = self._get_fen(tensordict).data
dest = tensordict.empty()
else:
dest = TensorDict()
Expand All @@ -139,7 +142,11 @@ def _reset(self, tensordict=None):
self.board.reset()
fen = self.board.fen()
else:
self.board.set_fen(fen.data)
self.board.set_fen(fen)
if self._is_done(self.board):
raise ValueError(
"Cannot reset to a fen that is a gameover state." f" fen: {fen}"
)

hashing = hash(fen)

Expand All @@ -162,16 +169,47 @@ def _get_fen(cls, tensordict):
fen = cls._hash_table.get(hashing.item())
return fen

def get_legal_moves(self, tensordict=None, uci=False):
"""List the legal moves in a position.
To choose one of the actions, the "action" key can be set to the index
of the move in this list.
Args:
tensordict (TensorDict, optional): Tensordict containing the fen
string of a position. Required if not stateful. If stateful,
this argument is ignored and the current state of the env is
used instead.
uci (bool, optional): If ``False``, moves are given in SAN format.
If ``True``, moves are given in UCI format. Default is
``False``.
"""
board = self.board
if not self.stateful:
if tensordict is None:
raise ValueError(
"tensordict must be given since this env is not stateful"
)
fen = self._get_fen(tensordict).data
board.set_fen(fen)
moves = board.legal_moves

if uci:
return [board.uci(move) for move in moves]
else:
return [board.san(move) for move in moves]

def _step(self, tensordict):
# action
action = tensordict.get("action")
board = self.board
if not self.stateful:
fen = self._get_fen(tensordict).data
board.set_fen(fen)
action = str(list(board.legal_moves)[action])
# assert chess.Move.from_uci(action) in board.legal_moves
board.push_san(action)
action = list(board.legal_moves)[action]
board.push(action)
self._set_action_space()

# Collect data
Expand All @@ -181,10 +219,15 @@ def _step(self, tensordict):
dest.set("fen", fen)
dest.set("hashing", hashing)

done = board.is_checkmate()
turn = torch.tensor(board.turn)
reward = torch.tensor([done]).int() * (turn.int() * 2 - 1)
done = done | board.is_stalemate() | board.is_game_over()
if board.is_checkmate():
# turn flips after every move, even if the game is over
winner = not turn
reward_val = 1 if winner == self.lib.WHITE else -1
else:
reward_val = 0
reward = torch.tensor([reward_val], dtype=torch.int32)
done = self._is_done(board)
dest.set("reward", reward)
dest.set("turn", turn)
dest.set("done", [done])
Expand Down

0 comments on commit f164641

Please sign in to comment.