From 32e24d6382a973452dd20611c22358cd8d5976bd Mon Sep 17 00:00:00 2001 From: Mayara Moromisato <44944954+moromimay@users.noreply.github.com> Date: Fri, 19 Feb 2021 10:18:09 -0300 Subject: [PATCH] [MLOP-635] Rebase Incremental Job/Interval Run branch for test on selected feature sets (#278) * Add interval branch modifications. * Add interval_runs notebook. * Add tests. * Apply style (black, flack8 and mypy). * Fix tests. * Change version to create package dev. --- .gitignore | 1 + CHANGELOG.md | 6 + butterfree/clients/cassandra_client.py | 41 +- butterfree/clients/spark_client.py | 58 +- butterfree/configs/db/metastore_config.py | 28 + butterfree/configs/environment.py | 4 +- butterfree/constants/window_definitions.py | 16 + butterfree/dataframe_service/__init__.py | 9 +- .../dataframe_service/incremental_strategy.py | 116 + butterfree/dataframe_service/partitioning.py | 25 + butterfree/extract/readers/file_reader.py | 12 +- butterfree/extract/readers/reader.py | 88 +- butterfree/extract/source.py | 24 +- butterfree/hooks/__init__.py | 5 + butterfree/hooks/hook.py | 20 + butterfree/hooks/hookable_component.py | 148 ++ .../hooks/schema_compatibility/__init__.py | 9 + ...ssandra_table_schema_compatibility_hook.py | 58 + .../spark_table_schema_compatibility_hook.py | 46 + butterfree/load/sink.py | 13 +- .../historical_feature_store_writer.py | 113 +- .../writers/online_feature_store_writer.py | 50 +- butterfree/load/writers/writer.py | 21 +- butterfree/pipelines/feature_set_pipeline.py | 56 +- .../transform/aggregated_feature_set.py | 49 +- butterfree/transform/feature_set.py | 38 +- butterfree/transform/utils/window_spec.py | 20 +- examples/interval_runs/interval_runs.ipynb | 2152 +++++++++++++++++ setup.py | 2 +- .../integration/butterfree/load/test_sink.py | 35 +- .../butterfree/pipelines/conftest.py | 202 ++ .../pipelines/test_feature_set_pipeline.py | 311 ++- .../butterfree/transform/conftest.py | 55 + .../transform/test_aggregated_feature_set.py | 50 + .../butterfree/transform/test_feature_set.py | 44 + tests/unit/butterfree/clients/conftest.py | 11 +- .../clients/test_cassandra_client.py | 4 +- .../butterfree/clients/test_spark_client.py | 69 +- .../butterfree/dataframe_service/conftest.py | 14 + .../test_incremental_srategy.py | 70 + .../dataframe_service/test_partitioning.py | 20 + tests/unit/butterfree/extract/conftest.py | 55 + .../extract/readers/test_file_reader.py | 10 +- .../butterfree/extract/readers/test_reader.py | 58 + tests/unit/butterfree/hooks/__init__.py | 0 .../hooks/schema_compatibility/__init__.py | 0 ...ssandra_table_schema_compatibility_hook.py | 49 + ...t_spark_table_schema_compatibility_hook.py | 53 + .../hooks/test_hookable_component.py | 107 + tests/unit/butterfree/load/conftest.py | 25 + tests/unit/butterfree/load/test_sink.py | 34 +- .../test_historical_feature_store_writer.py | 144 +- .../test_online_feature_store_writer.py | 41 +- tests/unit/butterfree/pipelines/conftest.py | 63 + .../pipelines/test_feature_set_pipeline.py | 182 +- tests/unit/butterfree/transform/conftest.py | 82 + .../transform/test_aggregated_feature_set.py | 68 +- .../butterfree/transform/test_feature_set.py | 43 +- 58 files changed, 4738 insertions(+), 389 deletions(-) create mode 100644 butterfree/constants/window_definitions.py create mode 100644 butterfree/dataframe_service/incremental_strategy.py create mode 100644 butterfree/dataframe_service/partitioning.py create mode 100644 butterfree/hooks/__init__.py create mode 100644 butterfree/hooks/hook.py create mode 100644 butterfree/hooks/hookable_component.py create mode 100644 butterfree/hooks/schema_compatibility/__init__.py create mode 100644 butterfree/hooks/schema_compatibility/cassandra_table_schema_compatibility_hook.py create mode 100644 butterfree/hooks/schema_compatibility/spark_table_schema_compatibility_hook.py create mode 100644 examples/interval_runs/interval_runs.ipynb create mode 100644 tests/unit/butterfree/dataframe_service/test_incremental_srategy.py create mode 100644 tests/unit/butterfree/dataframe_service/test_partitioning.py create mode 100644 tests/unit/butterfree/hooks/__init__.py create mode 100644 tests/unit/butterfree/hooks/schema_compatibility/__init__.py create mode 100644 tests/unit/butterfree/hooks/schema_compatibility/test_cassandra_table_schema_compatibility_hook.py create mode 100644 tests/unit/butterfree/hooks/schema_compatibility/test_spark_table_schema_compatibility_hook.py create mode 100644 tests/unit/butterfree/hooks/test_hookable_component.py create mode 100644 tests/unit/butterfree/pipelines/conftest.py diff --git a/.gitignore b/.gitignore index 72b591f3..62434612 100644 --- a/.gitignore +++ b/.gitignore @@ -47,6 +47,7 @@ coverage.xml *.cover .hypothesis/ *cov.xml +test_folder/ # Translations *.mo diff --git a/CHANGELOG.md b/CHANGELOG.md index 72994621..679e9834 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,11 +5,17 @@ Preferably use **Added**, **Changed**, **Removed** and **Fixed** topics in each ## [Unreleased] ### Added +* [MLOP-636] Create migration classes ([#282](https://github.com/quintoandar/butterfree/pull/282)) + +## [1.1.3](https://github.com/quintoandar/butterfree/releases/tag/1.1.3) +### Added * [MLOP-599] Apply mypy to ButterFree ([#273](https://github.com/quintoandar/butterfree/pull/273)) ### Changed * [MLOP-634] Butterfree dev workflow, set triggers for branches staging and master ([#280](https://github.com/quintoandar/butterfree/pull/280)) * Keep milliseconds when using 'from_ms' argument in timestamp feature ([#284](https://github.com/quintoandar/butterfree/pull/284)) +* [MLOP-633] Butterfree dev workflow, update documentation ([#281](https://github.com/quintoandar/butterfree/commit/74278986a49f1825beee0fd8df65a585764e5524)) +* [MLOP-632] Butterfree dev workflow, automate release description ([#279](https://github.com/quintoandar/butterfree/commit/245eaa594846166972241b03fddc61ee5117b1f7)) ### Fixed * Change trigger for pipeline staging ([#287](https://github.com/quintoandar/butterfree/pull/287)) diff --git a/butterfree/clients/cassandra_client.py b/butterfree/clients/cassandra_client.py index 1e541688..938d4e4d 100644 --- a/butterfree/clients/cassandra_client.py +++ b/butterfree/clients/cassandra_client.py @@ -33,33 +33,31 @@ class CassandraClient(AbstractClient): """Cassandra Client. Attributes: - cassandra_user: username to use in connection. - cassandra_password: password to use in connection. - cassandra_key_space: key space used in connection. - cassandra_host: cassandra endpoint used in connection. + user: username to use in connection. + password: password to use in connection. + keyspace: key space used in connection. + host: cassandra endpoint used in connection. """ def __init__( self, - cassandra_host: List[str], - cassandra_key_space: str, - cassandra_user: Optional[str] = None, - cassandra_password: Optional[str] = None, + host: List[str], + keyspace: str, + user: Optional[str] = None, + password: Optional[str] = None, ) -> None: - self.cassandra_host = cassandra_host - self.cassandra_key_space = cassandra_key_space - self.cassandra_user = cassandra_user - self.cassandra_password = cassandra_password + self.host = host + self.keyspace = keyspace + self.user = user + self.password = password self._session: Optional[Session] = None @property def conn(self, *, ssl_path: str = None) -> Session: # type: ignore """Establishes a Cassandra connection.""" auth_provider = ( - PlainTextAuthProvider( - username=self.cassandra_user, password=self.cassandra_password - ) - if self.cassandra_user is not None + PlainTextAuthProvider(username=self.user, password=self.password) + if self.user is not None else None ) ssl_opts = ( @@ -73,12 +71,12 @@ def conn(self, *, ssl_path: str = None) -> Session: # type: ignore ) cluster = Cluster( - contact_points=self.cassandra_host, + contact_points=self.host, auth_provider=auth_provider, ssl_options=ssl_opts, load_balancing_policy=RoundRobinPolicy(), ) - self._session = cluster.connect(self.cassandra_key_space) + self._session = cluster.connect(self.keyspace) self._session.row_factory = dict_factory return self._session @@ -106,7 +104,7 @@ def get_schema(self, table: str) -> List[Dict[str, str]]: """ query = ( f"SELECT column_name, type FROM system_schema.columns " # noqa - f"WHERE keyspace_name = '{self.cassandra_key_space}' " # noqa + f"WHERE keyspace_name = '{self.keyspace}' " # noqa f" AND table_name = '{table}';" # noqa ) @@ -114,8 +112,7 @@ def get_schema(self, table: str) -> List[Dict[str, str]]: if not response: raise RuntimeError( - f"No columns found for table: {table}" - f"in key space: {self.cassandra_key_space}" + f"No columns found for table: {table}" f"in key space: {self.keyspace}" ) return response @@ -143,7 +140,7 @@ def _get_create_table_query( else: columns_str = joined_parsed_columns - query = f"CREATE TABLE {self.cassandra_key_space}.{table} " f"({columns_str}); " + query = f"CREATE TABLE {self.keyspace}.{table} " f"({columns_str}); " return query diff --git a/butterfree/clients/spark_client.py b/butterfree/clients/spark_client.py index 0a8c717c..0f0113e2 100644 --- a/butterfree/clients/spark_client.py +++ b/butterfree/clients/spark_client.py @@ -34,9 +34,10 @@ def conn(self) -> SparkSession: def read( self, format: str, - options: Dict[str, Any], + path: Optional[Union[str, List[str]]] = None, schema: Optional[StructType] = None, stream: bool = False, + **options: Any, ) -> DataFrame: """Use the SparkSession.read interface to load data into a dataframe. @@ -45,9 +46,10 @@ def read( Args: format: string with the format to be used by the DataframeReader. - options: options to setup the DataframeReader. + path: optional string or a list of string for file-system. stream: flag to indicate if data must be read in stream mode. schema: an optional pyspark.sql.types.StructType for the input schema. + options: options to setup the DataframeReader. Returns: Dataframe @@ -55,14 +57,16 @@ def read( """ if not isinstance(format, str): raise ValueError("format needs to be a string with the desired read format") - if not isinstance(options, dict): - raise ValueError("options needs to be a dict with the setup configurations") + if not isinstance(path, (str, list)): + raise ValueError("path needs to be a string or a list of string") df_reader: Union[ DataStreamReader, DataFrameReader ] = self.conn.readStream if stream else self.conn.read + df_reader = df_reader.schema(schema) if schema else df_reader - return df_reader.format(format).options(**options).load() + + return df_reader.format(format).load(path, **options) # type: ignore def read_table(self, table: str, database: str = None) -> DataFrame: """Use the SparkSession.read interface to read a metastore table. @@ -223,3 +227,47 @@ def create_temporary_view(self, dataframe: DataFrame, name: str) -> Any: if not dataframe.isStreaming: return dataframe.createOrReplaceTempView(name) return dataframe.writeStream.format("memory").queryName(name).start() + + def add_table_partitions( + self, partitions: List[Dict[str, Any]], table: str, database: str = None + ) -> None: + """Add partitions to an existing table. + + Args: + partitions: partitions to add to the table. + It's expected a list of partition dicts to add to the table. + Example: `[{"year": 2020, "month": 8, "day": 14}, ...]` + table: table to add the partitions. + database: name of the database where the table is saved. + """ + for partition_dict in partitions: + if not all( + ( + isinstance(key, str) + and (isinstance(value, str) or isinstance(value, int)) + ) + for key, value in partition_dict.items() + ): + raise ValueError( + "Partition keys must be column names " + "and values must be string or int." + ) + + database_expr = f"`{database}`." if database else "" + key_values_expr = [ + ", ".join( + [ + "{} = {}".format(k, v) + if not isinstance(v, str) + else "{} = '{}'".format(k, v) + for k, v in partition.items() + ] + ) + for partition in partitions + ] + partitions_expr = " ".join(f"PARTITION ( {expr} )" for expr in key_values_expr) + command = ( + f"ALTER TABLE {database_expr}`{table}` ADD IF NOT EXISTS {partitions_expr}" + ) + + self.conn.sql(command) diff --git a/butterfree/configs/db/metastore_config.py b/butterfree/configs/db/metastore_config.py index d94b792c..a3b315d5 100644 --- a/butterfree/configs/db/metastore_config.py +++ b/butterfree/configs/db/metastore_config.py @@ -3,8 +3,11 @@ import os from typing import Any, Dict, List, Optional +from pyspark.sql import DataFrame + from butterfree.configs import environment from butterfree.configs.db import AbstractWriteConfig +from butterfree.dataframe_service import extract_partition_values class MetastoreConfig(AbstractWriteConfig): @@ -87,6 +90,31 @@ def get_options(self, key: str) -> Dict[Optional[str], Optional[str]]: "path": os.path.join(f"{self.file_system}://{self.path}/", key), } + def get_path_with_partitions(self, key: str, dataframe: DataFrame) -> List: + """Get options for AWS S3 from partitioned parquet file. + + Options will be a dictionary with the write and read configuration for + Spark to AWS S3. + + Args: + key: path to save data into AWS S3 bucket. + dataframe: spark dataframe containing data from a feature set. + + Returns: + A list of string for file-system backed data sources. + """ + path_list = [] + dataframe_values = extract_partition_values( + dataframe, partition_columns=["year", "month", "day"] + ) + for row in dataframe_values: + path_list.append( + f"{self.file_system}://{self.path}/{key}/year={row['year']}/" + f"month={row['month']}/day={row['day']}" + ) + + return path_list + def translate(self, schema: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Translate feature set spark schema to the corresponding database.""" pass diff --git a/butterfree/configs/environment.py b/butterfree/configs/environment.py index 6f5accbc..f98a7a01 100644 --- a/butterfree/configs/environment.py +++ b/butterfree/configs/environment.py @@ -35,8 +35,8 @@ def get_variable(variable_name: str, default_value: str = None) -> Optional[str] """Gets an environment variable. The variable comes from it's explicitly declared value in the running - environment or from the default value declared in the environment.yaml - specification or from the default_value. + environment or from the default value declared in specification or from the + default_value. Args: variable_name: environment variable name. diff --git a/butterfree/constants/window_definitions.py b/butterfree/constants/window_definitions.py new file mode 100644 index 00000000..560904f7 --- /dev/null +++ b/butterfree/constants/window_definitions.py @@ -0,0 +1,16 @@ +"""Allowed windows units and lengths in seconds.""" + +ALLOWED_WINDOWS = { + "second": 1, + "seconds": 1, + "minute": 60, + "minutes": 60, + "hour": 3600, + "hours": 3600, + "day": 86400, + "days": 86400, + "week": 604800, + "weeks": 604800, + "year": 29030400, + "years": 29030400, +} diff --git a/butterfree/dataframe_service/__init__.py b/butterfree/dataframe_service/__init__.py index 5116261d..c227dae2 100644 --- a/butterfree/dataframe_service/__init__.py +++ b/butterfree/dataframe_service/__init__.py @@ -1,4 +1,11 @@ """Dataframe optimization components regarding Butterfree.""" +from butterfree.dataframe_service.incremental_strategy import IncrementalStrategy +from butterfree.dataframe_service.partitioning import extract_partition_values from butterfree.dataframe_service.repartition import repartition_df, repartition_sort_df -__all__ = ["repartition_df", "repartition_sort_df"] +__all__ = [ + "extract_partition_values", + "IncrementalStrategy", + "repartition_df", + "repartition_sort_df", +] diff --git a/butterfree/dataframe_service/incremental_strategy.py b/butterfree/dataframe_service/incremental_strategy.py new file mode 100644 index 00000000..6554d3b7 --- /dev/null +++ b/butterfree/dataframe_service/incremental_strategy.py @@ -0,0 +1,116 @@ +"""IncrementalStrategy entity.""" + +from __future__ import annotations + +from pyspark.sql import DataFrame + + +class IncrementalStrategy: + """Define an incremental strategy to be used on data sources. + + Entity responsible for defining a column expression that will be used to + filter the original data source. The purpose is to get only the data related + to a specific pipeline execution time interval. + + Attributes: + column: column expression on which incremental filter will be applied. + The expression need to result on a date or timestamp format, so the + filter can properly work with the defined upper and lower bounds. + """ + + def __init__(self, column: str = None): + self.column = column + + def from_milliseconds(self, column_name: str) -> IncrementalStrategy: + """Create a column expression from ts column defined as milliseconds. + + Args: + column_name: column name where the filter will be applied. + + Returns: + `IncrementalStrategy` with the defined column expression. + """ + return IncrementalStrategy(column=f"from_unixtime({column_name}/ 1000.0)") + + def from_string(self, column_name: str, mask: str = None) -> IncrementalStrategy: + """Create a column expression from ts column defined as a simple string. + + Args: + column_name: column name where the filter will be applied. + mask: mask defining the date/timestamp format on the string. + + Returns: + `IncrementalStrategy` with the defined column expression. + """ + return IncrementalStrategy(column=f"to_date({column_name}, '{mask}')") + + def from_year_month_day_partitions( + self, + year_column: str = "year", + month_column: str = "month", + day_column: str = "day", + ) -> IncrementalStrategy: + """Create a column expression from year, month and day partitions. + + Args: + year_column: column name from the year partition. + month_column: column name from the month partition. + day_column: column name from the day partition. + + Returns: + `IncrementalStrategy` with the defined column expression. + """ + return IncrementalStrategy( + column=f"concat(string({year_column}), " + f"'-', string({month_column}), " + f"'-', string({day_column}))" + ) + + def get_expression(self, start_date: str = None, end_date: str = None) -> str: + """Get the incremental filter expression using the defined dates. + + Both arguments can be set to defined a specific date interval, but it's + only necessary to set one of the arguments for this method to work. + + Args: + start_date: date lower bound to use in the filter. + end_date: date upper bound to use in the filter. + + Returns: + Filter expression based on defined column and bounds. + + Raises: + ValuerError: If both arguments, start_date and end_date, are None. + ValueError: If the column expression was not defined. + """ + if not self.column: + raise ValueError("column parameter can't be None") + if not (start_date or end_date): + raise ValueError("Both arguments start_date and end_date can't be None.") + if start_date: + expression = f"date({self.column}) >= date('{start_date}')" + if end_date: + expression += f" and date({self.column}) <= date('{end_date}')" + return expression + return f"date({self.column}) <= date('{end_date}')" + + def filter_with_incremental_strategy( + self, dataframe: DataFrame, start_date: str = None, end_date: str = None + ) -> DataFrame: + """Filters the dataframe according to the date boundaries. + + Args: + dataframe: dataframe that will be filtered. + start_date: date lower bound to use in the filter. + end_date: date upper bound to use in the filter. + + Returns: + Filtered dataframe based on defined time boundaries. + """ + return ( + dataframe.where( + self.get_expression(start_date=start_date, end_date=end_date) + ) + if start_date or end_date + else dataframe + ) diff --git a/butterfree/dataframe_service/partitioning.py b/butterfree/dataframe_service/partitioning.py new file mode 100644 index 00000000..21e9b0ab --- /dev/null +++ b/butterfree/dataframe_service/partitioning.py @@ -0,0 +1,25 @@ +"""Module defining partitioning methods.""" + +from typing import Any, Dict, List + +from pyspark.sql import DataFrame + + +def extract_partition_values( + dataframe: DataFrame, partition_columns: List[str] +) -> List[Dict[str, Any]]: + """Extract distinct partition values from a given dataframe. + + Args: + dataframe: dataframe from where to extract partition values. + partition_columns: name of partition columns presented on the dataframe. + + Returns: + distinct partition values. + """ + return ( + dataframe.select(*partition_columns) + .distinct() + .rdd.map(lambda row: row.asDict(True)) + .collect() + ) diff --git a/butterfree/extract/readers/file_reader.py b/butterfree/extract/readers/file_reader.py index 17f68f1c..8cf15599 100644 --- a/butterfree/extract/readers/file_reader.py +++ b/butterfree/extract/readers/file_reader.py @@ -87,9 +87,7 @@ def __init__( self.path = path self.format = format self.schema = schema - self.options = dict( - {"path": self.path}, **format_options if format_options else {} - ) + self.options = dict(format_options if format_options else {}) self.stream = stream def consume(self, client: SparkClient) -> DataFrame: @@ -106,11 +104,15 @@ def consume(self, client: SparkClient) -> DataFrame: """ schema = ( - client.read(format=self.format, options=self.options,).schema + client.read(format=self.format, path=self.path, **self.options).schema if (self.stream and not self.schema) else self.schema ) return client.read( - format=self.format, options=self.options, schema=schema, stream=self.stream, + format=self.format, + schema=schema, + stream=self.stream, + path=self.path, + **self.options, ) diff --git a/butterfree/extract/readers/reader.py b/butterfree/extract/readers/reader.py index 78be2823..597c870f 100644 --- a/butterfree/extract/readers/reader.py +++ b/butterfree/extract/readers/reader.py @@ -2,14 +2,16 @@ from abc import ABC, abstractmethod from functools import reduce -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional from pyspark.sql import DataFrame from butterfree.clients import SparkClient +from butterfree.dataframe_service import IncrementalStrategy +from butterfree.hooks import HookableComponent -class Reader(ABC): +class Reader(ABC, HookableComponent): """Abstract base class for Readers. Attributes: @@ -19,9 +21,11 @@ class Reader(ABC): """ - def __init__(self, id: str): + def __init__(self, id: str, incremental_strategy: IncrementalStrategy = None): + super().__init__() self.id = id self.transformations: List[Dict[str, Any]] = [] + self.incremental_strategy = incremental_strategy def with_( self, transformer: Callable[..., DataFrame], *args: Any, **kwargs: Any @@ -48,14 +52,19 @@ def with_( self.transformations.append(new_transformation) return self - def _apply_transformations(self, df: DataFrame) -> Any: - return reduce( - lambda result_df, transformation: transformation["transformer"]( - result_df, *transformation["args"], **transformation["kwargs"] - ), - self.transformations, - df, - ) + def with_incremental_strategy( + self, incremental_strategy: IncrementalStrategy + ) -> "Reader": + """Define the incremental strategy for the Reader. + + Args: + incremental_strategy: definition of the incremental strategy. + + Returns: + Reader with defined incremental strategy. + """ + self.incremental_strategy = incremental_strategy + return self @abstractmethod def consume(self, client: SparkClient) -> DataFrame: @@ -70,24 +79,61 @@ def consume(self, client: SparkClient) -> DataFrame: :return: Spark dataframe """ - def build(self, client: SparkClient, columns: List[Any] = None) -> None: + def build( + self, + client: SparkClient, + columns: List[Any] = None, + start_date: str = None, + end_date: str = None, + ) -> None: """Register the data got from the reader in the Spark metastore. Create a temporary view in Spark metastore referencing the data extracted from the target origin after the application of all the defined pre-processing transformations. + The arguments start_date and end_date are going to be use only when there + is a defined `IncrementalStrategy` on the `Reader`. + Args: client: client responsible for connecting to Spark session. - columns: list of tuples for renaming/filtering the dataset. + columns: list of tuples for selecting/renaming columns on the df. + start_date: lower bound to use in the filter expression. + end_date: upper bound to use in the filter expression. """ - transformed_df = self._apply_transformations(self.consume(client)) - - if columns: - select_expression = [] - for old_expression, new_column_name in columns: - select_expression.append(f"{old_expression} as {new_column_name}") - transformed_df = transformed_df.selectExpr(*select_expression) + column_selection_df = self._select_columns(columns, client) + transformed_df = self._apply_transformations(column_selection_df) + + if self.incremental_strategy: + transformed_df = self.incremental_strategy.filter_with_incremental_strategy( + transformed_df, start_date, end_date + ) + + post_hook_df = self.run_post_hooks(transformed_df) + + post_hook_df.createOrReplaceTempView(self.id) + + def _select_columns( + self, columns: Optional[List[Any]], client: SparkClient + ) -> DataFrame: + df = self.consume(client) + return df.selectExpr( + *( + [ + f"{old_expression} as {new_column_name}" + for old_expression, new_column_name in columns + ] + if columns + else df.columns + ) + ) - transformed_df.createOrReplaceTempView(self.id) + def _apply_transformations(self, df: DataFrame) -> DataFrame: + return reduce( + lambda result_df, transformation: transformation["transformer"]( + result_df, *transformation["args"], **transformation["kwargs"] + ), + self.transformations, + df, + ) diff --git a/butterfree/extract/source.py b/butterfree/extract/source.py index 00ac9e43..6d905c6b 100644 --- a/butterfree/extract/source.py +++ b/butterfree/extract/source.py @@ -6,9 +6,10 @@ from butterfree.clients import SparkClient from butterfree.extract.readers.reader import Reader +from butterfree.hooks import HookableComponent -class Source: +class Source(HookableComponent): """The definition of the the entry point data for the ETL pipeline. A FeatureSet (the next step in the pipeline) expects a single dataframe as @@ -51,31 +52,44 @@ class Source: """ def __init__(self, readers: List[Reader], query: str) -> None: + super().__init__() + self.enable_pre_hooks = False self.readers = readers self.query = query - def construct(self, client: SparkClient) -> DataFrame: + def construct( + self, client: SparkClient, start_date: str = None, end_date: str = None + ) -> DataFrame: """Construct an entry point dataframe for a feature set. This method will assemble multiple readers, by building each one and - querying them using a Spark SQL. + querying them using a Spark SQL. It's important to highlight that in + order to filter a dataframe regarding date boundaries, it's important + to define a IncrementalStrategy, otherwise your data will not be filtered. + Besides, both start and end dates parameters are optional. After that, there's the caching of the dataframe, however since cache() in Spark is lazy, an action is triggered in order to force persistence. Args: client: client responsible for connecting to Spark session. + start_date: user defined start date for filtering. + end_date: user defined end date for filtering. Returns: DataFrame with the query result against all readers. """ for reader in self.readers: - reader.build(client) # create temporary views for each reader + reader.build( + client=client, start_date=start_date, end_date=end_date + ) # create temporary views for each reader dataframe = client.sql(self.query) if not dataframe.isStreaming: dataframe.cache().count() - return dataframe + post_hook_df = self.run_post_hooks(dataframe) + + return post_hook_df diff --git a/butterfree/hooks/__init__.py b/butterfree/hooks/__init__.py new file mode 100644 index 00000000..90bedeb2 --- /dev/null +++ b/butterfree/hooks/__init__.py @@ -0,0 +1,5 @@ +"""Holds Hooks definitions.""" +from butterfree.hooks.hook import Hook +from butterfree.hooks.hookable_component import HookableComponent + +__all__ = ["Hook", "HookableComponent"] diff --git a/butterfree/hooks/hook.py b/butterfree/hooks/hook.py new file mode 100644 index 00000000..f7d8c562 --- /dev/null +++ b/butterfree/hooks/hook.py @@ -0,0 +1,20 @@ +"""Hook abstract class entity.""" + +from abc import ABC, abstractmethod + +from pyspark.sql import DataFrame + + +class Hook(ABC): + """Definition of a hook function to call on a Dataframe.""" + + @abstractmethod + def run(self, dataframe: DataFrame) -> DataFrame: + """Run interface for Hook. + + Args: + dataframe: dataframe to use in the Hook. + + Returns: + dataframe result from the Hook. + """ diff --git a/butterfree/hooks/hookable_component.py b/butterfree/hooks/hookable_component.py new file mode 100644 index 00000000..d89babce --- /dev/null +++ b/butterfree/hooks/hookable_component.py @@ -0,0 +1,148 @@ +"""Definition of hookable component.""" + +from __future__ import annotations + +from typing import List + +from pyspark.sql import DataFrame + +from butterfree.hooks.hook import Hook + + +class HookableComponent: + """Defines a component with the ability to hold pre and post hook functions. + + All main module of Butterfree have a common object that enables their integration: + dataframes. Spark's dataframe is the glue that enables the transmission of data + between the main modules. Hooks have a simple interface, they are functions that + accepts a dataframe and outputs a dataframe. These Hooks can be triggered before or + after the main execution of a component. + + Components from Butterfree that inherit HookableComponent entity, are components + that can define a series of steps to occur before or after the execution of their + main functionality. + + Attributes: + pre_hooks: function steps to trigger before component main functionality. + post_hooks: function steps to trigger after component main functionality. + enable_pre_hooks: property to indicate if the component can define pre_hooks. + enable_post_hooks: property to indicate if the component can define post_hooks. + """ + + def __init__(self) -> None: + self.pre_hooks = [] + self.post_hooks = [] + self.enable_pre_hooks = True + self.enable_post_hooks = True + + @property + def pre_hooks(self) -> List[Hook]: + """Function steps to trigger before component main functionality.""" + return self.__pre_hook + + @pre_hooks.setter + def pre_hooks(self, value: List[Hook]) -> None: + if not isinstance(value, list): + raise ValueError("pre_hooks should be a list of Hooks.") + if not all(isinstance(item, Hook) for item in value): + raise ValueError( + "All items on pre_hooks list should be an instance of Hook." + ) + self.__pre_hook = value + + @property + def post_hooks(self) -> List[Hook]: + """Function steps to trigger after component main functionality.""" + return self.__post_hook + + @post_hooks.setter + def post_hooks(self, value: List[Hook]) -> None: + if not isinstance(value, list): + raise ValueError("post_hooks should be a list of Hooks.") + if not all(isinstance(item, Hook) for item in value): + raise ValueError( + "All items on post_hooks list should be an instance of Hook." + ) + self.__post_hook = value + + @property + def enable_pre_hooks(self) -> bool: + """Property to indicate if the component can define pre_hooks.""" + return self.__enable_pre_hooks + + @enable_pre_hooks.setter + def enable_pre_hooks(self, value: bool) -> None: + if not isinstance(value, bool): + raise ValueError("enable_pre_hooks accepts only boolean values.") + self.__enable_pre_hooks = value + + @property + def enable_post_hooks(self) -> bool: + """Property to indicate if the component can define post_hooks.""" + return self.__enable_post_hooks + + @enable_post_hooks.setter + def enable_post_hooks(self, value: bool) -> None: + if not isinstance(value, bool): + raise ValueError("enable_post_hooks accepts only boolean values.") + self.__enable_post_hooks = value + + def add_pre_hook(self, *hooks: Hook) -> HookableComponent: + """Add a pre-hook steps to the component. + + Args: + hooks: Hook steps to add to pre_hook list. + + Returns: + Component with the Hook inserted in pre_hook list. + + Raises: + ValueError: if the component does not accept pre-hooks. + """ + if not self.enable_pre_hooks: + raise ValueError("This component does not enable adding pre-hooks") + self.pre_hooks += list(hooks) + return self + + def add_post_hook(self, *hooks: Hook) -> HookableComponent: + """Add a post-hook steps to the component. + + Args: + hooks: Hook steps to add to post_hook list. + + Returns: + Component with the Hook inserted in post_hook list. + + Raises: + ValueError: if the component does not accept post-hooks. + """ + if not self.enable_post_hooks: + raise ValueError("This component does not enable adding post-hooks") + self.post_hooks += list(hooks) + return self + + def run_pre_hooks(self, dataframe: DataFrame) -> DataFrame: + """Run all defined pre-hook steps from a given dataframe. + + Args: + dataframe: data to input in the defined pre-hook steps. + + Returns: + dataframe after passing for all defined pre-hooks. + """ + for hook in self.pre_hooks: + dataframe = hook.run(dataframe) + return dataframe + + def run_post_hooks(self, dataframe: DataFrame) -> DataFrame: + """Run all defined post-hook steps from a given dataframe. + + Args: + dataframe: data to input in the defined post-hook steps. + + Returns: + dataframe after passing for all defined post-hooks. + """ + for hook in self.post_hooks: + dataframe = hook.run(dataframe) + return dataframe diff --git a/butterfree/hooks/schema_compatibility/__init__.py b/butterfree/hooks/schema_compatibility/__init__.py new file mode 100644 index 00000000..edf748bf --- /dev/null +++ b/butterfree/hooks/schema_compatibility/__init__.py @@ -0,0 +1,9 @@ +"""Holds Schema Compatibility Hooks definitions.""" +from butterfree.hooks.schema_compatibility.cassandra_table_schema_compatibility_hook import ( # noqa + CassandraTableSchemaCompatibilityHook, +) +from butterfree.hooks.schema_compatibility.spark_table_schema_compatibility_hook import ( # noqa + SparkTableSchemaCompatibilityHook, +) + +__all__ = ["SparkTableSchemaCompatibilityHook", "CassandraTableSchemaCompatibilityHook"] diff --git a/butterfree/hooks/schema_compatibility/cassandra_table_schema_compatibility_hook.py b/butterfree/hooks/schema_compatibility/cassandra_table_schema_compatibility_hook.py new file mode 100644 index 00000000..cdb40472 --- /dev/null +++ b/butterfree/hooks/schema_compatibility/cassandra_table_schema_compatibility_hook.py @@ -0,0 +1,58 @@ +"""Cassandra table schema compatibility Hook definition.""" + +from pyspark.sql import DataFrame + +from butterfree.clients import CassandraClient +from butterfree.constants import DataType +from butterfree.hooks.hook import Hook + + +class CassandraTableSchemaCompatibilityHook(Hook): + """Hook to verify the schema compatibility with a Cassandra's table. + + Verifies if all columns presented on the dataframe exists and are the same + type on the target Cassandra's table. + + Attributes: + cassandra_client: client to connect to Cassandra DB. + table: table name. + """ + + def __init__(self, cassandra_client: CassandraClient, table: str): + self.cassandra_client = cassandra_client + self.table = table + + def run(self, dataframe: DataFrame) -> DataFrame: + """Check the schema compatibility from a given Dataframe. + + This method does not change anything on the Dataframe. + + Args: + dataframe: dataframe to verify schema compatibility. + + Returns: + unchanged dataframe. + + Raises: + ValueError if the schemas are incompatible. + """ + table_schema = self.cassandra_client.get_schema(self.table) + type_cassandra = [ + type.cassandra + for field_id in range(len(dataframe.schema.fieldNames())) + for type in DataType + if dataframe.schema.fields.__getitem__(field_id).dataType == type.spark + ] + schema = [ + {"column_name": f"{column}", "type": f"{type}"} + for column, type in zip(dataframe.columns, type_cassandra) + ] + + if not all([column in table_schema for column in schema]): + raise ValueError( + "There's a schema incompatibility " + "between the defined dataframe and the Cassandra table.\n" + f"Dataframe schema = {schema}" + f"Target table schema = {table_schema}" + ) + return dataframe diff --git a/butterfree/hooks/schema_compatibility/spark_table_schema_compatibility_hook.py b/butterfree/hooks/schema_compatibility/spark_table_schema_compatibility_hook.py new file mode 100644 index 00000000..b08dd56a --- /dev/null +++ b/butterfree/hooks/schema_compatibility/spark_table_schema_compatibility_hook.py @@ -0,0 +1,46 @@ +"""Spark table schema compatibility Hook definition.""" + +from pyspark.sql import DataFrame + +from butterfree.clients import SparkClient +from butterfree.hooks.hook import Hook + + +class SparkTableSchemaCompatibilityHook(Hook): + """Hook to verify the schema compatibility with a Spark's table. + + Verifies if all columns presented on the dataframe exists and are the same + type on the target Spark's table. + + Attributes: + spark_client: client to connect to Spark's metastore. + table: table name. + database: database name. + """ + + def __init__(self, spark_client: SparkClient, table: str, database: str = None): + self.spark_client = spark_client + self.table_expression = (f"`{database}`." if database else "") + f"`{table}`" + + def run(self, dataframe: DataFrame) -> DataFrame: + """Check the schema compatibility from a given Dataframe. + + This method does not change anything on the Dataframe. + + Args: + dataframe: dataframe to verify schema compatibility. + + Returns: + unchanged dataframe. + + Raises: + ValueError if the schemas are incompatible. + """ + table_schema = self.spark_client.conn.table(self.table_expression).schema + if not all([column in table_schema for column in dataframe.schema]): + raise ValueError( + "The dataframe has a schema incompatible with the defined table.\n" + f"Dataframe schema = {dataframe.schema}" + f"Target table schema = {table_schema}" + ) + return dataframe diff --git a/butterfree/load/sink.py b/butterfree/load/sink.py index b4bf93e8..0b0c10c9 100644 --- a/butterfree/load/sink.py +++ b/butterfree/load/sink.py @@ -5,13 +5,14 @@ from pyspark.sql.streaming import StreamingQuery from butterfree.clients import SparkClient +from butterfree.hooks import HookableComponent from butterfree.load.writers.writer import Writer from butterfree.transform import FeatureSet from butterfree.validations import BasicValidation from butterfree.validations.validation import Validation -class Sink: +class Sink(HookableComponent): """Define the destinations for the feature set pipeline. A Sink is created from a set of writers. The main goal of the Sink is to @@ -26,6 +27,8 @@ class Sink: """ def __init__(self, writers: List[Writer], validation: Optional[Validation] = None): + super().__init__() + self.enable_post_hooks = False self.writers = writers self.validation = validation @@ -94,12 +97,16 @@ def flush( Streaming handlers for each defined writer, if writing streaming dfs. """ + pre_hook_df = self.run_pre_hooks(dataframe) + if self.validation is not None: - self.validation.input(dataframe).check() + self.validation.input(pre_hook_df).check() handlers = [ writer.write( - feature_set=feature_set, dataframe=dataframe, spark_client=spark_client + feature_set=feature_set, + dataframe=pre_hook_df, + spark_client=spark_client, ) for writer in self.writers ] diff --git a/butterfree/load/writers/historical_feature_store_writer.py b/butterfree/load/writers/historical_feature_store_writer.py index d70f68f0..456d9e6b 100644 --- a/butterfree/load/writers/historical_feature_store_writer.py +++ b/butterfree/load/writers/historical_feature_store_writer.py @@ -1,7 +1,7 @@ """Holds the Historical Feature Store writer class.""" import os -from typing import Union +from typing import Any, Union from pyspark.sql.dataframe import DataFrame from pyspark.sql.functions import dayofmonth, month, year @@ -12,6 +12,8 @@ from butterfree.constants import columns from butterfree.constants.spark_constants import DEFAULT_NUM_PARTITIONS from butterfree.dataframe_service import repartition_df +from butterfree.hooks import Hook +from butterfree.hooks.schema_compatibility import SparkTableSchemaCompatibilityHook from butterfree.load.writers.writer import Writer from butterfree.transform import FeatureSet @@ -60,6 +62,20 @@ class HistoricalFeatureStoreWriter(Writer): For what settings you can use on S3Config and default settings, to read S3Config class. + We can write with interval mode, where HistoricalFeatureStoreWrite + will need to use Dynamic Partition Inserts, + the behaviour of OVERWRITE keyword is controlled by + spark.sql.sources.partitionOverwriteMode configuration property. + The dynamic overwrite mode is enabled Spark will only delete the + partitions for which it has data to be written to. + All the other partitions remain intact. + + >>> spark_client = SparkClient() + >>> writer = HistoricalFeatureStoreWriter(interval_mode=True) + >>> writer.write(feature_set=feature_set, + ... dataframe=dataframe, + ... spark_client=spark_client) + We can instantiate HistoricalFeatureStoreWriter class to validate the df to be written. @@ -95,15 +111,17 @@ def __init__( num_partitions: int = None, validation_threshold: float = DEFAULT_VALIDATION_THRESHOLD, debug_mode: bool = False, + interval_mode: bool = False, + check_schema_hook: Hook = None, ): - super(HistoricalFeatureStoreWriter, self).__init__() + super(HistoricalFeatureStoreWriter, self).__init__(debug_mode, interval_mode) self.db_config = db_config or MetastoreConfig() self.database = database or environment.get_variable( "FEATURE_STORE_HISTORICAL_DATABASE" ) self.num_partitions = num_partitions or DEFAULT_NUM_PARTITIONS self.validation_threshold = validation_threshold - self.debug_mode = debug_mode + self.check_schema_hook = check_schema_hook def write( self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient, @@ -122,7 +140,25 @@ def write( """ dataframe = self._create_partitions(dataframe) - dataframe = self._apply_transformations(dataframe) + partition_df = self._apply_transformations(dataframe) + + if self.debug_mode: + dataframe = partition_df + else: + dataframe = self.check_schema( + spark_client, partition_df, feature_set.name, self.database + ) + + if self.interval_mode: + if self.debug_mode: + spark_client.create_temporary_view( + dataframe=dataframe, + name=f"historical_feature_store__{feature_set.name}", + ) + return + + self._incremental_mode(feature_set, dataframe, spark_client) + return if self.debug_mode: spark_client.create_temporary_view( @@ -132,6 +168,7 @@ def write( return s3_key = os.path.join("historical", feature_set.entity, feature_set.name) + spark_client.write_table( dataframe=dataframe, database=self.database, @@ -140,6 +177,34 @@ def write( **self.db_config.get_options(s3_key), ) + def _incremental_mode( + self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient + ) -> None: + + partition_overwrite_mode = spark_client.conn.conf.get( + "spark.sql.sources.partitionOverwriteMode" + ).lower() + + if partition_overwrite_mode != "dynamic": + raise RuntimeError( + "m=load_incremental_table, " + "spark.sql.sources.partitionOverwriteMode={}, " + "msg=partitionOverwriteMode have to be configured to 'dynamic'".format( + partition_overwrite_mode + ) + ) + + s3_key = os.path.join("historical", feature_set.entity, feature_set.name) + options = {"path": self.db_config.get_options(s3_key).get("path")} + + spark_client.write_dataframe( + dataframe=dataframe, + format_=self.db_config.format_, + mode=self.db_config.mode, + **options, + partitionBy=self.PARTITION_BY, + ) + def _assert_validation_count( self, table_name: str, written_count: int, dataframe_count: int ) -> None: @@ -169,12 +234,26 @@ def validate( """ table_name = ( - f"{self.database}.{feature_set.name}" - if not self.debug_mode - else f"historical_feature_store__{feature_set.name}" + f"{feature_set.name}" + if self.interval_mode and not self.debug_mode + else ( + f"{self.database}.{feature_set.name}" + if not self.debug_mode + else f"historical_feature_store__{feature_set.name}" + ) + ) + + written_count = ( + spark_client.read( + self.db_config.format_, + path=self.db_config.get_path_with_partitions(table_name, dataframe), + ).count() + if self.interval_mode and not self.debug_mode + else spark_client.read_table(table_name).count() ) - written_count = spark_client.read_table(table_name).count() + dataframe_count = dataframe.count() + self._assert_validation_count(table_name, written_count, dataframe_count) def _create_partitions(self, dataframe: DataFrame) -> DataFrame: @@ -191,3 +270,21 @@ def _create_partitions(self, dataframe: DataFrame) -> DataFrame: columns.PARTITION_DAY, dayofmonth(dataframe[columns.TIMESTAMP_COLUMN]) ) return repartition_df(dataframe, self.PARTITION_BY, self.num_partitions) + + def check_schema( + self, client: Any, dataframe: DataFrame, table_name: str, database: str = None + ) -> DataFrame: + """Instantiate the schema check hook to check schema between dataframe and database. + + Args: + client: client for Spark or Cassandra connections with external services. + dataframe: Spark dataframe containing data from a feature set. + table_name: table name where the dataframe will be saved. + database: database name where the dataframe will be saved. + """ + if not self.check_schema_hook: + self.check_schema_hook = SparkTableSchemaCompatibilityHook( + client, table_name, database + ) + + return self.check_schema_hook.run(dataframe) diff --git a/butterfree/load/writers/online_feature_store_writer.py b/butterfree/load/writers/online_feature_store_writer.py index a81a1040..fade3789 100644 --- a/butterfree/load/writers/online_feature_store_writer.py +++ b/butterfree/load/writers/online_feature_store_writer.py @@ -7,9 +7,11 @@ from pyspark.sql.functions import col, row_number from pyspark.sql.streaming import StreamingQuery -from butterfree.clients import SparkClient +from butterfree.clients import CassandraClient, SparkClient from butterfree.configs.db import AbstractWriteConfig, CassandraConfig from butterfree.constants.columns import TIMESTAMP_COLUMN +from butterfree.hooks import Hook +from butterfree.hooks.schema_compatibility import CassandraTableSchemaCompatibilityHook from butterfree.load.writers.writer import Writer from butterfree.transform import FeatureSet @@ -66,6 +68,12 @@ class OnlineFeatureStoreWriter(Writer): Both methods (writer and validate) will need the Spark Client, Feature Set and DataFrame, to write or to validate, according to OnlineFeatureStoreWriter class arguments. + + There's an important aspect to be highlighted here: if you're using + the incremental mode, we do not check if your data is the newest before + writing to the online feature store. + + This behavior is known and will be fixed soon. """ __name__ = "Online Feature Store Writer" @@ -75,11 +83,13 @@ def __init__( db_config: Union[AbstractWriteConfig, CassandraConfig] = None, debug_mode: bool = False, write_to_entity: bool = False, + interval_mode: bool = False, + check_schema_hook: Hook = None, ): - super(OnlineFeatureStoreWriter, self).__init__() + super(OnlineFeatureStoreWriter, self).__init__(debug_mode, interval_mode) self.db_config = db_config or CassandraConfig() - self.debug_mode = debug_mode self.write_to_entity = write_to_entity + self.check_schema_hook = check_schema_hook @staticmethod def filter_latest(dataframe: DataFrame, id_columns: List[Any]) -> DataFrame: @@ -170,6 +180,22 @@ def write( """ table_name = feature_set.entity if self.write_to_entity else feature_set.name + if not self.debug_mode: + config = ( + self.db_config + if self.db_config == CassandraConfig + else CassandraConfig() + ) + + cassandra_client = CassandraClient( + host=[config.host], + keyspace=config.keyspace, + user=config.username, + password=config.password, + ) + + dataframe = self.check_schema(cassandra_client, dataframe, table_name) + if dataframe.isStreaming: dataframe = self._apply_transformations(dataframe) if self.debug_mode: @@ -236,3 +262,21 @@ def get_db_schema(self, feature_set: FeatureSet) -> List[Dict[Any, Any]]: """ db_schema = self.db_config.translate(feature_set.get_schema()) return db_schema + + def check_schema( + self, client: Any, dataframe: DataFrame, table_name: str, database: str = None + ) -> DataFrame: + """Instantiate the schema check hook to check schema between dataframe and database. + + Args: + client: client for Spark or Cassandra connections with external services. + dataframe: Spark dataframe containing data from a feature set. + table_name: table name where the dataframe will be saved. + database: database name where the dataframe will be saved. + """ + if not self.check_schema_hook: + self.check_schema_hook = CassandraTableSchemaCompatibilityHook( + client, table_name + ) + + return self.check_schema_hook.run(dataframe) diff --git a/butterfree/load/writers/writer.py b/butterfree/load/writers/writer.py index f76b4c25..7e0f9018 100644 --- a/butterfree/load/writers/writer.py +++ b/butterfree/load/writers/writer.py @@ -7,10 +7,11 @@ from pyspark.sql.dataframe import DataFrame from butterfree.clients import SparkClient +from butterfree.hooks import HookableComponent from butterfree.transform import FeatureSet -class Writer(ABC): +class Writer(ABC, HookableComponent): """Abstract base class for Writers. Args: @@ -18,8 +19,11 @@ class Writer(ABC): """ - def __init__(self) -> None: + def __init__(self, debug_mode: bool = False, interval_mode: bool = False) -> None: + super().__init__() self.transformations: List[Dict[str, Any]] = [] + self.debug_mode = debug_mode + self.interval_mode = interval_mode def with_( self, transformer: Callable[..., DataFrame], *args: Any, **kwargs: Any @@ -70,6 +74,19 @@ def write( """ + @abstractmethod + def check_schema( + self, client: Any, dataframe: DataFrame, table_name: str, database: str = None + ) -> DataFrame: + """Instantiate the schema check hook to check schema between dataframe and database. + + Args: + client: client for Spark or Cassandra connections with external services. + dataframe: Spark dataframe containing data from a feature set. + table_name: table name where the dataframe will be saved. + database: database name where the dataframe will be saved. + """ + @abstractmethod def validate( self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient diff --git a/butterfree/pipelines/feature_set_pipeline.py b/butterfree/pipelines/feature_set_pipeline.py index ce1b7ba4..8aec54ec 100644 --- a/butterfree/pipelines/feature_set_pipeline.py +++ b/butterfree/pipelines/feature_set_pipeline.py @@ -40,11 +40,12 @@ class FeatureSetPipeline: ... ) >>> from butterfree.load import Sink >>> from butterfree.load.writers import HistoricalFeatureStoreWriter - >>> import pyspark.sql.functions as F + >>> from pyspark.sql import functions >>> def divide(df, fs, column1, column2): ... name = fs.get_output_columns()[0] - ... df = df.withColumn(name, F.col(column1) / F.col(column2)) + ... df = df.withColumn(name, + ... functions.col(column1) / functions.col(column2)) ... return df >>> pipeline = FeatureSetPipeline( @@ -67,7 +68,8 @@ class FeatureSetPipeline: ... name="feature1", ... description="test", ... transformation=SparkFunctionTransform( - ... functions=[F.avg, F.stddev_pop] + ... functions=[Function(functions.avg, DataType.DOUBLE), + ... Function(functions.stddev_pop, DataType.DOUBLE)], ... ).with_window( ... partition_by="id", ... order_by=TIMESTAMP_COLUMN, @@ -113,6 +115,19 @@ class FeatureSetPipeline: the defined sources, compute all the transformations and save the data to the specified locations. + We can run the pipeline over a range of dates by passing an end-date + and a start-date, where it will only bring data within this date range. + + >>> pipeline.run(end_date="2020-08-04", start_date="2020-07-04") + + Or run up to a date, where it will only bring data up to the specific date. + + >>> pipeline.run(end_date="2020-08-04") + + Or just a specific date, where you will only bring data for that day. + + >>> pipeline.run_for_date(execution_date="2020-08-04") + """ def __init__( @@ -179,6 +194,7 @@ def run( partition_by: List[str] = None, order_by: List[str] = None, num_processors: int = None, + start_date: str = None, ) -> None: """Runs the defined feature set pipeline. @@ -192,7 +208,11 @@ def run( soon. Use only if strictly necessary. """ - dataframe = self.source.construct(client=self.spark_client) + dataframe = self.source.construct( + client=self.spark_client, + start_date=self.feature_set.define_start_date(start_date), + end_date=end_date, + ) if partition_by: order_by = order_by or partition_by @@ -203,6 +223,7 @@ def run( dataframe = self.feature_set.construct( dataframe=dataframe, client=self.spark_client, + start_date=start_date, end_date=end_date, num_processors=num_processors, ) @@ -219,3 +240,30 @@ def run( feature_set=self.feature_set, spark_client=self.spark_client, ) + + def run_for_date( + self, + execution_date: str = None, + partition_by: List[str] = None, + order_by: List[str] = None, + num_processors: int = None, + ) -> None: + """Runs the defined feature set pipeline for a specific date. + + The pipeline consists in the following steps: + + - Constructs the input dataframe from the data source. + - Construct the feature set dataframe using the defined Features. + - Load the data to the configured sink locations. + + It's important to notice, however, that both parameters partition_by + and num_processors are WIP, we intend to enhance their functionality + soon. Use only if strictly necessary. + """ + self.run( + start_date=execution_date, + end_date=execution_date, + partition_by=partition_by, + order_by=order_by, + num_processors=num_processors, + ) diff --git a/butterfree/transform/aggregated_feature_set.py b/butterfree/transform/aggregated_feature_set.py index f43c12d5..a19efb35 100644 --- a/butterfree/transform/aggregated_feature_set.py +++ b/butterfree/transform/aggregated_feature_set.py @@ -1,6 +1,6 @@ """AggregatedFeatureSet entity.""" import itertools -from datetime import timedelta +from datetime import datetime, timedelta from functools import reduce from typing import Any, Dict, List, Optional, Union @@ -8,6 +8,7 @@ from pyspark.sql import DataFrame, functions from butterfree.clients import SparkClient +from butterfree.constants.window_definitions import ALLOWED_WINDOWS from butterfree.dataframe_service import repartition_df from butterfree.transform import FeatureSet from butterfree.transform.features import Feature, KeyFeature, TimestampFeature @@ -488,12 +489,45 @@ def get_schema(self) -> List[Dict[str, Any]]: return schema + @staticmethod + def _get_biggest_window_in_days(definitions: List[str]) -> float: + windows_list = [] + for window in definitions: + windows_list.append( + int(window.split()[0]) * ALLOWED_WINDOWS[window.split()[1]] + ) + return max(windows_list) / (60 * 60 * 24) + + def define_start_date(self, start_date: str = None) -> Optional[str]: + """Get aggregated feature set start date. + + Args: + start_date: start date regarding source dataframe. + + Returns: + start date. + """ + if self._windows and start_date: + window_definition = [ + definition.frame_boundaries.window_definition + for definition in self._windows + ] + biggest_window = self._get_biggest_window_in_days(window_definition) + + return ( + datetime.strptime(start_date, "%Y-%m-%d") + - timedelta(days=int(biggest_window) + 1) + ).strftime("%Y-%m-%d") + + return start_date + def construct( self, dataframe: DataFrame, client: SparkClient, end_date: str = None, num_processors: int = None, + start_date: str = None, ) -> DataFrame: """Use all the features to build the feature set dataframe. @@ -506,6 +540,7 @@ def construct( client: client responsible for connecting to Spark session. end_date: user defined max date for having aggregated data (exclusive). num_processors: cluster total number of processors for repartitioning. + start_date: user defined min date for having aggregated data. Returns: Spark dataframe with all the feature columns. @@ -519,10 +554,12 @@ def construct( if not isinstance(dataframe, DataFrame): raise ValueError("source_df must be a dataframe") + pre_hook_df = self.run_pre_hooks(dataframe) + output_df = reduce( lambda df, feature: feature.transform(df), self.keys + [self.timestamp], - dataframe, + pre_hook_df, ) if self._windows and end_date is not None: @@ -558,6 +595,10 @@ def construct( else: output_df = self._aggregate(output_df, features=self.features) + output_df = self.incremental_strategy.filter_with_incremental_strategy( + dataframe=output_df, start_date=start_date, end_date=end_date + ) + output_df = output_df.select(*self.columns).replace( # type: ignore float("nan"), None ) @@ -565,4 +606,6 @@ def construct( output_df = self._filter_duplicated_rows(output_df) output_df.cache().count() - return output_df + post_hook_df = self.run_post_hooks(output_df) + + return post_hook_df diff --git a/butterfree/transform/feature_set.py b/butterfree/transform/feature_set.py index c35e90fa..c2e40a49 100644 --- a/butterfree/transform/feature_set.py +++ b/butterfree/transform/feature_set.py @@ -1,7 +1,7 @@ """FeatureSet entity.""" import itertools from functools import reduce -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import pyspark.sql.functions as F from pyspark.sql import Window @@ -9,6 +9,8 @@ from butterfree.clients import SparkClient from butterfree.constants.columns import TIMESTAMP_COLUMN +from butterfree.dataframe_service import IncrementalStrategy +from butterfree.hooks import HookableComponent from butterfree.transform.features import Feature, KeyFeature, TimestampFeature from butterfree.transform.transformations import ( AggregatedTransform, @@ -16,7 +18,7 @@ ) -class FeatureSet: +class FeatureSet(HookableComponent): """Holds metadata about the feature set and constructs the final dataframe. Attributes: @@ -106,12 +108,14 @@ def __init__( timestamp: TimestampFeature, features: List[Feature], ) -> None: + super().__init__() self.name = name self.entity = entity self.description = description self.keys = keys self.timestamp = timestamp self.features = features + self.incremental_strategy = IncrementalStrategy(column=TIMESTAMP_COLUMN) @property def name(self) -> str: @@ -243,9 +247,6 @@ def columns(self) -> List[str]: def get_schema(self) -> List[Dict[str, Any]]: """Get feature set schema. - Args: - feature_set: object processed with feature set metadata. - Returns: List of dicts regarding cassandra feature set schema. @@ -378,12 +379,24 @@ def _filter_duplicated_rows(self, df: DataFrame) -> DataFrame: return df.select([column for column in self.columns]) + def define_start_date(self, start_date: str = None) -> Optional[str]: + """Get feature set start date. + + Args: + start_date: start date regarding source dataframe. + + Returns: + start date. + """ + return start_date + def construct( self, dataframe: DataFrame, client: SparkClient, end_date: str = None, num_processors: int = None, + start_date: str = None, ) -> DataFrame: """Use all the features to build the feature set dataframe. @@ -393,7 +406,8 @@ def construct( Args: dataframe: input dataframe to be transformed by the features. client: client responsible for connecting to Spark session. - end_date: user defined base date. + start_date: user defined start date. + end_date: user defined end date. num_processors: cluster total number of processors for repartitioning. Returns: @@ -403,14 +417,22 @@ def construct( if not isinstance(dataframe, DataFrame): raise ValueError("source_df must be a dataframe") + pre_hook_df = self.run_pre_hooks(dataframe) + output_df = reduce( lambda df, feature: feature.transform(df), self.keys + [self.timestamp] + self.features, - dataframe, + pre_hook_df, ).select(*self.columns) if not output_df.isStreaming: output_df = self._filter_duplicated_rows(output_df) output_df.cache().count() - return output_df + output_df = self.incremental_strategy.filter_with_incremental_strategy( + dataframe=output_df, start_date=start_date, end_date=end_date + ) + + post_hook_df = self.run_post_hooks(output_df) + + return post_hook_df diff --git a/butterfree/transform/utils/window_spec.py b/butterfree/transform/utils/window_spec.py index f3a392f6..a270fec0 100644 --- a/butterfree/transform/utils/window_spec.py +++ b/butterfree/transform/utils/window_spec.py @@ -5,6 +5,7 @@ from pyspark.sql import Column, WindowSpec, functions from butterfree.constants.columns import TIMESTAMP_COLUMN +from butterfree.constants.window_definitions import ALLOWED_WINDOWS class FrameBoundaries: @@ -16,21 +17,6 @@ class FrameBoundaries: it can be second(s), minute(s), hour(s), day(s), week(s) and year(s), """ - __ALLOWED_WINDOWS = { - "second": 1, - "seconds": 1, - "minute": 60, - "minutes": 60, - "hour": 3600, - "hours": 3600, - "day": 86400, - "days": 86400, - "week": 604800, - "weeks": 604800, - "year": 29030400, - "years": 29030400, - } - def __init__(self, mode: Optional[str], window_definition: str): self.mode = mode self.window_definition = window_definition @@ -46,7 +32,7 @@ def window_size(self) -> int: def window_unit(self) -> str: """Returns window unit.""" unit = self.window_definition.split()[1] - if unit not in self.__ALLOWED_WINDOWS and self.mode != "row_windows": + if unit not in ALLOWED_WINDOWS and self.mode != "row_windows": raise ValueError("Not allowed") return unit @@ -59,7 +45,7 @@ def get(self, window: WindowSpec) -> Any: span = self.window_size - 1 return window.rowsBetween(-span, 0) if self.mode == "fixed_windows": - span = self.__ALLOWED_WINDOWS[self.window_unit] * self.window_size + span = ALLOWED_WINDOWS[self.window_unit] * self.window_size return window.rangeBetween(-span, 0) diff --git a/examples/interval_runs/interval_runs.ipynb b/examples/interval_runs/interval_runs.ipynb new file mode 100644 index 00000000..e234da8a --- /dev/null +++ b/examples/interval_runs/interval_runs.ipynb @@ -0,0 +1,2152 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# #5 Discovering Butterfree - Interval Runs\n", + "\n", + "Welcome to Discovering Butterfree tutorial series!\n", + "\n", + "This is the fifth tutorial of this series: its goal is to cover interval runs.\n", + "\n", + "Before diving into the tutorial make sure you have a basic understanding of these main data concepts: features, feature sets and the \"Feature Store Architecture\", you can read more about this [here]." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example:\n", + "\n", + "Simulating the following scenario (the same from previous tutorials):\n", + "\n", + "- We want to create a feature set with features about houses for rent (listings).\n", + "\n", + "\n", + "We have an input dataset:\n", + "\n", + "- Table: `listing_events`. Table with data about events of house listings.\n", + "\n", + "\n", + "Our desire is to have three resulting datasets with the following schema:\n", + "\n", + "* id: **int**;\n", + "* timestamp: **timestamp**;\n", + "* rent__avg_over_1_day_rolling_windows: **double**;\n", + "* rent__stddev_pop_over_1_day_rolling_windows: **double**.\n", + " \n", + "The first dataset will be computed with just an end date time limit. The second one, on the other hand, uses both start and end date in order to filter data. Finally, the third one will be the result of a daily run. You can understand more about these definitions in our documentation.\n", + "\n", + "The following code blocks will show how to generate this feature set using Butterfree library:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# setup spark\n", + "from pyspark import SparkContext, SparkConf\n", + "from pyspark.sql import session\n", + "\n", + "conf = SparkConf().setAll([('spark.driver.host','127.0.0.1'), ('spark.sql.session.timeZone', 'UTC')])\n", + "sc = SparkContext(conf=conf)\n", + "spark = session.SparkSession(sc)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# fix working dir\n", + "import pathlib\n", + "import os\n", + "path = os.path.join(pathlib.Path().absolute(), '../..')\n", + "os.chdir(path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Showing test data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "listing_events_df = spark.read.json(f\"{path}/examples/data/listing_events.json\")\n", + "listing_events_df.createOrReplaceTempView(\"listing_events\") # creating listing_events view\n", + "\n", + "region = spark.read.json(f\"{path}/examples/data/region.json\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Listing events table:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
areabathroomsbedroomsidregion_idrenttimestamp
050111113001588302000000
150111120001588647600000
2100122215001588734000000
3100122225001589252400000
4150223330001589943600000
5175224432001589943600000
6250335532001590030000000
7225326632001590116400000
\n", + "
" + ], + "text/plain": [ + " area bathrooms bedrooms id region_id rent timestamp\n", + "0 50 1 1 1 1 1300 1588302000000\n", + "1 50 1 1 1 1 2000 1588647600000\n", + "2 100 1 2 2 2 1500 1588734000000\n", + "3 100 1 2 2 2 2500 1589252400000\n", + "4 150 2 2 3 3 3000 1589943600000\n", + "5 175 2 2 4 4 3200 1589943600000\n", + "6 250 3 3 5 5 3200 1590030000000\n", + "7 225 3 2 6 6 3200 1590116400000" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "listing_events_df.toPandas()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Region table:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
cityidlatlngregion
0Cerulean173.4448931.75030Kanto
1Veridian2-9.43510-167.11772Kanto
2Cinnabar329.73043117.66164Kanto
3Pallet4-52.95717-81.15251Kanto
4Violet5-47.35798-178.77255Johto
5Olivine651.7282046.21958Johto
\n", + "
" + ], + "text/plain": [ + " city id lat lng region\n", + "0 Cerulean 1 73.44489 31.75030 Kanto\n", + "1 Veridian 2 -9.43510 -167.11772 Kanto\n", + "2 Cinnabar 3 29.73043 117.66164 Kanto\n", + "3 Pallet 4 -52.95717 -81.15251 Kanto\n", + "4 Violet 5 -47.35798 -178.77255 Johto\n", + "5 Olivine 6 51.72820 46.21958 Johto" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "region.toPandas()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Extract\n", + "\n", + "- For the extract part, we need the `Source` entity and the `FileReader` for the data we have;\n", + "- We need to declare a query in order to bring the results from our lonely reader (it's as simples as a select all statement)." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from butterfree.clients import SparkClient\n", + "from butterfree.extract import Source\n", + "from butterfree.extract.readers import FileReader, TableReader\n", + "from butterfree.extract.pre_processing import filter\n", + "\n", + "readers = [\n", + " TableReader(id=\"listing_events\", table=\"listing_events\",),\n", + " FileReader(id=\"region\", path=f\"{path}/examples/data/region.json\", format=\"json\",)\n", + "]\n", + "\n", + "query = \"\"\"\n", + "select\n", + " listing_events.*,\n", + " region.city,\n", + " region.region,\n", + " region.lat,\n", + " region.lng,\n", + " region.region as region_name\n", + "from\n", + " listing_events\n", + " join region\n", + " on listing_events.region_id = region.id\n", + "\"\"\"\n", + "\n", + "source = Source(readers=readers, query=query)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "spark_client = SparkClient()\n", + "source_df = source.construct(spark_client)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And, finally, it's possible to see the results from building our souce dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
areabathroomsbedroomsidregion_idrenttimestampcityregionlatlngregion_name
050111113001588302000000CeruleanKanto73.4448931.75030Kanto
150111120001588647600000CeruleanKanto73.4448931.75030Kanto
2100122215001588734000000VeridianKanto-9.43510-167.11772Kanto
3100122225001589252400000VeridianKanto-9.43510-167.11772Kanto
4150223330001589943600000CinnabarKanto29.73043117.66164Kanto
5175224432001589943600000PalletKanto-52.95717-81.15251Kanto
6250335532001590030000000VioletJohto-47.35798-178.77255Johto
7225326632001590116400000OlivineJohto51.7282046.21958Johto
\n", + "
" + ], + "text/plain": [ + " area bathrooms bedrooms id region_id rent timestamp city \\\n", + "0 50 1 1 1 1 1300 1588302000000 Cerulean \n", + "1 50 1 1 1 1 2000 1588647600000 Cerulean \n", + "2 100 1 2 2 2 1500 1588734000000 Veridian \n", + "3 100 1 2 2 2 2500 1589252400000 Veridian \n", + "4 150 2 2 3 3 3000 1589943600000 Cinnabar \n", + "5 175 2 2 4 4 3200 1589943600000 Pallet \n", + "6 250 3 3 5 5 3200 1590030000000 Violet \n", + "7 225 3 2 6 6 3200 1590116400000 Olivine \n", + "\n", + " region lat lng region_name \n", + "0 Kanto 73.44489 31.75030 Kanto \n", + "1 Kanto 73.44489 31.75030 Kanto \n", + "2 Kanto -9.43510 -167.11772 Kanto \n", + "3 Kanto -9.43510 -167.11772 Kanto \n", + "4 Kanto 29.73043 117.66164 Kanto \n", + "5 Kanto -52.95717 -81.15251 Kanto \n", + "6 Johto -47.35798 -178.77255 Johto \n", + "7 Johto 51.72820 46.21958 Johto " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "source_df.toPandas()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Transform\n", + "- At the transform part, a set of `Feature` objects is declared;\n", + "- An Instance of `AggregatedFeatureSet` is used to hold the features;\n", + "- An `AggregatedFeatureSet` can only be created when it is possible to define a unique tuple formed by key columns and a time reference. This is an **architectural requirement** for the data. So least one `KeyFeature` and one `TimestampFeature` is needed;\n", + "- Every `Feature` needs a unique name, a description, and a data-type definition. Besides, in the case of the `AggregatedFeatureSet`, it's also mandatory to have an `AggregatedTransform` operator;\n", + "- An `AggregatedTransform` operator is used, as the name suggests, to define aggregation functions." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from pyspark.sql import functions as F\n", + "\n", + "from butterfree.transform.aggregated_feature_set import AggregatedFeatureSet\n", + "from butterfree.transform.features import Feature, KeyFeature, TimestampFeature\n", + "from butterfree.transform.transformations import AggregatedTransform\n", + "from butterfree.constants import DataType\n", + "from butterfree.transform.utils import Function\n", + "\n", + "keys = [\n", + " KeyFeature(\n", + " name=\"id\",\n", + " description=\"Unique identificator code for houses.\",\n", + " dtype=DataType.BIGINT,\n", + " )\n", + "]\n", + "\n", + "# from_ms = True because the data originally is not in a Timestamp format.\n", + "ts_feature = TimestampFeature(from_ms=True)\n", + "\n", + "features = [\n", + " Feature(\n", + " name=\"rent\",\n", + " description=\"Rent value by month described in the listing.\",\n", + " transformation=AggregatedTransform(\n", + " functions=[\n", + " Function(F.avg, DataType.DOUBLE),\n", + " Function(F.stddev_pop, DataType.DOUBLE),\n", + " ],\n", + " filter_expression=\"region_name = 'Kanto'\",\n", + " ),\n", + " )\n", + "]\n", + "\n", + "aggregated_feature_set = AggregatedFeatureSet(\n", + " name=\"house_listings\",\n", + " entity=\"house\", # entity: to which \"business context\" this feature set belongs\n", + " description=\"Features describring a house listing.\",\n", + " keys=keys,\n", + " timestamp=ts_feature,\n", + " features=features,\n", + ").with_windows(definitions=[\"1 day\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here, we'll define out first aggregated feature set, with just an `end date` parameter:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "aggregated_feature_set_windows_df = aggregated_feature_set.construct(\n", + " source_df, \n", + " spark_client, \n", + " end_date=\"2020-05-30\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The resulting dataset is:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtimestamprent__avg_over_1_day_rolling_windowsrent__stddev_pop_over_1_day_rolling_windows
012020-05-01NaNNaN
112020-05-021300.00.0
212020-05-03NaNNaN
312020-05-062000.00.0
412020-05-07NaNNaN
522020-05-01NaNNaN
622020-05-071500.00.0
722020-05-08NaNNaN
822020-05-132500.00.0
922020-05-14NaNNaN
1032020-05-01NaNNaN
1132020-05-213000.00.0
1232020-05-22NaNNaN
1342020-05-01NaNNaN
1442020-05-213200.00.0
1542020-05-22NaNNaN
1652020-05-01NaNNaN
1762020-05-01NaNNaN
\n", + "
" + ], + "text/plain": [ + " id timestamp rent__avg_over_1_day_rolling_windows \\\n", + "0 1 2020-05-01 NaN \n", + "1 1 2020-05-02 1300.0 \n", + "2 1 2020-05-03 NaN \n", + "3 1 2020-05-06 2000.0 \n", + "4 1 2020-05-07 NaN \n", + "5 2 2020-05-01 NaN \n", + "6 2 2020-05-07 1500.0 \n", + "7 2 2020-05-08 NaN \n", + "8 2 2020-05-13 2500.0 \n", + "9 2 2020-05-14 NaN \n", + "10 3 2020-05-01 NaN \n", + "11 3 2020-05-21 3000.0 \n", + "12 3 2020-05-22 NaN \n", + "13 4 2020-05-01 NaN \n", + "14 4 2020-05-21 3200.0 \n", + "15 4 2020-05-22 NaN \n", + "16 5 2020-05-01 NaN \n", + "17 6 2020-05-01 NaN \n", + "\n", + " rent__stddev_pop_over_1_day_rolling_windows \n", + "0 NaN \n", + "1 0.0 \n", + "2 NaN \n", + "3 0.0 \n", + "4 NaN \n", + "5 NaN \n", + "6 0.0 \n", + "7 NaN \n", + "8 0.0 \n", + "9 NaN \n", + "10 NaN \n", + "11 0.0 \n", + "12 NaN \n", + "13 NaN \n", + "14 0.0 \n", + "15 NaN \n", + "16 NaN \n", + "17 NaN " + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "aggregated_feature_set_windows_df.orderBy(\"id\", \"timestamp\").toPandas()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It's possible to see that if we use both a `start date` and `end_date` values. Then we'll achieve a time slice of the last dataframe, as it's possible to see:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtimestamprent__avg_over_1_day_rolling_windowsrent__stddev_pop_over_1_day_rolling_windows
012020-05-062000.00.0
112020-05-07NaNNaN
222020-05-06NaNNaN
322020-05-071500.00.0
422020-05-08NaNNaN
522020-05-132500.00.0
622020-05-14NaNNaN
732020-05-06NaNNaN
832020-05-213000.00.0
942020-05-06NaNNaN
1042020-05-213200.00.0
1152020-05-06NaNNaN
1262020-05-06NaNNaN
\n", + "
" + ], + "text/plain": [ + " id timestamp rent__avg_over_1_day_rolling_windows \\\n", + "0 1 2020-05-06 2000.0 \n", + "1 1 2020-05-07 NaN \n", + "2 2 2020-05-06 NaN \n", + "3 2 2020-05-07 1500.0 \n", + "4 2 2020-05-08 NaN \n", + "5 2 2020-05-13 2500.0 \n", + "6 2 2020-05-14 NaN \n", + "7 3 2020-05-06 NaN \n", + "8 3 2020-05-21 3000.0 \n", + "9 4 2020-05-06 NaN \n", + "10 4 2020-05-21 3200.0 \n", + "11 5 2020-05-06 NaN \n", + "12 6 2020-05-06 NaN \n", + "\n", + " rent__stddev_pop_over_1_day_rolling_windows \n", + "0 0.0 \n", + "1 NaN \n", + "2 NaN \n", + "3 0.0 \n", + "4 NaN \n", + "5 0.0 \n", + "6 NaN \n", + "7 NaN \n", + "8 0.0 \n", + "9 NaN \n", + "10 0.0 \n", + "11 NaN \n", + "12 NaN " + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "aggregated_feature_set.construct(\n", + " source_df, \n", + " spark_client, \n", + " end_date=\"2020-05-21\",\n", + " start_date=\"2020-05-06\",\n", + ").orderBy(\"id\", \"timestamp\").toPandas()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load\n", + "\n", + "- For the load part we need `Writer` instances and a `Sink`;\n", + "- `writers` define where to load the data;\n", + "- The `Sink` gets the transformed data (feature set) and trigger the load to all the defined `writers`;\n", + "- `debug_mode` will create a temporary view instead of trying to write in a real data store." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "from butterfree.load.writers import (\n", + " HistoricalFeatureStoreWriter,\n", + " OnlineFeatureStoreWriter,\n", + ")\n", + "from butterfree.load import Sink\n", + "\n", + "writers = [HistoricalFeatureStoreWriter(debug_mode=True, interval_mode=True), \n", + " OnlineFeatureStoreWriter(debug_mode=True, interval_mode=True)]\n", + "sink = Sink(writers=writers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pipeline\n", + "\n", + "- The `Pipeline` entity wraps all the other defined elements.\n", + "- `run` command will trigger the execution of the pipeline, end-to-end." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from butterfree.pipelines import FeatureSetPipeline\n", + "\n", + "pipeline = FeatureSetPipeline(source=source, feature_set=aggregated_feature_set, sink=sink)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The first run will use just an `end_date` as parameter:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "result_df = pipeline.run(end_date=\"2020-05-30\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtimestamprent__avg_over_1_day_rolling_windowsrent__stddev_pop_over_1_day_rolling_windowsyearmonthday
012020-05-01NaNNaN202051
112020-05-021300.00.0202052
212020-05-03NaNNaN202053
312020-05-062000.00.0202056
412020-05-07NaNNaN202057
522020-05-01NaNNaN202051
622020-05-071500.00.0202057
722020-05-08NaNNaN202058
822020-05-132500.00.02020513
922020-05-14NaNNaN2020514
1032020-05-01NaNNaN202051
1132020-05-213000.00.02020521
1232020-05-22NaNNaN2020522
1342020-05-01NaNNaN202051
1442020-05-213200.00.02020521
1542020-05-22NaNNaN2020522
1652020-05-01NaNNaN202051
1762020-05-01NaNNaN202051
\n", + "
" + ], + "text/plain": [ + " id timestamp rent__avg_over_1_day_rolling_windows \\\n", + "0 1 2020-05-01 NaN \n", + "1 1 2020-05-02 1300.0 \n", + "2 1 2020-05-03 NaN \n", + "3 1 2020-05-06 2000.0 \n", + "4 1 2020-05-07 NaN \n", + "5 2 2020-05-01 NaN \n", + "6 2 2020-05-07 1500.0 \n", + "7 2 2020-05-08 NaN \n", + "8 2 2020-05-13 2500.0 \n", + "9 2 2020-05-14 NaN \n", + "10 3 2020-05-01 NaN \n", + "11 3 2020-05-21 3000.0 \n", + "12 3 2020-05-22 NaN \n", + "13 4 2020-05-01 NaN \n", + "14 4 2020-05-21 3200.0 \n", + "15 4 2020-05-22 NaN \n", + "16 5 2020-05-01 NaN \n", + "17 6 2020-05-01 NaN \n", + "\n", + " rent__stddev_pop_over_1_day_rolling_windows year month day \n", + "0 NaN 2020 5 1 \n", + "1 0.0 2020 5 2 \n", + "2 NaN 2020 5 3 \n", + "3 0.0 2020 5 6 \n", + "4 NaN 2020 5 7 \n", + "5 NaN 2020 5 1 \n", + "6 0.0 2020 5 7 \n", + "7 NaN 2020 5 8 \n", + "8 0.0 2020 5 13 \n", + "9 NaN 2020 5 14 \n", + "10 NaN 2020 5 1 \n", + "11 0.0 2020 5 21 \n", + "12 NaN 2020 5 22 \n", + "13 NaN 2020 5 1 \n", + "14 0.0 2020 5 21 \n", + "15 NaN 2020 5 22 \n", + "16 NaN 2020 5 1 \n", + "17 NaN 2020 5 1 " + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "spark.table(\"historical_feature_store__house_listings\").orderBy(\n", + " \"id\", \"timestamp\"\n", + ").orderBy(\"id\", \"timestamp\").toPandas()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtimestamprent__avg_over_1_day_rolling_windowsrent__stddev_pop_over_1_day_rolling_windows
012020-05-07NaNNaN
122020-05-14NaNNaN
232020-05-22NaNNaN
342020-05-22NaNNaN
452020-05-01NaNNaN
562020-05-01NaNNaN
\n", + "
" + ], + "text/plain": [ + " id timestamp rent__avg_over_1_day_rolling_windows \\\n", + "0 1 2020-05-07 NaN \n", + "1 2 2020-05-14 NaN \n", + "2 3 2020-05-22 NaN \n", + "3 4 2020-05-22 NaN \n", + "4 5 2020-05-01 NaN \n", + "5 6 2020-05-01 NaN \n", + "\n", + " rent__stddev_pop_over_1_day_rolling_windows \n", + "0 NaN \n", + "1 NaN \n", + "2 NaN \n", + "3 NaN \n", + "4 NaN \n", + "5 NaN " + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "spark.table(\"online_feature_store__house_listings\").orderBy(\"id\", \"timestamp\").toPandas()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- We can see that we were able to create all the desired features in an easy way\n", + "- The **historical feature set** holds all the data, and we can see that it is partitioned by year, month and day (columns added in the `HistoricalFeatureStoreWriter`)\n", + "- In the **online feature set** there is only the latest data for each id" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The second run, on the other hand, will use both a `start_date` and `end_date` as parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "result_df = pipeline.run(end_date=\"2020-05-21\", start_date=\"2020-05-06\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtimestamprent__avg_over_1_day_rolling_windowsrent__stddev_pop_over_1_day_rolling_windowsyearmonthday
012020-05-062000.00.0202056
112020-05-07NaNNaN202057
222020-05-06NaNNaN202056
322020-05-071500.00.0202057
422020-05-08NaNNaN202058
522020-05-132500.00.02020513
622020-05-14NaNNaN2020514
732020-05-06NaNNaN202056
832020-05-213000.00.02020521
942020-05-06NaNNaN202056
1042020-05-213200.00.02020521
1152020-05-06NaNNaN202056
1262020-05-06NaNNaN202056
\n", + "
" + ], + "text/plain": [ + " id timestamp rent__avg_over_1_day_rolling_windows \\\n", + "0 1 2020-05-06 2000.0 \n", + "1 1 2020-05-07 NaN \n", + "2 2 2020-05-06 NaN \n", + "3 2 2020-05-07 1500.0 \n", + "4 2 2020-05-08 NaN \n", + "5 2 2020-05-13 2500.0 \n", + "6 2 2020-05-14 NaN \n", + "7 3 2020-05-06 NaN \n", + "8 3 2020-05-21 3000.0 \n", + "9 4 2020-05-06 NaN \n", + "10 4 2020-05-21 3200.0 \n", + "11 5 2020-05-06 NaN \n", + "12 6 2020-05-06 NaN \n", + "\n", + " rent__stddev_pop_over_1_day_rolling_windows year month day \n", + "0 0.0 2020 5 6 \n", + "1 NaN 2020 5 7 \n", + "2 NaN 2020 5 6 \n", + "3 0.0 2020 5 7 \n", + "4 NaN 2020 5 8 \n", + "5 0.0 2020 5 13 \n", + "6 NaN 2020 5 14 \n", + "7 NaN 2020 5 6 \n", + "8 0.0 2020 5 21 \n", + "9 NaN 2020 5 6 \n", + "10 0.0 2020 5 21 \n", + "11 NaN 2020 5 6 \n", + "12 NaN 2020 5 6 " + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "spark.table(\"historical_feature_store__house_listings\").orderBy(\n", + " \"id\", \"timestamp\"\n", + ").orderBy(\"id\", \"timestamp\").toPandas()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtimestamprent__avg_over_1_day_rolling_windowsrent__stddev_pop_over_1_day_rolling_windows
012020-05-07NaNNaN
122020-05-14NaNNaN
232020-05-213000.00.0
342020-05-213200.00.0
452020-05-06NaNNaN
562020-05-06NaNNaN
\n", + "
" + ], + "text/plain": [ + " id timestamp rent__avg_over_1_day_rolling_windows \\\n", + "0 1 2020-05-07 NaN \n", + "1 2 2020-05-14 NaN \n", + "2 3 2020-05-21 3000.0 \n", + "3 4 2020-05-21 3200.0 \n", + "4 5 2020-05-06 NaN \n", + "5 6 2020-05-06 NaN \n", + "\n", + " rent__stddev_pop_over_1_day_rolling_windows \n", + "0 NaN \n", + "1 NaN \n", + "2 0.0 \n", + "3 0.0 \n", + "4 NaN \n", + "5 NaN " + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "spark.table(\"online_feature_store__house_listings\").orderBy(\"id\", \"timestamp\").toPandas()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, the third run, will use only an `execution_date` as a parameter." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "result_df = pipeline.run_for_date(execution_date=\"2020-05-21\")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtimestamprent__avg_over_1_day_rolling_windowsrent__stddev_pop_over_1_day_rolling_windowsyearmonthday
012020-05-21NaNNaN2020521
122020-05-21NaNNaN2020521
232020-05-213000.00.02020521
342020-05-213200.00.02020521
452020-05-21NaNNaN2020521
562020-05-21NaNNaN2020521
\n", + "
" + ], + "text/plain": [ + " id timestamp rent__avg_over_1_day_rolling_windows \\\n", + "0 1 2020-05-21 NaN \n", + "1 2 2020-05-21 NaN \n", + "2 3 2020-05-21 3000.0 \n", + "3 4 2020-05-21 3200.0 \n", + "4 5 2020-05-21 NaN \n", + "5 6 2020-05-21 NaN \n", + "\n", + " rent__stddev_pop_over_1_day_rolling_windows year month day \n", + "0 NaN 2020 5 21 \n", + "1 NaN 2020 5 21 \n", + "2 0.0 2020 5 21 \n", + "3 0.0 2020 5 21 \n", + "4 NaN 2020 5 21 \n", + "5 NaN 2020 5 21 " + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "spark.table(\"historical_feature_store__house_listings\").orderBy(\n", + " \"id\", \"timestamp\"\n", + ").orderBy(\"id\", \"timestamp\").toPandas()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idtimestamprent__avg_over_1_day_rolling_windowsrent__stddev_pop_over_1_day_rolling_windows
012020-05-21NaNNaN
122020-05-21NaNNaN
232020-05-213000.00.0
342020-05-213200.00.0
452020-05-21NaNNaN
562020-05-21NaNNaN
\n", + "
" + ], + "text/plain": [ + " id timestamp rent__avg_over_1_day_rolling_windows \\\n", + "0 1 2020-05-21 NaN \n", + "1 2 2020-05-21 NaN \n", + "2 3 2020-05-21 3000.0 \n", + "3 4 2020-05-21 3200.0 \n", + "4 5 2020-05-21 NaN \n", + "5 6 2020-05-21 NaN \n", + "\n", + " rent__stddev_pop_over_1_day_rolling_windows \n", + "0 NaN \n", + "1 NaN \n", + "2 0.0 \n", + "3 0.0 \n", + "4 NaN \n", + "5 NaN " + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "spark.table(\"online_feature_store__house_listings\").orderBy(\"id\", \"timestamp\").toPandas()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/setup.py b/setup.py index bf471fec..4adcbce9 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import find_packages, setup __package_name__ = "butterfree" -__version__ = "1.1.3.dev0" +__version__ = "1.2.0.dev0" __repository_url__ = "https://github.com/quintoandar/butterfree" with open("requirements.txt") as f: diff --git a/tests/integration/butterfree/load/test_sink.py b/tests/integration/butterfree/load/test_sink.py index d00f4806..f507a335 100644 --- a/tests/integration/butterfree/load/test_sink.py +++ b/tests/integration/butterfree/load/test_sink.py @@ -9,9 +9,10 @@ ) -def test_sink(input_dataframe, feature_set): +def test_sink(input_dataframe, feature_set, mocker): # arrange client = SparkClient() + client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic") feature_set_df = feature_set.construct(input_dataframe, client) target_latest_df = OnlineFeatureStoreWriter.filter_latest( feature_set_df, id_columns=[key.name for key in feature_set.keys] @@ -20,14 +21,23 @@ def test_sink(input_dataframe, feature_set): # setup historical writer s3config = Mock() + s3config.mode = "overwrite" + s3config.format_ = "parquet" s3config.get_options = Mock( - return_value={ - "mode": "overwrite", - "format_": "parquet", - "path": "test_folder/historical/entity/feature_set", - } + return_value={"path": "test_folder/historical/entity/feature_set"} + ) + s3config.get_path_with_partitions = Mock( + return_value="test_folder/historical/entity/feature_set" + ) + + historical_writer = HistoricalFeatureStoreWriter( + db_config=s3config, interval_mode=True ) - historical_writer = HistoricalFeatureStoreWriter(db_config=s3config) + + schema_dataframe = historical_writer._create_partitions(feature_set_df) + historical_writer.check_schema_hook = mocker.stub("check_schema_hook") + historical_writer.check_schema_hook.run = mocker.stub("run") + historical_writer.check_schema_hook.run.return_value = schema_dataframe # setup online writer # TODO: Change for CassandraConfig when Cassandra for test is ready @@ -39,6 +49,10 @@ def test_sink(input_dataframe, feature_set): ) online_writer = OnlineFeatureStoreWriter(db_config=online_config) + online_writer.check_schema_hook = mocker.stub("check_schema_hook") + online_writer.check_schema_hook.run = mocker.stub("run") + online_writer.check_schema_hook.run.return_value = feature_set_df + writers = [historical_writer, online_writer] sink = Sink(writers) @@ -47,13 +61,14 @@ def test_sink(input_dataframe, feature_set): sink.flush(feature_set, feature_set_df, client) # get historical results - historical_result_df = client.read_table( - feature_set.name, historical_writer.database + historical_result_df = client.read( + s3config.format_, + path=s3config.get_path_with_partitions(feature_set.name, feature_set_df), ) # get online results online_result_df = client.read( - online_config.format_, options=online_config.get_options(feature_set.name) + online_config.format_, **online_config.get_options(feature_set.name) ) # assert diff --git a/tests/integration/butterfree/pipelines/conftest.py b/tests/integration/butterfree/pipelines/conftest.py index 79894176..73da163e 100644 --- a/tests/integration/butterfree/pipelines/conftest.py +++ b/tests/integration/butterfree/pipelines/conftest.py @@ -1,7 +1,19 @@ import pytest +from pyspark.sql import DataFrame +from pyspark.sql import functions as F from butterfree.constants import DataType from butterfree.constants.columns import TIMESTAMP_COLUMN +from butterfree.dataframe_service.incremental_strategy import IncrementalStrategy +from butterfree.extract import Source +from butterfree.extract.readers import TableReader +from butterfree.load import Sink +from butterfree.load.writers import HistoricalFeatureStoreWriter +from butterfree.pipelines.feature_set_pipeline import FeatureSetPipeline +from butterfree.transform import FeatureSet +from butterfree.transform.features import Feature, KeyFeature, TimestampFeature +from butterfree.transform.transformations import SparkFunctionTransform +from butterfree.transform.utils import Function @pytest.fixture() @@ -74,3 +86,193 @@ def fixed_windows_output_feature_set_dataframe(spark_context, spark_session): df = df.withColumn(TIMESTAMP_COLUMN, df.timestamp.cast(DataType.TIMESTAMP.spark)) return df + + +@pytest.fixture() +def mocked_date_df(spark_context, spark_session): + data = [ + {"id": 1, "ts": "2016-04-11 11:31:11", "feature": 200}, + {"id": 1, "ts": "2016-04-12 11:44:12", "feature": 300}, + {"id": 1, "ts": "2016-04-13 11:46:24", "feature": 400}, + {"id": 1, "ts": "2016-04-14 12:03:21", "feature": 500}, + ] + df = spark_session.read.json(spark_context.parallelize(data, 1)) + df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark)) + + return df + + +@pytest.fixture() +def fixed_windows_output_feature_set_date_dataframe(spark_context, spark_session): + data = [ + { + "id": 1, + "timestamp": "2016-04-12 11:44:12", + "feature__avg_over_1_day_fixed_windows": 300, + "feature__stddev_pop_over_1_day_fixed_windows": 0, + "year": 2016, + "month": 4, + "day": 12, + }, + { + "id": 1, + "timestamp": "2016-04-13 11:46:24", + "feature__avg_over_1_day_fixed_windows": 400, + "feature__stddev_pop_over_1_day_fixed_windows": 0, + "year": 2016, + "month": 4, + "day": 13, + }, + ] + df = spark_session.read.json(spark_context.parallelize(data, 1)) + df = df.withColumn(TIMESTAMP_COLUMN, df.timestamp.cast(DataType.TIMESTAMP.spark)) + + return df + + +@pytest.fixture() +def feature_set_pipeline( + spark_context, spark_session, +): + + feature_set_pipeline = FeatureSetPipeline( + source=Source( + readers=[ + TableReader(id="b_source", table="b_table",).with_incremental_strategy( + incremental_strategy=IncrementalStrategy(column="timestamp") + ), + ], + query=f"select * from b_source ", # noqa + ), + feature_set=FeatureSet( + name="feature_set", + entity="entity", + description="description", + features=[ + Feature( + name="feature", + description="test", + transformation=SparkFunctionTransform( + functions=[ + Function(F.avg, DataType.FLOAT), + Function(F.stddev_pop, DataType.FLOAT), + ], + ).with_window( + partition_by="id", + order_by=TIMESTAMP_COLUMN, + mode="fixed_windows", + window_definition=["1 day"], + ), + ), + ], + keys=[ + KeyFeature( + name="id", + description="The user's Main ID or device ID", + dtype=DataType.INTEGER, + ) + ], + timestamp=TimestampFeature(), + ), + sink=Sink(writers=[HistoricalFeatureStoreWriter(debug_mode=True)]), + ) + + return feature_set_pipeline + + +@pytest.fixture() +def pipeline_interval_run_target_dfs( + spark_session, spark_context +) -> (DataFrame, DataFrame, DataFrame): + first_data = [ + { + "id": 1, + "timestamp": "2016-04-11 11:31:11", + "feature": 200, + "run_id": 1, + "year": 2016, + "month": 4, + "day": 11, + }, + { + "id": 1, + "timestamp": "2016-04-12 11:44:12", + "feature": 300, + "run_id": 1, + "year": 2016, + "month": 4, + "day": 12, + }, + { + "id": 1, + "timestamp": "2016-04-13 11:46:24", + "feature": 400, + "run_id": 1, + "year": 2016, + "month": 4, + "day": 13, + }, + ] + + second_data = first_data + [ + { + "id": 1, + "timestamp": "2016-04-14 12:03:21", + "feature": 500, + "run_id": 2, + "year": 2016, + "month": 4, + "day": 14, + }, + ] + + third_data = [ + { + "id": 1, + "timestamp": "2016-04-11 11:31:11", + "feature": 200, + "run_id": 3, + "year": 2016, + "month": 4, + "day": 11, + }, + { + "id": 1, + "timestamp": "2016-04-12 11:44:12", + "feature": 300, + "run_id": 1, + "year": 2016, + "month": 4, + "day": 12, + }, + { + "id": 1, + "timestamp": "2016-04-13 11:46:24", + "feature": 400, + "run_id": 1, + "year": 2016, + "month": 4, + "day": 13, + }, + { + "id": 1, + "timestamp": "2016-04-14 12:03:21", + "feature": 500, + "run_id": 2, + "year": 2016, + "month": 4, + "day": 14, + }, + ] + + first_run_df = spark_session.read.json( + spark_context.parallelize(first_data, 1) + ).withColumn("timestamp", F.col("timestamp").cast(DataType.TIMESTAMP.spark)) + second_run_df = spark_session.read.json( + spark_context.parallelize(second_data, 1) + ).withColumn("timestamp", F.col("timestamp").cast(DataType.TIMESTAMP.spark)) + third_run_df = spark_session.read.json( + spark_context.parallelize(third_data, 1) + ).withColumn("timestamp", F.col("timestamp").cast(DataType.TIMESTAMP.spark)) + + return first_run_df, second_run_df, third_run_df diff --git a/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py b/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py index 23d200c1..a302dc9e 100644 --- a/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py +++ b/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py @@ -4,21 +4,48 @@ from pyspark.sql import DataFrame from pyspark.sql import functions as F +from butterfree.clients import SparkClient from butterfree.configs import environment +from butterfree.configs.db import MetastoreConfig from butterfree.constants import DataType from butterfree.constants.columns import TIMESTAMP_COLUMN +from butterfree.dataframe_service.incremental_strategy import IncrementalStrategy from butterfree.extract import Source from butterfree.extract.readers import TableReader +from butterfree.hooks import Hook from butterfree.load import Sink from butterfree.load.writers import HistoricalFeatureStoreWriter from butterfree.pipelines.feature_set_pipeline import FeatureSetPipeline from butterfree.testing.dataframe import assert_dataframe_equality from butterfree.transform import FeatureSet from butterfree.transform.features import Feature, KeyFeature, TimestampFeature -from butterfree.transform.transformations import CustomTransform, SparkFunctionTransform +from butterfree.transform.transformations import ( + CustomTransform, + SparkFunctionTransform, + SQLExpressionTransform, +) from butterfree.transform.utils import Function +class AddHook(Hook): + def __init__(self, value): + self.value = value + + def run(self, dataframe): + return dataframe.withColumn("feature", F.expr(f"feature + {self.value}")) + + +class RunHook(Hook): + def __init__(self, id): + self.id = id + + def run(self, dataframe): + return dataframe.withColumn( + "run_id", + F.when(F.lit(self.id).isNotNull(), F.lit(self.id)).otherwise(F.lit(None)), + ) + + def create_temp_view(dataframe: DataFrame, name): dataframe.createOrReplaceTempView(name) @@ -38,9 +65,21 @@ def divide(df, fs, column1, column2): return df +def create_ymd(dataframe): + return ( + dataframe.withColumn("year", F.year(F.col("timestamp"))) + .withColumn("month", F.month(F.col("timestamp"))) + .withColumn("day", F.dayofmonth(F.col("timestamp"))) + ) + + class TestFeatureSetPipeline: def test_feature_set_pipeline( - self, mocked_df, spark_session, fixed_windows_output_feature_set_dataframe + self, + mocked_df, + spark_session, + fixed_windows_output_feature_set_dataframe, + mocker, ): # arrange table_reader_id = "a_source" @@ -53,13 +92,25 @@ def test_feature_set_pipeline( table_reader_db=table_reader_db, table_reader_table=table_reader_table, ) + + spark_client = SparkClient() + spark_client.conn.conf.set( + "spark.sql.sources.partitionOverwriteMode", "dynamic" + ) + dbconfig = Mock() + dbconfig.mode = "overwrite" + dbconfig.format_ = "parquet" dbconfig.get_options = Mock( - return_value={ - "mode": "overwrite", - "format_": "parquet", - "path": "test_folder/historical/entity/feature_set", - } + return_value={"path": "test_folder/historical/entity/feature_set"} + ) + + historical_writer = HistoricalFeatureStoreWriter(db_config=dbconfig) + + historical_writer.check_schema_hook = mocker.stub("check_schema_hook") + historical_writer.check_schema_hook.run = mocker.stub("run") + historical_writer.check_schema_hook.run.return_value = ( + fixed_windows_output_feature_set_dataframe ) # act @@ -112,7 +163,7 @@ def test_feature_set_pipeline( ], timestamp=TimestampFeature(), ), - sink=Sink(writers=[HistoricalFeatureStoreWriter(db_config=dbconfig)],), + sink=Sink(writers=[historical_writer]), ) test_pipeline.run() @@ -129,3 +180,247 @@ def test_feature_set_pipeline( # tear down shutil.rmtree("test_folder") + + def test_feature_set_pipeline_with_dates( + self, + mocked_date_df, + spark_session, + fixed_windows_output_feature_set_date_dataframe, + feature_set_pipeline, + mocker, + ): + # arrange + table_reader_table = "b_table" + create_temp_view(dataframe=mocked_date_df, name=table_reader_table) + + historical_writer = HistoricalFeatureStoreWriter(debug_mode=True) + + feature_set_pipeline.sink.writers = [historical_writer] + + # act + feature_set_pipeline.run(start_date="2016-04-12", end_date="2016-04-13") + + df = spark_session.sql("select * from historical_feature_store__feature_set") + + # assert + assert_dataframe_equality(df, fixed_windows_output_feature_set_date_dataframe) + + def test_feature_set_pipeline_with_execution_date( + self, + mocked_date_df, + spark_session, + fixed_windows_output_feature_set_date_dataframe, + feature_set_pipeline, + mocker, + ): + # arrange + table_reader_table = "b_table" + create_temp_view(dataframe=mocked_date_df, name=table_reader_table) + + target_df = fixed_windows_output_feature_set_date_dataframe.filter( + "timestamp < '2016-04-13'" + ) + + historical_writer = HistoricalFeatureStoreWriter(debug_mode=True) + + feature_set_pipeline.sink.writers = [historical_writer] + + # act + feature_set_pipeline.run_for_date(execution_date="2016-04-12") + + df = spark_session.sql("select * from historical_feature_store__feature_set") + + # assert + assert_dataframe_equality(df, target_df) + + def test_pipeline_with_hooks(self, spark_session, mocker): + # arrange + hook1 = AddHook(value=1) + + spark_session.sql( + "select 1 as id, timestamp('2020-01-01') as timestamp, 0 as feature" + ).createOrReplaceTempView("test") + + target_df = spark_session.sql( + "select 1 as id, timestamp('2020-01-01') as timestamp, 6 as feature, 2020 " + "as year, 1 as month, 1 as day" + ) + + historical_writer = HistoricalFeatureStoreWriter(debug_mode=True) + + test_pipeline = FeatureSetPipeline( + source=Source( + readers=[TableReader(id="reader", table="test",).add_post_hook(hook1)], + query="select * from reader", + ).add_post_hook(hook1), + feature_set=FeatureSet( + name="feature_set", + entity="entity", + description="description", + features=[ + Feature( + name="feature", + description="test", + transformation=SQLExpressionTransform(expression="feature + 1"), + dtype=DataType.INTEGER, + ), + ], + keys=[ + KeyFeature( + name="id", + description="The user's Main ID or device ID", + dtype=DataType.INTEGER, + ) + ], + timestamp=TimestampFeature(), + ) + .add_pre_hook(hook1) + .add_post_hook(hook1), + sink=Sink(writers=[historical_writer],).add_pre_hook(hook1), + ) + + # act + test_pipeline.run() + output_df = spark_session.table("historical_feature_store__feature_set") + + # assert + output_df.show() + assert_dataframe_equality(output_df, target_df) + + def test_pipeline_interval_run( + self, mocked_date_df, pipeline_interval_run_target_dfs, spark_session + ): + """Testing pipeline's idempotent interval run feature. + Source data: + +-------+---+-------------------+-------------------+ + |feature| id| ts| timestamp| + +-------+---+-------------------+-------------------+ + | 200| 1|2016-04-11 11:31:11|2016-04-11 11:31:11| + | 300| 1|2016-04-12 11:44:12|2016-04-12 11:44:12| + | 400| 1|2016-04-13 11:46:24|2016-04-13 11:46:24| + | 500| 1|2016-04-14 12:03:21|2016-04-14 12:03:21| + +-------+---+-------------------+-------------------+ + The test executes 3 runs for different time intervals. The input data has 4 data + points: 2016-04-11, 2016-04-12, 2016-04-13 and 2016-04-14. The following run + specifications are: + 1) Interval: from 2016-04-11 to 2016-04-13 + Target table result: + +---+-------+---+-----+------+-------------------+----+ + |day|feature| id|month|run_id| timestamp|year| + +---+-------+---+-----+------+-------------------+----+ + | 11| 200| 1| 4| 1|2016-04-11 11:31:11|2016| + | 12| 300| 1| 4| 1|2016-04-12 11:44:12|2016| + | 13| 400| 1| 4| 1|2016-04-13 11:46:24|2016| + +---+-------+---+-----+------+-------------------+----+ + 2) Interval: only 2016-04-14. + Target table result: + +---+-------+---+-----+------+-------------------+----+ + |day|feature| id|month|run_id| timestamp|year| + +---+-------+---+-----+------+-------------------+----+ + | 11| 200| 1| 4| 1|2016-04-11 11:31:11|2016| + | 12| 300| 1| 4| 1|2016-04-12 11:44:12|2016| + | 13| 400| 1| 4| 1|2016-04-13 11:46:24|2016| + | 14| 500| 1| 4| 2|2016-04-14 12:03:21|2016| + +---+-------+---+-----+------+-------------------+----+ + 3) Interval: only 2016-04-11. + Target table result: + +---+-------+---+-----+------+-------------------+----+ + |day|feature| id|month|run_id| timestamp|year| + +---+-------+---+-----+------+-------------------+----+ + | 11| 200| 1| 4| 3|2016-04-11 11:31:11|2016| + | 12| 300| 1| 4| 1|2016-04-12 11:44:12|2016| + | 13| 400| 1| 4| 1|2016-04-13 11:46:24|2016| + | 14| 500| 1| 4| 2|2016-04-14 12:03:21|2016| + +---+-------+---+-----+------+-------------------+----+ + """ + # arrange + create_temp_view(dataframe=mocked_date_df, name="input_data") + + db = environment.get_variable("FEATURE_STORE_HISTORICAL_DATABASE") + path = "test_folder/historical/entity/feature_set" + + spark_session.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic") + spark_session.sql(f"create database if not exists {db}") + spark_session.sql( + f"create table if not exists {db}.feature_set_interval " + f"(id int, timestamp timestamp, feature int, " + f"run_id int, year int, month int, day int);" + ) + + dbconfig = MetastoreConfig() + dbconfig.get_options = Mock( + return_value={"mode": "overwrite", "format_": "parquet", "path": path} + ) + + historical_writer = HistoricalFeatureStoreWriter( + db_config=dbconfig, interval_mode=True + ) + + first_run_hook = RunHook(id=1) + second_run_hook = RunHook(id=2) + third_run_hook = RunHook(id=3) + + ( + first_run_target_df, + second_run_target_df, + third_run_target_df, + ) = pipeline_interval_run_target_dfs + + test_pipeline = FeatureSetPipeline( + source=Source( + readers=[ + TableReader(id="id", table="input_data",).with_incremental_strategy( + IncrementalStrategy("ts") + ), + ], + query="select * from id ", + ), + feature_set=FeatureSet( + name="feature_set_interval", + entity="entity", + description="", + keys=[KeyFeature(name="id", description="", dtype=DataType.INTEGER,)], + timestamp=TimestampFeature(from_column="ts"), + features=[ + Feature(name="feature", description="", dtype=DataType.INTEGER), + Feature(name="run_id", description="", dtype=DataType.INTEGER), + ], + ), + sink=Sink([historical_writer],), + ) + + # act and assert + dbconfig.get_path_with_partitions = Mock( + return_value=[ + "test_folder/historical/entity/feature_set/year=2016/month=4/day=11", + "test_folder/historical/entity/feature_set/year=2016/month=4/day=12", + "test_folder/historical/entity/feature_set/year=2016/month=4/day=13", + ] + ) + test_pipeline.feature_set.add_pre_hook(first_run_hook) + test_pipeline.run(end_date="2016-04-13", start_date="2016-04-11") + first_run_output_df = spark_session.read.parquet(path) + assert_dataframe_equality(first_run_output_df, first_run_target_df) + + dbconfig.get_path_with_partitions = Mock( + return_value=[ + "test_folder/historical/entity/feature_set/year=2016/month=4/day=14", + ] + ) + test_pipeline.feature_set.add_pre_hook(second_run_hook) + test_pipeline.run_for_date("2016-04-14") + second_run_output_df = spark_session.read.parquet(path) + assert_dataframe_equality(second_run_output_df, second_run_target_df) + + dbconfig.get_path_with_partitions = Mock( + return_value=[ + "test_folder/historical/entity/feature_set/year=2016/month=4/day=11", + ] + ) + test_pipeline.feature_set.add_pre_hook(third_run_hook) + test_pipeline.run_for_date("2016-04-11") + third_run_output_df = spark_session.read.parquet(path) + assert_dataframe_equality(third_run_output_df, third_run_target_df) + + # tear down + shutil.rmtree("test_folder") diff --git a/tests/integration/butterfree/transform/conftest.py b/tests/integration/butterfree/transform/conftest.py index 6621c9a3..fe0cc572 100644 --- a/tests/integration/butterfree/transform/conftest.py +++ b/tests/integration/butterfree/transform/conftest.py @@ -395,3 +395,58 @@ def rolling_windows_output_feature_set_dataframe_base_date( df = df.withColumn(TIMESTAMP_COLUMN, df.origin_ts.cast(DataType.TIMESTAMP.spark)) return df + + +@fixture +def feature_set_dates_dataframe(spark_context, spark_session): + data = [ + {"id": 1, "ts": "2016-04-11 11:31:11", "feature": 200}, + {"id": 1, "ts": "2016-04-12 11:44:12", "feature": 300}, + {"id": 1, "ts": "2016-04-13 11:46:24", "feature": 400}, + {"id": 1, "ts": "2016-04-14 12:03:21", "feature": 500}, + ] + df = spark_session.read.json(spark_context.parallelize(data, 1)) + df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark)) + df = df.withColumn("ts", df.ts.cast(DataType.TIMESTAMP.spark)) + + return df + + +@fixture +def feature_set_dates_output_dataframe(spark_context, spark_session): + data = [ + {"id": 1, "timestamp": "2016-04-11 11:31:11", "feature": 200}, + {"id": 1, "timestamp": "2016-04-12 11:44:12", "feature": 300}, + ] + df = spark_session.read.json(spark_context.parallelize(data, 1)) + df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) + + return df + + +@fixture +def rolling_windows_output_date_boundaries(spark_context, spark_session): + data = [ + { + "id": 1, + "ts": "2016-04-11 00:00:00", + "feature__avg_over_1_day_rolling_windows": None, + "feature__avg_over_1_week_rolling_windows": None, + "feature__stddev_pop_over_1_day_rolling_windows": None, + "feature__stddev_pop_over_1_week_rolling_windows": None, + }, + { + "id": 1, + "ts": "2016-04-12 00:00:00", + "feature__avg_over_1_day_rolling_windows": 200.0, + "feature__avg_over_1_week_rolling_windows": 200.0, + "feature__stddev_pop_over_1_day_rolling_windows": 0.0, + "feature__stddev_pop_over_1_week_rolling_windows": 0.0, + }, + ] + df = spark_session.read.json( + spark_context.parallelize(data).map(lambda x: json.dumps(x)) + ) + df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark)) + + return df diff --git a/tests/integration/butterfree/transform/test_aggregated_feature_set.py b/tests/integration/butterfree/transform/test_aggregated_feature_set.py index 559dbcb8..bc3ebb6c 100644 --- a/tests/integration/butterfree/transform/test_aggregated_feature_set.py +++ b/tests/integration/butterfree/transform/test_aggregated_feature_set.py @@ -241,3 +241,53 @@ def test_construct_with_pivot( # assert assert_dataframe_equality(output_df, target_df_pivot_agg) + + def test_construct_rolling_windows_with_date_boundaries( + self, feature_set_dates_dataframe, rolling_windows_output_date_boundaries, + ): + # given + + spark_client = SparkClient() + + # arrange + + feature_set = AggregatedFeatureSet( + name="feature_set", + entity="entity", + description="description", + features=[ + Feature( + name="feature", + description="test", + transformation=AggregatedTransform( + functions=[ + Function(F.avg, DataType.DOUBLE), + Function(F.stddev_pop, DataType.DOUBLE), + ], + ), + ), + ], + keys=[ + KeyFeature( + name="id", + description="The user's Main ID or device ID", + dtype=DataType.INTEGER, + ) + ], + timestamp=TimestampFeature(), + ).with_windows(definitions=["1 day", "1 week"]) + + # act + output_df = feature_set.construct( + feature_set_dates_dataframe, + client=spark_client, + start_date="2016-04-11", + end_date="2016-04-12", + ).orderBy("timestamp") + + target_df = rolling_windows_output_date_boundaries.orderBy( + feature_set.timestamp_column + ).select(feature_set.columns) + + # assert + assert_dataframe_equality(output_df, target_df) diff --git a/tests/integration/butterfree/transform/test_feature_set.py b/tests/integration/butterfree/transform/test_feature_set.py index 4872ded2..25f70b6e 100644 --- a/tests/integration/butterfree/transform/test_feature_set.py +++ b/tests/integration/butterfree/transform/test_feature_set.py @@ -77,3 +77,47 @@ def test_construct( # assert assert_dataframe_equality(output_df, target_df) + + def test_construct_with_date_boundaries( + self, feature_set_dates_dataframe, feature_set_dates_output_dataframe + ): + # given + + spark_client = SparkClient() + + # arrange + + feature_set = FeatureSet( + name="feature_set", + entity="entity", + description="description", + features=[ + Feature(name="feature", description="test", dtype=DataType.FLOAT,), + ], + keys=[ + KeyFeature( + name="id", + description="The user's Main ID or device ID", + dtype=DataType.INTEGER, + ) + ], + timestamp=TimestampFeature(), + ) + + output_df = ( + feature_set.construct( + feature_set_dates_dataframe, + client=spark_client, + start_date="2016-04-11", + end_date="2016-04-12", + ) + .orderBy(feature_set.timestamp_column) + .select(feature_set.columns) + ) + + target_df = feature_set_dates_output_dataframe.orderBy( + feature_set.timestamp_column + ).select(feature_set.columns) + + # assert + assert_dataframe_equality(output_df, target_df) diff --git a/tests/unit/butterfree/clients/conftest.py b/tests/unit/butterfree/clients/conftest.py index fda11f8e..ffb2db88 100644 --- a/tests/unit/butterfree/clients/conftest.py +++ b/tests/unit/butterfree/clients/conftest.py @@ -46,11 +46,16 @@ def mocked_stream_df() -> Mock: return mock +@pytest.fixture() +def mock_spark_sql() -> Mock: + mock = Mock() + mock.sql = mock + return mock + + @pytest.fixture def cassandra_client() -> CassandraClient: - return CassandraClient( - cassandra_host=["mock"], cassandra_key_space="dummy_keyspace" - ) + return CassandraClient(host=["mock"], keyspace="dummy_keyspace") @pytest.fixture diff --git a/tests/unit/butterfree/clients/test_cassandra_client.py b/tests/unit/butterfree/clients/test_cassandra_client.py index 8785485b..aa52e6f8 100644 --- a/tests/unit/butterfree/clients/test_cassandra_client.py +++ b/tests/unit/butterfree/clients/test_cassandra_client.py @@ -15,9 +15,7 @@ def sanitize_string(query: str) -> str: class TestCassandraClient: def test_conn(self, cassandra_client: CassandraClient) -> None: # arrange - cassandra_client = CassandraClient( - cassandra_host=["mock"], cassandra_key_space="dummy_keyspace" - ) + cassandra_client = CassandraClient(host=["mock"], keyspace="dummy_keyspace") # act start_conn = cassandra_client._session diff --git a/tests/unit/butterfree/clients/test_spark_client.py b/tests/unit/butterfree/clients/test_spark_client.py index 58d53a40..9f641506 100644 --- a/tests/unit/butterfree/clients/test_spark_client.py +++ b/tests/unit/butterfree/clients/test_spark_client.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Optional, Union +from datetime import datetime +from typing import Any, Optional, Union from unittest.mock import Mock import pytest @@ -26,19 +27,20 @@ def test_conn(self) -> None: assert start_conn is None @pytest.mark.parametrize( - "format, options, stream, schema", + "format, path, stream, schema, options", [ - ("parquet", {"path": "path/to/file"}, False, None), - ("csv", {"path": "path/to/file", "header": True}, False, None), - ("json", {"path": "path/to/file"}, True, None), + ("parquet", ["path/to/file"], False, None, {}), + ("csv", "path/to/file", False, None, {"header": True}), + ("json", "path/to/file", True, None, {}), ], ) def test_read( self, format: str, - options: Dict[str, Any], stream: bool, schema: Optional[StructType], + path: Any, + options: Any, target_df: DataFrame, mocked_spark_read: Mock, ) -> None: @@ -48,26 +50,25 @@ def test_read( spark_client._session = mocked_spark_read # act - result_df = spark_client.read(format, options, schema, stream) + result_df = spark_client.read( + format=format, schema=schema, stream=stream, path=path, **options + ) # assert mocked_spark_read.format.assert_called_once_with(format) - mocked_spark_read.options.assert_called_once_with(**options) + mocked_spark_read.load.assert_called_once_with(path, **options) assert target_df.collect() == result_df.collect() @pytest.mark.parametrize( - "format, options", - [(None, {"path": "path/to/file"}), ("csv", "not a valid options")], + "format, path", [(None, "path/to/file"), ("csv", 123)], ) - def test_read_invalid_params( - self, format: Optional[str], options: Union[Dict[str, Any], str] - ) -> None: + def test_read_invalid_params(self, format: Optional[str], path: Any) -> None: # arrange spark_client = SparkClient() # act and assert with pytest.raises(ValueError): - spark_client.read(format, options) # type: ignore + spark_client.read(format=format, path=path) # type: ignore def test_sql(self, target_df: DataFrame) -> None: # arrange @@ -252,3 +253,43 @@ def test_create_temporary_view( # assert assert_dataframe_equality(target_df, result_df) + + def test_add_table_partitions(self, mock_spark_sql: Mock): + # arrange + target_command = ( + f"ALTER TABLE `db`.`table` ADD IF NOT EXISTS " + f"PARTITION ( year = 2020, month = 8, day = 14 ) " + f"PARTITION ( year = 2020, month = 8, day = 15 ) " + f"PARTITION ( year = 2020, month = 8, day = 16 )" + ) + + spark_client = SparkClient() + spark_client._session = mock_spark_sql + partitions = [ + {"year": 2020, "month": 8, "day": 14}, + {"year": 2020, "month": 8, "day": 15}, + {"year": 2020, "month": 8, "day": 16}, + ] + + # act + spark_client.add_table_partitions(partitions, "table", "db") + + # assert + mock_spark_sql.assert_called_once_with(target_command) + + @pytest.mark.parametrize( + "partition", + [ + [{"float_partition": 2.72}], + [{123: 2020}], + [{"date": datetime(year=2020, month=8, day=18)}], + ], + ) + def test_add_invalid_partitions(self, mock_spark_sql: Mock, partition): + # arrange + spark_client = SparkClient() + spark_client._session = mock_spark_sql + + # act and assert + with pytest.raises(ValueError): + spark_client.add_table_partitions(partition, "table", "db") diff --git a/tests/unit/butterfree/dataframe_service/conftest.py b/tests/unit/butterfree/dataframe_service/conftest.py index 867bc80a..09470c9a 100644 --- a/tests/unit/butterfree/dataframe_service/conftest.py +++ b/tests/unit/butterfree/dataframe_service/conftest.py @@ -25,3 +25,17 @@ def input_df(spark_context, spark_session): return spark_session.read.json( spark_context.parallelize(data, 1), schema="timestamp timestamp" ) + + +@pytest.fixture() +def test_partitioning_input_df(spark_context, spark_session): + data = [ + {"feature": 1, "year": 2009, "month": 8, "day": 20}, + {"feature": 2, "year": 2009, "month": 8, "day": 20}, + {"feature": 3, "year": 2020, "month": 8, "day": 20}, + {"feature": 4, "year": 2020, "month": 9, "day": 20}, + {"feature": 5, "year": 2020, "month": 9, "day": 20}, + {"feature": 6, "year": 2020, "month": 8, "day": 20}, + {"feature": 7, "year": 2020, "month": 8, "day": 21}, + ] + return spark_session.read.json(spark_context.parallelize(data, 1)) diff --git a/tests/unit/butterfree/dataframe_service/test_incremental_srategy.py b/tests/unit/butterfree/dataframe_service/test_incremental_srategy.py new file mode 100644 index 00000000..a140ceb3 --- /dev/null +++ b/tests/unit/butterfree/dataframe_service/test_incremental_srategy.py @@ -0,0 +1,70 @@ +from butterfree.dataframe_service import IncrementalStrategy + + +class TestIncrementalStrategy: + def test_from_milliseconds(self): + # arrange + incremental_strategy = IncrementalStrategy().from_milliseconds("ts") + target_expression = "date(from_unixtime(ts/ 1000.0)) >= date('2020-01-01')" + + # act + result_expression = incremental_strategy.get_expression(start_date="2020-01-01") + + # assert + assert target_expression.split() == result_expression.split() + + def test_from_string(self): + # arrange + incremental_strategy = IncrementalStrategy().from_string( + "dt", mask="dd/MM/yyyy" + ) + target_expression = "date(to_date(dt, 'dd/MM/yyyy')) >= date('2020-01-01')" + + # act + result_expression = incremental_strategy.get_expression(start_date="2020-01-01") + + # assert + assert target_expression.split() == result_expression.split() + + def test_from_year_month_day_partitions(self): + # arrange + incremental_strategy = IncrementalStrategy().from_year_month_day_partitions( + year_column="y", month_column="m", day_column="d" + ) + target_expression = ( + "date(concat(string(y), " + "'-', string(m), " + "'-', string(d))) >= date('2020-01-01')" + ) + + # act + result_expression = incremental_strategy.get_expression(start_date="2020-01-01") + + # assert + assert target_expression.split() == result_expression.split() + + def test_get_expression_with_just_end_date(self): + # arrange + incremental_strategy = IncrementalStrategy(column="dt") + target_expression = "date(dt) <= date('2020-01-01')" + + # act + result_expression = incremental_strategy.get_expression(end_date="2020-01-01") + + # assert + assert target_expression.split() == result_expression.split() + + def test_get_expression_with_start_and_end_date(self): + # arrange + incremental_strategy = IncrementalStrategy(column="dt") + target_expression = ( + "date(dt) >= date('2019-12-30') and date(dt) <= date('2020-01-01')" + ) + + # act + result_expression = incremental_strategy.get_expression( + start_date="2019-12-30", end_date="2020-01-01" + ) + + # assert + assert target_expression.split() == result_expression.split() diff --git a/tests/unit/butterfree/dataframe_service/test_partitioning.py b/tests/unit/butterfree/dataframe_service/test_partitioning.py new file mode 100644 index 00000000..3a6b5b40 --- /dev/null +++ b/tests/unit/butterfree/dataframe_service/test_partitioning.py @@ -0,0 +1,20 @@ +from butterfree.dataframe_service import extract_partition_values + + +class TestPartitioning: + def test_extract_partition_values(self, test_partitioning_input_df): + # arrange + target_values = [ + {"year": 2009, "month": 8, "day": 20}, + {"year": 2020, "month": 8, "day": 20}, + {"year": 2020, "month": 9, "day": 20}, + {"year": 2020, "month": 8, "day": 21}, + ] + + # act + result_values = extract_partition_values( + test_partitioning_input_df, partition_columns=["year", "month", "day"] + ) + + # assert + assert result_values == target_values diff --git a/tests/unit/butterfree/extract/conftest.py b/tests/unit/butterfree/extract/conftest.py index ab6f525c..3d0e763d 100644 --- a/tests/unit/butterfree/extract/conftest.py +++ b/tests/unit/butterfree/extract/conftest.py @@ -1,6 +1,7 @@ from unittest.mock import Mock import pytest +from pyspark.sql.functions import col, to_date from butterfree.constants.columns import TIMESTAMP_COLUMN @@ -17,6 +18,60 @@ def target_df(spark_context, spark_session): return spark_session.read.json(spark_context.parallelize(data, 1)) +@pytest.fixture() +def incremental_source_df(spark_context, spark_session): + data = [ + { + "id": 1, + "feature": 100, + "date_str": "28/07/2020", + "milliseconds": 1595894400000, + "year": 2020, + "month": 7, + "day": 28, + }, + { + "id": 1, + "feature": 110, + "date_str": "29/07/2020", + "milliseconds": 1595980800000, + "year": 2020, + "month": 7, + "day": 29, + }, + { + "id": 1, + "feature": 120, + "date_str": "30/07/2020", + "milliseconds": 1596067200000, + "year": 2020, + "month": 7, + "day": 30, + }, + { + "id": 2, + "feature": 150, + "date_str": "31/07/2020", + "milliseconds": 1596153600000, + "year": 2020, + "month": 7, + "day": 31, + }, + { + "id": 2, + "feature": 200, + "date_str": "01/08/2020", + "milliseconds": 1596240000000, + "year": 2020, + "month": 8, + "day": 1, + }, + ] + return spark_session.read.json(spark_context.parallelize(data, 1)).withColumn( + "date", to_date(col("date_str"), "dd/MM/yyyy") + ) + + @pytest.fixture() def spark_client(): return Mock() diff --git a/tests/unit/butterfree/extract/readers/test_file_reader.py b/tests/unit/butterfree/extract/readers/test_file_reader.py index d337d4fe..9e1c42bc 100644 --- a/tests/unit/butterfree/extract/readers/test_file_reader.py +++ b/tests/unit/butterfree/extract/readers/test_file_reader.py @@ -36,11 +36,11 @@ def test_consume( # act output_df = file_reader.consume(spark_client) - options = dict({"path": path}, **format_options if format_options else {}) + options = dict(format_options if format_options else {}) # assert spark_client.read.assert_called_once_with( - format=format, options=options, schema=schema, stream=False + format=format, schema=schema, stream=False, path=path, **options ) assert target_df.collect() == output_df.collect() @@ -51,7 +51,7 @@ def test_consume_with_stream_without_schema(self, spark_client, target_df): schema = None format_options = None stream = True - options = dict({"path": path}) + options = dict({}) spark_client.read.return_value = target_df file_reader = FileReader( @@ -64,11 +64,11 @@ def test_consume_with_stream_without_schema(self, spark_client, target_df): # assert # assert call for schema infer - spark_client.read.assert_any_call(format=format, options=options) + spark_client.read.assert_any_call(format=format, path=path, **options) # assert call for stream read # stream spark_client.read.assert_called_with( - format=format, options=options, schema=output_df.schema, stream=stream + format=format, schema=output_df.schema, stream=stream, path=path, **options ) assert target_df.collect() == output_df.collect() diff --git a/tests/unit/butterfree/extract/readers/test_reader.py b/tests/unit/butterfree/extract/readers/test_reader.py index c210a756..78160553 100644 --- a/tests/unit/butterfree/extract/readers/test_reader.py +++ b/tests/unit/butterfree/extract/readers/test_reader.py @@ -1,7 +1,9 @@ import pytest from pyspark.sql.functions import expr +from butterfree.dataframe_service import IncrementalStrategy from butterfree.extract.readers import FileReader +from butterfree.testing.dataframe import assert_dataframe_equality def add_value_transformer(df, column, value): @@ -152,3 +154,59 @@ def test_build_with_columns( # assert assert column_target_df.collect() == result_df.collect() + + def test_build_with_incremental_strategy( + self, incremental_source_df, spark_client, spark_session + ): + # arrange + readers = [ + # directly from column + FileReader( + id="test_1", path="path/to/file", format="format" + ).with_incremental_strategy( + incremental_strategy=IncrementalStrategy(column="date") + ), + # from milliseconds + FileReader( + id="test_2", path="path/to/file", format="format" + ).with_incremental_strategy( + incremental_strategy=IncrementalStrategy().from_milliseconds( + column_name="milliseconds" + ) + ), + # from str + FileReader( + id="test_3", path="path/to/file", format="format" + ).with_incremental_strategy( + incremental_strategy=IncrementalStrategy().from_string( + column_name="date_str", mask="dd/MM/yyyy" + ) + ), + # from year, month, day partitions + FileReader( + id="test_4", path="path/to/file", format="format" + ).with_incremental_strategy( + incremental_strategy=( + IncrementalStrategy().from_year_month_day_partitions() + ) + ), + ] + + spark_client.read.return_value = incremental_source_df + target_df = incremental_source_df.where( + "date >= date('2020-07-29') and date <= date('2020-07-31')" + ) + + # act + for reader in readers: + reader.build( + client=spark_client, start_date="2020-07-29", end_date="2020-07-31" + ) + + output_dfs = [ + spark_session.table(f"test_{i + 1}") for i, _ in enumerate(readers) + ] + + # assert + for output_df in output_dfs: + assert_dataframe_equality(output_df=output_df, target_df=target_df) diff --git a/tests/unit/butterfree/hooks/__init__.py b/tests/unit/butterfree/hooks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/butterfree/hooks/schema_compatibility/__init__.py b/tests/unit/butterfree/hooks/schema_compatibility/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/butterfree/hooks/schema_compatibility/test_cassandra_table_schema_compatibility_hook.py b/tests/unit/butterfree/hooks/schema_compatibility/test_cassandra_table_schema_compatibility_hook.py new file mode 100644 index 00000000..eccb8d8c --- /dev/null +++ b/tests/unit/butterfree/hooks/schema_compatibility/test_cassandra_table_schema_compatibility_hook.py @@ -0,0 +1,49 @@ +from unittest.mock import MagicMock + +import pytest + +from butterfree.clients import CassandraClient +from butterfree.hooks.schema_compatibility import CassandraTableSchemaCompatibilityHook + + +class TestCassandraTableSchemaCompatibilityHook: + def test_run_compatible_schema(self, spark_session): + cassandra_client = CassandraClient(host=["mock"], keyspace="dummy_keyspace") + + cassandra_client.sql = MagicMock( # type: ignore + return_value=[ + {"column_name": "feature1", "type": "text"}, + {"column_name": "feature2", "type": "int"}, + ] + ) + + table = "table" + + input_dataframe = spark_session.sql("select 'abc' as feature1, 1 as feature2") + + hook = CassandraTableSchemaCompatibilityHook(cassandra_client, table) + + # act and assert + assert hook.run(input_dataframe) == input_dataframe + + def test_run_incompatible_schema(self, spark_session): + cassandra_client = CassandraClient(host=["mock"], keyspace="dummy_keyspace") + + cassandra_client.sql = MagicMock( # type: ignore + return_value=[ + {"column_name": "feature1", "type": "text"}, + {"column_name": "feature2", "type": "bigint"}, + ] + ) + + table = "table" + + input_dataframe = spark_session.sql("select 'abc' as feature1, 1 as feature2") + + hook = CassandraTableSchemaCompatibilityHook(cassandra_client, table) + + # act and assert + with pytest.raises( + ValueError, match="There's a schema incompatibility between" + ): + hook.run(input_dataframe) diff --git a/tests/unit/butterfree/hooks/schema_compatibility/test_spark_table_schema_compatibility_hook.py b/tests/unit/butterfree/hooks/schema_compatibility/test_spark_table_schema_compatibility_hook.py new file mode 100644 index 00000000..3a31b600 --- /dev/null +++ b/tests/unit/butterfree/hooks/schema_compatibility/test_spark_table_schema_compatibility_hook.py @@ -0,0 +1,53 @@ +import pytest + +from butterfree.clients import SparkClient +from butterfree.hooks.schema_compatibility import SparkTableSchemaCompatibilityHook + + +class TestSparkTableSchemaCompatibilityHook: + @pytest.mark.parametrize( + "table, database, target_table_expression", + [("table", "database", "`database`.`table`"), ("table", None, "`table`")], + ) + def test_build_table_expression(self, table, database, target_table_expression): + # arrange + spark_client = SparkClient() + + # act + result_table_expression = SparkTableSchemaCompatibilityHook( + spark_client, table, database + ).table_expression + + # assert + assert target_table_expression == result_table_expression + + def test_run_compatible_schema(self, spark_session): + # arrange + spark_client = SparkClient() + target_table = spark_session.sql( + "select 1 as feature_a, 'abc' as feature_b, true as other_feature" + ) + input_dataframe = spark_session.sql("select 1 as feature_a, 'abc' as feature_b") + target_table.registerTempTable("test") + + hook = SparkTableSchemaCompatibilityHook(spark_client, "test") + + # act and assert + assert hook.run(input_dataframe) == input_dataframe + + def test_run_incompatible_schema(self, spark_session): + # arrange + spark_client = SparkClient() + target_table = spark_session.sql( + "select 1 as feature_a, 'abc' as feature_b, true as other_feature" + ) + input_dataframe = spark_session.sql( + "select 1 as feature_a, 'abc' as feature_b, true as unregisted_column" + ) + target_table.registerTempTable("test") + + hook = SparkTableSchemaCompatibilityHook(spark_client, "test") + + # act and assert + with pytest.raises(ValueError, match="The dataframe has a schema incompatible"): + hook.run(input_dataframe) diff --git a/tests/unit/butterfree/hooks/test_hookable_component.py b/tests/unit/butterfree/hooks/test_hookable_component.py new file mode 100644 index 00000000..37e34e69 --- /dev/null +++ b/tests/unit/butterfree/hooks/test_hookable_component.py @@ -0,0 +1,107 @@ +import pytest +from pyspark.sql.functions import expr + +from butterfree.hooks import Hook, HookableComponent +from butterfree.testing.dataframe import assert_dataframe_equality + + +class TestComponent(HookableComponent): + def construct(self, dataframe): + pre_hook_df = self.run_pre_hooks(dataframe) + construct_df = pre_hook_df.withColumn("feature", expr("feature * feature")) + return self.run_post_hooks(construct_df) + + +class AddHook(Hook): + def __init__(self, value): + self.value = value + + def run(self, dataframe): + return dataframe.withColumn("feature", expr(f"feature + {self.value}")) + + +class TestHookableComponent: + def test_add_hooks(self): + # arrange + hook1 = AddHook(value=1) + hook2 = AddHook(value=2) + hook3 = AddHook(value=3) + hook4 = AddHook(value=4) + hookable_component = HookableComponent() + + # act + hookable_component.add_pre_hook(hook1, hook2) + hookable_component.add_post_hook(hook3, hook4) + + # assert + assert hookable_component.pre_hooks == [hook1, hook2] + assert hookable_component.post_hooks == [hook3, hook4] + + @pytest.mark.parametrize( + "enable_pre_hooks, enable_post_hooks", + [("not boolean", False), (False, "not boolean")], + ) + def test_invalid_enable_hook(self, enable_pre_hooks, enable_post_hooks): + # arrange + hookable_component = HookableComponent() + + # act and assert + with pytest.raises(ValueError): + hookable_component.enable_pre_hooks = enable_pre_hooks + hookable_component.enable_post_hooks = enable_post_hooks + + @pytest.mark.parametrize( + "pre_hooks, post_hooks", + [ + ([AddHook(1)], "not a list of hooks"), + ([AddHook(1)], [AddHook(1), 2, 3]), + ("not a list of hooks", [AddHook(1)]), + ([AddHook(1), 2, 3], [AddHook(1)]), + ], + ) + def test_invalid_hooks(self, pre_hooks, post_hooks): + # arrange + hookable_component = HookableComponent() + + # act and assert + with pytest.raises(ValueError): + hookable_component.pre_hooks = pre_hooks + hookable_component.post_hooks = post_hooks + + @pytest.mark.parametrize( + "pre_hook, enable_pre_hooks, post_hook, enable_post_hooks", + [ + (AddHook(value=1), False, AddHook(value=1), True), + (AddHook(value=1), True, AddHook(value=1), False), + ("not a pre-hook", True, AddHook(value=1), True), + (AddHook(value=1), True, "not a pre-hook", True), + ], + ) + def test_add_invalid_hooks( + self, pre_hook, enable_pre_hooks, post_hook, enable_post_hooks + ): + # arrange + hookable_component = HookableComponent() + hookable_component.enable_pre_hooks = enable_pre_hooks + hookable_component.enable_post_hooks = enable_post_hooks + + # act and assert + with pytest.raises(ValueError): + hookable_component.add_pre_hook(pre_hook) + hookable_component.add_post_hook(post_hook) + + def test_run_hooks(self, spark_session): + # arrange + input_dataframe = spark_session.sql("select 2 as feature") + test_component = ( + TestComponent() + .add_pre_hook(AddHook(value=1)) + .add_post_hook(AddHook(value=1)) + ) + target_table = spark_session.sql("select 10 as feature") + + # act + output_df = test_component.construct(input_dataframe) + + # assert + assert_dataframe_equality(output_df, target_table) diff --git a/tests/unit/butterfree/load/conftest.py b/tests/unit/butterfree/load/conftest.py index 7c2549c5..4dcf25c9 100644 --- a/tests/unit/butterfree/load/conftest.py +++ b/tests/unit/butterfree/load/conftest.py @@ -32,6 +32,31 @@ def feature_set(): ) +@fixture +def feature_set_incremental(): + key_features = [ + KeyFeature(name="id", description="Description", dtype=DataType.INTEGER) + ] + ts_feature = TimestampFeature(from_column=TIMESTAMP_COLUMN) + features = [ + Feature( + name="feature", + description="test", + transformation=AggregatedTransform( + functions=[Function(functions.sum, DataType.INTEGER)] + ), + ), + ] + return AggregatedFeatureSet( + "feature_set", + "entity", + "description", + keys=key_features, + timestamp=ts_feature, + features=features, + ) + + @fixture def feature_set_dataframe(spark_context, spark_session): data = [ diff --git a/tests/unit/butterfree/load/test_sink.py b/tests/unit/butterfree/load/test_sink.py index 93b5e279..ef377f67 100644 --- a/tests/unit/butterfree/load/test_sink.py +++ b/tests/unit/butterfree/load/test_sink.py @@ -120,7 +120,7 @@ def test_flush_with_writers_list_empty(self): with pytest.raises(ValueError): Sink(writers=writer) - def test_flush_streaming_df(self, feature_set): + def test_flush_streaming_df(self, feature_set, mocker): """Testing the return of the streaming handlers by the sink.""" # arrange spark_client = SparkClient() @@ -136,10 +136,25 @@ def test_flush_streaming_df(self, feature_set): mocked_stream_df.start.return_value = Mock(spec=StreamingQuery) online_feature_store_writer = OnlineFeatureStoreWriter() + + online_feature_store_writer.check_schema_hook = mocker.stub("check_schema_hook") + online_feature_store_writer.check_schema_hook.run = mocker.stub("run") + online_feature_store_writer.check_schema_hook.run.return_value = ( + mocked_stream_df + ) + online_feature_store_writer_on_entity = OnlineFeatureStoreWriter( write_to_entity=True ) + online_feature_store_writer_on_entity.check_schema_hook = mocker.stub( + "check_schema_hook" + ) + online_feature_store_writer_on_entity.check_schema_hook.run = mocker.stub("run") + online_feature_store_writer_on_entity.check_schema_hook.run.return_value = ( + mocked_stream_df + ) + sink = Sink( writers=[ online_feature_store_writer, @@ -162,7 +177,7 @@ def test_flush_streaming_df(self, feature_set): assert isinstance(handler, StreamingQuery) def test_flush_with_multiple_online_writers( - self, feature_set, feature_set_dataframe + self, feature_set, feature_set_dataframe, mocker ): """Testing the flow of writing to a feature-set table and to an entity table.""" # arrange @@ -173,10 +188,25 @@ def test_flush_with_multiple_online_writers( feature_set.name = "my_feature_set" online_feature_store_writer = OnlineFeatureStoreWriter() + + online_feature_store_writer.check_schema_hook = mocker.stub("check_schema_hook") + online_feature_store_writer.check_schema_hook.run = mocker.stub("run") + online_feature_store_writer.check_schema_hook.run.return_value = ( + feature_set_dataframe + ) + online_feature_store_writer_on_entity = OnlineFeatureStoreWriter( write_to_entity=True ) + online_feature_store_writer_on_entity.check_schema_hook = mocker.stub( + "check_schema_hook" + ) + online_feature_store_writer_on_entity.check_schema_hook.run = mocker.stub("run") + online_feature_store_writer_on_entity.check_schema_hook.run.return_value = ( + feature_set_dataframe + ) + sink = Sink( writers=[online_feature_store_writer, online_feature_store_writer_on_entity] ) diff --git a/tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py b/tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py index 14c067f9..aac806f7 100644 --- a/tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py +++ b/tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py @@ -19,10 +19,15 @@ def test_write( feature_set, ): # given - spark_client = mocker.stub("spark_client") + spark_client = SparkClient() spark_client.write_table = mocker.stub("write_table") writer = HistoricalFeatureStoreWriter() + schema_dataframe = writer._create_partitions(feature_set_dataframe) + writer.check_schema_hook = mocker.stub("check_schema_hook") + writer.check_schema_hook.run = mocker.stub("run") + writer.check_schema_hook.run.return_value = schema_dataframe + # when writer.write( feature_set=feature_set, @@ -41,7 +46,76 @@ def test_write( assert ( writer.PARTITION_BY == spark_client.write_table.call_args[1]["partition_by"] ) - assert feature_set.name == spark_client.write_table.call_args[1]["table_name"] + + def test_write_interval_mode( + self, + feature_set_dataframe, + historical_feature_set_dataframe, + mocker, + feature_set, + ): + # given + spark_client = SparkClient() + spark_client.write_dataframe = mocker.stub("write_dataframe") + spark_client.conn.conf.set( + "spark.sql.sources.partitionOverwriteMode", "dynamic" + ) + writer = HistoricalFeatureStoreWriter(interval_mode=True) + + schema_dataframe = writer._create_partitions(feature_set_dataframe) + writer.check_schema_hook = mocker.stub("check_schema_hook") + writer.check_schema_hook.run = mocker.stub("run") + writer.check_schema_hook.run.return_value = schema_dataframe + + # when + writer.write( + feature_set=feature_set, + dataframe=feature_set_dataframe, + spark_client=spark_client, + ) + result_df = spark_client.write_dataframe.call_args[1]["dataframe"] + + # then + assert_dataframe_equality(historical_feature_set_dataframe, result_df) + + assert ( + writer.db_config.format_ + == spark_client.write_dataframe.call_args[1]["format_"] + ) + assert ( + writer.db_config.mode == spark_client.write_dataframe.call_args[1]["mode"] + ) + assert ( + writer.PARTITION_BY + == spark_client.write_dataframe.call_args[1]["partitionBy"] + ) + + def test_write_interval_mode_invalid_partition_mode( + self, + feature_set_dataframe, + historical_feature_set_dataframe, + mocker, + feature_set, + ): + # given + spark_client = SparkClient() + spark_client.write_dataframe = mocker.stub("write_dataframe") + spark_client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "static") + + writer = HistoricalFeatureStoreWriter(interval_mode=True) + + schema_dataframe = writer._create_partitions(feature_set_dataframe) + writer.check_schema_hook = mocker.stub("check_schema_hook") + writer.check_schema_hook.run = mocker.stub("run") + writer.check_schema_hook.run.return_value = schema_dataframe + + # when + with pytest.raises(RuntimeError): + _ = writer.write( + feature_set=feature_set, + dataframe=feature_set_dataframe, + spark_client=spark_client, + ) def test_write_in_debug_mode( self, @@ -49,6 +123,7 @@ def test_write_in_debug_mode( historical_feature_set_dataframe, feature_set, spark_session, + mocker, ): # given spark_client = SparkClient() @@ -65,33 +140,75 @@ def test_write_in_debug_mode( # then assert_dataframe_equality(historical_feature_set_dataframe, result_df) - def test_validate(self, feature_set_dataframe, mocker, feature_set): + def test_write_in_debug_mode_with_interval_mode( + self, + feature_set_dataframe, + historical_feature_set_dataframe, + feature_set, + spark_session, + ): + # given + spark_client = SparkClient() + writer = HistoricalFeatureStoreWriter(debug_mode=True, interval_mode=True) + + # when + writer.write( + feature_set=feature_set, + dataframe=feature_set_dataframe, + spark_client=spark_client, + ) + result_df = spark_session.table(f"historical_feature_store__{feature_set.name}") + + # then + assert_dataframe_equality(historical_feature_set_dataframe, result_df) + + def test_validate(self, historical_feature_set_dataframe, mocker, feature_set): # given spark_client = mocker.stub("spark_client") spark_client.read_table = mocker.stub("read_table") - spark_client.read_table.return_value = feature_set_dataframe + spark_client.read_table.return_value = historical_feature_set_dataframe writer = HistoricalFeatureStoreWriter() # when - writer.validate(feature_set, feature_set_dataframe, spark_client) + writer.validate(feature_set, historical_feature_set_dataframe, spark_client) # then spark_client.read_table.assert_called_once() - def test_validate_false(self, feature_set_dataframe, mocker, feature_set): + def test_validate_interval_mode( + self, historical_feature_set_dataframe, mocker, feature_set + ): # given spark_client = mocker.stub("spark_client") - spark_client.read_table = mocker.stub("read_table") + spark_client.read = mocker.stub("read") + spark_client.read.return_value = historical_feature_set_dataframe + + writer = HistoricalFeatureStoreWriter(interval_mode=True) + + # when + writer.validate(feature_set, historical_feature_set_dataframe, spark_client) + + # then + spark_client.read.assert_called_once() + + def test_validate_false( + self, historical_feature_set_dataframe, mocker, feature_set + ): + # given + spark_client = mocker.stub("spark_client") + spark_client.read = mocker.stub("read") # limiting df to 1 row, now the counts should'n be the same - spark_client.read_table.return_value = feature_set_dataframe.limit(1) + spark_client.read.return_value = historical_feature_set_dataframe.limit(1) - writer = HistoricalFeatureStoreWriter() + writer = HistoricalFeatureStoreWriter(interval_mode=True) # when with pytest.raises(AssertionError): - _ = writer.validate(feature_set, feature_set_dataframe, spark_client) + _ = writer.validate( + feature_set, historical_feature_set_dataframe, spark_client + ) def test__create_partitions(self, spark_session, spark_context): # arrange @@ -201,8 +318,15 @@ def test_write_with_transform( # given spark_client = mocker.stub("spark_client") spark_client.write_table = mocker.stub("write_table") + writer = HistoricalFeatureStoreWriter().with_(json_transform) + schema_dataframe = writer._create_partitions(feature_set_dataframe) + json_dataframe = writer._apply_transformations(schema_dataframe) + writer.check_schema_hook = mocker.stub("check_schema_hook") + writer.check_schema_hook.run = mocker.stub("run") + writer.check_schema_hook.run.return_value = json_dataframe + # when writer.write( feature_set=feature_set, diff --git a/tests/unit/butterfree/load/writers/test_online_feature_store_writer.py b/tests/unit/butterfree/load/writers/test_online_feature_store_writer.py index 87823c55..384ec152 100644 --- a/tests/unit/butterfree/load/writers/test_online_feature_store_writer.py +++ b/tests/unit/butterfree/load/writers/test_online_feature_store_writer.py @@ -68,6 +68,10 @@ def test_write( spark_client.write_dataframe = mocker.stub("write_dataframe") writer = OnlineFeatureStoreWriter(cassandra_config) + writer.check_schema_hook = mocker.stub("check_schema_hook") + writer.check_schema_hook.run = mocker.stub("run") + writer.check_schema_hook.run.return_value = feature_set_dataframe + # when writer.write(feature_set, feature_set_dataframe, spark_client) @@ -94,11 +98,16 @@ def test_write_in_debug_mode( latest_feature_set_dataframe, feature_set, spark_session, + mocker, ): # given spark_client = SparkClient() writer = OnlineFeatureStoreWriter(debug_mode=True) + writer.check_schema_hook = mocker.stub("check_schema_hook") + writer.check_schema_hook.run = mocker.stub("run") + writer.check_schema_hook.run.return_value = feature_set_dataframe + # when writer.write( feature_set=feature_set, @@ -110,9 +119,7 @@ def test_write_in_debug_mode( # then assert_dataframe_equality(latest_feature_set_dataframe, result_df) - def test_write_in_debug_and_stream_mode( - self, feature_set, spark_session, - ): + def test_write_in_debug_and_stream_mode(self, feature_set, spark_session, mocker): # arrange spark_client = SparkClient() @@ -125,6 +132,10 @@ def test_write_in_debug_and_stream_mode( writer = OnlineFeatureStoreWriter(debug_mode=True) + writer.check_schema_hook = mocker.stub("check_schema_hook") + writer.check_schema_hook.run = mocker.stub("run") + writer.check_schema_hook.run.return_value = mocked_stream_df + # act handler = writer.write( feature_set=feature_set, @@ -140,7 +151,7 @@ def test_write_in_debug_and_stream_mode( assert isinstance(handler, StreamingQuery) @pytest.mark.parametrize("has_checkpoint", [True, False]) - def test_write_stream(self, feature_set, has_checkpoint, monkeypatch): + def test_write_stream(self, feature_set, has_checkpoint, monkeypatch, mocker): # arrange spark_client = SparkClient() spark_client.write_stream = Mock() @@ -163,6 +174,10 @@ def test_write_stream(self, feature_set, has_checkpoint, monkeypatch): writer = OnlineFeatureStoreWriter(cassandra_config) writer.filter_latest = Mock() + writer.check_schema_hook = mocker.stub("check_schema_hook") + writer.check_schema_hook.run = mocker.stub("run") + writer.check_schema_hook.run.return_value = dataframe + # act stream_handler = writer.write(feature_set, dataframe, spark_client) @@ -186,7 +201,7 @@ def test_get_db_schema(self, cassandra_config, test_feature_set, expected_schema assert schema == expected_schema - def test_write_stream_on_entity(self, feature_set, monkeypatch): + def test_write_stream_on_entity(self, feature_set, monkeypatch, mocker): """Test write method with stream dataframe and write_to_entity enabled. The main purpose of this test is assert the correct setup of stream checkpoint @@ -209,6 +224,10 @@ def test_write_stream_on_entity(self, feature_set, monkeypatch): writer = OnlineFeatureStoreWriter(write_to_entity=True) + writer.check_schema_hook = mocker.stub("check_schema_hook") + writer.check_schema_hook.run = mocker.stub("run") + writer.check_schema_hook.run.return_value = dataframe + # act stream_handler = writer.write(feature_set, dataframe, spark_client) @@ -237,6 +256,10 @@ def test_write_with_transform( spark_client.write_dataframe = mocker.stub("write_dataframe") writer = OnlineFeatureStoreWriter(cassandra_config).with_(json_transform) + writer.check_schema_hook = mocker.stub("check_schema_hook") + writer.check_schema_hook.run = mocker.stub("run") + writer.check_schema_hook.run.return_value = feature_set_dataframe + # when writer.write(feature_set, feature_set_dataframe, spark_client) @@ -270,6 +293,10 @@ def test_write_with_kafka_config( kafka_config = KafkaConfig() writer = OnlineFeatureStoreWriter(kafka_config).with_(json_transform) + writer.check_schema_hook = mocker.stub("check_schema_hook") + writer.check_schema_hook.run = mocker.stub("run") + writer.check_schema_hook.run.return_value = feature_set_dataframe + # when writer.write(feature_set, feature_set_dataframe, spark_client) @@ -293,6 +320,10 @@ def test_write_with_custom_kafka_config( json_transform ) + custom_writer.check_schema_hook = mocker.stub("check_schema_hook") + custom_writer.check_schema_hook.run = mocker.stub("run") + custom_writer.check_schema_hook.run.return_value = feature_set_dataframe + # when custom_writer.write(feature_set, feature_set_dataframe, spark_client) diff --git a/tests/unit/butterfree/pipelines/conftest.py b/tests/unit/butterfree/pipelines/conftest.py new file mode 100644 index 00000000..47e65efb --- /dev/null +++ b/tests/unit/butterfree/pipelines/conftest.py @@ -0,0 +1,63 @@ +from unittest.mock import Mock + +from pyspark.sql import functions +from pytest import fixture + +from butterfree.clients import SparkClient +from butterfree.constants import DataType +from butterfree.constants.columns import TIMESTAMP_COLUMN +from butterfree.extract import Source +from butterfree.extract.readers import TableReader +from butterfree.load import Sink +from butterfree.load.writers import HistoricalFeatureStoreWriter +from butterfree.pipelines import FeatureSetPipeline +from butterfree.transform import FeatureSet +from butterfree.transform.features import Feature, KeyFeature, TimestampFeature +from butterfree.transform.transformations import SparkFunctionTransform +from butterfree.transform.utils import Function + + +@fixture() +def feature_set_pipeline(): + test_pipeline = FeatureSetPipeline( + spark_client=SparkClient(), + source=Mock( + spec=Source, + readers=[TableReader(id="source_a", database="db", table="table",)], + query="select * from source_a", + ), + feature_set=Mock( + spec=FeatureSet, + name="feature_set", + entity="entity", + description="description", + keys=[ + KeyFeature( + name="user_id", + description="The user's Main ID or device ID", + dtype=DataType.INTEGER, + ) + ], + timestamp=TimestampFeature(from_column="ts"), + features=[ + Feature( + name="listing_page_viewed__rent_per_month", + description="Average of something.", + transformation=SparkFunctionTransform( + functions=[ + Function(functions.avg, DataType.FLOAT), + Function(functions.stddev_pop, DataType.FLOAT), + ], + ).with_window( + partition_by="user_id", + order_by=TIMESTAMP_COLUMN, + window_definition=["7 days", "2 weeks"], + mode="fixed_windows", + ), + ), + ], + ), + sink=Mock(spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)],), + ) + + return test_pipeline diff --git a/tests/unit/butterfree/pipelines/test_feature_set_pipeline.py b/tests/unit/butterfree/pipelines/test_feature_set_pipeline.py index 1bc3c707..7bae6606 100644 --- a/tests/unit/butterfree/pipelines/test_feature_set_pipeline.py +++ b/tests/unit/butterfree/pipelines/test_feature_set_pipeline.py @@ -17,12 +17,8 @@ from butterfree.load.writers.writer import Writer from butterfree.pipelines.feature_set_pipeline import FeatureSetPipeline from butterfree.transform import FeatureSet -from butterfree.transform.aggregated_feature_set import AggregatedFeatureSet from butterfree.transform.features import Feature, KeyFeature, TimestampFeature -from butterfree.transform.transformations import ( - AggregatedTransform, - SparkFunctionTransform, -) +from butterfree.transform.transformations import SparkFunctionTransform from butterfree.transform.utils import Function @@ -104,115 +100,29 @@ def test_feature_set_args(self): assert len(pipeline.sink.writers) == 2 assert all(isinstance(writer, Writer) for writer in pipeline.sink.writers) - def test_run(self, spark_session): - test_pipeline = FeatureSetPipeline( - spark_client=SparkClient(), - source=Mock( - spec=Source, - readers=[TableReader(id="source_a", database="db", table="table",)], - query="select * from source_a", - ), - feature_set=Mock( - spec=FeatureSet, - name="feature_set", - entity="entity", - description="description", - keys=[ - KeyFeature( - name="user_id", - description="The user's Main ID or device ID", - dtype=DataType.INTEGER, - ) - ], - timestamp=TimestampFeature(from_column="ts"), - features=[ - Feature( - name="listing_page_viewed__rent_per_month", - description="Average of something.", - transformation=SparkFunctionTransform( - functions=[ - Function(functions.avg, DataType.FLOAT), - Function(functions.stddev_pop, DataType.FLOAT), - ], - ).with_window( - partition_by="user_id", - order_by=TIMESTAMP_COLUMN, - window_definition=["7 days", "2 weeks"], - mode="fixed_windows", - ), - ), - ], - ), - sink=Mock( - spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)], - ), - ) - + def test_run(self, spark_session, feature_set_pipeline): # feature_set need to return a real df for streaming validation sample_df = spark_session.createDataFrame([{"a": "x", "b": "y", "c": "3"}]) - test_pipeline.feature_set.construct.return_value = sample_df + feature_set_pipeline.feature_set.construct.return_value = sample_df - test_pipeline.run() + feature_set_pipeline.run() - test_pipeline.source.construct.assert_called_once() - test_pipeline.feature_set.construct.assert_called_once() - test_pipeline.sink.flush.assert_called_once() - test_pipeline.sink.validate.assert_called_once() - - def test_run_with_repartition(self, spark_session): - test_pipeline = FeatureSetPipeline( - spark_client=SparkClient(), - source=Mock( - spec=Source, - readers=[TableReader(id="source_a", database="db", table="table",)], - query="select * from source_a", - ), - feature_set=Mock( - spec=FeatureSet, - name="feature_set", - entity="entity", - description="description", - keys=[ - KeyFeature( - name="user_id", - description="The user's Main ID or device ID", - dtype=DataType.INTEGER, - ) - ], - timestamp=TimestampFeature(from_column="ts"), - features=[ - Feature( - name="listing_page_viewed__rent_per_month", - description="Average of something.", - transformation=SparkFunctionTransform( - functions=[ - Function(functions.avg, DataType.FLOAT), - Function(functions.stddev_pop, DataType.FLOAT), - ], - ).with_window( - partition_by="user_id", - order_by=TIMESTAMP_COLUMN, - window_definition=["7 days", "2 weeks"], - mode="fixed_windows", - ), - ), - ], - ), - sink=Mock( - spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)], - ), - ) + feature_set_pipeline.source.construct.assert_called_once() + feature_set_pipeline.feature_set.construct.assert_called_once() + feature_set_pipeline.sink.flush.assert_called_once() + feature_set_pipeline.sink.validate.assert_called_once() + def test_run_with_repartition(self, spark_session, feature_set_pipeline): # feature_set need to return a real df for streaming validation sample_df = spark_session.createDataFrame([{"a": "x", "b": "y", "c": "3"}]) - test_pipeline.feature_set.construct.return_value = sample_df + feature_set_pipeline.feature_set.construct.return_value = sample_df - test_pipeline.run(partition_by=["id"]) + feature_set_pipeline.run(partition_by=["id"]) - test_pipeline.source.construct.assert_called_once() - test_pipeline.feature_set.construct.assert_called_once() - test_pipeline.sink.flush.assert_called_once() - test_pipeline.sink.validate.assert_called_once() + feature_set_pipeline.source.construct.assert_called_once() + feature_set_pipeline.feature_set.construct.assert_called_once() + feature_set_pipeline.sink.flush.assert_called_once() + feature_set_pipeline.sink.validate.assert_called_once() def test_source_raise(self): with pytest.raises(ValueError, match="source must be a Source instance"): @@ -343,52 +253,26 @@ def test_sink_raise(self): sink=Mock(writers=[HistoricalFeatureStoreWriter(db_config=None)],), ) - def test_run_agg_with_end_date(self, spark_session): - test_pipeline = FeatureSetPipeline( - spark_client=SparkClient(), - source=Mock( - spec=Source, - readers=[TableReader(id="source_a", database="db", table="table",)], - query="select * from source_a", - ), - feature_set=Mock( - spec=AggregatedFeatureSet, - name="feature_set", - entity="entity", - description="description", - keys=[ - KeyFeature( - name="user_id", - description="The user's Main ID or device ID", - dtype=DataType.INTEGER, - ) - ], - timestamp=TimestampFeature(from_column="ts"), - features=[ - Feature( - name="listing_page_viewed__rent_per_month", - description="Average of something.", - transformation=AggregatedTransform( - functions=[ - Function(functions.avg, DataType.FLOAT), - Function(functions.stddev_pop, DataType.FLOAT), - ], - ), - ), - ], - ), - sink=Mock( - spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)], - ), - ) + def test_run_agg_with_end_date(self, spark_session, feature_set_pipeline): + # feature_set need to return a real df for streaming validation + sample_df = spark_session.createDataFrame([{"a": "x", "b": "y", "c": "3"}]) + feature_set_pipeline.feature_set.construct.return_value = sample_df + + feature_set_pipeline.run(end_date="2016-04-18") + + feature_set_pipeline.source.construct.assert_called_once() + feature_set_pipeline.feature_set.construct.assert_called_once() + feature_set_pipeline.sink.flush.assert_called_once() + feature_set_pipeline.sink.validate.assert_called_once() + def test_run_agg_with_start_date(self, spark_session, feature_set_pipeline): # feature_set need to return a real df for streaming validation sample_df = spark_session.createDataFrame([{"a": "x", "b": "y", "c": "3"}]) - test_pipeline.feature_set.construct.return_value = sample_df + feature_set_pipeline.feature_set.construct.return_value = sample_df - test_pipeline.run(end_date="2016-04-18") + feature_set_pipeline.run(start_date="2020-08-04") - test_pipeline.source.construct.assert_called_once() - test_pipeline.feature_set.construct.assert_called_once() - test_pipeline.sink.flush.assert_called_once() - test_pipeline.sink.validate.assert_called_once() + feature_set_pipeline.source.construct.assert_called_once() + feature_set_pipeline.feature_set.construct.assert_called_once() + feature_set_pipeline.sink.flush.assert_called_once() + feature_set_pipeline.sink.validate.assert_called_once() diff --git a/tests/unit/butterfree/transform/conftest.py b/tests/unit/butterfree/transform/conftest.py index 2d7d3e50..febc8bbc 100644 --- a/tests/unit/butterfree/transform/conftest.py +++ b/tests/unit/butterfree/transform/conftest.py @@ -1,11 +1,19 @@ import json from unittest.mock import Mock +from pyspark.sql import functions from pytest import fixture from butterfree.constants import DataType from butterfree.constants.columns import TIMESTAMP_COLUMN +from butterfree.transform import FeatureSet +from butterfree.transform.aggregated_feature_set import AggregatedFeatureSet from butterfree.transform.features import Feature, KeyFeature, TimestampFeature +from butterfree.transform.transformations import ( + AggregatedTransform, + SparkFunctionTransform, +) +from butterfree.transform.utils import Function def make_dataframe(spark_context, spark_session): @@ -297,3 +305,77 @@ def key_id(): @fixture def timestamp_c(): return TimestampFeature() + + +@fixture +def feature_set(): + feature_set = FeatureSet( + name="feature_set", + entity="entity", + description="description", + features=[ + Feature( + name="feature1", + description="test", + transformation=SparkFunctionTransform( + functions=[ + Function(functions.avg, DataType.FLOAT), + Function(functions.stddev_pop, DataType.DOUBLE), + ] + ).with_window( + partition_by="id", + order_by=TIMESTAMP_COLUMN, + mode="fixed_windows", + window_definition=["2 minutes", "15 minutes"], + ), + ), + ], + keys=[ + KeyFeature( + name="id", + description="The user's Main ID or device ID", + dtype=DataType.BIGINT, + ) + ], + timestamp=TimestampFeature(), + ) + + return feature_set + + +@fixture +def agg_feature_set(): + feature_set = AggregatedFeatureSet( + name="feature_set", + entity="entity", + description="description", + features=[ + Feature( + name="feature1", + description="test", + transformation=AggregatedTransform( + functions=[ + Function(functions.avg, DataType.DOUBLE), + Function(functions.stddev_pop, DataType.FLOAT), + ], + ), + ), + Feature( + name="feature2", + description="test", + transformation=AggregatedTransform( + functions=[Function(functions.count, DataType.ARRAY_STRING)] + ), + ), + ], + keys=[ + KeyFeature( + name="id", + description="The user's Main ID or device ID", + dtype=DataType.BIGINT, + ) + ], + timestamp=TimestampFeature(), + ).with_windows(definitions=["1 week", "2 days"]) + + return feature_set diff --git a/tests/unit/butterfree/transform/test_aggregated_feature_set.py b/tests/unit/butterfree/transform/test_aggregated_feature_set.py index 2c404fea..8025d6f8 100644 --- a/tests/unit/butterfree/transform/test_aggregated_feature_set.py +++ b/tests/unit/butterfree/transform/test_aggregated_feature_set.py @@ -89,7 +89,7 @@ def test_agg_feature_set_with_window( output_df = fs.construct(dataframe, spark_client, end_date="2016-05-01") assert_dataframe_equality(output_df, rolling_windows_agg_dataframe) - def test_get_schema(self): + def test_get_schema(self, agg_feature_set): expected_schema = [ {"column_name": "id", "type": LongType(), "primary_key": True}, {"column_name": "timestamp", "type": TimestampType(), "primary_key": False}, @@ -125,40 +125,7 @@ def test_get_schema(self): }, ] - feature_set = AggregatedFeatureSet( - name="feature_set", - entity="entity", - description="description", - features=[ - Feature( - name="feature1", - description="test", - transformation=AggregatedTransform( - functions=[ - Function(functions.avg, DataType.DOUBLE), - Function(functions.stddev_pop, DataType.FLOAT), - ], - ), - ), - Feature( - name="feature2", - description="test", - transformation=AggregatedTransform( - functions=[Function(functions.count, DataType.ARRAY_STRING)] - ), - ), - ], - keys=[ - KeyFeature( - name="id", - description="The user's Main ID or device ID", - dtype=DataType.BIGINT, - ) - ], - timestamp=TimestampFeature(), - ).with_windows(definitions=["1 week", "2 days"]) - - schema = feature_set.get_schema() + schema = agg_feature_set.get_schema() assert schema == expected_schema @@ -389,3 +356,34 @@ def test_feature_transform_with_data_type_array(self, spark_context, spark_sessi # assert assert_dataframe_equality(target_df, output_df) + + def test_define_start_date(self, agg_feature_set): + start_date = agg_feature_set.define_start_date("2020-08-04") + + assert isinstance(start_date, str) + assert start_date == "2020-07-27" + + def test_feature_set_start_date( + self, timestamp_c, feature_set_with_distinct_dataframe, + ): + fs = AggregatedFeatureSet( + name="name", + entity="entity", + description="description", + features=[ + Feature( + name="feature", + description="test", + transformation=AggregatedTransform( + functions=[Function(functions.sum, DataType.INTEGER)] + ), + ), + ], + keys=[KeyFeature(name="h3", description="test", dtype=DataType.STRING)], + timestamp=timestamp_c, + ).with_windows(["10 days", "3 weeks", "90 days"]) + + # assert + start_date = fs.define_start_date("2016-04-14") + + assert start_date == "2016-01-14" diff --git a/tests/unit/butterfree/transform/test_feature_set.py b/tests/unit/butterfree/transform/test_feature_set.py index bdb1ff7d..43d937be 100644 --- a/tests/unit/butterfree/transform/test_feature_set.py +++ b/tests/unit/butterfree/transform/test_feature_set.py @@ -12,13 +12,11 @@ from butterfree.clients import SparkClient from butterfree.constants import DataType -from butterfree.constants.columns import TIMESTAMP_COLUMN from butterfree.testing.dataframe import assert_dataframe_equality from butterfree.transform import FeatureSet -from butterfree.transform.features import Feature, KeyFeature, TimestampFeature +from butterfree.transform.features import Feature from butterfree.transform.transformations import ( AggregatedTransform, - SparkFunctionTransform, SQLExpressionTransform, ) from butterfree.transform.utils import Function @@ -341,7 +339,7 @@ def test_feature_set_with_invalid_feature(self, key_id, timestamp_c, dataframe): timestamp=timestamp_c, ).construct(dataframe, spark_client) - def test_get_schema(self): + def test_get_schema(self, feature_set): expected_schema = [ {"column_name": "id", "type": LongType(), "primary_key": True}, {"column_name": "timestamp", "type": TimestampType(), "primary_key": False}, @@ -367,37 +365,6 @@ def test_get_schema(self): }, ] - feature_set = FeatureSet( - name="feature_set", - entity="entity", - description="description", - features=[ - Feature( - name="feature1", - description="test", - transformation=SparkFunctionTransform( - functions=[ - Function(F.avg, DataType.FLOAT), - Function(F.stddev_pop, DataType.DOUBLE), - ] - ).with_window( - partition_by="id", - order_by=TIMESTAMP_COLUMN, - mode="fixed_windows", - window_definition=["2 minutes", "15 minutes"], - ), - ), - ], - keys=[ - KeyFeature( - name="id", - description="The user's Main ID or device ID", - dtype=DataType.BIGINT, - ) - ], - timestamp=TimestampFeature(), - ) - schema = feature_set.get_schema() assert schema == expected_schema @@ -421,3 +388,9 @@ def test_feature_without_datatype(self, key_id, timestamp_c, dataframe): keys=[key_id], timestamp=timestamp_c, ).construct(dataframe, spark_client) + + def test_define_start_date(self, feature_set): + start_date = feature_set.define_start_date("2020-08-04") + + assert isinstance(start_date, str) + assert start_date == "2020-08-04"