Skip to content

Commit

Permalink
agg granularity to restrict group by
Browse files Browse the repository at this point in the history
  • Loading branch information
lyliyu committed Jun 15, 2023
1 parent 0a21113 commit 4821dc5
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 13 deletions.
11 changes: 8 additions & 3 deletions framework/feature_factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from framework.feature_factory.feature import Feature, FeatureSet, Multiplier
from framework.configobj import ConfigObj
from framework.feature_factory.helpers import Helpers
from framework.feature_factory.agg_granularity import AggregationGranularity
import re
import logging
import datetime
import inspect
from collections import OrderedDict
from typing import List
from enum import Enum

logger = logging.getLogger(__name__)

Expand All @@ -19,7 +21,7 @@ class Feature_Factory():
def __init__(self):
self.helpers = Helpers()

def append_features(self, df: DataFrame, groupBy_cols, feature_sets: List[FeatureSet], withTrendsForFeatures: List[FeatureSet] = None):
def append_features(self, df: DataFrame, groupBy_cols, feature_sets: List[FeatureSet], withTrendsForFeatures: List[FeatureSet] = None, granularityEnum: Enum = None):
"""
Appends features to incoming df. The features columns and groupby cols will be deduped and validated.
If there's a group by, the groupby cols will be applied before appending features.
Expand Down Expand Up @@ -50,10 +52,13 @@ def append_features(self, df: DataFrame, groupBy_cols, feature_sets: List[Featur

# valid_result, undef_cols = self.helpers._validate_col(df, *groupBy_cols)
# assert valid_result, "groupby cols {} are not defined in df columns {}".format(undef_cols, df.columns)
granularity_validator = AggregationGranularity(granularityEnum) if granularityEnum else None
for feature in features:
assert True if ((len(feature.aggs) > 0) and (len(
groupBy_cols) > 0) or feature.agg_func is None) else False, "{} has either aggs or groupBys " \
"but not both, ensure both are present".format(feature.name)
if granularity_validator:
granularity_validator.validate(feature, groupBy_cols)
# feature_cols.append(feature.assembled_column)
# feature_cols.append(F.col(feature.output_alias))
agg_cols += [agg_col for agg_col in feature.aggs]
Expand All @@ -76,7 +81,7 @@ def append_features(self, df: DataFrame, groupBy_cols, feature_sets: List[Featur
# new_df = df.select(*df.columns + feature_cols)
return final_df

def append_catalog(self, df: DataFrame, groupBy_cols, catalog_cls, feature_names = [], withTrendsForFeatures: List[FeatureSet] = None):
def append_catalog(self, df: DataFrame, groupBy_cols, catalog_cls, feature_names = [], withTrendsForFeatures: List[FeatureSet] = None, granularityEnum: Enum = None):
"""
Appends features to incoming df. The features columns and groupby cols will be deduped and validated.
If there's a group by, the groupby cols will be applied before appending features.
Expand All @@ -89,5 +94,5 @@ def append_catalog(self, df: DataFrame, groupBy_cols, catalog_cls, feature_names
# dct = self._get_all_features(catalog_cls)
dct = catalog_cls.get_all_features()
fs = FeatureSet(dct)
return self.append_features(df, groupBy_cols, [fs], withTrendsForFeatures)
return self.append_features(df, groupBy_cols, [fs], withTrendsForFeatures, granularityEnum)

26 changes: 26 additions & 0 deletions framework/feature_factory/agg_granularity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from enum import IntEnum, EnumMeta
from pyspark.sql.column import Column

class AggregationGranularity:
def __init__(self, granularity: EnumMeta) -> None:
assert isinstance(granularity, EnumMeta), "Granularity should be of type Enum."
self.granularity = granularity


def validate(self, feat, groupby_list):
if not feat.agg_granularity:
return None
min_granularity_level = float("inf")
for level in groupby_list:
if isinstance(level, str):
try:
level = self.granularity[level]
except:
print(f"{level} is not part of {self.granularity}")
continue
if isinstance(level, Column):
continue
min_granularity_level = min(min_granularity_level, level.value)
assert min_granularity_level <= feat.agg_granularity.value, f"Required granularity for {feat.name} is {feat.agg_granularity}"


10 changes: 7 additions & 3 deletions framework/feature_factory/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def __init__(self,
_agg_func=None,
_agg_alias:str=None,
_kind="multipliable",
_is_temporary=False):
_is_temporary=False,
_agg_granularity=None):
"""
:param _name: name of the feature
Expand All @@ -47,6 +48,7 @@ def __init__(self,
self.kind = _kind
self.is_temporary = _is_temporary
self.names = None
self.agg_granularity = _agg_granularity

def set_feature_name(self, name: str):
self._name = name
Expand Down Expand Up @@ -135,15 +137,17 @@ def create(cls, base_col: Column,
filter: List[Column] = [],
negative_value=None,
agg_func=None,
agg_alias: str = None):
agg_alias: str = None,
agg_granularity: str = None):

return Feature(
_name = "",
_base_col = base_col,
_filter = filter,
_negative_value = negative_value,
_agg_func = agg_func,
_agg_alias = agg_alias)
_agg_alias = agg_alias,
_agg_granularity = agg_granularity)


class FeatureSet:
Expand Down
31 changes: 24 additions & 7 deletions test/test_feature_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pyspark.sql.types import StructType
from test.local_spark_singleton import SparkSingleton
from framework.feature_factory.catalog import CatalogBase
from enum import IntEnum

class CommonCatalog(CatalogBase):
total_sales = Feature.create(
Expand All @@ -18,15 +19,22 @@ class CommonCatalog(CatalogBase):
total_quants = Feature.create(base_col=f.col("ss_quantity"),
agg_func=f.sum)

class Granularity(IntEnum):
PRODUCT_ID = 1,
PRODUCT_DIVISION = 2,
COUNTRY = 3

class SalesCatalog(CommonCatalog):
_valid_sales_filter = f.col("ss_net_paid") > 0

total_sales = Feature.create(
base_col=CommonCatalog.total_sales,
filter=_valid_sales_filter,
agg_func=f.sum
agg_func=f.sum,
agg_granularity=Granularity.PRODUCT_DIVISION
)



def generate_sales_catalog(CommonCatalog):
class SalesCatalog(CommonCatalog):
_valid_sales_filter = f.col("ss_net_paid") > 0
Expand All @@ -40,10 +48,13 @@ class SalesCatalog(CommonCatalog):

class TestSalesCatalog(unittest.TestCase):
def setUp(self):
with open("test/data/sales_store_schema.json") as f:
sales_schema = StructType.fromJson(json.load(f))
self.sales_df = SparkSingleton.get_instance().read.csv("test/data/sales_store_tpcds.csv", schema=sales_schema, header=True)

with open("test/data/sales_store_schema.json") as fp:
sales_schema = StructType.fromJson(json.load(fp))
df = SparkSingleton.get_instance().read.csv("test/data/sales_store_tpcds.csv", schema=sales_schema, header=True)
self.sales_df = df.withColumn("PRODUCT_ID", f.lit("product"))\
.withColumn("PRODUCT_DIVISION", f.lit("division"))\
.withColumn("COUNTRY", f.lit("country"))

def test_append_catalog(self):
customer_id = f.col("ss_customer_sk").alias("customer_id")
ff = Feature_Factory()
Expand All @@ -57,4 +68,10 @@ def test_common_catalog(self):
salesCatalogClass = generate_sales_catalog(CommonCatalog=CommonCatalog)
df = ff.append_catalog(self.sales_df, [customer_id], salesCatalogClass)
assert df.count() > 0
assert "total_sales" in df.columns and "total_quants" in df.columns
assert "total_sales" in df.columns and "total_quants" in df.columns

def test_granularity(self):
customer_id = f.col("ss_customer_sk").alias("customer_id")
ff = Feature_Factory()
df = ff.append_catalog(self.sales_df, [customer_id, "PRODUCT_ID"], SalesCatalog, granularityEnum=Granularity)
assert df.count() > 0

0 comments on commit 4821dc5

Please sign in to comment.