diff --git a/README.md b/README.md index 19da9e7..728001b 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,8 @@ The different kinds of importance samples can also be visualised by querying the * Updated PropensityModels structure for sklearn and added a helper class for compatability with torch * Full runtime typechecking with jaxtyping * Fixed bug with IS methods where the average was being taken twice +* Significantly simplified API, especially integrating Policy classes with propensity models +* Generalised d3rlpy API to allow for wrapping continuous policies with D3RlPyTorchAlgoPredict #### 5.0.1 * Fixed bug where GreedyDeterministic couldn't handle multi-dimensional action spaces diff --git a/examples/PropensityTrainingLoop.py b/examples/PropensityTrainingLoop.py new file mode 100644 index 0000000..778445d --- /dev/null +++ b/examples/PropensityTrainingLoop.py @@ -0,0 +1,179 @@ +from abc import abstractmethod +from typing import Any, Callable, Dict, Tuple, List +from offline_rl_ope.PropensityModels import (PropensityTorchBase) +from pymlrf.SupervisedLearning.torch import ( + train, validate_single_epoch + ) +from pymlrf.Structs.torch import DatasetOutput +from pymlrf.utils import set_seed +import numpy as np +from torch.utils.data import Dataset, DataLoader +import torch +from torch.optim import Adam +import os + +from offline_rl_ope import logger + + +class PropensityTrainingLoop: + + @abstractmethod + def fit(self, *args, **kwargs) -> Dict[str,Any]: + pass + + +class PropensityDataset(Dataset): + + def __init__( + self, + x:np.array, + y:np.array + ) -> None: + super().__init__() + if x.shape[0] != y.shape[0]: + raise Exception + if len(x.shape) != 2: + raise Exception + if len(y.shape) != 2: + raise Exception + self.x = x + self.y = y + self.__len = self.x.shape[0] + + def __len__(self)->int: + return self.__len + + def __getitem__(self, idx:int)->Tuple[np.array]: + return self.x[idx,:], self.y[idx,:] + + +class PropensityCollector: + + def __init__(self, trgt_type=torch.float) -> None: + self.trgt_type=trgt_type + + def __call__(self, batch:List)->DatasetOutput: + in_dict = {"x":[]} + out_dict = {"y":[]} + for row in batch: + in_dict["x"].append(row[0]) + out_dict["y"].append(row[1]) + in_dict["x"] = torch.tensor(in_dict["x"], dtype=torch.float) + out_dict["y"] = torch.tensor(out_dict["y"], dtype=self.trgt_type) + return DatasetOutput(input=in_dict, output=out_dict) + +propensity_collector = PropensityCollector() + +class TorchTrainingLoop: + + def train( + self, + model:PropensityTorchBase, + x_train: np.array, + y_train: np.array, + x_val: np.array, + y_val: np.array, + batch_size:int, + shuffle:bool, + lr:float, + gpu:bool, + criterion:Callable, + epochs:int, + seed:int, + save_dir:str, + early_stopping_func:Callable + ): + + train_dataset = PropensityDataset(x=x_train, y=y_train) + train_data_loader=DataLoader( + dataset=train_dataset, batch_size=batch_size, shuffle=shuffle, + collate_fn=propensity_collector + ) + + val_dataset = PropensityDataset(x=x_val, y=y_val) + val_data_loader=DataLoader( + dataset=val_dataset, batch_size=batch_size, shuffle=shuffle, + collate_fn=propensity_collector + ) + optimizer = Adam( + params=model.parameters(), + lr=lr + ) + mo, optimal_epoch = train( + model=model, + train_data_loader=train_data_loader, + val_data_loader=val_data_loader, + gpu=gpu, + optimizer=optimizer, + criterion=criterion, + epochs=epochs, + logger=logger, + seed=seed, + save_dir=save_dir, + early_stopping_func=early_stopping_func + ) + metric_df = mo.all_metrics_to_df() + metric_df.to_csv(os.path.join(save_dir, "training_metric_df.csv")) + res = {} + for key, metric in mo.metrics.items(): + res[key] = metric.value_dict[f"epoch_{optimal_epoch}"] + res["optimal_epoch"] = optimal_epoch + return res + + def test( + self, + model: PropensityTorchBase, + x_test: np.array, + y_test: np.array, + gpu:bool, + criterion:Callable, + batch_size:int, + seed:int + ): + dataset = PropensityDataset(x=x_test, y=y_test) + data_loader=DataLoader( + dataset=dataset, batch_size=batch_size, shuffle=False, + collate_fn=propensity_collector + ) + set_seed(seed) + losses, preds = validate_single_epoch( + model=model, + data_loader=data_loader, + gpu=gpu, + criterion=criterion + ) + res = {"mean_criterion_over_batch": np.mean(losses)} + return res + + def train_test( + self, + model:PropensityTorchBase, + x_train: np.array, + y_train: np.array, + x_val: np.array, + y_val: np.array, + x_test: np.array, + y_test: np.array, + batch_size:int, + shuffle:bool, + lr:float, + gpu:bool, + criterion:Callable, + epochs:int, + seed:int, + save_dir:str + ): + train_res = self.train( + model=model, x_train=x_train, y_train=y_train, x_val=x_val, + y_val=y_val, batch_size=batch_size, shuffle=shuffle, + lr=lr, gpu=gpu, criterion=criterion, epochs=epochs, seed=seed, + save_dir=save_dir + ) + test_res = self.test( + model=model, x_test=x_test, y_test=y_test, gpu=gpu, + batch_size=batch_size, seed=seed + ) + return {**train_res, **test_res} + + +torch_training_loop = TorchTrainingLoop() \ No newline at end of file diff --git a/examples/d3rlpy_training_api.py b/examples/d3rlpy_training_api.py index b9ea024..1fa066a 100644 --- a/examples/d3rlpy_training_api.py +++ b/examples/d3rlpy_training_api.py @@ -11,6 +11,7 @@ from sklearn.multiclass import OneVsRestClassifier from xgboost import XGBClassifier import pandas as pd +import shutil # Import callbacks from offline_rl_ope.api.d3rlpy.Callbacks import ( @@ -23,9 +24,9 @@ from offline_rl_ope.api.d3rlpy.Scorers import ( ISEstimatorScorer, ISDiscreteActionDistScorer, QueryScorer ) -from offline_rl_ope.components.Policy import BehavPolicy +from offline_rl_ope.components.Policy import NumpyPolicyFuncWrapper, Policy from offline_rl_ope.PropensityModels.sklearn import ( - MultiOutputMultiClassTrainer, SklearnTorchTrainerWrapper) + SklearnDiscrete) if __name__=="__main__": @@ -60,18 +61,14 @@ behav_est.fit(X=observations, Y=actions.reshape(-1,1)) - sklearn_trainer = MultiOutputMultiClassTrainer( + sklearn_trainer = SklearnDiscrete( theoretical_action_classes=[np.array([0,1])], estimator=behav_est ) sklearn_trainer.fitted_cls = [pd.Series(actions).unique()] - gbt_est = SklearnTorchTrainerWrapper( - sklearn_trainer=sklearn_trainer - ) - #gbt_est = GbtEst(estimator=behav_est) - gbt_policy_be = BehavPolicy( - policy_func=gbt_est, + gbt_policy_be = Policy( + policy_func=NumpyPolicyFuncWrapper(sklearn_trainer.predict_proba), collect_res=False ) @@ -80,6 +77,7 @@ is_types=["vanilla", "per_decision"], behav_policy=gbt_policy_be, dataset=dataset, + action_dim=1, eval_policy_kwargs={ "gpu": False, "collect_act": True @@ -162,5 +160,5 @@ # evaluate trained algorithm evaluate_qlearning_with_environment(algo=dqn, env=env) - # shutil.rmtree("d3rlpy_data") - # shutil.rmtree("d3rlpy_logs") \ No newline at end of file + shutil.rmtree("d3rlpy_data") + shutil.rmtree("d3rlpy_logs") \ No newline at end of file diff --git a/examples/static_torch_continuous.py b/examples/static_torch_continuous.py new file mode 100644 index 0000000..5739602 --- /dev/null +++ b/examples/static_torch_continuous.py @@ -0,0 +1,239 @@ +from d3rlpy.algos import SACConfig +from d3rlpy.datasets import get_pendulum +from typing import Dict +from d3rlpy.ope import FQEConfig, FQE +from d3rlpy.metrics import (SoftOPCEvaluator, + InitialStateValueEstimationEvaluator) +from d3rlpy.dataset import BasicTransitionPicker, ReplayBuffer, InfiniteBuffer +import numpy as np +import torch +from torch.distributions import Normal + +from pymlrf.SupervisedLearning.torch import ( + PercEpsImprove) +from pymlrf.FileSystem import DirectoryHandler + +from offline_rl_ope.Dataset import ISEpisode +from offline_rl_ope.components.Policy import Policy, GreedyDeterministic +from offline_rl_ope.components.ImportanceSampler import ISWeightOrchestrator +from offline_rl_ope.OPEEstimators import ( + ISEstimator, DREstimator, D3rlpyQlearnDM) +from offline_rl_ope.PropensityModels.torch import FullGuassian, TorchRegTrainer +from offline_rl_ope.LowerBounds.HCOPE import get_lower_bound + +from offline_rl_ope.api.d3rlpy.Misc import D3RlPyTorchAlgoPredict +from offline_rl_ope.types import PropensityTorchOutputType + +from PropensityTrainingLoop import torch_training_loop + +class GaussianLossWrapper: + + def __init__(self) -> None: + self.scorer = torch.nn.GaussianNLLLoss() + + def __call__( + self, + y_pred:PropensityTorchOutputType, + y_true:Dict[str,torch.Tensor] + ) -> torch.Tensor: + res = self.scorer( + input=y_pred["loc"], + var=y_pred["scale"], + target=y_true["y"] + ) + return res + +if __name__ == "__main__": + # obtain dataset + dataset, env = get_pendulum() + + # setup algorithm + gamma = 0.99 + sac = SACConfig(gamma=gamma).create() + + dataset = ReplayBuffer( + buffer=InfiniteBuffer(), + episodes=dataset.episodes[0:100] + ) + + # Fit the behaviour model + observations = [] + actions = [] + tp = BasicTransitionPicker() + for ep in dataset.episodes: + for i in range(ep.transition_count): + _transition = tp(ep,i) + observations.append(_transition.observation.reshape(1,-1)) + actions.append(_transition.action) + + observations = np.concatenate(observations) + actions = np.concatenate(actions) + + assert len(env.observation_space.shape) == 1 + estimator = FullGuassian( + input_dim=env.observation_space.shape[0], + #layers_dim=[64,64], + layers_dim=[64,64], + m_out_dim=1, + sd_out_dim=1 + ) + estimator = TorchRegTrainer( + estimator=estimator, + dist_func=Normal, + gpu=False + ) + early_stop_criteria = PercEpsImprove(eps=0, direction="gr") + meta_data = { + "train_loss_criteria": "gauss_nll", + "val_loss_criteria": "gauss_nll" + } + criterion = GaussianLossWrapper() + + prop_output_dh = DirectoryHandler(loc="./propensity_output") + if not prop_output_dh.is_created: + prop_output_dh.create() + else: + prop_output_dh.clear() + + torch_training_loop.train( + model=estimator.estimator, + x_train=observations, + y_train=actions.reshape(-1,1), + x_val=observations, + y_val=actions.reshape(-1,1), + batch_size=32, + shuffle=True, + lr=0.01, + gpu=False, + criterion=criterion, + epochs=4, + seed=1, + save_dir=prop_output_dh.loc, + early_stopping_func=early_stop_criteria + ) + + policy_be = Policy( + policy_func=estimator.predict_proba, + collect_res=False + ) + + no_obs_steps = int(len(actions)*0.025) + n_epochs=1 + n_steps_per_epoch = no_obs_steps + n_steps = no_obs_steps*n_epochs + sac.fit( + dataset, n_steps=n_steps, n_steps_per_epoch=n_steps_per_epoch, + with_timestamp=False + ) + + fqe_scorers = { + "soft_opc": SoftOPCEvaluator( + return_threshold=70, + episodes=dataset.episodes + ), + "init_state_val": InitialStateValueEstimationEvaluator( + episodes=dataset.episodes + ) + } + + + fqe_config = FQEConfig(learning_rate=1e-4) + #discrete_fqe = DiscreteFQE(algo=dqn, **fqe_init_kwargs) + fqe = FQE(algo=sac, config=fqe_config, device=False) + + fqe.fit(dataset, evaluators=fqe_scorers, n_steps=no_obs_steps) + + + # Static OPE evaluation + policy_func = D3RlPyTorchAlgoPredict( + predict_func=sac.predict, + action_dim=1 + ) + eval_policy = GreedyDeterministic( + policy_func=policy_func, collect_res=False, + collect_act=True, gpu=False + ) + + episodes = [] + for ep in dataset.episodes: + episodes.append(ISEpisode( + state=torch.Tensor(ep.observations), + action=torch.Tensor(ep.actions).view(-1,1), + reward=torch.Tensor(ep.rewards)) + ) + + is_weight_calculator = ISWeightOrchestrator("vanilla", "per_decision", + behav_policy=policy_be) + is_weight_calculator.update( + states=[ep.state for ep in episodes], + actions=[ep.action for ep in episodes], + eval_policy=eval_policy) + + fqe_dm_model = D3rlpyQlearnDM(model=fqe) + + is_estimator = ISEstimator(norm_weights=False, cache_traj_rewards=True) + wis_estimator = ISEstimator(norm_weights=True) + wis_estimator_smooth = ISEstimator(norm_weights=True, norm_kwargs={ + "smooth_eps":0.0000001 + }) + w_dr_estimator = DREstimator( + dm_model=fqe_dm_model, norm_weights=True, + ignore_nan=True) + + + res = is_estimator.predict( + rewards=[ep.reward for ep in episodes], discount=0.99, + weights=is_weight_calculator["vanilla"].traj_is_weights, + is_msk=is_weight_calculator.weight_msk, + states=[ep.state for ep in episodes], + actions=[ep.action for ep in episodes], + ) + print(res) + + res = is_estimator.predict( + weights=is_weight_calculator["per_decision"].traj_is_weights, + rewards=[ep.reward for ep in episodes], discount=0.99, + is_msk=is_weight_calculator.weight_msk, + states=[ep.state for ep in episodes], + actions=[ep.action for ep in episodes], + ) + print(res) + traj_rewards = is_estimator.traj_rewards_cache.squeeze().numpy() + print(get_lower_bound(X=traj_rewards, delta=0.05)) + + res = wis_estimator.predict( + rewards=[ep.reward for ep in episodes], discount=0.99, + weights=is_weight_calculator["vanilla"].traj_is_weights, + is_msk=is_weight_calculator.weight_msk, + states=[ep.state for ep in episodes], + actions=[ep.action for ep in episodes] + ) + print(res) + + res = wis_estimator.predict( + weights=is_weight_calculator["per_decision"].traj_is_weights, + rewards=[ep.reward for ep in episodes], discount=0.99, + is_msk=is_weight_calculator.weight_msk, + states=[ep.state for ep in episodes], + actions=[ep.action for ep in episodes], + ) + print(res) + + res = wis_estimator_smooth.predict( + weights=is_weight_calculator["vanilla"].traj_is_weights, + rewards=[ep.reward for ep in episodes], discount=0.99, + is_msk=is_weight_calculator.weight_msk, + states=[ep.state for ep in episodes], + actions=[ep.action for ep in episodes], + ) + print(res) + + res = w_dr_estimator.predict( + weights=is_weight_calculator["per_decision"].traj_is_weights, + rewards=[ep.reward for ep in episodes], discount=0.99, + is_msk=is_weight_calculator.weight_msk, + states=[ep.state for ep in episodes], + actions=[ep.action for ep in episodes], + ) + print(res) + diff --git a/examples/static.py b/examples/static_xgboost_discrete.py similarity index 93% rename from examples/static.py rename to examples/static_xgboost_discrete.py index 34b827c..19be689 100644 --- a/examples/static.py +++ b/examples/static_xgboost_discrete.py @@ -12,12 +12,13 @@ import torch from offline_rl_ope.Dataset import ISEpisode -from offline_rl_ope.components.Policy import BehavPolicy, GreedyDeterministic +from offline_rl_ope.components.Policy import ( + GreedyDeterministic, Policy, NumpyPolicyFuncWrapper) from offline_rl_ope.components.ImportanceSampler import ISWeightOrchestrator from offline_rl_ope.OPEEstimators import ( ISEstimator, DREstimator, D3rlpyQlearnDM) from offline_rl_ope.PropensityModels.sklearn import ( - MultiOutputMultiClassTrainer, SklearnTorchTrainerWrapper) + SklearnDiscrete) from offline_rl_ope.LowerBounds.HCOPE import get_lower_bound from offline_rl_ope.api.d3rlpy.Misc import D3RlPyTorchAlgoPredict @@ -54,16 +55,14 @@ behav_est.fit(X=observations, Y=actions.reshape(-1,1)) - sklearn_trainer = MultiOutputMultiClassTrainer( + sklearn_trainer = SklearnDiscrete( theoretical_action_classes=[np.array([0,1])], estimator=behav_est ) sklearn_trainer.fitted_cls = [pd.Series(actions).unique()] - gbt_est = SklearnTorchTrainerWrapper( - sklearn_trainer=sklearn_trainer - ) - gbt_policy_be = BehavPolicy( - policy_func=gbt_est, + + gbt_policy_be = Policy( + policy_func=NumpyPolicyFuncWrapper(sklearn_trainer.predict_proba), collect_res=False ) @@ -93,7 +92,10 @@ # Static OPE evaluation - policy_func = D3RlPyTorchAlgoPredict(predict_func=dqn.predict) + policy_func = D3RlPyTorchAlgoPredict( + predict_func=dqn.predict, + action_dim=1 + ) eval_policy = GreedyDeterministic( policy_func=policy_func, collect_res=False, collect_act=True, gpu=False diff --git a/propensity_output/epoch_1_train_preds.pkl b/propensity_output/epoch_1_train_preds.pkl new file mode 100644 index 0000000..c7e4f6d Binary files /dev/null and b/propensity_output/epoch_1_train_preds.pkl differ diff --git a/propensity_output/epoch_1_val_preds.pkl b/propensity_output/epoch_1_val_preds.pkl new file mode 100644 index 0000000..e8cd87c Binary files /dev/null and b/propensity_output/epoch_1_val_preds.pkl differ diff --git a/propensity_output/epoch_2_train_preds.pkl b/propensity_output/epoch_2_train_preds.pkl new file mode 100644 index 0000000..7e08a05 Binary files /dev/null and b/propensity_output/epoch_2_train_preds.pkl differ diff --git a/propensity_output/epoch_2_val_preds.pkl b/propensity_output/epoch_2_val_preds.pkl new file mode 100644 index 0000000..8ce3784 Binary files /dev/null and b/propensity_output/epoch_2_val_preds.pkl differ diff --git a/propensity_output/epoch_3_train_preds.pkl b/propensity_output/epoch_3_train_preds.pkl new file mode 100644 index 0000000..5b1046b Binary files /dev/null and b/propensity_output/epoch_3_train_preds.pkl differ diff --git a/propensity_output/epoch_3_val_preds.pkl b/propensity_output/epoch_3_val_preds.pkl new file mode 100644 index 0000000..7f0bafa Binary files /dev/null and b/propensity_output/epoch_3_val_preds.pkl differ diff --git a/propensity_output/epoch_4_train_preds.pkl b/propensity_output/epoch_4_train_preds.pkl new file mode 100644 index 0000000..d089577 Binary files /dev/null and b/propensity_output/epoch_4_train_preds.pkl differ diff --git a/propensity_output/epoch_4_val_preds.pkl b/propensity_output/epoch_4_val_preds.pkl new file mode 100644 index 0000000..f05af17 Binary files /dev/null and b/propensity_output/epoch_4_val_preds.pkl differ diff --git a/propensity_output/mdl_chkpnt_epoch_1.pt b/propensity_output/mdl_chkpnt_epoch_1.pt new file mode 100644 index 0000000..91cb6ca Binary files /dev/null and b/propensity_output/mdl_chkpnt_epoch_1.pt differ diff --git a/propensity_output/mdl_chkpnt_epoch_2.pt b/propensity_output/mdl_chkpnt_epoch_2.pt new file mode 100644 index 0000000..0a73a78 Binary files /dev/null and b/propensity_output/mdl_chkpnt_epoch_2.pt differ diff --git a/propensity_output/mdl_chkpnt_epoch_3.pt b/propensity_output/mdl_chkpnt_epoch_3.pt new file mode 100644 index 0000000..deabebc Binary files /dev/null and b/propensity_output/mdl_chkpnt_epoch_3.pt differ diff --git a/propensity_output/mdl_chkpnt_epoch_4.pt b/propensity_output/mdl_chkpnt_epoch_4.pt new file mode 100644 index 0000000..2300ba1 Binary files /dev/null and b/propensity_output/mdl_chkpnt_epoch_4.pt differ diff --git a/propensity_output/training_metric_df.csv b/propensity_output/training_metric_df.csv new file mode 100644 index 0000000..072c7dc --- /dev/null +++ b/propensity_output/training_metric_df.csv @@ -0,0 +1,9 @@ +,raw_vals,metric_name +epoch_1,0.5026572188648346,epoch_train_loss +epoch_2,0.45641687404069453,epoch_train_loss +epoch_3,0.4506564313404394,epoch_train_loss +epoch_4,0.4499283042737856,epoch_train_loss +epoch_1,0.4542446283801114,epoch_val_loss +epoch_2,0.43553749836608163,epoch_val_loss +epoch_3,0.4922322245233471,epoch_val_loss +epoch_4,0.4404787527628977,epoch_val_loss diff --git a/src/offline_rl_ope/PropensityModels/sklearn/Discrete.py b/src/offline_rl_ope/PropensityModels/sklearn/Discrete.py index 5d2d150..b06b1ee 100644 --- a/src/offline_rl_ope/PropensityModels/sklearn/Discrete.py +++ b/src/offline_rl_ope/PropensityModels/sklearn/Discrete.py @@ -9,10 +9,10 @@ from ..base import PropensityTrainer __all__ = [ - "MultiOutputMultiClassTrainer" + "SklearnDiscrete" ] -class MultiOutputMultiClassTrainer(PropensityTrainer): +class SklearnDiscrete(PropensityTrainer): def __init__( self, diff --git a/src/offline_rl_ope/PropensityModels/sklearn/__init__.py b/src/offline_rl_ope/PropensityModels/sklearn/__init__.py index 92c14d4..aed4cdf 100644 --- a/src/offline_rl_ope/PropensityModels/sklearn/__init__.py +++ b/src/offline_rl_ope/PropensityModels/sklearn/__init__.py @@ -1,2 +1 @@ -from .Discrete import * -from .utils import * \ No newline at end of file +from .Discrete import * \ No newline at end of file diff --git a/src/offline_rl_ope/PropensityModels/sklearn/utils.py b/src/offline_rl_ope/PropensityModels/sklearn/utils.py deleted file mode 100644 index 889ed7a..0000000 --- a/src/offline_rl_ope/PropensityModels/sklearn/utils.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -from jaxtyping import jaxtyped, Float -from typeguard import typechecked as typechecker - -from .Discrete import MultiOutputMultiClassTrainer -from ...types import StateTensor, ActionTensor - -class SklearnTorchTrainerWrapper: - - def __init__( - self, - sklearn_trainer:MultiOutputMultiClassTrainer - ) -> None: - assert isinstance(sklearn_trainer,MultiOutputMultiClassTrainer) - self.__sklearn_trainer = sklearn_trainer - - @jaxtyped(typechecker=typechecker) - def __call__( - self, - y:ActionTensor, - x:StateTensor - )->torch.Tensor: - # assert isinstance(y,torch.Tensor) - # assert isinstance(x,torch.Tensor) - res = torch.Tensor( - self.__sklearn_trainer.predict_proba( - y=y.detach().cpu().numpy(), - x=x.detach().cpu().numpy() - ) - ) - return res \ No newline at end of file diff --git a/src/offline_rl_ope/_version.py b/src/offline_rl_ope/_version.py index 3f4942a..0bc38c4 100644 --- a/src/offline_rl_ope/_version.py +++ b/src/offline_rl_ope/_version.py @@ -1,38 +1 @@ -__version__ = "6.0.0" - -import numpy as np -in_arr = np.array([[0.91518986, 0.08481016], - [0.70076823, 0.29923177], - [0.9462318 , 0.05376823], - [0.9555495 , 0.04445055], - [0.824901 , 0.17509903], - [0.4951074 , 0.5048926 ], - [0.7441163 , 0.25588372], - [0.15970033, 0.84029967], - [0.18569005, 0.81430995], - [0.1911512 , 0.8088488 ], - [0.1069113 , 0.8930887 ], - [0.02992624, 0.97007376], - [0.5498897 , 0.4501103 ], - [0.02385765, 0.97614235], - [0.5980996 , 0.40190044], - [0.01977807, 0.9802219 ], - [0.46130735, 0.53869265], - [0.05885756, 0.94114244], - [0.06812495, 0.93187505], - [0.06184661, 0.9381534 ], - [0.05487567, 0.9451243 ], - [0.405338 , 0.594662 ], - [0.16424012, 0.8357599 ], - [0.06257224, 0.93742776], - [0.13872904, 0.86127096], - [0.17525399, 0.824746 ], - [0.27241033, 0.72758967], - [0.4413392 , 0.5586608 ], - [0.20125568, 0.7987443 ], - [0.7028972 , 0.2971028 ], - [0.13157076, 0.86842924], - [0.6233928 , 0.37660718], - [0.7283679 , 0.2716321 ], - [0.68545943, 0.31454057]]) -in_arr = in_arr[:,0][:,None] \ No newline at end of file +__version__ = "6.0.0" \ No newline at end of file diff --git a/src/offline_rl_ope/api/d3rlpy/Callbacks/IS.py b/src/offline_rl_ope/api/d3rlpy/Callbacks/IS.py index a12ff4b..c67bf9e 100644 --- a/src/offline_rl_ope/api/d3rlpy/Callbacks/IS.py +++ b/src/offline_rl_ope/api/d3rlpy/Callbacks/IS.py @@ -6,7 +6,7 @@ from d3rlpy.interface import QLearningAlgoProtocol from d3rlpy.dataset import ReplayBuffer -from ....components.Policy import Policy, GreedyDeterministic +from ....components.Policy import BasePolicy, GreedyDeterministic from ....components.ImportanceSampler import ISWeightOrchestrator from .base import OPECallbackBase from ..Misc import D3RlPyTorchAlgoPredict @@ -23,8 +23,9 @@ class ISCallback(ISWeightOrchestrator, OPECallbackBase): def __init__( self, is_types:List[str], - behav_policy: Policy, - dataset: ReplayBuffer, + behav_policy: BasePolicy, + dataset: ReplayBuffer, + action_dim:int, eval_policy_kwargs:Dict[str,Any] = {}, debug:bool=False, debug_path:str="", @@ -35,6 +36,7 @@ def __init__( self.states:List[torch.Tensor] = [] self.actions:List[torch.Tensor] = [] self.rewards:List[torch.Tensor] = [] + self.action_dim = action_dim for traj in dataset.episodes: self.states.append(torch.Tensor(traj.observations)) self.actions.append(torch.Tensor(traj.actions)) @@ -62,7 +64,9 @@ def run( total_step:int ) -> None: policy_func = D3RlPyTorchAlgoPredict( - predict_func=algo.predict) + predict_func=algo.predict, + action_dim=self.action_dim + ) eval_policy = GreedyDeterministic( policy_func=policy_func, **self.eval_policy_kwargs diff --git a/src/offline_rl_ope/api/d3rlpy/Misc.py b/src/offline_rl_ope/api/d3rlpy/Misc.py index d815068..6ebda2f 100644 --- a/src/offline_rl_ope/api/d3rlpy/Misc.py +++ b/src/offline_rl_ope/api/d3rlpy/Misc.py @@ -1,17 +1,27 @@ import torch +from jaxtyping import jaxtyped +from typeguard import typechecked as typechecker + from .types import D3rlpyAlgoPredictProtocal +from ...types import StateTensor, ActionTensor + -from ...RuntimeChecks import check_array_dim __all__ = ["D3RlPyTorchAlgoPredict"] class D3RlPyTorchAlgoPredict: - def __init__(self, predict_func:D3rlpyAlgoPredictProtocal): + def __init__( + self, + predict_func:D3rlpyAlgoPredictProtocal, + action_dim:int + ): self.predict_func = predict_func - - def __call__(self, x:torch.Tensor): - pred = self.predict_func(x.cpu().numpy()) - check_array_dim(pred,1) - pred = pred.reshape(-1,1) + self.action_dim = action_dim + + @jaxtyped(typechecker=typechecker) + def __call__(self, x:StateTensor)->ActionTensor: + pred = self.predict_func(x.cpu().numpy()).reshape( + -1, self.action_dim + ) return torch.Tensor(pred) diff --git a/src/offline_rl_ope/components/ImportanceSampler.py b/src/offline_rl_ope/components/ImportanceSampler.py index 1032018..23ee841 100644 --- a/src/offline_rl_ope/components/ImportanceSampler.py +++ b/src/offline_rl_ope/components/ImportanceSampler.py @@ -7,14 +7,20 @@ from ..RuntimeChecks import check_array_dim, check_array_shape -from .Policy import Policy +from .Policy import BasePolicy from .. import logger from ..types import (StateTensor,ActionTensor,WeightTensor) +__all__ = [ + "ISWeightCalculator", "ImportanceSampler", "VanillaIS", "PerDecisionIS", + "ISWeightOrchestrator" + ] + + class ISWeightCalculator: - def __init__(self, behav_policy:Policy) -> None: - assert isinstance(behav_policy,Policy) + def __init__(self, behav_policy:BasePolicy) -> None: + assert isinstance(behav_policy,BasePolicy) self.__behav_policy = behav_policy self.is_weights = torch.empty(0) self.weight_msk = torch.empty(0) @@ -25,7 +31,7 @@ def get_traj_w( self, states:StateTensor, actions:ActionTensor, - eval_policy:Policy + eval_policy:BasePolicy )->Float[torch.Tensor, "traj_length"]: """Function to calculate the timestep IS weights over a trajectory i.e., for each timestep (t) Tensor(\pi_{e}(a_{t}|s_{t})/\pi_{b}(a_{t}|s_{t})) @@ -46,7 +52,7 @@ def get_traj_w( # check_array_dim(actions,2) # assert isinstance(states, torch.Tensor) # assert isinstance(actions, torch.Tensor) - #assert isinstance(eval_policy, Policy) + #assert isinstance(eval_policy, BasePolicy) with torch.no_grad(): behav_probs = self.__behav_policy( @@ -70,7 +76,7 @@ def get_dataset_w( self, states:List[torch.Tensor], actions:List[torch.Tensor], - eval_policy:Policy + eval_policy:BasePolicy )->Tuple[WeightTensor, WeightTensor]: """_summary_ @@ -81,7 +87,7 @@ def get_dataset_w( (traj_length, number of actions). Note, this is likely (traj_length,1) if for example a discrete action space has been flattened from [0,1]^2 to [0,1,2,3] - eval_policy (Policy): Policy class defining the target policy to be + eval_policy (BasePolicy): Policy class defining the target policy to be evaluated Returns: @@ -94,7 +100,7 @@ def get_dataset_w( ith trajectory was observed """ assert len(states) == len(actions) - #assert isinstance(eval_policy, Policy) + #assert isinstance(eval_policy, BasePolicy) # weight_res = torch.zeros(size=(len(states),h)) # weight_msk = torch.zeros(size=(len(states),h)) weight_res_lst:List[torch.Tensor] = [] @@ -131,7 +137,7 @@ def update( self, states:List[torch.Tensor], actions:List[torch.Tensor], - eval_policy:Policy + eval_policy:BasePolicy ): _is_weights, _weight_msk = self.get_dataset_w( states=states, actions=actions, eval_policy=eval_policy) @@ -242,7 +248,7 @@ class ISWeightOrchestrator(ISWeightCalculator): "per_decision": PerDecisionIS } - def __init__(self, *args, behav_policy:Policy) -> None: + def __init__(self, *args, behav_policy:BasePolicy) -> None: super().__init__(behav_policy=behav_policy) self.is_samplers:Dict[str,ImportanceSampler] = {} for arg in args: @@ -263,7 +269,7 @@ def update( self, states:List[torch.Tensor], actions:List[torch.Tensor], - eval_policy:Policy + eval_policy:BasePolicy ): super().update(states=states, actions=actions, eval_policy=eval_policy) for sampler in self.is_samplers.keys(): diff --git a/src/offline_rl_ope/components/Policy.py b/src/offline_rl_ope/components/Policy.py index d2e2db7..a442152 100644 --- a/src/offline_rl_ope/components/Policy.py +++ b/src/offline_rl_ope/components/Policy.py @@ -3,9 +3,10 @@ from typing import Callable, List from jaxtyping import jaxtyped, Float from typeguard import typechecked as typechecker +import numpy as np from ..RuntimeChecks import check_array_dim -from ..types import (StateTensor,ActionTensor) +from ..types import (StateTensor,ActionTensor,StateArray,ActionArray) def postproc_pass(x:torch.Tensor)->torch.Tensor: @@ -19,10 +20,43 @@ def preproc_cuda(x:torch.Tensor)->torch.Tensor: __all__ = [ - "Policy", "GreedyDeterministic", "BehavPolicy" + "Policy", "GreedyDeterministic", "BasePolicy", "NumpyPolicyFuncWrapper", + "NumpyGreedyPolicyFuncWrapper" ] -class Policy(metaclass=ABCMeta): + +class NumpyPolicyFuncWrapper: + + def __init__(self, policy_func:Callable[..., torch.Tensor]) -> None: + self.policy_func = policy_func + + def __call__( + self, + state:StateTensor, + action:ActionTensor + )->Float[torch.Tensor, "traj_length 1"]: + res = self.policy_func( + state.cpu().detach().numpy(), + action.cpu().detach().numpy(), + ) + return torch.Tensor(res) + +class NumpyGreedyPolicyFuncWrapper: + + def __init__(self, policy_func:Callable[..., torch.Tensor]) -> None: + self.policy_func = policy_func + + def __call__( + self, + state:StateTensor + )->ActionTensor: + res = self.policy_func( + state.cpu().detach().numpy() + ) + return torch.Tensor(res) + + +class BasePolicy(metaclass=ABCMeta): def __init__( self, @@ -90,11 +124,14 @@ def __call__( """ pass -class BehavPolicy(Policy): +class Policy(BasePolicy): def __init__( self, - policy_func:Callable[[torch.Tensor,torch.Tensor], torch.Tensor], + policy_func:Callable[ + [StateTensor,ActionTensor], + Float[torch.Tensor, "traj_length 1"] + ], collect_res:bool=False, collect_act:bool=False, gpu:bool=False @@ -111,12 +148,13 @@ def __call__( # assert isinstance(state,torch.Tensor) # assert isinstance(action,torch.Tensor) # check_array_dim(action,2) - res = self.policy_func(y=action, x=state) + state = self.preproc_tens(state) + res = self.policy_func(state, action) + res = self.postproc_tens(res) self.collect_res_fn(res) return res - -class GreedyDeterministic(Policy): +class GreedyDeterministic(BasePolicy): def __init__( self, @@ -140,7 +178,7 @@ def __call__( # assert isinstance(action,torch.Tensor) # check_array_dim(action,2) state = self.preproc_tens(state) - greedy_action = self.policy_func(x=state) + greedy_action = self.policy_func(state) check_array_dim(greedy_action,2) assert action.shape == greedy_action.shape greedy_action = self.postproc_tens(greedy_action) diff --git a/src/offline_rl_ope/components/__init__.py b/src/offline_rl_ope/components/__init__.py index 57f8aff..e69de29 100644 --- a/src/offline_rl_ope/components/__init__.py +++ b/src/offline_rl_ope/components/__init__.py @@ -1,3 +0,0 @@ -from .Policy import * -from .ImportanceSampler import ( - ISWeightCalculator, ISWeightOrchestrator, PerDecisionIS, VanillaIS) \ No newline at end of file diff --git a/src/offline_rl_ope/types.py b/src/offline_rl_ope/types.py index 9943617..282bbcb 100644 --- a/src/offline_rl_ope/types.py +++ b/src/offline_rl_ope/types.py @@ -42,4 +42,14 @@ def __call__(self, x:torch.Tensor) -> PropensityTorchOutputType: ... def eval(self)->None: - ... \ No newline at end of file + ... + + +@runtime_checkable +class PropensitySklearnContinuousType(Protocol): + + def predict_proba(self, X:StateArray) -> ActionArray: + ... + + def predict(self, X:StateArray) -> ActionArray: + ... diff --git a/tests/components/test_ImportanceSampler.py b/tests/components/test_ImportanceSampler.py index c0019be..b6b8a5f 100644 --- a/tests/components/test_ImportanceSampler.py +++ b/tests/components/test_ImportanceSampler.py @@ -3,7 +3,7 @@ import torch import numpy as np import copy -from offline_rl_ope.components.Policy import Policy +from offline_rl_ope.components.Policy import BasePolicy from offline_rl_ope.components.ImportanceSampler import ( VanillaIS, PerDecisionIS, ISWeightCalculator ) @@ -66,7 +66,7 @@ def setUp(self) -> None: be_policy_mock = TestPolicy(self.test_conf.test_action_probs) behav_policy = MagicMock( - spec=Policy, + spec=BasePolicy, side_effect=be_policy_mock ) #behav_policy.__call__ = MagicMock(side_effect=) @@ -92,7 +92,7 @@ def test_get_traj_w(self): #eval_policy = TestPolicy(self.test_conf.test_eval_action_probs) e_policy_mock = TestPolicy(self.test_conf.test_eval_action_probs) eval_policy = MagicMock( - spec=Policy, + spec=BasePolicy, side_effect=e_policy_mock ) for s,a in zip(self.test_conf.test_state_vals, self.test_conf.test_action_vals): @@ -115,7 +115,7 @@ def test_get_dataset_w(self): #eval_policy = TestPolicy(self.test_conf.test_eval_action_probs) e_policy_mock = TestPolicy(self.test_conf.test_eval_action_probs) eval_policy = MagicMock( - spec=Policy, + spec=BasePolicy, side_effect=e_policy_mock ) is_weights, weight_msk = self.is_sampler.get_dataset_w( diff --git a/tests/components/test_Policy.py b/tests/components/test_Policy.py index b378dd1..d817957 100644 --- a/tests/components/test_Policy.py +++ b/tests/components/test_Policy.py @@ -3,7 +3,7 @@ import torch import numpy as np from offline_rl_ope.components.Policy import ( - GreedyDeterministic, BehavPolicy) + GreedyDeterministic, Policy) from offline_rl_ope import logger from parameterized import parameterized_class from ..base import test_configs_fmt_class, TestConfig @@ -158,12 +158,12 @@ def __init__(self) -> None: pass @parameterized_class(test_configs_fmt_class) -class BehavPolicyTest(unittest.TestCase): +class PolicyTest(unittest.TestCase): test_conf:TestConfig def setUp(self) -> None: - def __mock_return(y, x): + def __mock_return(x,y): lkp = { "_".join( [ @@ -183,7 +183,7 @@ def __mock_return(y, x): #policy_func = MockPolicyClass() #policy_func.__call__ = MagicMock(side_effect=__mock_return) #self.policy = BehavPolicy(policy_func) - self.policy = BehavPolicy( + self.policy = Policy( policy_func=MagicMock(side_effect=__mock_return))