Skip to content

Commit

Permalink
End truncation at discount convergence to zero, to avoid spurious rew…
Browse files Browse the repository at this point in the history
…ards equal to zero cancelling truncation too soon.
  • Loading branch information
MatthewGerber committed Aug 13, 2024
1 parent 6b40b91 commit 95d639a
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 28 deletions.
26 changes: 12 additions & 14 deletions src/rlai/gpi/monte_carlo/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,14 @@ def evaluate_v_pi(
t += 1
agent.sense(state, t)

# if we've truncated and the discounted reward has converged to zero, then there's no point in running
# longer.
# if we've truncated and the discount has converged to zero, then the return at the truncation time will
# not change by running longer. we've got an accurate return estimate at truncation. exit the episode.
if truncation_time_step is not None:
steps_past_truncation = (t - truncation_time_step)
discounted_reward = next_reward.r * (agent.gamma ** steps_past_truncation)
if np.isclose(discounted_reward, 0.0):
num_post_truncation_steps = (t - truncation_time_step)
post_truncation_discount = agent.gamma ** num_post_truncation_steps
if np.isclose(post_truncation_discount, 0.0):
raise ValueError(
f'Discounted reward converged to zero after {steps_past_truncation} post-truncation '
f'step(s).'
f'Post-truncation discount converged to zero after {num_post_truncation_steps} step(s).'
)

# if anything blows up, then let the environment know that we are exiting the episode.
Expand Down Expand Up @@ -208,15 +207,14 @@ def evaluate_q_pi(
t += 1
episode_generation_agent.sense(state, t)

# if we've truncated and the discounted reward has converged to zero, then there's no point in running
# longer.
# if we've truncated and the discount has converged to zero, then the return at the truncation time will
# not change by running longer. we've got an accurate return estimate at truncation. exit the episode.
if truncation_time_step is not None:
steps_past_truncation = (t - truncation_time_step)
discounted_reward = next_reward.r * (agent.gamma ** steps_past_truncation)
if np.isclose(discounted_reward, 0.0):
num_post_truncation_steps = (t - truncation_time_step)
post_truncation_discount = agent.gamma ** num_post_truncation_steps
if np.isclose(post_truncation_discount, 0.0):
raise ValueError(
f'Discounted reward converged to zero after {steps_past_truncation} post-truncation '
f'step(s).'
f'Post-truncation discount converged to zero after {num_post_truncation_steps} step(s).'
)

# if anything blows up, then let the environment know that we are exiting the episode.
Expand Down
26 changes: 12 additions & 14 deletions src/rlai/policy_gradient/monte_carlo/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,14 @@ def improve(
t += 1
agent.sense(state, t)

# if we've truncated and the discounted reward has converged to zero, then there's no point in running
# longer.
# if we've truncated and the discount has converged to zero, then the return at the truncation time will
# not change by running longer. we've got an accurate return estimate at truncation. exit the episode.
if truncation_time_step is not None:
steps_past_truncation = (t - truncation_time_step)
discounted_reward = next_reward.r * (gamma ** steps_past_truncation)
if np.isclose(discounted_reward, 0.0):
num_post_truncation_steps = (t - truncation_time_step)
post_truncation_discount = gamma ** num_post_truncation_steps
if np.isclose(post_truncation_discount, 0.0):
raise ValueError(
f'Discounted reward converged to zero after {steps_past_truncation} post-truncation '
f'step(s).'
f'Post-truncation discount converged to zero after {num_post_truncation_steps} step(s).'
)

# if anything blows up, then let the environment know that we are exiting the episode.
Expand Down Expand Up @@ -639,15 +638,14 @@ def iterate(
t += 1
self.agent.sense(state, t)

# if we've truncated and the discounted reward has converged to zero, then there's no point in running
# longer.
# if we've truncated and the discount has converged to zero, then the return at the truncation time will
# not change by running longer. we've got an accurate return estimate at truncation. exit the episode.
if truncation_time_step is not None:
steps_past_truncation = (t - truncation_time_step)
discounted_reward = next_reward.r * (self.agent.gamma ** steps_past_truncation)
if np.isclose(discounted_reward, 0.0):
num_post_truncation_steps = (t - truncation_time_step)
post_truncation_discount = self.agent.gamma ** num_post_truncation_steps
if np.isclose(post_truncation_discount, 0.0):
raise ValueError(
f'Discounted reward converged to zero after {steps_past_truncation} post-truncation '
f'step(s).'
f'Post-truncation discount converged to zero after {num_post_truncation_steps} step(s).'
)

# if anything blows up, then let the environment know that we are exiting the episode.
Expand Down
20 changes: 20 additions & 0 deletions src/rlai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import List, Any, Optional, Callable, Tuple, TextIO

import numpy as np
import scipy
from numpy import linalg as la
from numpy.random import RandomState

Expand Down Expand Up @@ -587,3 +588,22 @@ def insert_index_into_path(
path_parts.insert(1, f'-{index}')

return ''.join(path_parts)


def get_sample_size(
confidence: float,
std: float,
margin_of_error: float
) -> int:
"""
Get sample size for calculating the mean for a given standard deviation and margin of error.
:param confidence: Confidence in (0.0, 1.0].
:param std: Standard deviation.
:param margin_of_error: Margin of error.
:return: Sample size.
"""

z = scipy.stats.norm.ppf(1.0 - ((1.0 - confidence) / 2.0))

return ((z * std) / margin_of_error) ** 2.0
Binary file not shown.

0 comments on commit 95d639a

Please sign in to comment.