Skip to content

Commit

Permalink
reworked weighting approach. Removed discounting option and included …
Browse files Browse the repository at this point in the history
…cumulative option
  • Loading branch information
joshuaspear committed May 29, 2024
1 parent ce50611 commit 7d04bfa
Show file tree
Hide file tree
Showing 10 changed files with 1,298 additions and 995 deletions.
6 changes: 3 additions & 3 deletions src/offline_rl_ope/OPEEstimators/IS.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
from jaxtyping import jaxtyped, Float
from typeguard import typechecked as typechecker

Expand All @@ -21,7 +21,7 @@ def __init__(
clip_weights:bool=False,
cache_traj_rewards:bool=False,
clip:float=0.0,
norm_kwargs:Dict[str,Any] = {}
norm_kwargs:Dict[str,Union[str,bool]] = {}
) -> None:
super().__init__(cache_traj_rewards)
assert isinstance(norm_weights,bool)
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(
clip_weights:bool=False,
clip: float = 0.0,
cache_traj_rewards:bool=False,
norm_kwargs:Dict[str,Any] = {}
norm_kwargs:Dict[str,Union[str,bool]] = {}
) -> None:
super().__init__(norm_weights=norm_weights, clip_weights=clip_weights,
clip=clip, cache_traj_rewards=cache_traj_rewards,
Expand Down
74 changes: 48 additions & 26 deletions src/offline_rl_ope/OPEEstimators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ def __init__(
self,
smooth_eps:float=0.0,
avg_denom:bool=False,
discount:float=1.0,
cumulative:bool=False,
*args,
**kwargs
) -> None:
assert isinstance(smooth_eps,float)
assert isinstance(avg_denom,bool)
assert isinstance(discount,float)
assert isinstance(cumulative,bool)
self.smooth_eps = smooth_eps
self.avg_denom = avg_denom
self.discount = discount

self.cumulative = cumulative
@jaxtyped(typechecker=typechecker)
def calc_norm(
self,
Expand All @@ -49,22 +49,40 @@ def calc_norm(
smooth_eps prevents nan values occuring in instances where there exists
valid time t importance ratios however, these are all 0. This should
be set as small as possible.
avg_denom: defines the denominator as the average weight for time t
as per http://proceedings.mlr.press/v48/jiang16.pdf
avg_denom defines the denominator as the average importance weight
rather than the sum of importance weights i.e.:
- http://proceedings.mlr.press/v48/jiang16.pdf and;
- https://arxiv.org/pdf/2005.01643
Note:
- If traj_is_weights represents vanilla IS samples then:
- The denominator will be w_{t} = sum_{i=1}^{n} p_{1:H} for all
samples.
vanilla IS samples => traj_is_weights has entries:
$w_{i,H} = \prod_{t=0}^{H_{i}}w_{i,t}$
- If traj_is_weights represents vanilla IS samples:
- The denominator will be:
$sum_{i=1}^{n} w_{i,H}$ for all samples.
- If cumulative is True, the denominator will be:
$sum_{i=1}^{n} w_{i,H}$ for all samples i.e., there is no
difference
as the cumulative sum of weights are all the same
- If avg_denom is set to true, the denominator will be
w_{t} = 1/n_{t} sum_{i=1}^{n} p_{1:H} where n_{t} is the number of
trajectories of at least length, t.
$\frac{1}{n}sum_{i=1}^{n} w_{i,H}$
PD samples => traj_is_weights has entries:
$w_{i,t'} = \prod_{t=0}^{t'}w_{i,t'}$
- If traj_is_weights represents PD IS samples then:
- The denominator will be w_{t} = sum_{i=1}^{n} p_{1:t}.
- The denominator will be:
$sum_{i=1}^{n} w_{i,H}$ for all samples i.e., the same as for
vanilla IS
- If avg_denom is set to true, the denominator will be
w_{t} = 1/n_{t} sum_{i=1}^{n} p_{1:t} where n_{t} is the number of
trajectories of at least length, t. This definition aligns with
http://proceedings.mlr.press/v48/jiang16.pdf
$\frac{1}{n}sum_{i=1}^{n} w_{i,H}$
- If cumulative is True, the denominator will be:
[i,t] entry of the weights will be $sum_{i=1}^{n} w_{i,t'}$
i.e., the value will be the same across all trajectories,
for a time point
- If avg_denom is set to true, the denominator will be
[i,t] entry of the weights will be
$\frac{1}{n}sum_{i=1}^{n} w_{i,t'}$
Args:
traj_is_weights (torch.Tensor): (# trajectories, max(traj_length))
Tensor. traj_is_weights[i,j] defines the jth timestep propensity
Expand All @@ -74,25 +92,29 @@ def calc_norm(
ith trajectory was observed
Returns:
torch.Tensor: Tensor of dimension (# trajectories, 1)
torch.Tensor: Tensor of dimension (1 max(traj_length))
defining the normalisation value for each timestep
"""
# assert isinstance(traj_is_weights,torch.Tensor)
# assert isinstance(is_msk,torch.Tensor)
# assert traj_is_weights.shape == is_msk.shape
# check_array_dim(traj_is_weights,2)
# check_array_dim(is_msk,2)
discnt_tens = torch.full(traj_is_weights.shape, self.discount)
discnt_pows = torch.arange(0, traj_is_weights.shape[1])[None,:].repeat(
traj_is_weights.shape[0],1)
discnt_tens = torch.pow(discnt_tens,discnt_pows)
traj_is_weights = torch.mul(traj_is_weights,discnt_tens)
denom = (
traj_is_weights.sum(dim=0, keepdim=True) + self.smooth_eps
)
if self.cumulative:
# For each timepoint, sum across the trajectories
denom = (
traj_is_weights.sum(dim=0, keepdim=True) + self.smooth_eps
)
else:
# Find the index of the final step for each trajectory
_final_idx = is_msk.cumsum(dim=1).argmax(dim=1)
# Find the associated weight of each trajectory and sum
denom = traj_is_weights[
torch.arange(traj_is_weights.shape[0]), _final_idx].sum()
denom = denom.repeat((1,traj_is_weights.shape[1])) + self.smooth_eps

if self.avg_denom:
denom = denom/(
is_msk.sum(dim=0, keepdim=True)+self.smooth_eps)
denom = denom/traj_is_weights.shape[0]
return denom

@jaxtyped(typechecker=typechecker)
Expand Down
45 changes: 18 additions & 27 deletions tests/Metrics/test_EffectiveSampleSize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,23 @@
import numpy as np
from offline_rl_ope.Metrics import EffectiveSampleSize
from offline_rl_ope import logger
# from ..base import weight_test_res
from ..base import (
single_discrete_action_test as sdat,
duel_discrete_action_test as ddat,
bin_discrete_action_test as bdat
)
from parameterized import parameterized_class
from ..base import test_configs_fmt_class, TestConfig

for test_conf in [sdat,ddat,bdat]:
class TestImportanceSampler:

def __init__(self) -> None:
self.is_weight_calc = None
self.traj_is_weights = test_conf.weight_test_res

@parameterized_class(test_configs_fmt_class)
class EffectiveSampleSizeTest(unittest.TestCase):

class EffectiveSampleSizeTest(unittest.TestCase):

def test_call(self):
num = 2
weights = test_conf.weight_test_res.sum(dim=1)
assert len(weights) == 2
denum = 1 + torch.var(weights)
act_res = (num/denum).item()
metric = EffectiveSampleSize(nan_if_all_0=True)
pred_res = metric(
weights=test_conf.weight_test_res
)
tol = act_res/1000
np.testing.assert_allclose(pred_res, act_res, atol=tol)
test_conf:TestConfig
def test_call(self):
num = 2
weights = self.test_conf.weight_test_res.sum(dim=1)
assert len(weights) == 2
denum = 1 + torch.var(weights)
act_res = (num/denum).item()
metric = EffectiveSampleSize(nan_if_all_0=True)
pred_res = metric(
weights=self.test_conf.weight_test_res
)
tol = act_res/1000
np.testing.assert_allclose(pred_res, act_res, atol=tol)
40 changes: 19 additions & 21 deletions tests/Metrics/test_ValidWeightsProp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,24 @@
import copy
from offline_rl_ope.Metrics import ValidWeightsProp
from offline_rl_ope import logger
# from ..base import weight_test_res, msk_test_res
from ..base import (
single_discrete_action_test as sdat,
duel_discrete_action_test as ddat,
bin_discrete_action_test as bdat
)
from parameterized import parameterized_class
from ..base import test_configs_fmt_class, TestConfig

for test_conf in [sdat,ddat,bdat]:
class TestValidWeightsProp(unittest.TestCase):
@parameterized_class(test_configs_fmt_class)
class TestValidWeightsProp(unittest.TestCase):

def test_call(self):
max_val=10000
min_val=0.000001
num = (test_conf.weight_test_res > min_val) & (test_conf.weight_test_res < max_val)
num = torch.sum(num, axis=1)
denum = torch.sum(test_conf.msk_test_res, axis=1)
act_res = torch.mean(num/denum).item()
metric = ValidWeightsProp(
max_w=max_val,
min_w=min_val
)
pred_res = metric(weights=test_conf.weight_test_res, weight_msk=test_conf.msk_test_res)
self.assertEqual(act_res,pred_res)
test_conf:TestConfig

def test_call(self):
max_val=10000
min_val=0.000001
num = (self.test_conf.weight_test_res > min_val) & (self.test_conf.weight_test_res < max_val)
num = torch.sum(num, axis=1)
denum = torch.sum(self.test_conf.msk_test_res, axis=1)
act_res = torch.mean(num/denum).item()
metric = ValidWeightsProp(
max_w=max_val,
min_w=min_val
)
pred_res = metric(weights=self.test_conf.weight_test_res, weight_msk=self.test_conf.msk_test_res)
self.assertEqual(act_res,pred_res)
Loading

0 comments on commit 7d04bfa

Please sign in to comment.