Skip to content

Commit

Permalink
Generalised d3rlpy API to allow for wrapping continuous policies with…
Browse files Browse the repository at this point in the history
… D3RlPyTorchAlgoPredict. Generalised d3rlpy API to allow for wrapping continuous policies with D3RlPyTorchAlgoPredict.
  • Loading branch information
joshuaspear committed Jun 3, 2024
1 parent 568b6c2 commit d7ecc1a
Show file tree
Hide file tree
Showing 30 changed files with 561 additions and 136 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ The different kinds of importance samples can also be visualised by querying the
* Updated PropensityModels structure for sklearn and added a helper class for compatability with torch
* Full runtime typechecking with jaxtyping
* Fixed bug with IS methods where the average was being taken twice
* Significantly simplified API, especially integrating Policy classes with propensity models
* Generalised d3rlpy API to allow for wrapping continuous policies with D3RlPyTorchAlgoPredict

#### 5.0.1
* Fixed bug where GreedyDeterministic couldn't handle multi-dimensional action spaces
Expand Down
179 changes: 179 additions & 0 deletions examples/PropensityTrainingLoop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from abc import abstractmethod
from typing import Any, Callable, Dict, Tuple, List
from offline_rl_ope.PropensityModels import (PropensityTorchBase)
from pymlrf.SupervisedLearning.torch import (
train, validate_single_epoch
)
from pymlrf.Structs.torch import DatasetOutput
from pymlrf.utils import set_seed
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from torch.optim import Adam
import os

from offline_rl_ope import logger


class PropensityTrainingLoop:

@abstractmethod
def fit(self, *args, **kwargs) -> Dict[str,Any]:
pass


class PropensityDataset(Dataset):

def __init__(
self,
x:np.array,
y:np.array
) -> None:
super().__init__()
if x.shape[0] != y.shape[0]:
raise Exception
if len(x.shape) != 2:
raise Exception
if len(y.shape) != 2:
raise Exception
self.x = x
self.y = y
self.__len = self.x.shape[0]

def __len__(self)->int:
return self.__len

def __getitem__(self, idx:int)->Tuple[np.array]:
return self.x[idx,:], self.y[idx,:]


class PropensityCollector:

def __init__(self, trgt_type=torch.float) -> None:
self.trgt_type=trgt_type

def __call__(self, batch:List)->DatasetOutput:
in_dict = {"x":[]}
out_dict = {"y":[]}
for row in batch:
in_dict["x"].append(row[0])
out_dict["y"].append(row[1])
in_dict["x"] = torch.tensor(in_dict["x"], dtype=torch.float)
out_dict["y"] = torch.tensor(out_dict["y"], dtype=self.trgt_type)
return DatasetOutput(input=in_dict, output=out_dict)

propensity_collector = PropensityCollector()

class TorchTrainingLoop:

def train(
self,
model:PropensityTorchBase,
x_train: np.array,
y_train: np.array,
x_val: np.array,
y_val: np.array,
batch_size:int,
shuffle:bool,
lr:float,
gpu:bool,
criterion:Callable,
epochs:int,
seed:int,
save_dir:str,
early_stopping_func:Callable
):

train_dataset = PropensityDataset(x=x_train, y=y_train)
train_data_loader=DataLoader(
dataset=train_dataset, batch_size=batch_size, shuffle=shuffle,
collate_fn=propensity_collector
)

val_dataset = PropensityDataset(x=x_val, y=y_val)
val_data_loader=DataLoader(
dataset=val_dataset, batch_size=batch_size, shuffle=shuffle,
collate_fn=propensity_collector
)
optimizer = Adam(
params=model.parameters(),
lr=lr
)
mo, optimal_epoch = train(
model=model,
train_data_loader=train_data_loader,
val_data_loader=val_data_loader,
gpu=gpu,
optimizer=optimizer,
criterion=criterion,
epochs=epochs,
logger=logger,
seed=seed,
save_dir=save_dir,
early_stopping_func=early_stopping_func
)
metric_df = mo.all_metrics_to_df()
metric_df.to_csv(os.path.join(save_dir, "training_metric_df.csv"))
res = {}
for key, metric in mo.metrics.items():
res[key] = metric.value_dict[f"epoch_{optimal_epoch}"]
res["optimal_epoch"] = optimal_epoch
return res

def test(
self,
model: PropensityTorchBase,
x_test: np.array,
y_test: np.array,
gpu:bool,
criterion:Callable,
batch_size:int,
seed:int
):
dataset = PropensityDataset(x=x_test, y=y_test)
data_loader=DataLoader(
dataset=dataset, batch_size=batch_size, shuffle=False,
collate_fn=propensity_collector
)
set_seed(seed)
losses, preds = validate_single_epoch(
model=model,
data_loader=data_loader,
gpu=gpu,
criterion=criterion
)
res = {"mean_criterion_over_batch": np.mean(losses)}
return res

def train_test(
self,
model:PropensityTorchBase,
x_train: np.array,
y_train: np.array,
x_val: np.array,
y_val: np.array,
x_test: np.array,
y_test: np.array,
batch_size:int,
shuffle:bool,
lr:float,
gpu:bool,
criterion:Callable,
epochs:int,
seed:int,
save_dir:str
):
train_res = self.train(
model=model, x_train=x_train, y_train=y_train, x_val=x_val,
y_val=y_val, batch_size=batch_size, shuffle=shuffle,
lr=lr, gpu=gpu, criterion=criterion, epochs=epochs, seed=seed,
save_dir=save_dir
)
test_res = self.test(
model=model, x_test=x_test, y_test=y_test, gpu=gpu,
batch_size=batch_size, seed=seed
)
return {**train_res, **test_res}


torch_training_loop = TorchTrainingLoop()
20 changes: 9 additions & 11 deletions examples/d3rlpy_training_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sklearn.multiclass import OneVsRestClassifier
from xgboost import XGBClassifier
import pandas as pd
import shutil

# Import callbacks
from offline_rl_ope.api.d3rlpy.Callbacks import (
Expand All @@ -23,9 +24,9 @@
from offline_rl_ope.api.d3rlpy.Scorers import (
ISEstimatorScorer, ISDiscreteActionDistScorer, QueryScorer
)
from offline_rl_ope.components.Policy import BehavPolicy
from offline_rl_ope.components.Policy import NumpyPolicyFuncWrapper, Policy
from offline_rl_ope.PropensityModels.sklearn import (
MultiOutputMultiClassTrainer, SklearnTorchTrainerWrapper)
SklearnDiscrete)

if __name__=="__main__":

Expand Down Expand Up @@ -60,18 +61,14 @@

behav_est.fit(X=observations, Y=actions.reshape(-1,1))

sklearn_trainer = MultiOutputMultiClassTrainer(
sklearn_trainer = SklearnDiscrete(
theoretical_action_classes=[np.array([0,1])],
estimator=behav_est
)
sklearn_trainer.fitted_cls = [pd.Series(actions).unique()]
gbt_est = SklearnTorchTrainerWrapper(
sklearn_trainer=sklearn_trainer
)

#gbt_est = GbtEst(estimator=behav_est)
gbt_policy_be = BehavPolicy(
policy_func=gbt_est,
gbt_policy_be = Policy(
policy_func=NumpyPolicyFuncWrapper(sklearn_trainer.predict_proba),
collect_res=False
)

Expand All @@ -80,6 +77,7 @@
is_types=["vanilla", "per_decision"],
behav_policy=gbt_policy_be,
dataset=dataset,
action_dim=1,
eval_policy_kwargs={
"gpu": False,
"collect_act": True
Expand Down Expand Up @@ -162,5 +160,5 @@

# evaluate trained algorithm
evaluate_qlearning_with_environment(algo=dqn, env=env)
# shutil.rmtree("d3rlpy_data")
# shutil.rmtree("d3rlpy_logs")
shutil.rmtree("d3rlpy_data")
shutil.rmtree("d3rlpy_logs")
Loading

0 comments on commit d7ecc1a

Please sign in to comment.