Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

51 tests refactor #62

Merged
merged 3 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions run_configurations/meta.run.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
<env name="ANNOTATIONS_ON" value="True" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/src/rlai/meta/" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/src/rlai/" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/src/rlai/meta/__init__.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/src/rlai/meta.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import pytest
from numpy.random import RandomState

from rlai.core import Reward, Action, MdpState, Monitor
from rlai.core import Reward, Action, MdpState, Monitor, State
from rlai.core.environments.gamblers_problem import GamblersProblem
from rlai.core.environments.gridworld import Gridworld
from rlai.core.environments.mdp import PrioritizedSweepingMdpPlanningEnvironment, StochasticEnvironmentModel
from rlai.gpi.dynamic_programming.iteration import iterate_value_v_pi
from rlai.gpi.state_action_value import ActionValueMdpAgent
from rlai.gpi.state_action_value.tabular import TabularStateActionValueEstimator
from rlai.utils import sample_list_item


def test_gamblers_problem():
Expand Down Expand Up @@ -125,3 +126,46 @@ def test_check_marginal_probabilities():

with pytest.raises(ValueError, match='Expected next-state/next-reward marginal probability of 1.0, but got 2.0'):
gridworld.check_marginal_probabilities()


def test_stochastic_environment_model():
"""
Test.
"""

random_state = RandomState(12345)

model = StochasticEnvironmentModel()

actions = [
Action(i)
for i in range(5)
]

states = [
State(i, actions)
for i in range(5)
]

for t in range(1000):
state = sample_list_item(states, None, random_state)
action = sample_list_item(state.AA, None, random_state)
next_state = sample_list_item(states, None, random_state)
reward = Reward(None, random_state.randint(10))
model.update(state, action, next_state, reward)

environment_sequence = []
for i in range(1000):
state = model.sample_state(random_state)
action = model.sample_action(state, random_state)
next_state, reward = model.sample_next_state_and_reward(state, action, random_state)
environment_sequence.append((next_state, reward))

# uncomment the following line and run test to update fixture
# with open(f'{os.path.dirname(__file__)}/fixtures/test_stochastic_environment_model.pickle', 'wb') as file:
# pickle.dump(environment_sequence, file)

with open(f'{os.path.dirname(__file__)}/fixtures/test_stochastic_environment_model.pickle', 'rb') as file:
environment_sequence_fixture = pickle.load(file)

assert environment_sequence == environment_sequence_fixture
File renamed without changes.
File renamed without changes.
Binary file not shown.
30 changes: 0 additions & 30 deletions test/rlai/gpi/improvement_test.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest

from rlai.gpi.state_action_value.function_approximation import ApproximateStateActionValueEstimator
from rlai.gpi.state_action_value.tabular import TabularStateActionValueEstimator


# noinspection PyTypeChecker
Expand All @@ -10,8 +9,5 @@ def test_invalid_epsilon():
Test.
"""

with pytest.raises(ValueError, match='epsilon must be >= 0'):
TabularStateActionValueEstimator(None, -1, None)

with pytest.raises(ValueError, match='epsilon must be >= 0'):
ApproximateStateActionValueEstimator(None, -1, None, None, None, False, None, None)
68 changes: 68 additions & 0 deletions test/rlai/gpi/state_action_value/tabular_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy as np
import pytest
from numpy.random import RandomState

from rlai.core.environments.gridworld import Gridworld
from rlai.gpi.state_action_value import ActionValueMdpAgent
from rlai.gpi.state_action_value.tabular import TabularPolicy
from rlai.gpi.state_action_value.tabular import TabularStateActionValueEstimator


def test_invalid_get_state_i():
"""
Test.
"""

policy = TabularPolicy(None, None)

with pytest.raises(ValueError, match='Attempted to discretize a continuous state without a resolution.'):
policy.get_state_i(np.array([[1, 2, 3]]))

with pytest.raises(ValueError, match=f'Unknown state space type: {type(3)}'):
# noinspection PyTypeChecker
policy.get_state_i(3)


def test_policy_not_equal():
"""
Test.
"""

policy_1 = TabularPolicy(None, None)
policy_2 = TabularPolicy(None, None)

assert not (policy_1 != policy_2)


# noinspection PyTypeChecker
def test_invalid_epsilon():
"""
Test.
"""

with pytest.raises(ValueError, match='epsilon must be >= 0'):
TabularStateActionValueEstimator(None, -1, None)


def test_invalid_improve_policy_with_q_pi():
"""
Test.
"""

random_state = RandomState(12345)
mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)
epsilon = 0.0
mdp_agent = ActionValueMdpAgent(
'test',
random_state,
1,
TabularStateActionValueEstimator(mdp_environment, epsilon, None)
)

assert isinstance(mdp_agent.pi, TabularPolicy)

with pytest.raises(ValueError, match='Epsilon must be >= 0'):
mdp_agent.pi.improve_with_q_pi(
{},
-1.0
)
89 changes: 89 additions & 0 deletions test/rlai/gpi/temporal_difference/iteration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
from numpy.random import RandomState

from rlai.core import MdpState
from rlai.core.environments.gridworld import Gridworld, GridworldFeatureExtractor
from rlai.core.environments.mdp import TrajectorySamplingMdpPlanningEnvironment, StochasticEnvironmentModel
from rlai.gpi.state_action_value import ActionValueMdpAgent
Expand Down Expand Up @@ -724,3 +725,91 @@ def test_q_learning_iterate_value_q_pi_tabular_policy_ne():
assert q_S_A_1 != q_S_A_2
assert q_S_A_1[test_state] != q_S_A_2[test_state]
assert q_S_A_1[test_state][test_action] != q_S_A_2[test_state][test_action]


def test_policy_overrides():
"""
Test.
"""

random_state = RandomState(12345)

mdp_environment: Gridworld = Gridworld.example_4_1(random_state, 20)

epsilon = 0.05

q_S_A = ApproximateStateActionValueEstimator(
mdp_environment,
epsilon,
SKLearnSGD(BaseSKLearnSGD(random_state=random_state)),
GridworldFeatureExtractor(mdp_environment),
None,
False,
None,
None
)

mdp_agent = ActionValueMdpAgent(
'test',
random_state,
1,
q_S_A
)

iterate_value_q_pi(
agent=mdp_agent,
environment=mdp_environment,
num_improvements=10,
num_episodes_per_improvement=20,
num_updates_per_improvement=None,
alpha=None,
mode=Mode.Q_LEARNING,
n_steps=None,
planning_environment=None,
make_final_policy_greedy=True
)

random_state = RandomState(12345)

mdp_environment_2: Gridworld = Gridworld.example_4_1(random_state, 20)

q_S_A_2 = ApproximateStateActionValueEstimator(
mdp_environment_2,
epsilon,
SKLearnSGD(BaseSKLearnSGD(random_state=random_state)),
GridworldFeatureExtractor(mdp_environment_2),
None,
False,
None,
None
)

mdp_agent_2 = ActionValueMdpAgent(
'test',
random_state,
1,
q_S_A_2
)

iterate_value_q_pi(
agent=mdp_agent_2,
environment=mdp_environment_2,
num_improvements=10,
num_episodes_per_improvement=20,
num_updates_per_improvement=None,
alpha=None,
mode=Mode.Q_LEARNING,
n_steps=None,
planning_environment=None,
make_final_policy_greedy=True
)

assert isinstance(mdp_agent_2.most_recent_state, MdpState) and mdp_agent_2.most_recent_state in mdp_agent_2.pi

with pytest.raises(ValueError, match='Attempted to check for None in policy.'):
# noinspection PyTypeChecker
if None in mdp_agent_2.pi: # pragma no cover
pass

assert mdp_agent.pi == mdp_agent_2.pi
assert not (mdp_agent.pi != mdp_agent_2.pi)
6 changes: 3 additions & 3 deletions test/rlai/gpi/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from test.rlai.utils import start_virtual_display_if_headless


def test_resume_gym_valid_environment():
def test_resume_from_checkpoint():
"""
Test.
"""
Expand Down Expand Up @@ -49,10 +49,10 @@ def train_function_args_callback(
resume_environment.close()

# uncomment the following line and run test to update fixture
# with open(f'{os.path.dirname(__file__)}/fixtures/test_resume_gym_valid_environment.pickle', 'wb') as file:
# with open(f'{os.path.dirname(__file__)}/fixtures/test_resume_from_checkpoint.pickle', 'wb') as file:
# pickle.dump(agent.pi, file)

with open(f'{os.path.dirname(__file__)}/fixtures/test_resume_gym_valid_environment.pickle', 'rb') as file:
with open(f'{os.path.dirname(__file__)}/fixtures/test_resume_from_checkpoint.pickle', 'rb') as file:
pi_fixture = pickle.load(file)

assert agent.pi == pi_fixture
40 changes: 39 additions & 1 deletion test/rlai/models/feature_extraction_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from rlai.models.feature_extraction import OneHotCategory
import numpy as np
import pytest
from numpy.random import RandomState

from rlai.core import MdpState, Action
from rlai.core.environments.gridworld import Gridworld, GridworldFeatureExtractor
from rlai.models.feature_extraction import OneHotCategory, OneHotCategoricalFeatureInteracter


def test_one_hot_category():
Expand All @@ -12,3 +18,35 @@ def test_one_hot_category():

ohc_2 = OneHotCategory(*booleans)
assert ohc_1 == ohc_2


def test_check_state_and_action_lists():
"""
Test.
"""

random = RandomState(12345)
gw = Gridworld.example_4_1(random, T=None)
fex = GridworldFeatureExtractor(gw)

states = [MdpState(i=None, AA=[], terminal=False, truncated=False)]
actions = [Action(0)]
fex.check_state_and_action_lists(states, actions)

with pytest.raises(ValueError, match='Expected '):
actions.clear()
fex.check_state_and_action_lists(states, actions)


def test_bad_interact():
"""
Test.
"""

cats = [1, 2]
interacter = OneHotCategoricalFeatureInteracter(cats)
with pytest.raises(ValueError, match='Expected '):
interacter.interact(np.array([
[1, 2, 3],
[4, 5, 6]
]), [1])
Loading
Loading