Skip to content

Commit

Permalink
Finish refactoring to match src layout.
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewGerber committed Jul 6, 2024
1 parent e69d2a4 commit e543c1b
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 33 deletions.
File renamed without changes.
File renamed without changes.
Binary file not shown.
30 changes: 0 additions & 30 deletions test/rlai/gpi/improvement_test.py

This file was deleted.

27 changes: 27 additions & 0 deletions test/rlai/gpi/state_action_value/tabular_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import pytest
from numpy.random import RandomState

from rlai.core.environments.gridworld import Gridworld
from rlai.gpi.state_action_value import ActionValueMdpAgent
from rlai.gpi.state_action_value.tabular import TabularPolicy
from rlai.gpi.state_action_value.tabular import TabularStateActionValueEstimator

Expand Down Expand Up @@ -39,3 +42,27 @@ def test_invalid_epsilon():

with pytest.raises(ValueError, match='epsilon must be >= 0'):
TabularStateActionValueEstimator(None, -1, None)


def test_invalid_improve_policy_with_q_pi():
"""
Test.
"""

random_state = RandomState(12345)
mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)
epsilon = 0.0
mdp_agent = ActionValueMdpAgent(
'test',
random_state,
1,
TabularStateActionValueEstimator(mdp_environment, epsilon, None)
)

assert isinstance(mdp_agent.pi, TabularPolicy)

with pytest.raises(ValueError, match='Epsilon must be >= 0'):
mdp_agent.pi.improve_with_q_pi(
{},
-1.0
)
6 changes: 3 additions & 3 deletions test/rlai/gpi/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from test.rlai.utils import start_virtual_display_if_headless


def test_resume_gym_valid_environment():
def test_resume_from_checkpoint():
"""
Test.
"""
Expand Down Expand Up @@ -49,10 +49,10 @@ def train_function_args_callback(
resume_environment.close()

# uncomment the following line and run test to update fixture
# with open(f'{os.path.dirname(__file__)}/fixtures/test_resume_gym_valid_environment.pickle', 'wb') as file:
# with open(f'{os.path.dirname(__file__)}/fixtures/test_resume_from_checkpoint.pickle', 'wb') as file:
# pickle.dump(agent.pi, file)

with open(f'{os.path.dirname(__file__)}/fixtures/test_resume_gym_valid_environment.pickle', 'rb') as file:
with open(f'{os.path.dirname(__file__)}/fixtures/test_resume_from_checkpoint.pickle', 'rb') as file:
pi_fixture = pickle.load(file)

assert agent.pi == pi_fixture
Empty file removed test/rlai/rewards/__init__.py
Empty file.
Empty file removed test/rlai/states/__init__.py
Empty file.

0 comments on commit e543c1b

Please sign in to comment.