Skip to content

Commit

Permalink
Cartpole working again, after not scaling the intercept.
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewGerber committed Dec 27, 2023
1 parent 6d57d04 commit c4d4323
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 25 deletions.
2 changes: 1 addition & 1 deletion run_configurations/CartPole run FA.run.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/src/rlai/runners/agent_in_environment.py" />
<option name="PARAMETERS" value="--agent ../../../trained_agents/cartpole/parametric/cartpole_agent.pickle --environment rlai.core.environments.gymnasium.Gym --gym-id CartPole-v1 --plot-environment --render-every-nth-episode 1 --n-runs 1 --T 10000 --plot" />
<option name="PARAMETERS" value="--agent ~/Desktop/cartpole_agent.pickle --environment rlai.core.environments.gymnasium.Gym --gym-id CartPole-v1 --plot-environment --render-every-nth-episode 1 --n-runs 1 --T 10000 --plot" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
Expand Down
2 changes: 1 addition & 1 deletion run_configurations/CartPole train FA.run.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/src/rlai/runners/trainer.py" />
<option name="PARAMETERS" value="--random-seed 12345 --agent rlai.gpi.state_action_value.ActionValueMdpAgent --gamma 0.95 --environment rlai.core.environments.gymnasium.Gym --gym-id CartPole-v1 --T 1000 --render-every-nth-episode 100 --video-directory ~/Desktop/cartpole_videos --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode SARSA --num-improvements 5000 --num-episodes-per-improvement 1 --epsilon 0.05 --q-S-A rlai.gpi.state_action_value.function_approximation.ApproximateStateActionValueEstimator --plot-model --plot-model-per-improvements 100 --function-approximation-model rlai.gpi.state_action_value.function_approximation.models.sklearn.SKLearnSGD --loss squared_error --sgd-alpha 0.0 --learning-rate constant --eta0 0.001 --feature-extractor rlai.core.environments.gymnasium.CartpoleFeatureExtractor --make-final-policy-greedy True --num-improvements-per-plot 100 --num-improvements-per-checkpoint 1000 --checkpoint-path ~/Desktop/cartpole_checkpoint/cartpole_checkpoint.pickle --save-agent-path ~/Desktop/cartpole_agent.pickle --log INFO" />
<option name="PARAMETERS" value="--random-seed 12345 --agent rlai.gpi.state_action_value.ActionValueMdpAgent --gamma 0.99 --environment rlai.core.environments.gymnasium.Gym --gym-id CartPole-v1 --T 1000 --render-every-nth-episode 100 --video-directory ~/Desktop/cartpole_videos --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode SARSA --num-improvements 5000 --num-episodes-per-improvement 1 --epsilon 0.01 --q-S-A rlai.gpi.state_action_value.function_approximation.ApproximateStateActionValueEstimator --plot-model --plot-model-per-improvements 100 --function-approximation-model rlai.gpi.state_action_value.function_approximation.models.sklearn.SKLearnSGD --loss squared_error --sgd-alpha 0.0 --learning-rate constant --eta0 0.001 --feature-extractor rlai.core.environments.gymnasium.CartpoleFeatureExtractor --make-final-policy-greedy True --num-improvements-per-plot 100 --num-improvements-per-checkpoint 1000 --checkpoint-path ~/Desktop/cartpole_checkpoint/cartpole_checkpoint.pickle --save-agent-path ~/Desktop/cartpole_agent.pickle --log INFO" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
Expand Down
28 changes: 14 additions & 14 deletions src/rlai/core/environments/gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,13 @@ def __init__(
self.gym_native = self.init_gym_native()

if self.gym_id == Gym.LLC_V2:
self.gym_extender = ContinuousLunarLanderCustomizer()
self.gym_customizer = ContinuousLunarLanderCustomizer()
elif self.gym_id == Gym.MCC_V0:
self.gym_extender = ContinuousMountainCarCustomizer()
self.gym_customizer = ContinuousMountainCarCustomizer()
elif self.gym_id == Gym.CARTPOLE_V1:
self.gym_extender = CartpoleCustomizer()
self.gym_customizer = CartpoleCustomizer()
else:
self.gym_extender = GymCustomizer()
self.gym_customizer = GymCustomizer()

self.plot_environment = plot_environment
self.state_reward_scatter_plot = None
Expand All @@ -252,7 +252,7 @@ def __init__(
i=i,
name=name
)
for i, name in zip(range(action_space.n), self.gym_extender.get_action_names(self.gym_native))
for i, name in zip(range(action_space.n), self.gym_customizer.get_action_names(self.gym_native))
]

# action space is continuous, and we lack a discretization resolution: initialize a single, multidimensional
Expand Down Expand Up @@ -351,10 +351,10 @@ def advance(
else:
gym_action = a.i

gym_action = self.gym_extender.get_action_to_step(gym_action)
gym_action = self.gym_customizer.get_action_to_step(gym_action)
observation, reward, terminated, truncated, _ = self.gym_native.step(action=gym_action)
observation = self.gym_extender.get_post_step_observation(observation)
reward, terminated = self.gym_extender.get_reward(self.gym_native, float(reward), observation, terminated, truncated)
observation = self.gym_customizer.get_post_step_observation(observation)
reward, terminated = self.gym_customizer.get_reward(self.gym_native, float(reward), observation, terminated, truncated)

# call render if rendering manually
if self.check_render_current_episode(True):
Expand Down Expand Up @@ -401,7 +401,7 @@ def reset_for_new_run(

observation, _ = self.gym_native.reset()

observation = self.gym_extender.get_reset_observation(observation)
observation = self.gym_customizer.get_reset_observation(observation)

# call render if rendering manually
if self.check_render_current_episode(True):
Expand Down Expand Up @@ -510,7 +510,7 @@ def get_state_space_dimensionality(
:return: Number of dimensions.
"""

return len(self.gym_extender.get_state_dimension_names(self.gym_native))
return len(self.gym_customizer.get_state_dimension_names(self.gym_native))

def get_state_dimension_names(
self
Expand All @@ -521,7 +521,7 @@ def get_state_dimension_names(
:return: List of names.
"""

return self.gym_extender.get_state_dimension_names(self.gym_native)
return self.gym_customizer.get_state_dimension_names(self.gym_native)

def get_action_space_dimensionality(
self
Expand All @@ -532,7 +532,7 @@ def get_action_space_dimensionality(
:return: Number of dimensions.
"""

assert isinstance(self.gym_extender, ContinuousActionGymCustomizer)
assert isinstance(self.gym_customizer, ContinuousActionGymCustomizer)

return len(self.get_action_dimension_names())

Expand All @@ -545,9 +545,9 @@ def get_action_dimension_names(
:return: List of names.
"""

assert isinstance(self.gym_extender, ContinuousActionGymCustomizer)
assert isinstance(self.gym_customizer, ContinuousActionGymCustomizer)

return self.gym_extender.get_action_dimension_names(self.gym_native)
return self.gym_customizer.get_action_dimension_names(self.gym_native)


class GymCustomizer(ABC):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def __init__(
"x + y + z" for this argument. See the Patsy documentation for full details of the formula language. Statistical
learning models used in reinforcement learning generally need to operate "online", learning the reward function
incrementally at each step. An example of such a model would be
`rlai.gpi.state_action_value.function_approximation.statistical_learning.sklearn.SKLearnSGD`. Online learning
`rlai.gpi.state_action_value.function_approximation.models.sklearn.SKLearnSGD`. Online learning
has implications for the use and coding of categorical variables in the model formula. In particular, the full
ranges of state and action levels must be specified up front. See
`test.rlai.gpi.temporal_difference.iteration_test.test_q_learning_iterate_value_q_pi_function_approximation` for
Expand Down
16 changes: 14 additions & 2 deletions src/rlai/gpi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,20 @@ def plot_policy_iteration(

# twin-x states per iteration
_state_space_ax = _iteration_ax.twinx()
_iteration_total_states_line, = _state_space_ax.plot(iterations, iteration_total_states, '--', color='orange', label='total')
_iteration_num_states_improved_line, = _state_space_ax.plot(iterations, iteration_num_states_improved, '-', color='orange', label='improved')
_iteration_total_states_line, = _state_space_ax.plot(
iterations,
iteration_total_states,
'--',
color='orange',
label='total'
)
_iteration_num_states_improved_line, = _state_space_ax.plot(
iterations,
iteration_num_states_improved,
'-',
color='orange',
label='improved'
)
_state_space_ax.set_yscale('log')
_state_space_ax.set_ylabel('# states')
_state_space_ax.legend(loc='center right')
Expand Down
48 changes: 42 additions & 6 deletions src/rlai/models/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,15 +277,33 @@ def plot(
# plot average return and loss per iteration
self.iteration_ax = axs[0]
iterations = list(range(1, len(self.y_averages) + 1))
self.iteration_return_line, = self.iteration_ax.plot(iterations, self.y_averages, linewidth=0.75, color='darkgreen', label='Obtained (avg./iter.)')
self.iteration_loss_line, = self.iteration_ax.plot(iterations, self.loss_averages, linewidth=0.75, color='red', label='Loss (avg./iter.)')
self.iteration_return_line, = self.iteration_ax.plot(
iterations,
self.y_averages,
linewidth=0.75,
color='darkgreen',
label='Obtained (avg./iter.)'
)
self.iteration_loss_line, = self.iteration_ax.plot(
iterations,
self.loss_averages,
linewidth=0.75,
color='red',
label='Loss (avg./iter.)'
)
self.iteration_ax.set_xlabel('Policy improvement iteration')
self.iteration_ax.set_ylabel('Return (G)')
self.iteration_ax.legend(loc='upper left')

# plot twin-x for average step size per iteration
self.iteration_eta0_ax = self.iteration_ax.twinx()
self.iteration_eta0_line, = self.iteration_eta0_ax.plot(iterations, self.eta0_averages, linewidth=0.75, color='blue', label='Step size (eta0, avg./iter.)')
self.iteration_eta0_line, = self.iteration_eta0_ax.plot(
iterations,
self.eta0_averages,
linewidth=0.75,
color='blue',
label='Step size (eta0, avg./iter.)'
)
self.iteration_eta0_ax.set_yscale('log')
self.iteration_eta0_ax.legend(loc='upper right')

Expand All @@ -294,15 +312,33 @@ def plot(
self.time_step_ax = axs[1]
y_values = self.iteration_y_values.get(self.plot_iteration, [])
time_steps = list(range(1, len(y_values) + 1))
self.time_step_return_line, = self.time_step_ax.plot(time_steps, y_values, linewidth=0.75, color='darkgreen', label='Obtained')
self.time_step_loss_line, = self.time_step_ax.plot(time_steps, self.iteration_loss_values.get(self.plot_iteration, []), linewidth=0.75, color='red', label='Loss')
self.time_step_return_line, = self.time_step_ax.plot(
time_steps,
y_values,
linewidth=0.75,
color='darkgreen',
label='Obtained'
)
self.time_step_loss_line, = self.time_step_ax.plot(
time_steps,
self.iteration_loss_values.get(self.plot_iteration, []),
linewidth=0.75,
color='red',
label='Loss'
)
self.time_step_ax.set_xlabel(f'Time step (iteration {self.plot_iteration})')
self.iteration_ax.set_ylabel('Return (G)')
self.time_step_ax.legend(loc='upper left')

# plot twin-x for step size per time step of the most recent plot iteration.
self.time_step_eta0_ax = self.time_step_ax.twinx()
self.time_step_eta0_line, = self.time_step_eta0_ax.plot(time_steps, self.iteration_eta0_values.get(self.plot_iteration, []), linewidth=0.75, color='blue', label='Step size (eta0)')
self.time_step_eta0_line, = self.time_step_eta0_ax.plot(
time_steps,
self.iteration_eta0_values.get(self.plot_iteration, []),
linewidth=0.75,
color='blue',
label='Step size (eta0)'
)
self.time_step_eta0_ax.set_yscale('log')
self.time_step_eta0_ax.legend(loc='upper right')

Expand Down

0 comments on commit c4d4323

Please sign in to comment.