diff --git a/.readthedocs.yml b/.readthedocs.yml index 9f9649b..7484b97 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -4,13 +4,20 @@ # Required version: 2 +# Build documentation in the docs/ directory with Sphinx +build: + os: ubuntu-20.04 + tools: + python: "3.8" +# jobs: +# pre_build: +# - cp -r notebooks docs/ + # Build documentation in the docs/ directory with Sphinx sphinx: + builder: html configuration: docs/conf.py - -# Build documentation with MkDocs -#mkdocs: -# configuration: mkdocs.yml + fail_on_warning: false # Optionally build your docs in additional formats such as PDF and ePub formats: @@ -18,7 +25,6 @@ formats: # Optionally set the version of Python and requirements required to build your docs python: - version: 3.7 install: - requirements: docs/requirements.txt - requirements: requirements.txt \ No newline at end of file diff --git a/Readme.rst b/Readme.rst index 923f4ea..89b1a69 100644 --- a/Readme.rst +++ b/Readme.rst @@ -265,5 +265,5 @@ Tags **RU**: аплифт моделирование, Uplift модель -**ZH**: 隆起建模,因果推断,因果效应,因果关系,个人治疗效应,真正的电梯,净电梯 +**ZH**: uplift增量建模, 因果推断, 因果效应, 因果关系, 个体干预因果效应, 真实增量, 净增量, 增量建模 diff --git a/docs/_static/images/x5_table_scheme.png b/docs/_static/images/x5_table_scheme.png new file mode 100644 index 0000000..1d65bca Binary files /dev/null and b/docs/_static/images/x5_table_scheme.png differ diff --git a/docs/api/metrics/index.rst b/docs/api/metrics/index.rst index b60d4d6..44ca0b9 100644 --- a/docs/api/metrics/index.rst +++ b/docs/api/metrics/index.rst @@ -17,4 +17,5 @@ ./response_rate_by_percentile ./treatment_balance_curve ./average_squared_deviation + ./max_prof_uplift ./make_uplift_scorer \ No newline at end of file diff --git a/docs/api/metrics/max_prof_uplift.rst b/docs/api/metrics/max_prof_uplift.rst new file mode 100644 index 0000000..105cc1e --- /dev/null +++ b/docs/api/metrics/max_prof_uplift.rst @@ -0,0 +1,5 @@ +********************************************** +`sklift.metrics <./>`_.max_prof_uplift +********************************************** + +.. autofunction:: sklift.metrics.metrics.max_prof_uplift \ No newline at end of file diff --git a/docs/changelog.md b/docs/changelog.md index 849f5d2..a045457 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -8,9 +8,28 @@ * 🔨 something that previously didn’t work as documented – or according to reasonable expectations – should now work. * ❗️ you will need to change your code to have the same effect in the future; or a feature will be removed in the future. +## Version 0.5.0 + +### [sklift.models](https://www.uplift-modeling.com/en/v0.5.0/api/models/index.html) + +* 🔥 Add [ClassTransformationReg](https://www.uplift-modeling.com/en/v0.5.0/api/models.html#sklift.models.models.TwoModels) model by [@mcullan](https://github.com/mcullan) and [@ElisovaIra](https://github.com/ElisovaIra). +* 🔨 Add the ability to process a series with different indexes in the [TwoModels](https://www.uplift-modeling.com/en/v0.5.0/api/models.html#sklift.models.models.TwoModels) by [@flashlight101](https://github.com/flashlight101). + +### [sklift.metrics](https://www.uplift-modeling.com/en/v0.5.0/api/index/metrics.html) + +* 🔥 Add new metric [Maximum profit uplift measure](https://www.uplift-modeling.com/en/v0.5.0/api/metrics/max_prof_uplift.html) by [@rooti123](https://github.com/rooti123). + +### [sklift.datasets](https://www.uplift-modeling.com/en/v0.5.0/api/datasets/index.html) + +* 💥 Add cheker based on hash for all datasets by [@flashlight101](https://github.com/flashlight101) +* 📝 Add [scheme](https://www.uplift-modeling.com/en/v0.5.0/api/datasets/fetch_x5.html) of x5 dataframes. + +### Miscellaneous +* 📝 Improve Chinise tags by [@00helloworld](https://github.com/00helloworld) + ## Version 0.4.1 -### [sklift.datasets](https://www.uplift-modeling.com/en/v0.4.0/api/datasets/index.html) +### [sklift.datasets](https://www.uplift-modeling.com/en/v0.4.1/api/datasets/index.html) * 🔨 Fix bug in dataset links. * 📝 Add about a company section diff --git a/docs/index.rst b/docs/index.rst index 6d790d2..571e72d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -153,4 +153,4 @@ Tags **RU**: аплифт моделирование, Uplift модель -**ZH**: 隆起建模,因果推断,因果效应,因果关系,个人治疗效应,真正的电梯,净电梯 \ No newline at end of file +**ZH**: uplift增量建模, 因果推断, 因果效应, 因果关系, 个体干预因果效应, 真实增量, 净增量, 增量建模 diff --git a/docs/requirements.txt b/docs/requirements.txt index b34aed6..79735c3 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,4 @@ -sphinx-autobuild -sphinx_rtd_theme +sphinx==5.1.1 +sphinx-rtd-theme==1.0.0 myst-parser sphinxcontrib-bibtex \ No newline at end of file diff --git a/sklift/__init__.py b/sklift/__init__.py index f0ede3d..2b8877c 100644 --- a/sklift/__init__.py +++ b/sklift/__init__.py @@ -1 +1 @@ -__version__ = '0.4.1' +__version__ = '0.5.0' diff --git a/sklift/datasets/datasets.py b/sklift/datasets/datasets.py index ae5c24b..2838c03 100644 --- a/sklift/datasets/datasets.py +++ b/sklift/datasets/datasets.py @@ -1,5 +1,6 @@ import os import shutil +import hashlib import pandas as pd import requests @@ -95,6 +96,11 @@ def _get_data(data_home, url, dest_subdir, dest_filename, download_if_missing, raise IOError("Dataset missing") return dest_path +def _get_file_hash(csv_path): + with open(csv_path, 'rb') as file_to_check: + data = file_to_check.read() + return hashlib.md5(data).hexdigest() + def clear_data_dir(path=None): """Delete all the content of the data home cache. @@ -170,11 +176,19 @@ def fetch_lenta(data_home=None, dest_subdir=None, download_if_missing=True, retu :func:`.fetch_megafon`: Load and return the MegaFon Uplift Competition dataset (classification). """ - url = 'https://sklift.s3.eu-west-2.amazonaws.com/lenta_dataset.csv.gz' - filename = url.split('/')[-1] - csv_path = _get_data(data_home=data_home, url=url, dest_subdir=dest_subdir, + lenta_metadata = { + 'url': 'https://sklift.s3.eu-west-2.amazonaws.com/lenta_dataset.csv.gz', + 'hash': '6ab28ff0989ed8b8647f530e2e86452f' + } + + filename = lenta_metadata['url'].split('/')[-1] + csv_path = _get_data(data_home=data_home, url=lenta_metadata['url'], dest_subdir=dest_subdir, dest_filename=filename, download_if_missing=download_if_missing) + + if _get_file_hash(csv_path) != lenta_metadata['hash']: + raise ValueError(f"The {filename} file is broken,\ + please clean the directory with the clean_data_dir function, and run the function again") target_col = 'response_att' treatment_col = 'group' @@ -262,11 +276,24 @@ def fetch_x5(data_home=None, dest_subdir=None, download_if_missing=True): :func:`.fetch_megafon`: Load and return the MegaFon Uplift Competition dataset (classification). """ - url_train = 'https://sklift.s3.eu-west-2.amazonaws.com/uplift_train.csv.gz' - file_train = url_train.split('/')[-1] - csv_train_path = _get_data(data_home=data_home, url=url_train, dest_subdir=dest_subdir, + + x5_metadata = { + 'url_train': 'https://sklift.s3.eu-west-2.amazonaws.com/uplift_train.csv.gz', + 'url_clients': 'https://sklift.s3.eu-west-2.amazonaws.com/clients.csv.gz', + 'url_purchases': 'https://sklift.s3.eu-west-2.amazonaws.com/purchases.csv.gz', + 'uplift_hash': '2720bbb659daa9e0989b2777b6a42d19', + 'clients_hash': 'b9cdeb2806b732771de03e819b3354c5', + 'purchases_hash': '48d2de13428e24e8b61d66fef02957a8' + } + file_train = x5_metadata['url_train'].split('/')[-1] + csv_train_path = _get_data(data_home=data_home, url=x5_metadata['url_train'], dest_subdir=dest_subdir, dest_filename=file_train, download_if_missing=download_if_missing) + + if _get_file_hash(csv_train_path) != x5_metadata['uplift_hash']: + raise ValueError(f"The {file_train} file is broken,\ + please clean the directory with the clean_data_dir function, and run the function again") + train = pd.read_csv(csv_train_path) train_features = list(train.columns) @@ -277,19 +304,27 @@ def fetch_x5(data_home=None, dest_subdir=None, download_if_missing=True): train = train.drop([target_col, treatment_col], axis=1) - url_clients = 'https://sklift.s3.eu-west-2.amazonaws.com/clients.csv.gz' - file_clients = url_clients.split('/')[-1] - csv_clients_path = _get_data(data_home=data_home, url=url_clients, dest_subdir=dest_subdir, + file_clients = x5_metadata['url_clients'].split('/')[-1] + csv_clients_path = _get_data(data_home=data_home, url=x5_metadata['url_clients'], dest_subdir=dest_subdir, dest_filename=file_clients, download_if_missing=download_if_missing) + + if _get_file_hash(csv_clients_path) != x5_metadata['clients_hash']: + raise ValueError(f"The {file_clients} file is broken,\ + please clean the directory with the clean_data_dir function, and run the function again") + clients = pd.read_csv(csv_clients_path) clients_features = list(clients.columns) - url_purchases = 'https://sklift.s3.eu-west-2.amazonaws.com/purchases.csv.gz' - file_purchases = url_purchases.split('/')[-1] - csv_purchases_path = _get_data(data_home=data_home, url=url_purchases, dest_subdir=dest_subdir, + file_purchases = x5_metadata['url_purchases'].split('/')[-1] + csv_purchases_path = _get_data(data_home=data_home, url=x5_metadata['url_purchases'], dest_subdir=dest_subdir, dest_filename=file_purchases, download_if_missing=download_if_missing) + + if _get_file_hash(csv_clients_path) != x5_metadata['purchases_hash']: + raise ValueError(f"The {file_purchases} file is broken,\ + please clean the directory with the clean_data_dir function, and run the function again") + purchases = pd.read_csv(csv_purchases_path) purchases_features = list(purchases.columns) @@ -391,16 +426,27 @@ def fetch_criteo(target_col='visit', treatment_col='treatment', data_home=None, raise ValueError(f"The target_col must be an element of {target_cols + ['all']}. " f"Got value target_col={target_col}.") + criteo_metadata = { + 'url': '', + 'criteo_hash': '' + } + if percent10: - url = 'https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo10.csv.gz' + criteo_metadata['url'] = 'https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo10.csv.gz' + criteo_metadata['criteo_hash'] = 'fe159bcee2cea57548e48eb2d7d5d00c' else: - url = "https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo.csv.gz" + criteo_metadata['url'] = "https://criteo-bucket.s3.eu-central-1.amazonaws.com/criteo.csv.gz" + criteo_metadata['criteo_hash'] = 'd2236769ef69e9be52556110102911ec' - filename = url.split('/')[-1] - csv_path = _get_data(data_home=data_home, url=url, dest_subdir=dest_subdir, + filename = criteo_metadata['url'].split('/')[-1] + csv_path = _get_data(data_home=data_home, url=criteo_metadata['url'], dest_subdir=dest_subdir, dest_filename=filename, download_if_missing=download_if_missing) + if _get_file_hash(csv_path) != criteo_metadata['criteo_hash']: + raise ValueError(f"The {filename} file is broken,\ + please clean the directory with the clean_data_dir function, and run the function again") + dtypes = { 'exposure': 'Int8', 'treatment': 'Int8', @@ -497,11 +543,19 @@ def fetch_hillstrom(target_col='visit', data_home=None, dest_subdir=None, downlo raise ValueError(f"The target_col must be an element of {target_cols + ['all']}. " f"Got value target_col={target_col}.") - url = 'https://hillstorm1.s3.us-east-2.amazonaws.com/hillstorm_no_indices.csv.gz' - filename = url.split('/')[-1] - csv_path = _get_data(data_home=data_home, url=url, dest_subdir=dest_subdir, + hillstrom_metadata = { + 'url': 'https://hillstorm1.s3.us-east-2.amazonaws.com/hillstorm_no_indices.csv.gz', + 'hillstrom_hash': 'a68a81291f53a14f4e29002629803ba3' + } + + filename = hillstrom_metadata['url'].split('/')[-1] + csv_path = _get_data(data_home=data_home, url=hillstrom_metadata['url'], dest_subdir=dest_subdir, dest_filename=filename, download_if_missing=download_if_missing) + + if _get_file_hash(csv_path) != hillstrom_metadata['hillstrom_hash']: + raise ValueError(f"The {filename} file is broken,\ + please clean the directory with the clean_data_dir function, and run the function again") treatment_col = 'segment' @@ -582,12 +636,21 @@ def fetch_megafon(data_home=None, dest_subdir=None, download_if_missing=True, :func:`.fetch_hillstrom`: Load and return Kevin Hillstrom Dataset MineThatData (classification or regression). """ - url_train = 'https://sklift.s3.eu-west-2.amazonaws.com/megafon_dataset.csv.gz' - file_train = url_train.split('/')[-1] - csv_train_path = _get_data(data_home=data_home, url=url_train, dest_subdir=dest_subdir, - dest_filename=file_train, + megafon_metadata = { + 'url': 'https://sklift.s3.eu-west-2.amazonaws.com/megafon_dataset.csv.gz', + 'megafon_hash': 'ee8d45a343d4d2cf90bb756c93959ecd' + } + + filename = megafon_metadata['url'].split('/')[-1] + csv_path = _get_data(data_home=data_home, url=megafon_metadata['url'], dest_subdir=dest_subdir, + dest_filename=filename, download_if_missing=download_if_missing) - train = pd.read_csv(csv_train_path) + + if _get_file_hash(csv_path) != megafon_metadata['megafon_hash']: + raise ValueError(f"The {filename} file is broken,\ + please clean the directory with the clean_data_dir function, and run the function again") + + train = pd.read_csv(csv_path) target_col = 'conversion' treatment_col = 'treatment_group' diff --git a/sklift/datasets/descr/x5.rst b/sklift/datasets/descr/x5.rst index ca6553f..f6c789f 100644 --- a/sklift/datasets/descr/x5.rst +++ b/sklift/datasets/descr/x5.rst @@ -17,6 +17,9 @@ Data contains several parts: * clients.csv: general info about clients; * purchases.csv: clients’ purchase history prior to communication. +.. image:: ../../_static/images/x5_table_scheme.png + :alt: X5 table schema + Fields ################ diff --git a/sklift/metrics/__init__.py b/sklift/metrics/__init__.py index 0b67f45..64c10a7 100644 --- a/sklift/metrics/__init__.py +++ b/sklift/metrics/__init__.py @@ -3,7 +3,7 @@ qini_curve, perfect_qini_curve, qini_auc_score, uplift_at_k, response_rate_by_percentile, weighted_average_uplift, uplift_by_percentile, treatment_balance_curve, - average_squared_deviation, make_uplift_scorer + average_squared_deviation, make_uplift_scorer, max_prof_uplift ) __all__ = [ @@ -11,5 +11,5 @@ 'qini_curve', 'perfect_qini_curve', 'qini_auc_score', 'uplift_at_k', 'response_rate_by_percentile', 'weighted_average_uplift', 'uplift_by_percentile', 'treatment_balance_curve', - 'average_squared_deviation', 'make_uplift_scorer' + 'average_squared_deviation', 'make_uplift_scorer', 'max_prof_uplift' ] diff --git a/sklift/metrics/metrics.py b/sklift/metrics/metrics.py index 87b6a74..87078fa 100644 --- a/sklift/metrics/metrics.py +++ b/sklift/metrics/metrics.py @@ -826,3 +826,74 @@ def average_squared_deviation(y_true_train, uplift_train, treatment_train, y_tru strategy=strategy, bins=bins) return np.mean(np.square(uplift_by_percentile_train['uplift'] - uplift_by_percentile_val['uplift'])) + + +def max_prof_uplift(df_sorted, treatment_name, churn_name, pos_outcome, benefit, c_incentive, c_contact, a_cost=0): + """Compute the maximum profit generated from an uplift model decided campaign + + This can be visualised by plotting plt.plot(perc, cumulative_profit) + + Args: + df_sorted (pandas dataframe): dataframe with descending uplift predictions for each customer (i.e. highest 1st) + treatment_name (string): column name of treatment columm (assuming 1 = treated) + churn_name (string): column name of churn column + pos_outcome (int or float): 1 or 0 value in churn column indicating a positive outcome (i.e. purchase = 1, whereas churn = 0) + benefit (int or float): the benefit of retaining a customer (e.g., the average customer lifetime value) + c_incentive (int or float): the cost of the incentive if a customer accepts the offer + c_contact (int or float): the cost of contacting a customer regardless of conversion + a_cost (int or float): the fixed administration cost for the campaign + + Returns: + 1d array-like: the incremental increase in x, for plotting + 1d array-like: the cumulative profit per customer + + References: + Floris Devriendt, Jeroen Berrevoets, Wouter Verbeke. Why you should stop predicting customer churn and start using uplift models. + """ +# VARIABLES + +# n_ct0 no. people not treated +# n_ct1 no. people treated + +# n_y1_ct0 no. people not treated with +ve outcome +# n_y1_ct1 no. people treated with +ve outcome + +# r_y1_ct0 mean of not treated people with +ve outcome +# r_y1_ct1 mean of treated people with +ve outcome + +# cs cumsum() of each variable + + n_ct0 = np.where(df_sorted[treatment_name] == 0, 1, 0) + cs_n_ct0 = pd.Series(n_ct0.cumsum()) + + n_ct1 = np.where(df_sorted[treatment_name] == 1, 1, 0) + cs_n_ct1 = pd.Series(n_ct1.cumsum()) + + if pos_outcome == 0: + n_y1_ct0 = np.where((df_sorted[treatment_name] == 0) & (df_sorted[churn_name] == 0), 1, 0) + n_y1_ct1 = np.where((df_sorted[treatment_name] == 1) & (df_sorted[churn_name] == 0), 1, 0) + + elif pos_outcome == 1: + n_y1_ct0 = np.where((df_sorted[treatment_name] == 0) & (df_sorted[churn_name] == 1), 1, 0) + n_y1_ct1 = np.where((df_sorted[treatment_name] == 1) & (df_sorted[churn_name] == 1), 1, 0) + + cs_n_y1_ct0 = pd.Series(n_y1_ct0.cumsum()) + cs_n_y1_ct1 = pd.Series(n_y1_ct1.cumsum()) + + cs_r_y1_ct0 = (cs_n_y1_ct0 / cs_n_ct0).fillna(0) + cs_r_y1_ct1 = (cs_n_y1_ct1 / cs_n_ct1).fillna(0) + + cs_uplift = cs_r_y1_ct1 - cs_r_y1_ct0 + + # Dataframe of all calculated variables + df = pd.concat([cs_n_ct0,cs_n_ct1,cs_n_y1_ct0,cs_n_y1_ct1, cs_r_y1_ct0, cs_r_y1_ct1, cs_uplift], axis=1) + df.columns = ['cs_n_ct0', 'cs_n_ct1', 'cs_n_y1_ct0', 'cs_n_y1_ct1', 'cs_r_y1_c0', 'cs_r_y1_ct1', 'cs_uplift'] + + x = cs_n_ct0 + cs_n_ct1 + max = cs_n_ct0.max() + cs_n_ct1.max() + + t_profit = (x * cs_uplift * benefit) - (c_incentive * x * cs_r_y1_ct1) - (c_contact * x) - a_cost + perc = x / max + cumulative_profit = t_profit / max + + return perc, cumulative_profit diff --git a/sklift/models/__init__.py b/sklift/models/__init__.py index 05ef032..81290fe 100644 --- a/sklift/models/__init__.py +++ b/sklift/models/__init__.py @@ -1,6 +1,3 @@ -from .models import (SoloModel, ClassTransformation, TwoModels) +from .models import SoloModel, ClassTransformation, ClassTransformationReg, TwoModels -__all__ = [ - 'SoloModel', - 'ClassTransformation', - 'TwoModels'] +__all__ = [SoloModel, ClassTransformation, ClassTransformationReg, TwoModels] diff --git a/sklift/models/models.py b/sklift/models/models.py index 16f3c08..5220608 100644 --- a/sklift/models/models.py +++ b/sklift/models/models.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np import pandas as pd from sklearn.base import BaseEstimator @@ -22,7 +24,7 @@ class SoloModel(BaseEstimator): Args: estimator (estimator object implementing 'fit'): The object to use to fit the data. method (string, ’dummy’ or ’treatment_interaction’, default='dummy'): Specifies the approach: - + * ``'dummy'``: Single model; * ``'treatment_interaction'``: @@ -268,6 +270,137 @@ def predict(self, X): return uplift +class ClassTransformationReg(BaseEstimator): + """aka CATE (Conditional Average Treatment Effect) generating transformation approach for continuous labels. + + Redefine target variable, which indicates that treatment make some impact on target or + did target is negative without treatment: ``Z = Y * (W - p)/(p * (1 - p))``, + + where ``Y`` - target vector, ``W`` - vector of binary communication flags, and ``p`` is a propensity score (the probabilty that each y_i is assigned to the treatment group.). + + Then, train a regressor on ``Z`` to predict uplift. + + Returns uplift predictions and optionally propensity predictions. + + The propensity score can be a scalar value (e.g. p = 0.5), which would mean that every subject has identical probability of being assigned to the treatment group. + + Alternatively, the propensity can be learned using a Classifier model. In this case, the model predicts the probability that a given subject would be assigned to the treatment group. + + Read more in the :ref:`User Guide `. + + Args: + estimator (estimator object implementing 'fit'): The object to use to fit the data. + propensity_val (float): A constant propensity value, which assumes every subject has equal probability of assignment to the treatment group. + propensity_estimator (estimator object with `predict_proba`): The object used to predict the propensity score if `propensity_val` is not given. + + + Example:: + + # import approach + from sklift.models import ClassTransformationReg + # import any estimator adheres to scikit-learn conventions + from sklearn.linear_model import LinearRegression, LogisticRegression + + + # define approach + ct = ClassTransformationReg(estimator=LinearRegression, propensity_estimator=LogisticRegression()) + # fit the model + ct = ct.fit(X_train, y_train, treat_train) + # predict uplift + uplift_ct = ct.predict(X_val) + + References: + Maciej Jaskowski and Szymon Jaroszewicz. Uplift modeling for clinical trial data. + ICML Workshop on Clinical Data Analysis, 2012. + + See Also: + + **Other approaches:** + + * :class:`.SoloModel`: Single model approach. + * :class:`.TwoModels`: Double classifier approach. + * :classL1`.ClassTransformation`: Binary classifier transformation approach. + """ + + def __init__(self, estimator, propensity_val=None, propensity_estimator=None): + + if (propensity_val is None) and (propensity_estimator is None): + raise ValueError('`propensity_val` and `propensity_estimator` cannot both be equal to `None`. Both arguments are currently null.') + elif (propensity_val is not None) and (propensity_estimator is not None): + raise ValueError('Exactly one of (`propensity_val`, `propensity_estimator`) must be None, and the other must be defined. Both arguments are currently non-null.') + + self.estimator = estimator + self.propensity_val = propensity_val + self.propensity_estimator = propensity_estimator + + self._type_of_target = None + + def fit(self, X, y, treatment, estimator_fit_params=None): + """Fit the model according to the given training data. + + Args: + X (array-like, shape (n_samples, n_features)): Training vector, where n_samples is the number of samples and + n_features is the number of features. + y (array-like, shape (n_samples,)): Target vector relative to X. + treatment (array-like, shape (n_samples,)): Binary treatment vector relative to X. + estimator_fit_params (dict, optional): Parameters to pass to the fit method of the estimator. + + Returns: + object: self + """ + + check_consistent_length(X, y, treatment) + check_is_binary(treatment) + self._type_of_target = type_of_target(y) + + if self.propensity_val is not None: + p = self.propensity_val + + elif self.propensity_estimator is not None: + self.propensity_estimator.fit(X, treatment) + p = self.propensity_estimator.predict_proba(X)[:, 1] + + y_mod = y * ((treatment - p) / (p * (1 - p))) + + if estimator_fit_params is None: + estimator_fit_params = {} + + self.estimator.fit(X, y_mod, **estimator_fit_params) + + return self + + + def predict_propensity(self, X): + """Predict propensity values. + + Args: + X (array-like, shape (n_samples, n_features)): Training vector, where n_samples is the number of samples + and n_features is the number of features. + + Returns: + array (shape (n_samples,)): propensity + """ + + if self.propensity_estimator is not None: + return self.propensity_estimator.predict_proba(X)[:, 1] + else: + return self.propensity_val + + def predict(self, X): + """Perform uplift on samples in X. + + Args: + X (array-like, shape (n_samples, n_features)): Training vector, where n_samples is the number of samples + and n_features is the number of features. + + Returns: + array (shape (n_samples,)): uplift + """ + + uplift = self.estimator.predict(X) + return uplift + + class TwoModels(BaseEstimator): """aka naïve approach, or difference score method, or double classifier approach. @@ -381,8 +514,18 @@ def fit(self, X, y, treatment, estimator_trmnt_fit_params=None, estimator_ctrl_f check_is_binary(treatment) self._type_of_target = type_of_target(y) - X_ctrl, y_ctrl = X[treatment == 0], y[treatment == 0] - X_trmnt, y_trmnt = X[treatment == 1], y[treatment == 1] + y_copy = y.copy() + treatment_copy = treatment.copy() + + if (isinstance(X, pd.Series) or isinstance(X, pd.DataFrame)) and isinstance(y_copy, pd.Series) and not X.index.equals(y_copy.index): + y_copy.index = X.index + warnings.warn("Target indexes do not match data indexes, re-indexing has been performed") + if (isinstance(X, pd.Series) or isinstance(X, pd.DataFrame)) and isinstance(treatment_copy, pd.Series) and not X.index.equals(treatment_copy.index): + treatment_copy.index = X.index + warnings.warn("Treatment indexes do not match data indexes, re-indexing has been performed") + + X_ctrl, y_ctrl = X[treatment_copy == 0], y_copy[treatment_copy == 0] + X_trmnt, y_trmnt = X[treatment_copy == 1], y_copy[treatment_copy == 1] if estimator_trmnt_fit_params is None: estimator_trmnt_fit_params = {} diff --git a/sklift/tests/test_models.py b/sklift/tests/test_models.py index 1afa939..2e58281 100644 --- a/sklift/tests/test_models.py +++ b/sklift/tests/test_models.py @@ -1,5 +1,8 @@ +import warnings + import pytest import numpy as np +import pandas as pd from sklearn.linear_model import LogisticRegression, LinearRegression from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler @@ -91,3 +94,16 @@ def test_same_estimator_error(): with pytest.raises(ValueError): TwoModels(est, est) +@pytest.mark.parametrize( + "X, y, treatment", + [ + (pd.DataFrame(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),columns=['a', 'b', 'c'], index=[0,1,2]), + pd.Series(np.array([1, 0, 1]),index=[0,2,3]), pd.Series(np.array([0, 0, 1]),index=[0,1,2])), + (pd.DataFrame(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),columns=['a', 'b', 'c'], index=[0,1,2]), + pd.Series(np.array([1, 0, 1]),index=[0,1,2]), pd.Series(np.array([0, 0, 1]),index=[1,2,3])) + ] +) +def test_input_data(X, y, treatment): + model = TwoModels(LinearRegression(), LinearRegression()) + with pytest.warns(UserWarning): + model.fit(X, y, treatment) \ No newline at end of file