Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Custom Metrics #56

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion casual_inference/evaluator/aatest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Union

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
from scipy.stats import kstest
from typing_extensions import Self

from ..model import CustomMetric
from ..statistical_testing import t_test
from .base import BaseEvaluator

Expand All @@ -30,7 +33,7 @@ def __init__(self, n_simulation: int = 1000, sample_rate: float = 1.0) -> None:
self.sample_rate = sample_rate

# ignore mypy error temporary, because the "Self" type support on mypy is ongoing. https://github.com/python/mypy/pull/11666
def evaluate(self, data: pd.DataFrame, unit_col: str, metrics: list[str]) -> Self: # type: ignore
def evaluate(self, data: pd.DataFrame, unit_col: str, metrics: list[Union[str, CustomMetric]]) -> Self: # type: ignore
"""split data n times, and calculate statistics n times, then store it as an attribute.

Parameters
Expand Down
5 changes: 3 additions & 2 deletions casual_inference/evaluator/abtest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union

import pandas as pd
import pandas.api.types as pd_types
Expand All @@ -9,6 +9,7 @@
from scipy.stats import chisquare
from typing_extensions import Self

from ..model import CustomMetric
from ..statistical_testing import eval_ttest_significance, t_test
from .base import BaseEvaluator

Expand Down Expand Up @@ -41,7 +42,7 @@ def evaluate(
self,
data: pd.DataFrame,
unit_col: str,
metrics: list[str],
metrics: list[Union[str, CustomMetric]],
variant_col: str = "variant",
segment_col: Optional[str] = None,
) -> Self: # type: ignore
Expand Down
7 changes: 5 additions & 2 deletions casual_inference/evaluator/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from abc import ABC, abstractmethod
from typing import Union

import pandas as pd
import plotly.graph_objs as go
from typing_extensions import Self

from ..model import CustomMetric


class BaseEvaluator(ABC):
def __init__(self) -> None:
self.stats: pd.DataFrame = pd.DataFrame()

@abstractmethod
# ignore mypy error temporary, because the "Self" type support on mypy is ongoing. https://github.com/python/mypy/pull/11666
def evaluate(self, data: pd.DataFrame, unit_col: str, metrics: list[str]) -> Self: # type: ignore
def evaluate(self, data: pd.DataFrame, unit_col: str, metrics: list[Union[str, CustomMetric]]) -> Self: # type: ignore
return self

@abstractmethod
Expand All @@ -26,7 +29,7 @@ def _validate_evaluate_executed(self) -> None:
if self.stats.shape[0] == 0:
raise ValueError("Evaluated statistics haven't been calculated. Please call evaluate() in advance.")

def _validate_passed_data(self, data: pd.DataFrame, unit_col: str, metrics: list[str]) -> None:
def _validate_passed_data(self, data: pd.DataFrame, unit_col: str, metrics: list[Union[str, CustomMetric]]) -> None:
if data.shape[0] != data[unit_col].nunique():
raise ValueError("passed dataframe hasn't been aggregated by the randomization unit.")
if len(metrics) == 0:
Expand Down
7 changes: 6 additions & 1 deletion casual_inference/evaluator/linear_regression.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Union

import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
import statsmodels.formula.api as smf
from typing_extensions import Self

from ..model import CustomMetric
from .base import BaseEvaluator


Expand All @@ -25,7 +28,7 @@ def evaluate(
self,
data: pd.DataFrame,
unit_col: str,
metrics: list[str],
metrics: list[Union[str, CustomMetric]],
treatment_col: str = "treatment",
covariates: list[str] = [],
) -> Self: # type: ignore
Expand Down Expand Up @@ -61,6 +64,8 @@ def evaluate(
covariates_str = "+ " + "+ ".join(covariates)

for metric in metrics:
if isinstance(metric, CustomMetric):
raise ValueError("CustomMetric is not supported in this evaluator.")
model = smf.ols(formula=f"{metric} ~ {treatment_col} {covariates_str}", data=data).fit()
self.models[metric] = model

Expand Down
5 changes: 3 additions & 2 deletions casual_inference/evaluator/samplesize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional
from typing import Optional, Union

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
from typing_extensions import Self

from ..model import CustomMetric
from .base import BaseEvaluator


Expand All @@ -14,7 +15,7 @@ def __init__(self) -> None:
super().__init__()

# ignore mypy error temporary, because the "Self" type support on mypy is ongoing. https://github.com/python/mypy/pull/11666
def evaluate(self, data: pd.DataFrame, unit_col: str, metrics: list[str], n_variant: int = 2) -> Self: # type: ignore
def evaluate(self, data: pd.DataFrame, unit_col: str, metrics: list[Union[str, CustomMetric]], n_variant: int = 2) -> Self: # type: ignore
"""Calculate statistics of metrics and mde with simulating A/B test threshold.

Parameters
Expand Down
8 changes: 8 additions & 0 deletions casual_inference/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass


@dataclass
class CustomMetric:
name: str
denominator: str
numerator: str
16 changes: 12 additions & 4 deletions casual_inference/statistical_testing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import Union

import numpy as np
import pandas as pd
from scipy.stats import t, ttest_ind_from_stats

from .model import CustomMetric


def t_test(data: pd.DataFrame, unit_col: str, variant_col: str, metrics: list[str]) -> pd.DataFrame:
def t_test(
data: pd.DataFrame, unit_col: str, variant_col: str, metrics: list[Union[str, CustomMetric]]
) -> pd.DataFrame:
"""_summary_

Parameters
Expand All @@ -30,14 +36,16 @@ def t_test(data: pd.DataFrame, unit_col: str, variant_col: str, metrics: list[st
raise ValueError("metrics hasn't been specified.")
if data[variant_col].min() != 1:
raise ValueError("the control variant seems not to exist.")
normal_metrics = [metric for metric in metrics if isinstance(metric, CustomMetric) == False]
# custom_metrics = [metric for metric in metrics if isinstance(metric, CustomMetric)]
means = (
data.groupby(variant_col)[metrics].mean().stack().reset_index().rename(columns={"level_1": "metric", 0: "mean"})
data.groupby(variant_col)[normal_metrics].mean().stack().reset_index().rename(columns={"level_1": "metric", 0: "mean"})
)
vars = (
data.groupby(variant_col)[metrics].var().stack().reset_index().rename(columns={"level_1": "metric", 0: "var"})
data.groupby(variant_col)[normal_metrics].var().stack().reset_index().rename(columns={"level_1": "metric", 0: "var"})
)
counts = (
data.groupby(variant_col)[metrics]
data.groupby(variant_col)[normal_metrics]
.count()
.stack()
.reset_index()
Expand Down
Loading
Loading