Skip to content

Commit

Permalink
Treat truncation as termination.
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewGerber committed Nov 26, 2023
1 parent f631815 commit dde6c3c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
2 changes: 1 addition & 1 deletion run_configurations/CartPole train FA.run.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/src/rlai/runners/trainer.py" />
<option name="PARAMETERS" value="--random-seed 12345 --agent rlai.gpi.state_action_value.ActionValueMdpAgent --gamma 0.95 --environment rlai.core.environments.gymnasium.Gym --gym-id CartPole-v1 --render-every-nth-episode 1000 --video-directory ~/Desktop/cartpole_videos --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode SARSA --num-improvements 30000 --num-episodes-per-improvement 1 --num-updates-per-improvement 1 --epsilon 0.2 --q-S-A rlai.gpi.state_action_value.function_approximation.ApproximateStateActionValueEstimator --plot-model --plot-model-per-improvements 1000 --function-approximation-model rlai.gpi.state_action_value.function_approximation.models.sklearn.SKLearnSGD --loss squared_error --sgd-alpha 0.0 --learning-rate constant --eta0 0.00001 --feature-extractor rlai.core.environments.gymnasium.CartpoleFeatureExtractor --make-final-policy-greedy True --num-improvements-per-plot 1000 --num-improvements-per-checkpoint 1000 --checkpoint-path ~/Desktop/cartpole_checkpoint/cartpole_checkpoint.pickle --save-agent-path ~/Desktop/cartpole_agent.pickle --log INFO" />
<option name="PARAMETERS" value="--random-seed 12345 --agent rlai.gpi.state_action_value.ActionValueMdpAgent --gamma 0.95 --environment rlai.core.environments.gymnasium.Gym --gym-id CartPole-v1 --render-every-nth-episode 1000 --video-directory ~/Desktop/cartpole_videos --train-function rlai.gpi.temporal_difference.iteration.iterate_value_q_pi --mode SARSA --num-improvements 50000 --num-episodes-per-improvement 1 --num-updates-per-improvement 1 --epsilon 0.2 --q-S-A rlai.gpi.state_action_value.function_approximation.ApproximateStateActionValueEstimator --plot-model --plot-model-per-improvements 1000 --function-approximation-model rlai.gpi.state_action_value.function_approximation.models.sklearn.SKLearnSGD --loss squared_error --sgd-alpha 0.0 --learning-rate constant --eta0 0.00001 --feature-extractor rlai.core.environments.gymnasium.CartpoleFeatureExtractor --make-final-policy-greedy True --num-improvements-per-plot 1000 --num-improvements-per-checkpoint 1000 --checkpoint-path ~/Desktop/cartpole_checkpoint/cartpole_checkpoint.pickle --save-agent-path ~/Desktop/cartpole_agent.pickle --log INFO" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
Expand Down
8 changes: 7 additions & 1 deletion src/rlai/core/environments/gymnasium.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import math
import os
import warnings
Expand Down Expand Up @@ -233,7 +234,12 @@ def advance(
else:
fuel_used = required_fuel

observation, reward, terminated, _, _ = self.gym_native.step(action=gym_action)
observation, reward, terminated, truncated, _ = self.gym_native.step(action=gym_action)

# truncation is a special case of termination
if truncated and not terminated:
logging.info(f'Episode was truncated after {t + 1} step(s). Terminating.')
terminated = truncated

# update fuel remaining if needed
fuel_remaining = None
Expand Down

0 comments on commit dde6c3c

Please sign in to comment.