Skip to content

Commit

Permalink
fixed unit test for vwp
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuaspear committed Apr 19, 2024
1 parent 479d346 commit 3f7eb32
Showing 1 changed file with 1 addition and 15 deletions.
16 changes: 1 addition & 15 deletions tests/Metrics/test_ValidWeightsProp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,6 @@

logger = logging.getLogger("offline_rl_ope")


class TestImportanceCalc:

def __init__(self) -> None:
self.weight_msk = msk_test_res

class TestImportanceSampler:

def __init__(self) -> None:
self.is_weight_calc = None
self.traj_is_weights = weight_test_res
self.is_weight_calc = TestImportanceCalc()

class TestValidWeightsProp(unittest.TestCase):

def test_call(self):
Expand All @@ -31,9 +18,8 @@ def test_call(self):
denum = torch.sum(msk_test_res, axis=1)
act_res = torch.mean(num/denum).item()
metric = ValidWeightsProp(
is_obj=TestImportanceSampler(),
max_w=max_val,
min_w=min_val
)
pred_res = metric()
pred_res = metric(weights=weight_test_res, weight_msk=msk_test_res)
self.assertEqual(act_res,pred_res)

0 comments on commit 3f7eb32

Please sign in to comment.