diff --git a/run_configurations/meta.run.xml b/run_configurations/meta.run.xml
index a704151e..af9a8a26 100644
--- a/run_configurations/meta.run.xml
+++ b/run_configurations/meta.run.xml
@@ -9,12 +9,12 @@
-
+
-
+
diff --git a/test/rlai/actions/__init__.py b/test/rlai/core/__init__.py
similarity index 100%
rename from test/rlai/actions/__init__.py
rename to test/rlai/core/__init__.py
diff --git a/test/rlai/actions/action_test.py b/test/rlai/core/action_test.py
similarity index 100%
rename from test/rlai/actions/action_test.py
rename to test/rlai/core/action_test.py
diff --git a/test/rlai/agents/agent_test.py b/test/rlai/core/agent_test.py
similarity index 100%
rename from test/rlai/agents/agent_test.py
rename to test/rlai/core/agent_test.py
diff --git a/test/rlai/agents/__init__.py b/test/rlai/core/environments/__init__.py
similarity index 100%
rename from test/rlai/agents/__init__.py
rename to test/rlai/core/environments/__init__.py
diff --git a/test/rlai/environments/bandit_test.py b/test/rlai/core/environments/bandit_test.py
similarity index 100%
rename from test/rlai/environments/bandit_test.py
rename to test/rlai/core/environments/bandit_test.py
diff --git a/test/rlai/environments/fixtures/Cube_3d_printing_sample.stl b/test/rlai/core/environments/fixtures/Cube_3d_printing_sample.stl
similarity index 100%
rename from test/rlai/environments/fixtures/Cube_3d_printing_sample.stl
rename to test/rlai/core/environments/fixtures/Cube_3d_printing_sample.stl
diff --git a/test/rlai/environments/fixtures/test_continuous_learn.pickle b/test/rlai/core/environments/fixtures/test_continuous_learn.pickle
similarity index 100%
rename from test/rlai/environments/fixtures/test_continuous_learn.pickle
rename to test/rlai/core/environments/fixtures/test_continuous_learn.pickle
diff --git a/test/rlai/environments/fixtures/test_gamblers_problem.pickle b/test/rlai/core/environments/fixtures/test_gamblers_problem.pickle
similarity index 100%
rename from test/rlai/environments/fixtures/test_gamblers_problem.pickle
rename to test/rlai/core/environments/fixtures/test_gamblers_problem.pickle
diff --git a/test/rlai/environments/fixtures/test_gym.pickle b/test/rlai/core/environments/fixtures/test_gym.pickle
similarity index 100%
rename from test/rlai/environments/fixtures/test_gym.pickle
rename to test/rlai/core/environments/fixtures/test_gym.pickle
diff --git a/test/rlai/environments/fixtures/test_mancala.pickle b/test/rlai/core/environments/fixtures/test_mancala.pickle
similarity index 100%
rename from test/rlai/environments/fixtures/test_mancala.pickle
rename to test/rlai/core/environments/fixtures/test_mancala.pickle
diff --git a/test/rlai/environments/fixtures/test_robocode.pickle b/test/rlai/core/environments/fixtures/test_robocode.pickle
similarity index 100%
rename from test/rlai/environments/fixtures/test_robocode.pickle
rename to test/rlai/core/environments/fixtures/test_robocode.pickle
diff --git a/test/rlai/environments/fixtures/test_run.pickle b/test/rlai/core/environments/fixtures/test_run.pickle
similarity index 100%
rename from test/rlai/environments/fixtures/test_run.pickle
rename to test/rlai/core/environments/fixtures/test_run.pickle
diff --git a/test/rlai/planning/fixtures/test_stochastic_environment_model.pickle b/test/rlai/core/environments/fixtures/test_stochastic_environment_model.pickle
similarity index 100%
rename from test/rlai/planning/fixtures/test_stochastic_environment_model.pickle
rename to test/rlai/core/environments/fixtures/test_stochastic_environment_model.pickle
diff --git a/test/rlai/environments/openai_gym_test.py b/test/rlai/core/environments/gymnasium_test.py
similarity index 100%
rename from test/rlai/environments/openai_gym_test.py
rename to test/rlai/core/environments/gymnasium_test.py
diff --git a/test/rlai/environments/mancala_test.py b/test/rlai/core/environments/mancala_test.py
similarity index 100%
rename from test/rlai/environments/mancala_test.py
rename to test/rlai/core/environments/mancala_test.py
diff --git a/test/rlai/environments/mdp_test.py b/test/rlai/core/environments/mdp_test.py
similarity index 71%
rename from test/rlai/environments/mdp_test.py
rename to test/rlai/core/environments/mdp_test.py
index 2dfbbd3a..2538a024 100644
--- a/test/rlai/environments/mdp_test.py
+++ b/test/rlai/core/environments/mdp_test.py
@@ -4,13 +4,14 @@
import pytest
from numpy.random import RandomState
-from rlai.core import Reward, Action, MdpState, Monitor
+from rlai.core import Reward, Action, MdpState, Monitor, State
from rlai.core.environments.gamblers_problem import GamblersProblem
from rlai.core.environments.gridworld import Gridworld
from rlai.core.environments.mdp import PrioritizedSweepingMdpPlanningEnvironment, StochasticEnvironmentModel
from rlai.gpi.dynamic_programming.iteration import iterate_value_v_pi
from rlai.gpi.state_action_value import ActionValueMdpAgent
from rlai.gpi.state_action_value.tabular import TabularStateActionValueEstimator
+from rlai.utils import sample_list_item
def test_gamblers_problem():
@@ -125,3 +126,46 @@ def test_check_marginal_probabilities():
with pytest.raises(ValueError, match='Expected next-state/next-reward marginal probability of 1.0, but got 2.0'):
gridworld.check_marginal_probabilities()
+
+
+def test_stochastic_environment_model():
+ """
+ Test.
+ """
+
+ random_state = RandomState(12345)
+
+ model = StochasticEnvironmentModel()
+
+ actions = [
+ Action(i)
+ for i in range(5)
+ ]
+
+ states = [
+ State(i, actions)
+ for i in range(5)
+ ]
+
+ for t in range(1000):
+ state = sample_list_item(states, None, random_state)
+ action = sample_list_item(state.AA, None, random_state)
+ next_state = sample_list_item(states, None, random_state)
+ reward = Reward(None, random_state.randint(10))
+ model.update(state, action, next_state, reward)
+
+ environment_sequence = []
+ for i in range(1000):
+ state = model.sample_state(random_state)
+ action = model.sample_action(state, random_state)
+ next_state, reward = model.sample_next_state_and_reward(state, action, random_state)
+ environment_sequence.append((next_state, reward))
+
+ # uncomment the following line and run test to update fixture
+ # with open(f'{os.path.dirname(__file__)}/fixtures/test_stochastic_environment_model.pickle', 'wb') as file:
+ # pickle.dump(environment_sequence, file)
+
+ with open(f'{os.path.dirname(__file__)}/fixtures/test_stochastic_environment_model.pickle', 'rb') as file:
+ environment_sequence_fixture = pickle.load(file)
+
+ assert environment_sequence == environment_sequence_fixture
diff --git a/test/rlai/environments/mujoco_test.py b/test/rlai/core/environments/mujoco_test.py
similarity index 100%
rename from test/rlai/environments/mujoco_test.py
rename to test/rlai/core/environments/mujoco_test.py
diff --git a/test/rlai/environments/robocode_continuous_test.py b/test/rlai/core/environments/robocode_continuous_test.py
similarity index 100%
rename from test/rlai/environments/robocode_continuous_test.py
rename to test/rlai/core/environments/robocode_continuous_test.py
diff --git a/test/rlai/environments/robocode_test.py b/test/rlai/core/environments/robocode_test.py
similarity index 100%
rename from test/rlai/environments/robocode_test.py
rename to test/rlai/core/environments/robocode_test.py
diff --git a/test/rlai/rewards/reward_test.py b/test/rlai/core/reward_test.py
similarity index 100%
rename from test/rlai/rewards/reward_test.py
rename to test/rlai/core/reward_test.py
diff --git a/test/rlai/states/test_states.py b/test/rlai/core/state_test.py
similarity index 100%
rename from test/rlai/states/test_states.py
rename to test/rlai/core/state_test.py
diff --git a/test/rlai/gpi/fixtures/test_resume_gym_valid_environment.pickle b/test/rlai/gpi/fixtures/test_resume_from_checkpoint.pickle
similarity index 100%
rename from test/rlai/gpi/fixtures/test_resume_gym_valid_environment.pickle
rename to test/rlai/gpi/fixtures/test_resume_from_checkpoint.pickle
diff --git a/test/rlai/gpi/fixtures/test_resume_gym_invalid_environment.pickle b/test/rlai/gpi/fixtures/test_resume_gym_invalid_environment.pickle
deleted file mode 100644
index 3e2ef5fd..00000000
Binary files a/test/rlai/gpi/fixtures/test_resume_gym_invalid_environment.pickle and /dev/null differ
diff --git a/test/rlai/gpi/improvement_test.py b/test/rlai/gpi/improvement_test.py
deleted file mode 100644
index 7172e1cf..00000000
--- a/test/rlai/gpi/improvement_test.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import pytest
-from numpy.random import RandomState
-
-from rlai.core.environments.gridworld import Gridworld
-from rlai.gpi.state_action_value import ActionValueMdpAgent
-from rlai.gpi.state_action_value.tabular import TabularStateActionValueEstimator, TabularPolicy
-
-
-def test_invalid_improve_policy_with_q_pi():
- """
- Test.
- """
-
- random_state = RandomState(12345)
- mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)
- epsilon = 0.0
- mdp_agent = ActionValueMdpAgent(
- 'test',
- random_state,
- 1,
- TabularStateActionValueEstimator(mdp_environment, epsilon, None)
- )
-
- assert isinstance(mdp_agent.pi, TabularPolicy)
-
- with pytest.raises(ValueError, match='Epsilon must be >= 0'):
- mdp_agent.pi.improve_with_q_pi(
- {},
- -1.0
- )
diff --git a/test/rlai/environments/__init__.py b/test/rlai/gpi/state_action_value/__init__.py
similarity index 100%
rename from test/rlai/environments/__init__.py
rename to test/rlai/gpi/state_action_value/__init__.py
diff --git a/test/rlai/planning/__init__.py b/test/rlai/gpi/state_action_value/function_approximation/__init__.py
similarity index 100%
rename from test/rlai/planning/__init__.py
rename to test/rlai/gpi/state_action_value/function_approximation/__init__.py
diff --git a/test/rlai/policies/__init__.py b/test/rlai/gpi/state_action_value/function_approximation/models/__init__.py
similarity index 100%
rename from test/rlai/policies/__init__.py
rename to test/rlai/gpi/state_action_value/function_approximation/models/__init__.py
diff --git a/test/rlai/q_S_A/function_approximation/models/models_test.py b/test/rlai/gpi/state_action_value/function_approximation/models/sklearn_test.py
similarity index 100%
rename from test/rlai/q_S_A/function_approximation/models/models_test.py
rename to test/rlai/gpi/state_action_value/function_approximation/models/sklearn_test.py
diff --git a/test/rlai/q_S_A/value_estimator_test.py b/test/rlai/gpi/state_action_value/function_approximation/test_function_approximation.py
similarity index 63%
rename from test/rlai/q_S_A/value_estimator_test.py
rename to test/rlai/gpi/state_action_value/function_approximation/test_function_approximation.py
index 1167aa74..d9b3eaee 100644
--- a/test/rlai/q_S_A/value_estimator_test.py
+++ b/test/rlai/gpi/state_action_value/function_approximation/test_function_approximation.py
@@ -1,7 +1,6 @@
import pytest
from rlai.gpi.state_action_value.function_approximation import ApproximateStateActionValueEstimator
-from rlai.gpi.state_action_value.tabular import TabularStateActionValueEstimator
# noinspection PyTypeChecker
@@ -10,8 +9,5 @@ def test_invalid_epsilon():
Test.
"""
- with pytest.raises(ValueError, match='epsilon must be >= 0'):
- TabularStateActionValueEstimator(None, -1, None)
-
with pytest.raises(ValueError, match='epsilon must be >= 0'):
ApproximateStateActionValueEstimator(None, -1, None, None, None, False, None, None)
diff --git a/test/rlai/gpi/state_action_value/tabular_test.py b/test/rlai/gpi/state_action_value/tabular_test.py
new file mode 100644
index 00000000..35a80f93
--- /dev/null
+++ b/test/rlai/gpi/state_action_value/tabular_test.py
@@ -0,0 +1,68 @@
+import numpy as np
+import pytest
+from numpy.random import RandomState
+
+from rlai.core.environments.gridworld import Gridworld
+from rlai.gpi.state_action_value import ActionValueMdpAgent
+from rlai.gpi.state_action_value.tabular import TabularPolicy
+from rlai.gpi.state_action_value.tabular import TabularStateActionValueEstimator
+
+
+def test_invalid_get_state_i():
+ """
+ Test.
+ """
+
+ policy = TabularPolicy(None, None)
+
+ with pytest.raises(ValueError, match='Attempted to discretize a continuous state without a resolution.'):
+ policy.get_state_i(np.array([[1, 2, 3]]))
+
+ with pytest.raises(ValueError, match=f'Unknown state space type: {type(3)}'):
+ # noinspection PyTypeChecker
+ policy.get_state_i(3)
+
+
+def test_policy_not_equal():
+ """
+ Test.
+ """
+
+ policy_1 = TabularPolicy(None, None)
+ policy_2 = TabularPolicy(None, None)
+
+ assert not (policy_1 != policy_2)
+
+
+# noinspection PyTypeChecker
+def test_invalid_epsilon():
+ """
+ Test.
+ """
+
+ with pytest.raises(ValueError, match='epsilon must be >= 0'):
+ TabularStateActionValueEstimator(None, -1, None)
+
+
+def test_invalid_improve_policy_with_q_pi():
+ """
+ Test.
+ """
+
+ random_state = RandomState(12345)
+ mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)
+ epsilon = 0.0
+ mdp_agent = ActionValueMdpAgent(
+ 'test',
+ random_state,
+ 1,
+ TabularStateActionValueEstimator(mdp_environment, epsilon, None)
+ )
+
+ assert isinstance(mdp_agent.pi, TabularPolicy)
+
+ with pytest.raises(ValueError, match='Epsilon must be >= 0'):
+ mdp_agent.pi.improve_with_q_pi(
+ {},
+ -1.0
+ )
diff --git a/test/rlai/gpi/temporal_difference/iteration_test.py b/test/rlai/gpi/temporal_difference/iteration_test.py
index bee5cc1e..aea36fe2 100644
--- a/test/rlai/gpi/temporal_difference/iteration_test.py
+++ b/test/rlai/gpi/temporal_difference/iteration_test.py
@@ -10,6 +10,7 @@
import pytest
from numpy.random import RandomState
+from rlai.core import MdpState
from rlai.core.environments.gridworld import Gridworld, GridworldFeatureExtractor
from rlai.core.environments.mdp import TrajectorySamplingMdpPlanningEnvironment, StochasticEnvironmentModel
from rlai.gpi.state_action_value import ActionValueMdpAgent
@@ -724,3 +725,91 @@ def test_q_learning_iterate_value_q_pi_tabular_policy_ne():
assert q_S_A_1 != q_S_A_2
assert q_S_A_1[test_state] != q_S_A_2[test_state]
assert q_S_A_1[test_state][test_action] != q_S_A_2[test_state][test_action]
+
+
+def test_policy_overrides():
+ """
+ Test.
+ """
+
+ random_state = RandomState(12345)
+
+ mdp_environment: Gridworld = Gridworld.example_4_1(random_state, 20)
+
+ epsilon = 0.05
+
+ q_S_A = ApproximateStateActionValueEstimator(
+ mdp_environment,
+ epsilon,
+ SKLearnSGD(BaseSKLearnSGD(random_state=random_state)),
+ GridworldFeatureExtractor(mdp_environment),
+ None,
+ False,
+ None,
+ None
+ )
+
+ mdp_agent = ActionValueMdpAgent(
+ 'test',
+ random_state,
+ 1,
+ q_S_A
+ )
+
+ iterate_value_q_pi(
+ agent=mdp_agent,
+ environment=mdp_environment,
+ num_improvements=10,
+ num_episodes_per_improvement=20,
+ num_updates_per_improvement=None,
+ alpha=None,
+ mode=Mode.Q_LEARNING,
+ n_steps=None,
+ planning_environment=None,
+ make_final_policy_greedy=True
+ )
+
+ random_state = RandomState(12345)
+
+ mdp_environment_2: Gridworld = Gridworld.example_4_1(random_state, 20)
+
+ q_S_A_2 = ApproximateStateActionValueEstimator(
+ mdp_environment_2,
+ epsilon,
+ SKLearnSGD(BaseSKLearnSGD(random_state=random_state)),
+ GridworldFeatureExtractor(mdp_environment_2),
+ None,
+ False,
+ None,
+ None
+ )
+
+ mdp_agent_2 = ActionValueMdpAgent(
+ 'test',
+ random_state,
+ 1,
+ q_S_A_2
+ )
+
+ iterate_value_q_pi(
+ agent=mdp_agent_2,
+ environment=mdp_environment_2,
+ num_improvements=10,
+ num_episodes_per_improvement=20,
+ num_updates_per_improvement=None,
+ alpha=None,
+ mode=Mode.Q_LEARNING,
+ n_steps=None,
+ planning_environment=None,
+ make_final_policy_greedy=True
+ )
+
+ assert isinstance(mdp_agent_2.most_recent_state, MdpState) and mdp_agent_2.most_recent_state in mdp_agent_2.pi
+
+ with pytest.raises(ValueError, match='Attempted to check for None in policy.'):
+ # noinspection PyTypeChecker
+ if None in mdp_agent_2.pi: # pragma no cover
+ pass
+
+ assert mdp_agent.pi == mdp_agent_2.pi
+ assert not (mdp_agent.pi != mdp_agent_2.pi)
diff --git a/test/rlai/gpi/utils_test.py b/test/rlai/gpi/utils_test.py
index 927efc0c..52f19d0a 100644
--- a/test/rlai/gpi/utils_test.py
+++ b/test/rlai/gpi/utils_test.py
@@ -13,7 +13,7 @@
from test.rlai.utils import start_virtual_display_if_headless
-def test_resume_gym_valid_environment():
+def test_resume_from_checkpoint():
"""
Test.
"""
@@ -49,10 +49,10 @@ def train_function_args_callback(
resume_environment.close()
# uncomment the following line and run test to update fixture
- # with open(f'{os.path.dirname(__file__)}/fixtures/test_resume_gym_valid_environment.pickle', 'wb') as file:
+ # with open(f'{os.path.dirname(__file__)}/fixtures/test_resume_from_checkpoint.pickle', 'wb') as file:
# pickle.dump(agent.pi, file)
- with open(f'{os.path.dirname(__file__)}/fixtures/test_resume_gym_valid_environment.pickle', 'rb') as file:
+ with open(f'{os.path.dirname(__file__)}/fixtures/test_resume_from_checkpoint.pickle', 'rb') as file:
pi_fixture = pickle.load(file)
assert agent.pi == pi_fixture
diff --git a/test/rlai/models/feature_extraction_test.py b/test/rlai/models/feature_extraction_test.py
index ada827ce..7a0b2936 100644
--- a/test/rlai/models/feature_extraction_test.py
+++ b/test/rlai/models/feature_extraction_test.py
@@ -1,4 +1,10 @@
-from rlai.models.feature_extraction import OneHotCategory
+import numpy as np
+import pytest
+from numpy.random import RandomState
+
+from rlai.core import MdpState, Action
+from rlai.core.environments.gridworld import Gridworld, GridworldFeatureExtractor
+from rlai.models.feature_extraction import OneHotCategory, OneHotCategoricalFeatureInteracter
def test_one_hot_category():
@@ -12,3 +18,35 @@ def test_one_hot_category():
ohc_2 = OneHotCategory(*booleans)
assert ohc_1 == ohc_2
+
+
+def test_check_state_and_action_lists():
+ """
+ Test.
+ """
+
+ random = RandomState(12345)
+ gw = Gridworld.example_4_1(random, T=None)
+ fex = GridworldFeatureExtractor(gw)
+
+ states = [MdpState(i=None, AA=[], terminal=False, truncated=False)]
+ actions = [Action(0)]
+ fex.check_state_and_action_lists(states, actions)
+
+ with pytest.raises(ValueError, match='Expected '):
+ actions.clear()
+ fex.check_state_and_action_lists(states, actions)
+
+
+def test_bad_interact():
+ """
+ Test.
+ """
+
+ cats = [1, 2]
+ interacter = OneHotCategoricalFeatureInteracter(cats)
+ with pytest.raises(ValueError, match='Expected '):
+ interacter.interact(np.array([
+ [1, 2, 3],
+ [4, 5, 6]
+ ]), [1])
diff --git a/test/rlai/planning/environment_models_test.py b/test/rlai/planning/environment_models_test.py
deleted file mode 100644
index 2e865186..00000000
--- a/test/rlai/planning/environment_models_test.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import os
-import pickle
-
-from numpy.random import RandomState
-
-from rlai.core import Reward, Action, State
-from rlai.core.environments.mdp import StochasticEnvironmentModel
-from rlai.utils import sample_list_item
-
-
-def test_stochastic_environment_model():
- """
- Test.
- """
-
- random_state = RandomState(12345)
-
- model = StochasticEnvironmentModel()
-
- actions = [
- Action(i)
- for i in range(5)
- ]
-
- states = [
- State(i, actions)
- for i in range(5)
- ]
-
- for t in range(1000):
- state = sample_list_item(states, None, random_state)
- action = sample_list_item(state.AA, None, random_state)
- next_state = sample_list_item(states, None, random_state)
- reward = Reward(None, random_state.randint(10))
- model.update(state, action, next_state, reward)
-
- environment_sequence = []
- for i in range(1000):
- state = model.sample_state(random_state)
- action = model.sample_action(state, random_state)
- next_state, reward = model.sample_next_state_and_reward(state, action, random_state)
- environment_sequence.append((next_state, reward))
-
- # uncomment the following line and run test to update fixture
- # with open(f'{os.path.dirname(__file__)}/fixtures/test_stochastic_environment_model.pickle', 'wb') as file:
- # pickle.dump(environment_sequence, file)
-
- with open(f'{os.path.dirname(__file__)}/fixtures/test_stochastic_environment_model.pickle', 'rb') as file:
- environment_sequence_fixture = pickle.load(file)
-
- assert environment_sequence == environment_sequence_fixture
diff --git a/test/rlai/policies/function_approximation_test.py b/test/rlai/policies/function_approximation_test.py
deleted file mode 100644
index 370a48f3..00000000
--- a/test/rlai/policies/function_approximation_test.py
+++ /dev/null
@@ -1,99 +0,0 @@
-import pytest
-from numpy.random import RandomState
-
-from rlai.core import MdpState
-from rlai.core.environments.gridworld import Gridworld, GridworldFeatureExtractor
-from rlai.gpi.state_action_value import ActionValueMdpAgent
-from rlai.gpi.state_action_value.function_approximation import ApproximateStateActionValueEstimator
-from rlai.gpi.state_action_value.function_approximation.models.sklearn import SKLearnSGD
-from rlai.gpi.temporal_difference.evaluation import Mode
-from rlai.gpi.temporal_difference.iteration import iterate_value_q_pi
-from rlai.models.sklearn import SKLearnSGD as BaseSKLearnSGD
-
-
-def test_policy_overrides():
- """
- Test.
- """
-
- random_state = RandomState(12345)
-
- mdp_environment: Gridworld = Gridworld.example_4_1(random_state, 20)
-
- epsilon = 0.05
-
- q_S_A = ApproximateStateActionValueEstimator(
- mdp_environment,
- epsilon,
- SKLearnSGD(BaseSKLearnSGD(random_state=random_state)),
- GridworldFeatureExtractor(mdp_environment),
- None,
- False,
- None,
- None
- )
-
- mdp_agent = ActionValueMdpAgent(
- 'test',
- random_state,
- 1,
- q_S_A
- )
-
- iterate_value_q_pi(
- agent=mdp_agent,
- environment=mdp_environment,
- num_improvements=10,
- num_episodes_per_improvement=20,
- num_updates_per_improvement=None,
- alpha=None,
- mode=Mode.Q_LEARNING,
- n_steps=None,
- planning_environment=None,
- make_final_policy_greedy=True,
- )
-
- random_state = RandomState(12345)
-
- mdp_environment_2: Gridworld = Gridworld.example_4_1(random_state, 20)
-
- q_S_A_2 = ApproximateStateActionValueEstimator(
- mdp_environment_2,
- epsilon,
- SKLearnSGD(BaseSKLearnSGD(random_state=random_state)),
- GridworldFeatureExtractor(mdp_environment_2),
- None,
- False,
- None,
- None
- )
-
- mdp_agent_2 = ActionValueMdpAgent(
- 'test',
- random_state,
- 1,
- q_S_A_2
- )
-
- iterate_value_q_pi(
- agent=mdp_agent_2,
- environment=mdp_environment_2,
- num_improvements=10,
- num_episodes_per_improvement=20,
- num_updates_per_improvement=None,
- alpha=None,
- mode=Mode.Q_LEARNING,
- n_steps=None,
- planning_environment=None,
- make_final_policy_greedy=True
- )
-
- assert isinstance(mdp_agent_2.most_recent_state, MdpState) and mdp_agent_2.most_recent_state in mdp_agent_2.pi
-
- with pytest.raises(ValueError, match='Attempted to check for None in policy.'):
- # noinspection PyTypeChecker
- if None in mdp_agent_2.pi: # pragma no cover
- pass
-
- assert mdp_agent.pi == mdp_agent_2.pi
- assert not (mdp_agent.pi != mdp_agent_2.pi)
diff --git a/test/rlai/policies/tabular_test.py b/test/rlai/policies/tabular_test.py
deleted file mode 100644
index b473276c..00000000
--- a/test/rlai/policies/tabular_test.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import numpy as np
-import pytest
-
-from rlai.gpi.state_action_value.tabular import TabularPolicy
-
-
-def test_invalid_get_state_i():
- """
- Test.
- """
-
- policy = TabularPolicy(None, None)
-
- with pytest.raises(ValueError, match='Attempted to discretize a continuous state without a resolution.'):
- policy.get_state_i(np.array([[1, 2, 3]]))
-
- with pytest.raises(ValueError, match=f'Unknown state space type: {type(3)}'):
- # noinspection PyTypeChecker
- policy.get_state_i(3)
-
-
-def test_policy_not_equal():
- """
- Test.
- """
-
- policy_1 = TabularPolicy(None, None)
- policy_2 = TabularPolicy(None, None)
-
- assert not (policy_1 != policy_2)
diff --git a/test/rlai/policies/parameterized/__init__.py b/test/rlai/policy_gradient/policies/__init__.py
similarity index 100%
rename from test/rlai/policies/parameterized/__init__.py
rename to test/rlai/policy_gradient/policies/__init__.py
diff --git a/test/rlai/policies/parameterized/continuous_action_test.py b/test/rlai/policy_gradient/policies/continuous_action_test.py
similarity index 100%
rename from test/rlai/policies/parameterized/continuous_action_test.py
rename to test/rlai/policy_gradient/policies/continuous_action_test.py
diff --git a/test/rlai/q_S_A/__init__.py b/test/rlai/q_S_A/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/test/rlai/q_S_A/function_approximation/__init__.py b/test/rlai/q_S_A/function_approximation/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/test/rlai/q_S_A/function_approximation/models/__init__.py b/test/rlai/q_S_A/function_approximation/models/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/test/rlai/q_S_A/function_approximation/models/feature_extraction_test.py b/test/rlai/q_S_A/function_approximation/models/feature_extraction_test.py
deleted file mode 100644
index 3f351cc6..00000000
--- a/test/rlai/q_S_A/function_approximation/models/feature_extraction_test.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import numpy as np
-import pytest
-from numpy.random import RandomState
-
-from rlai.core import Action, MdpState
-from rlai.core.environments.gridworld import GridworldFeatureExtractor, Gridworld
-from rlai.models.feature_extraction import OneHotCategoricalFeatureInteracter
-
-
-def test_check_state_and_action_lists():
- """
- Test.
- """
-
- random = RandomState(12345)
- gw = Gridworld.example_4_1(random, T=None)
- fex = GridworldFeatureExtractor(gw)
-
- states = [MdpState(i=None, AA=[], terminal=False, truncated=False)]
- actions = [Action(0)]
- fex.check_state_and_action_lists(states, actions)
-
- with pytest.raises(ValueError, match='Expected '):
- actions.clear()
- fex.check_state_and_action_lists(states, actions)
-
-
-def test_bad_interact():
- """
- Test.
- """
-
- cats = [1, 2]
- interacter = OneHotCategoricalFeatureInteracter(cats)
- with pytest.raises(ValueError, match='Expected '):
- interacter.interact(np.array([
- [1, 2, 3],
- [4, 5, 6]
- ]), [1])
diff --git a/test/rlai/q_S_A/function_approximation/models/fixtures/test_nonstationary_feature_scaler.pickle b/test/rlai/q_S_A/function_approximation/models/fixtures/test_nonstationary_feature_scaler.pickle
deleted file mode 100644
index e5f469b9..00000000
Binary files a/test/rlai/q_S_A/function_approximation/models/fixtures/test_nonstationary_feature_scaler.pickle and /dev/null differ
diff --git a/test/rlai/rewards/__init__.py b/test/rlai/rewards/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/test/rlai/states/__init__.py b/test/rlai/states/__init__.py
deleted file mode 100644
index e69de29b..00000000