From 32e24d6382a973452dd20611c22358cd8d5976bd Mon Sep 17 00:00:00 2001
From: Mayara Moromisato <44944954+moromimay@users.noreply.github.com>
Date: Fri, 19 Feb 2021 10:18:09 -0300
Subject: [PATCH] [MLOP-635] Rebase Incremental Job/Interval Run branch for
test on selected feature sets (#278)
* Add interval branch modifications.
* Add interval_runs notebook.
* Add tests.
* Apply style (black, flack8 and mypy).
* Fix tests.
* Change version to create package dev.
---
.gitignore | 1 +
CHANGELOG.md | 6 +
butterfree/clients/cassandra_client.py | 41 +-
butterfree/clients/spark_client.py | 58 +-
butterfree/configs/db/metastore_config.py | 28 +
butterfree/configs/environment.py | 4 +-
butterfree/constants/window_definitions.py | 16 +
butterfree/dataframe_service/__init__.py | 9 +-
.../dataframe_service/incremental_strategy.py | 116 +
butterfree/dataframe_service/partitioning.py | 25 +
butterfree/extract/readers/file_reader.py | 12 +-
butterfree/extract/readers/reader.py | 88 +-
butterfree/extract/source.py | 24 +-
butterfree/hooks/__init__.py | 5 +
butterfree/hooks/hook.py | 20 +
butterfree/hooks/hookable_component.py | 148 ++
.../hooks/schema_compatibility/__init__.py | 9 +
...ssandra_table_schema_compatibility_hook.py | 58 +
.../spark_table_schema_compatibility_hook.py | 46 +
butterfree/load/sink.py | 13 +-
.../historical_feature_store_writer.py | 113 +-
.../writers/online_feature_store_writer.py | 50 +-
butterfree/load/writers/writer.py | 21 +-
butterfree/pipelines/feature_set_pipeline.py | 56 +-
.../transform/aggregated_feature_set.py | 49 +-
butterfree/transform/feature_set.py | 38 +-
butterfree/transform/utils/window_spec.py | 20 +-
examples/interval_runs/interval_runs.ipynb | 2152 +++++++++++++++++
setup.py | 2 +-
.../integration/butterfree/load/test_sink.py | 35 +-
.../butterfree/pipelines/conftest.py | 202 ++
.../pipelines/test_feature_set_pipeline.py | 311 ++-
.../butterfree/transform/conftest.py | 55 +
.../transform/test_aggregated_feature_set.py | 50 +
.../butterfree/transform/test_feature_set.py | 44 +
tests/unit/butterfree/clients/conftest.py | 11 +-
.../clients/test_cassandra_client.py | 4 +-
.../butterfree/clients/test_spark_client.py | 69 +-
.../butterfree/dataframe_service/conftest.py | 14 +
.../test_incremental_srategy.py | 70 +
.../dataframe_service/test_partitioning.py | 20 +
tests/unit/butterfree/extract/conftest.py | 55 +
.../extract/readers/test_file_reader.py | 10 +-
.../butterfree/extract/readers/test_reader.py | 58 +
tests/unit/butterfree/hooks/__init__.py | 0
.../hooks/schema_compatibility/__init__.py | 0
...ssandra_table_schema_compatibility_hook.py | 49 +
...t_spark_table_schema_compatibility_hook.py | 53 +
.../hooks/test_hookable_component.py | 107 +
tests/unit/butterfree/load/conftest.py | 25 +
tests/unit/butterfree/load/test_sink.py | 34 +-
.../test_historical_feature_store_writer.py | 144 +-
.../test_online_feature_store_writer.py | 41 +-
tests/unit/butterfree/pipelines/conftest.py | 63 +
.../pipelines/test_feature_set_pipeline.py | 182 +-
tests/unit/butterfree/transform/conftest.py | 82 +
.../transform/test_aggregated_feature_set.py | 68 +-
.../butterfree/transform/test_feature_set.py | 43 +-
58 files changed, 4738 insertions(+), 389 deletions(-)
create mode 100644 butterfree/constants/window_definitions.py
create mode 100644 butterfree/dataframe_service/incremental_strategy.py
create mode 100644 butterfree/dataframe_service/partitioning.py
create mode 100644 butterfree/hooks/__init__.py
create mode 100644 butterfree/hooks/hook.py
create mode 100644 butterfree/hooks/hookable_component.py
create mode 100644 butterfree/hooks/schema_compatibility/__init__.py
create mode 100644 butterfree/hooks/schema_compatibility/cassandra_table_schema_compatibility_hook.py
create mode 100644 butterfree/hooks/schema_compatibility/spark_table_schema_compatibility_hook.py
create mode 100644 examples/interval_runs/interval_runs.ipynb
create mode 100644 tests/unit/butterfree/dataframe_service/test_incremental_srategy.py
create mode 100644 tests/unit/butterfree/dataframe_service/test_partitioning.py
create mode 100644 tests/unit/butterfree/hooks/__init__.py
create mode 100644 tests/unit/butterfree/hooks/schema_compatibility/__init__.py
create mode 100644 tests/unit/butterfree/hooks/schema_compatibility/test_cassandra_table_schema_compatibility_hook.py
create mode 100644 tests/unit/butterfree/hooks/schema_compatibility/test_spark_table_schema_compatibility_hook.py
create mode 100644 tests/unit/butterfree/hooks/test_hookable_component.py
create mode 100644 tests/unit/butterfree/pipelines/conftest.py
diff --git a/.gitignore b/.gitignore
index 72b591f3..62434612 100644
--- a/.gitignore
+++ b/.gitignore
@@ -47,6 +47,7 @@ coverage.xml
*.cover
.hypothesis/
*cov.xml
+test_folder/
# Translations
*.mo
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 72994621..679e9834 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,11 +5,17 @@ Preferably use **Added**, **Changed**, **Removed** and **Fixed** topics in each
## [Unreleased]
### Added
+* [MLOP-636] Create migration classes ([#282](https://github.com/quintoandar/butterfree/pull/282))
+
+## [1.1.3](https://github.com/quintoandar/butterfree/releases/tag/1.1.3)
+### Added
* [MLOP-599] Apply mypy to ButterFree ([#273](https://github.com/quintoandar/butterfree/pull/273))
### Changed
* [MLOP-634] Butterfree dev workflow, set triggers for branches staging and master ([#280](https://github.com/quintoandar/butterfree/pull/280))
* Keep milliseconds when using 'from_ms' argument in timestamp feature ([#284](https://github.com/quintoandar/butterfree/pull/284))
+* [MLOP-633] Butterfree dev workflow, update documentation ([#281](https://github.com/quintoandar/butterfree/commit/74278986a49f1825beee0fd8df65a585764e5524))
+* [MLOP-632] Butterfree dev workflow, automate release description ([#279](https://github.com/quintoandar/butterfree/commit/245eaa594846166972241b03fddc61ee5117b1f7))
### Fixed
* Change trigger for pipeline staging ([#287](https://github.com/quintoandar/butterfree/pull/287))
diff --git a/butterfree/clients/cassandra_client.py b/butterfree/clients/cassandra_client.py
index 1e541688..938d4e4d 100644
--- a/butterfree/clients/cassandra_client.py
+++ b/butterfree/clients/cassandra_client.py
@@ -33,33 +33,31 @@ class CassandraClient(AbstractClient):
"""Cassandra Client.
Attributes:
- cassandra_user: username to use in connection.
- cassandra_password: password to use in connection.
- cassandra_key_space: key space used in connection.
- cassandra_host: cassandra endpoint used in connection.
+ user: username to use in connection.
+ password: password to use in connection.
+ keyspace: key space used in connection.
+ host: cassandra endpoint used in connection.
"""
def __init__(
self,
- cassandra_host: List[str],
- cassandra_key_space: str,
- cassandra_user: Optional[str] = None,
- cassandra_password: Optional[str] = None,
+ host: List[str],
+ keyspace: str,
+ user: Optional[str] = None,
+ password: Optional[str] = None,
) -> None:
- self.cassandra_host = cassandra_host
- self.cassandra_key_space = cassandra_key_space
- self.cassandra_user = cassandra_user
- self.cassandra_password = cassandra_password
+ self.host = host
+ self.keyspace = keyspace
+ self.user = user
+ self.password = password
self._session: Optional[Session] = None
@property
def conn(self, *, ssl_path: str = None) -> Session: # type: ignore
"""Establishes a Cassandra connection."""
auth_provider = (
- PlainTextAuthProvider(
- username=self.cassandra_user, password=self.cassandra_password
- )
- if self.cassandra_user is not None
+ PlainTextAuthProvider(username=self.user, password=self.password)
+ if self.user is not None
else None
)
ssl_opts = (
@@ -73,12 +71,12 @@ def conn(self, *, ssl_path: str = None) -> Session: # type: ignore
)
cluster = Cluster(
- contact_points=self.cassandra_host,
+ contact_points=self.host,
auth_provider=auth_provider,
ssl_options=ssl_opts,
load_balancing_policy=RoundRobinPolicy(),
)
- self._session = cluster.connect(self.cassandra_key_space)
+ self._session = cluster.connect(self.keyspace)
self._session.row_factory = dict_factory
return self._session
@@ -106,7 +104,7 @@ def get_schema(self, table: str) -> List[Dict[str, str]]:
"""
query = (
f"SELECT column_name, type FROM system_schema.columns " # noqa
- f"WHERE keyspace_name = '{self.cassandra_key_space}' " # noqa
+ f"WHERE keyspace_name = '{self.keyspace}' " # noqa
f" AND table_name = '{table}';" # noqa
)
@@ -114,8 +112,7 @@ def get_schema(self, table: str) -> List[Dict[str, str]]:
if not response:
raise RuntimeError(
- f"No columns found for table: {table}"
- f"in key space: {self.cassandra_key_space}"
+ f"No columns found for table: {table}" f"in key space: {self.keyspace}"
)
return response
@@ -143,7 +140,7 @@ def _get_create_table_query(
else:
columns_str = joined_parsed_columns
- query = f"CREATE TABLE {self.cassandra_key_space}.{table} " f"({columns_str}); "
+ query = f"CREATE TABLE {self.keyspace}.{table} " f"({columns_str}); "
return query
diff --git a/butterfree/clients/spark_client.py b/butterfree/clients/spark_client.py
index 0a8c717c..0f0113e2 100644
--- a/butterfree/clients/spark_client.py
+++ b/butterfree/clients/spark_client.py
@@ -34,9 +34,10 @@ def conn(self) -> SparkSession:
def read(
self,
format: str,
- options: Dict[str, Any],
+ path: Optional[Union[str, List[str]]] = None,
schema: Optional[StructType] = None,
stream: bool = False,
+ **options: Any,
) -> DataFrame:
"""Use the SparkSession.read interface to load data into a dataframe.
@@ -45,9 +46,10 @@ def read(
Args:
format: string with the format to be used by the DataframeReader.
- options: options to setup the DataframeReader.
+ path: optional string or a list of string for file-system.
stream: flag to indicate if data must be read in stream mode.
schema: an optional pyspark.sql.types.StructType for the input schema.
+ options: options to setup the DataframeReader.
Returns:
Dataframe
@@ -55,14 +57,16 @@ def read(
"""
if not isinstance(format, str):
raise ValueError("format needs to be a string with the desired read format")
- if not isinstance(options, dict):
- raise ValueError("options needs to be a dict with the setup configurations")
+ if not isinstance(path, (str, list)):
+ raise ValueError("path needs to be a string or a list of string")
df_reader: Union[
DataStreamReader, DataFrameReader
] = self.conn.readStream if stream else self.conn.read
+
df_reader = df_reader.schema(schema) if schema else df_reader
- return df_reader.format(format).options(**options).load()
+
+ return df_reader.format(format).load(path, **options) # type: ignore
def read_table(self, table: str, database: str = None) -> DataFrame:
"""Use the SparkSession.read interface to read a metastore table.
@@ -223,3 +227,47 @@ def create_temporary_view(self, dataframe: DataFrame, name: str) -> Any:
if not dataframe.isStreaming:
return dataframe.createOrReplaceTempView(name)
return dataframe.writeStream.format("memory").queryName(name).start()
+
+ def add_table_partitions(
+ self, partitions: List[Dict[str, Any]], table: str, database: str = None
+ ) -> None:
+ """Add partitions to an existing table.
+
+ Args:
+ partitions: partitions to add to the table.
+ It's expected a list of partition dicts to add to the table.
+ Example: `[{"year": 2020, "month": 8, "day": 14}, ...]`
+ table: table to add the partitions.
+ database: name of the database where the table is saved.
+ """
+ for partition_dict in partitions:
+ if not all(
+ (
+ isinstance(key, str)
+ and (isinstance(value, str) or isinstance(value, int))
+ )
+ for key, value in partition_dict.items()
+ ):
+ raise ValueError(
+ "Partition keys must be column names "
+ "and values must be string or int."
+ )
+
+ database_expr = f"`{database}`." if database else ""
+ key_values_expr = [
+ ", ".join(
+ [
+ "{} = {}".format(k, v)
+ if not isinstance(v, str)
+ else "{} = '{}'".format(k, v)
+ for k, v in partition.items()
+ ]
+ )
+ for partition in partitions
+ ]
+ partitions_expr = " ".join(f"PARTITION ( {expr} )" for expr in key_values_expr)
+ command = (
+ f"ALTER TABLE {database_expr}`{table}` ADD IF NOT EXISTS {partitions_expr}"
+ )
+
+ self.conn.sql(command)
diff --git a/butterfree/configs/db/metastore_config.py b/butterfree/configs/db/metastore_config.py
index d94b792c..a3b315d5 100644
--- a/butterfree/configs/db/metastore_config.py
+++ b/butterfree/configs/db/metastore_config.py
@@ -3,8 +3,11 @@
import os
from typing import Any, Dict, List, Optional
+from pyspark.sql import DataFrame
+
from butterfree.configs import environment
from butterfree.configs.db import AbstractWriteConfig
+from butterfree.dataframe_service import extract_partition_values
class MetastoreConfig(AbstractWriteConfig):
@@ -87,6 +90,31 @@ def get_options(self, key: str) -> Dict[Optional[str], Optional[str]]:
"path": os.path.join(f"{self.file_system}://{self.path}/", key),
}
+ def get_path_with_partitions(self, key: str, dataframe: DataFrame) -> List:
+ """Get options for AWS S3 from partitioned parquet file.
+
+ Options will be a dictionary with the write and read configuration for
+ Spark to AWS S3.
+
+ Args:
+ key: path to save data into AWS S3 bucket.
+ dataframe: spark dataframe containing data from a feature set.
+
+ Returns:
+ A list of string for file-system backed data sources.
+ """
+ path_list = []
+ dataframe_values = extract_partition_values(
+ dataframe, partition_columns=["year", "month", "day"]
+ )
+ for row in dataframe_values:
+ path_list.append(
+ f"{self.file_system}://{self.path}/{key}/year={row['year']}/"
+ f"month={row['month']}/day={row['day']}"
+ )
+
+ return path_list
+
def translate(self, schema: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Translate feature set spark schema to the corresponding database."""
pass
diff --git a/butterfree/configs/environment.py b/butterfree/configs/environment.py
index 6f5accbc..f98a7a01 100644
--- a/butterfree/configs/environment.py
+++ b/butterfree/configs/environment.py
@@ -35,8 +35,8 @@ def get_variable(variable_name: str, default_value: str = None) -> Optional[str]
"""Gets an environment variable.
The variable comes from it's explicitly declared value in the running
- environment or from the default value declared in the environment.yaml
- specification or from the default_value.
+ environment or from the default value declared in specification or from the
+ default_value.
Args:
variable_name: environment variable name.
diff --git a/butterfree/constants/window_definitions.py b/butterfree/constants/window_definitions.py
new file mode 100644
index 00000000..560904f7
--- /dev/null
+++ b/butterfree/constants/window_definitions.py
@@ -0,0 +1,16 @@
+"""Allowed windows units and lengths in seconds."""
+
+ALLOWED_WINDOWS = {
+ "second": 1,
+ "seconds": 1,
+ "minute": 60,
+ "minutes": 60,
+ "hour": 3600,
+ "hours": 3600,
+ "day": 86400,
+ "days": 86400,
+ "week": 604800,
+ "weeks": 604800,
+ "year": 29030400,
+ "years": 29030400,
+}
diff --git a/butterfree/dataframe_service/__init__.py b/butterfree/dataframe_service/__init__.py
index 5116261d..c227dae2 100644
--- a/butterfree/dataframe_service/__init__.py
+++ b/butterfree/dataframe_service/__init__.py
@@ -1,4 +1,11 @@
"""Dataframe optimization components regarding Butterfree."""
+from butterfree.dataframe_service.incremental_strategy import IncrementalStrategy
+from butterfree.dataframe_service.partitioning import extract_partition_values
from butterfree.dataframe_service.repartition import repartition_df, repartition_sort_df
-__all__ = ["repartition_df", "repartition_sort_df"]
+__all__ = [
+ "extract_partition_values",
+ "IncrementalStrategy",
+ "repartition_df",
+ "repartition_sort_df",
+]
diff --git a/butterfree/dataframe_service/incremental_strategy.py b/butterfree/dataframe_service/incremental_strategy.py
new file mode 100644
index 00000000..6554d3b7
--- /dev/null
+++ b/butterfree/dataframe_service/incremental_strategy.py
@@ -0,0 +1,116 @@
+"""IncrementalStrategy entity."""
+
+from __future__ import annotations
+
+from pyspark.sql import DataFrame
+
+
+class IncrementalStrategy:
+ """Define an incremental strategy to be used on data sources.
+
+ Entity responsible for defining a column expression that will be used to
+ filter the original data source. The purpose is to get only the data related
+ to a specific pipeline execution time interval.
+
+ Attributes:
+ column: column expression on which incremental filter will be applied.
+ The expression need to result on a date or timestamp format, so the
+ filter can properly work with the defined upper and lower bounds.
+ """
+
+ def __init__(self, column: str = None):
+ self.column = column
+
+ def from_milliseconds(self, column_name: str) -> IncrementalStrategy:
+ """Create a column expression from ts column defined as milliseconds.
+
+ Args:
+ column_name: column name where the filter will be applied.
+
+ Returns:
+ `IncrementalStrategy` with the defined column expression.
+ """
+ return IncrementalStrategy(column=f"from_unixtime({column_name}/ 1000.0)")
+
+ def from_string(self, column_name: str, mask: str = None) -> IncrementalStrategy:
+ """Create a column expression from ts column defined as a simple string.
+
+ Args:
+ column_name: column name where the filter will be applied.
+ mask: mask defining the date/timestamp format on the string.
+
+ Returns:
+ `IncrementalStrategy` with the defined column expression.
+ """
+ return IncrementalStrategy(column=f"to_date({column_name}, '{mask}')")
+
+ def from_year_month_day_partitions(
+ self,
+ year_column: str = "year",
+ month_column: str = "month",
+ day_column: str = "day",
+ ) -> IncrementalStrategy:
+ """Create a column expression from year, month and day partitions.
+
+ Args:
+ year_column: column name from the year partition.
+ month_column: column name from the month partition.
+ day_column: column name from the day partition.
+
+ Returns:
+ `IncrementalStrategy` with the defined column expression.
+ """
+ return IncrementalStrategy(
+ column=f"concat(string({year_column}), "
+ f"'-', string({month_column}), "
+ f"'-', string({day_column}))"
+ )
+
+ def get_expression(self, start_date: str = None, end_date: str = None) -> str:
+ """Get the incremental filter expression using the defined dates.
+
+ Both arguments can be set to defined a specific date interval, but it's
+ only necessary to set one of the arguments for this method to work.
+
+ Args:
+ start_date: date lower bound to use in the filter.
+ end_date: date upper bound to use in the filter.
+
+ Returns:
+ Filter expression based on defined column and bounds.
+
+ Raises:
+ ValuerError: If both arguments, start_date and end_date, are None.
+ ValueError: If the column expression was not defined.
+ """
+ if not self.column:
+ raise ValueError("column parameter can't be None")
+ if not (start_date or end_date):
+ raise ValueError("Both arguments start_date and end_date can't be None.")
+ if start_date:
+ expression = f"date({self.column}) >= date('{start_date}')"
+ if end_date:
+ expression += f" and date({self.column}) <= date('{end_date}')"
+ return expression
+ return f"date({self.column}) <= date('{end_date}')"
+
+ def filter_with_incremental_strategy(
+ self, dataframe: DataFrame, start_date: str = None, end_date: str = None
+ ) -> DataFrame:
+ """Filters the dataframe according to the date boundaries.
+
+ Args:
+ dataframe: dataframe that will be filtered.
+ start_date: date lower bound to use in the filter.
+ end_date: date upper bound to use in the filter.
+
+ Returns:
+ Filtered dataframe based on defined time boundaries.
+ """
+ return (
+ dataframe.where(
+ self.get_expression(start_date=start_date, end_date=end_date)
+ )
+ if start_date or end_date
+ else dataframe
+ )
diff --git a/butterfree/dataframe_service/partitioning.py b/butterfree/dataframe_service/partitioning.py
new file mode 100644
index 00000000..21e9b0ab
--- /dev/null
+++ b/butterfree/dataframe_service/partitioning.py
@@ -0,0 +1,25 @@
+"""Module defining partitioning methods."""
+
+from typing import Any, Dict, List
+
+from pyspark.sql import DataFrame
+
+
+def extract_partition_values(
+ dataframe: DataFrame, partition_columns: List[str]
+) -> List[Dict[str, Any]]:
+ """Extract distinct partition values from a given dataframe.
+
+ Args:
+ dataframe: dataframe from where to extract partition values.
+ partition_columns: name of partition columns presented on the dataframe.
+
+ Returns:
+ distinct partition values.
+ """
+ return (
+ dataframe.select(*partition_columns)
+ .distinct()
+ .rdd.map(lambda row: row.asDict(True))
+ .collect()
+ )
diff --git a/butterfree/extract/readers/file_reader.py b/butterfree/extract/readers/file_reader.py
index 17f68f1c..8cf15599 100644
--- a/butterfree/extract/readers/file_reader.py
+++ b/butterfree/extract/readers/file_reader.py
@@ -87,9 +87,7 @@ def __init__(
self.path = path
self.format = format
self.schema = schema
- self.options = dict(
- {"path": self.path}, **format_options if format_options else {}
- )
+ self.options = dict(format_options if format_options else {})
self.stream = stream
def consume(self, client: SparkClient) -> DataFrame:
@@ -106,11 +104,15 @@ def consume(self, client: SparkClient) -> DataFrame:
"""
schema = (
- client.read(format=self.format, options=self.options,).schema
+ client.read(format=self.format, path=self.path, **self.options).schema
if (self.stream and not self.schema)
else self.schema
)
return client.read(
- format=self.format, options=self.options, schema=schema, stream=self.stream,
+ format=self.format,
+ schema=schema,
+ stream=self.stream,
+ path=self.path,
+ **self.options,
)
diff --git a/butterfree/extract/readers/reader.py b/butterfree/extract/readers/reader.py
index 78be2823..597c870f 100644
--- a/butterfree/extract/readers/reader.py
+++ b/butterfree/extract/readers/reader.py
@@ -2,14 +2,16 @@
from abc import ABC, abstractmethod
from functools import reduce
-from typing import Any, Callable, Dict, List
+from typing import Any, Callable, Dict, List, Optional
from pyspark.sql import DataFrame
from butterfree.clients import SparkClient
+from butterfree.dataframe_service import IncrementalStrategy
+from butterfree.hooks import HookableComponent
-class Reader(ABC):
+class Reader(ABC, HookableComponent):
"""Abstract base class for Readers.
Attributes:
@@ -19,9 +21,11 @@ class Reader(ABC):
"""
- def __init__(self, id: str):
+ def __init__(self, id: str, incremental_strategy: IncrementalStrategy = None):
+ super().__init__()
self.id = id
self.transformations: List[Dict[str, Any]] = []
+ self.incremental_strategy = incremental_strategy
def with_(
self, transformer: Callable[..., DataFrame], *args: Any, **kwargs: Any
@@ -48,14 +52,19 @@ def with_(
self.transformations.append(new_transformation)
return self
- def _apply_transformations(self, df: DataFrame) -> Any:
- return reduce(
- lambda result_df, transformation: transformation["transformer"](
- result_df, *transformation["args"], **transformation["kwargs"]
- ),
- self.transformations,
- df,
- )
+ def with_incremental_strategy(
+ self, incremental_strategy: IncrementalStrategy
+ ) -> "Reader":
+ """Define the incremental strategy for the Reader.
+
+ Args:
+ incremental_strategy: definition of the incremental strategy.
+
+ Returns:
+ Reader with defined incremental strategy.
+ """
+ self.incremental_strategy = incremental_strategy
+ return self
@abstractmethod
def consume(self, client: SparkClient) -> DataFrame:
@@ -70,24 +79,61 @@ def consume(self, client: SparkClient) -> DataFrame:
:return: Spark dataframe
"""
- def build(self, client: SparkClient, columns: List[Any] = None) -> None:
+ def build(
+ self,
+ client: SparkClient,
+ columns: List[Any] = None,
+ start_date: str = None,
+ end_date: str = None,
+ ) -> None:
"""Register the data got from the reader in the Spark metastore.
Create a temporary view in Spark metastore referencing the data
extracted from the target origin after the application of all the
defined pre-processing transformations.
+ The arguments start_date and end_date are going to be use only when there
+ is a defined `IncrementalStrategy` on the `Reader`.
+
Args:
client: client responsible for connecting to Spark session.
- columns: list of tuples for renaming/filtering the dataset.
+ columns: list of tuples for selecting/renaming columns on the df.
+ start_date: lower bound to use in the filter expression.
+ end_date: upper bound to use in the filter expression.
"""
- transformed_df = self._apply_transformations(self.consume(client))
-
- if columns:
- select_expression = []
- for old_expression, new_column_name in columns:
- select_expression.append(f"{old_expression} as {new_column_name}")
- transformed_df = transformed_df.selectExpr(*select_expression)
+ column_selection_df = self._select_columns(columns, client)
+ transformed_df = self._apply_transformations(column_selection_df)
+
+ if self.incremental_strategy:
+ transformed_df = self.incremental_strategy.filter_with_incremental_strategy(
+ transformed_df, start_date, end_date
+ )
+
+ post_hook_df = self.run_post_hooks(transformed_df)
+
+ post_hook_df.createOrReplaceTempView(self.id)
+
+ def _select_columns(
+ self, columns: Optional[List[Any]], client: SparkClient
+ ) -> DataFrame:
+ df = self.consume(client)
+ return df.selectExpr(
+ *(
+ [
+ f"{old_expression} as {new_column_name}"
+ for old_expression, new_column_name in columns
+ ]
+ if columns
+ else df.columns
+ )
+ )
- transformed_df.createOrReplaceTempView(self.id)
+ def _apply_transformations(self, df: DataFrame) -> DataFrame:
+ return reduce(
+ lambda result_df, transformation: transformation["transformer"](
+ result_df, *transformation["args"], **transformation["kwargs"]
+ ),
+ self.transformations,
+ df,
+ )
diff --git a/butterfree/extract/source.py b/butterfree/extract/source.py
index 00ac9e43..6d905c6b 100644
--- a/butterfree/extract/source.py
+++ b/butterfree/extract/source.py
@@ -6,9 +6,10 @@
from butterfree.clients import SparkClient
from butterfree.extract.readers.reader import Reader
+from butterfree.hooks import HookableComponent
-class Source:
+class Source(HookableComponent):
"""The definition of the the entry point data for the ETL pipeline.
A FeatureSet (the next step in the pipeline) expects a single dataframe as
@@ -51,31 +52,44 @@ class Source:
"""
def __init__(self, readers: List[Reader], query: str) -> None:
+ super().__init__()
+ self.enable_pre_hooks = False
self.readers = readers
self.query = query
- def construct(self, client: SparkClient) -> DataFrame:
+ def construct(
+ self, client: SparkClient, start_date: str = None, end_date: str = None
+ ) -> DataFrame:
"""Construct an entry point dataframe for a feature set.
This method will assemble multiple readers, by building each one and
- querying them using a Spark SQL.
+ querying them using a Spark SQL. It's important to highlight that in
+ order to filter a dataframe regarding date boundaries, it's important
+ to define a IncrementalStrategy, otherwise your data will not be filtered.
+ Besides, both start and end dates parameters are optional.
After that, there's the caching of the dataframe, however since cache()
in Spark is lazy, an action is triggered in order to force persistence.
Args:
client: client responsible for connecting to Spark session.
+ start_date: user defined start date for filtering.
+ end_date: user defined end date for filtering.
Returns:
DataFrame with the query result against all readers.
"""
for reader in self.readers:
- reader.build(client) # create temporary views for each reader
+ reader.build(
+ client=client, start_date=start_date, end_date=end_date
+ ) # create temporary views for each reader
dataframe = client.sql(self.query)
if not dataframe.isStreaming:
dataframe.cache().count()
- return dataframe
+ post_hook_df = self.run_post_hooks(dataframe)
+
+ return post_hook_df
diff --git a/butterfree/hooks/__init__.py b/butterfree/hooks/__init__.py
new file mode 100644
index 00000000..90bedeb2
--- /dev/null
+++ b/butterfree/hooks/__init__.py
@@ -0,0 +1,5 @@
+"""Holds Hooks definitions."""
+from butterfree.hooks.hook import Hook
+from butterfree.hooks.hookable_component import HookableComponent
+
+__all__ = ["Hook", "HookableComponent"]
diff --git a/butterfree/hooks/hook.py b/butterfree/hooks/hook.py
new file mode 100644
index 00000000..f7d8c562
--- /dev/null
+++ b/butterfree/hooks/hook.py
@@ -0,0 +1,20 @@
+"""Hook abstract class entity."""
+
+from abc import ABC, abstractmethod
+
+from pyspark.sql import DataFrame
+
+
+class Hook(ABC):
+ """Definition of a hook function to call on a Dataframe."""
+
+ @abstractmethod
+ def run(self, dataframe: DataFrame) -> DataFrame:
+ """Run interface for Hook.
+
+ Args:
+ dataframe: dataframe to use in the Hook.
+
+ Returns:
+ dataframe result from the Hook.
+ """
diff --git a/butterfree/hooks/hookable_component.py b/butterfree/hooks/hookable_component.py
new file mode 100644
index 00000000..d89babce
--- /dev/null
+++ b/butterfree/hooks/hookable_component.py
@@ -0,0 +1,148 @@
+"""Definition of hookable component."""
+
+from __future__ import annotations
+
+from typing import List
+
+from pyspark.sql import DataFrame
+
+from butterfree.hooks.hook import Hook
+
+
+class HookableComponent:
+ """Defines a component with the ability to hold pre and post hook functions.
+
+ All main module of Butterfree have a common object that enables their integration:
+ dataframes. Spark's dataframe is the glue that enables the transmission of data
+ between the main modules. Hooks have a simple interface, they are functions that
+ accepts a dataframe and outputs a dataframe. These Hooks can be triggered before or
+ after the main execution of a component.
+
+ Components from Butterfree that inherit HookableComponent entity, are components
+ that can define a series of steps to occur before or after the execution of their
+ main functionality.
+
+ Attributes:
+ pre_hooks: function steps to trigger before component main functionality.
+ post_hooks: function steps to trigger after component main functionality.
+ enable_pre_hooks: property to indicate if the component can define pre_hooks.
+ enable_post_hooks: property to indicate if the component can define post_hooks.
+ """
+
+ def __init__(self) -> None:
+ self.pre_hooks = []
+ self.post_hooks = []
+ self.enable_pre_hooks = True
+ self.enable_post_hooks = True
+
+ @property
+ def pre_hooks(self) -> List[Hook]:
+ """Function steps to trigger before component main functionality."""
+ return self.__pre_hook
+
+ @pre_hooks.setter
+ def pre_hooks(self, value: List[Hook]) -> None:
+ if not isinstance(value, list):
+ raise ValueError("pre_hooks should be a list of Hooks.")
+ if not all(isinstance(item, Hook) for item in value):
+ raise ValueError(
+ "All items on pre_hooks list should be an instance of Hook."
+ )
+ self.__pre_hook = value
+
+ @property
+ def post_hooks(self) -> List[Hook]:
+ """Function steps to trigger after component main functionality."""
+ return self.__post_hook
+
+ @post_hooks.setter
+ def post_hooks(self, value: List[Hook]) -> None:
+ if not isinstance(value, list):
+ raise ValueError("post_hooks should be a list of Hooks.")
+ if not all(isinstance(item, Hook) for item in value):
+ raise ValueError(
+ "All items on post_hooks list should be an instance of Hook."
+ )
+ self.__post_hook = value
+
+ @property
+ def enable_pre_hooks(self) -> bool:
+ """Property to indicate if the component can define pre_hooks."""
+ return self.__enable_pre_hooks
+
+ @enable_pre_hooks.setter
+ def enable_pre_hooks(self, value: bool) -> None:
+ if not isinstance(value, bool):
+ raise ValueError("enable_pre_hooks accepts only boolean values.")
+ self.__enable_pre_hooks = value
+
+ @property
+ def enable_post_hooks(self) -> bool:
+ """Property to indicate if the component can define post_hooks."""
+ return self.__enable_post_hooks
+
+ @enable_post_hooks.setter
+ def enable_post_hooks(self, value: bool) -> None:
+ if not isinstance(value, bool):
+ raise ValueError("enable_post_hooks accepts only boolean values.")
+ self.__enable_post_hooks = value
+
+ def add_pre_hook(self, *hooks: Hook) -> HookableComponent:
+ """Add a pre-hook steps to the component.
+
+ Args:
+ hooks: Hook steps to add to pre_hook list.
+
+ Returns:
+ Component with the Hook inserted in pre_hook list.
+
+ Raises:
+ ValueError: if the component does not accept pre-hooks.
+ """
+ if not self.enable_pre_hooks:
+ raise ValueError("This component does not enable adding pre-hooks")
+ self.pre_hooks += list(hooks)
+ return self
+
+ def add_post_hook(self, *hooks: Hook) -> HookableComponent:
+ """Add a post-hook steps to the component.
+
+ Args:
+ hooks: Hook steps to add to post_hook list.
+
+ Returns:
+ Component with the Hook inserted in post_hook list.
+
+ Raises:
+ ValueError: if the component does not accept post-hooks.
+ """
+ if not self.enable_post_hooks:
+ raise ValueError("This component does not enable adding post-hooks")
+ self.post_hooks += list(hooks)
+ return self
+
+ def run_pre_hooks(self, dataframe: DataFrame) -> DataFrame:
+ """Run all defined pre-hook steps from a given dataframe.
+
+ Args:
+ dataframe: data to input in the defined pre-hook steps.
+
+ Returns:
+ dataframe after passing for all defined pre-hooks.
+ """
+ for hook in self.pre_hooks:
+ dataframe = hook.run(dataframe)
+ return dataframe
+
+ def run_post_hooks(self, dataframe: DataFrame) -> DataFrame:
+ """Run all defined post-hook steps from a given dataframe.
+
+ Args:
+ dataframe: data to input in the defined post-hook steps.
+
+ Returns:
+ dataframe after passing for all defined post-hooks.
+ """
+ for hook in self.post_hooks:
+ dataframe = hook.run(dataframe)
+ return dataframe
diff --git a/butterfree/hooks/schema_compatibility/__init__.py b/butterfree/hooks/schema_compatibility/__init__.py
new file mode 100644
index 00000000..edf748bf
--- /dev/null
+++ b/butterfree/hooks/schema_compatibility/__init__.py
@@ -0,0 +1,9 @@
+"""Holds Schema Compatibility Hooks definitions."""
+from butterfree.hooks.schema_compatibility.cassandra_table_schema_compatibility_hook import ( # noqa
+ CassandraTableSchemaCompatibilityHook,
+)
+from butterfree.hooks.schema_compatibility.spark_table_schema_compatibility_hook import ( # noqa
+ SparkTableSchemaCompatibilityHook,
+)
+
+__all__ = ["SparkTableSchemaCompatibilityHook", "CassandraTableSchemaCompatibilityHook"]
diff --git a/butterfree/hooks/schema_compatibility/cassandra_table_schema_compatibility_hook.py b/butterfree/hooks/schema_compatibility/cassandra_table_schema_compatibility_hook.py
new file mode 100644
index 00000000..cdb40472
--- /dev/null
+++ b/butterfree/hooks/schema_compatibility/cassandra_table_schema_compatibility_hook.py
@@ -0,0 +1,58 @@
+"""Cassandra table schema compatibility Hook definition."""
+
+from pyspark.sql import DataFrame
+
+from butterfree.clients import CassandraClient
+from butterfree.constants import DataType
+from butterfree.hooks.hook import Hook
+
+
+class CassandraTableSchemaCompatibilityHook(Hook):
+ """Hook to verify the schema compatibility with a Cassandra's table.
+
+ Verifies if all columns presented on the dataframe exists and are the same
+ type on the target Cassandra's table.
+
+ Attributes:
+ cassandra_client: client to connect to Cassandra DB.
+ table: table name.
+ """
+
+ def __init__(self, cassandra_client: CassandraClient, table: str):
+ self.cassandra_client = cassandra_client
+ self.table = table
+
+ def run(self, dataframe: DataFrame) -> DataFrame:
+ """Check the schema compatibility from a given Dataframe.
+
+ This method does not change anything on the Dataframe.
+
+ Args:
+ dataframe: dataframe to verify schema compatibility.
+
+ Returns:
+ unchanged dataframe.
+
+ Raises:
+ ValueError if the schemas are incompatible.
+ """
+ table_schema = self.cassandra_client.get_schema(self.table)
+ type_cassandra = [
+ type.cassandra
+ for field_id in range(len(dataframe.schema.fieldNames()))
+ for type in DataType
+ if dataframe.schema.fields.__getitem__(field_id).dataType == type.spark
+ ]
+ schema = [
+ {"column_name": f"{column}", "type": f"{type}"}
+ for column, type in zip(dataframe.columns, type_cassandra)
+ ]
+
+ if not all([column in table_schema for column in schema]):
+ raise ValueError(
+ "There's a schema incompatibility "
+ "between the defined dataframe and the Cassandra table.\n"
+ f"Dataframe schema = {schema}"
+ f"Target table schema = {table_schema}"
+ )
+ return dataframe
diff --git a/butterfree/hooks/schema_compatibility/spark_table_schema_compatibility_hook.py b/butterfree/hooks/schema_compatibility/spark_table_schema_compatibility_hook.py
new file mode 100644
index 00000000..b08dd56a
--- /dev/null
+++ b/butterfree/hooks/schema_compatibility/spark_table_schema_compatibility_hook.py
@@ -0,0 +1,46 @@
+"""Spark table schema compatibility Hook definition."""
+
+from pyspark.sql import DataFrame
+
+from butterfree.clients import SparkClient
+from butterfree.hooks.hook import Hook
+
+
+class SparkTableSchemaCompatibilityHook(Hook):
+ """Hook to verify the schema compatibility with a Spark's table.
+
+ Verifies if all columns presented on the dataframe exists and are the same
+ type on the target Spark's table.
+
+ Attributes:
+ spark_client: client to connect to Spark's metastore.
+ table: table name.
+ database: database name.
+ """
+
+ def __init__(self, spark_client: SparkClient, table: str, database: str = None):
+ self.spark_client = spark_client
+ self.table_expression = (f"`{database}`." if database else "") + f"`{table}`"
+
+ def run(self, dataframe: DataFrame) -> DataFrame:
+ """Check the schema compatibility from a given Dataframe.
+
+ This method does not change anything on the Dataframe.
+
+ Args:
+ dataframe: dataframe to verify schema compatibility.
+
+ Returns:
+ unchanged dataframe.
+
+ Raises:
+ ValueError if the schemas are incompatible.
+ """
+ table_schema = self.spark_client.conn.table(self.table_expression).schema
+ if not all([column in table_schema for column in dataframe.schema]):
+ raise ValueError(
+ "The dataframe has a schema incompatible with the defined table.\n"
+ f"Dataframe schema = {dataframe.schema}"
+ f"Target table schema = {table_schema}"
+ )
+ return dataframe
diff --git a/butterfree/load/sink.py b/butterfree/load/sink.py
index b4bf93e8..0b0c10c9 100644
--- a/butterfree/load/sink.py
+++ b/butterfree/load/sink.py
@@ -5,13 +5,14 @@
from pyspark.sql.streaming import StreamingQuery
from butterfree.clients import SparkClient
+from butterfree.hooks import HookableComponent
from butterfree.load.writers.writer import Writer
from butterfree.transform import FeatureSet
from butterfree.validations import BasicValidation
from butterfree.validations.validation import Validation
-class Sink:
+class Sink(HookableComponent):
"""Define the destinations for the feature set pipeline.
A Sink is created from a set of writers. The main goal of the Sink is to
@@ -26,6 +27,8 @@ class Sink:
"""
def __init__(self, writers: List[Writer], validation: Optional[Validation] = None):
+ super().__init__()
+ self.enable_post_hooks = False
self.writers = writers
self.validation = validation
@@ -94,12 +97,16 @@ def flush(
Streaming handlers for each defined writer, if writing streaming dfs.
"""
+ pre_hook_df = self.run_pre_hooks(dataframe)
+
if self.validation is not None:
- self.validation.input(dataframe).check()
+ self.validation.input(pre_hook_df).check()
handlers = [
writer.write(
- feature_set=feature_set, dataframe=dataframe, spark_client=spark_client
+ feature_set=feature_set,
+ dataframe=pre_hook_df,
+ spark_client=spark_client,
)
for writer in self.writers
]
diff --git a/butterfree/load/writers/historical_feature_store_writer.py b/butterfree/load/writers/historical_feature_store_writer.py
index d70f68f0..456d9e6b 100644
--- a/butterfree/load/writers/historical_feature_store_writer.py
+++ b/butterfree/load/writers/historical_feature_store_writer.py
@@ -1,7 +1,7 @@
"""Holds the Historical Feature Store writer class."""
import os
-from typing import Union
+from typing import Any, Union
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import dayofmonth, month, year
@@ -12,6 +12,8 @@
from butterfree.constants import columns
from butterfree.constants.spark_constants import DEFAULT_NUM_PARTITIONS
from butterfree.dataframe_service import repartition_df
+from butterfree.hooks import Hook
+from butterfree.hooks.schema_compatibility import SparkTableSchemaCompatibilityHook
from butterfree.load.writers.writer import Writer
from butterfree.transform import FeatureSet
@@ -60,6 +62,20 @@ class HistoricalFeatureStoreWriter(Writer):
For what settings you can use on S3Config and default settings,
to read S3Config class.
+ We can write with interval mode, where HistoricalFeatureStoreWrite
+ will need to use Dynamic Partition Inserts,
+ the behaviour of OVERWRITE keyword is controlled by
+ spark.sql.sources.partitionOverwriteMode configuration property.
+ The dynamic overwrite mode is enabled Spark will only delete the
+ partitions for which it has data to be written to.
+ All the other partitions remain intact.
+
+ >>> spark_client = SparkClient()
+ >>> writer = HistoricalFeatureStoreWriter(interval_mode=True)
+ >>> writer.write(feature_set=feature_set,
+ ... dataframe=dataframe,
+ ... spark_client=spark_client)
+
We can instantiate HistoricalFeatureStoreWriter class to validate the df
to be written.
@@ -95,15 +111,17 @@ def __init__(
num_partitions: int = None,
validation_threshold: float = DEFAULT_VALIDATION_THRESHOLD,
debug_mode: bool = False,
+ interval_mode: bool = False,
+ check_schema_hook: Hook = None,
):
- super(HistoricalFeatureStoreWriter, self).__init__()
+ super(HistoricalFeatureStoreWriter, self).__init__(debug_mode, interval_mode)
self.db_config = db_config or MetastoreConfig()
self.database = database or environment.get_variable(
"FEATURE_STORE_HISTORICAL_DATABASE"
)
self.num_partitions = num_partitions or DEFAULT_NUM_PARTITIONS
self.validation_threshold = validation_threshold
- self.debug_mode = debug_mode
+ self.check_schema_hook = check_schema_hook
def write(
self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient,
@@ -122,7 +140,25 @@ def write(
"""
dataframe = self._create_partitions(dataframe)
- dataframe = self._apply_transformations(dataframe)
+ partition_df = self._apply_transformations(dataframe)
+
+ if self.debug_mode:
+ dataframe = partition_df
+ else:
+ dataframe = self.check_schema(
+ spark_client, partition_df, feature_set.name, self.database
+ )
+
+ if self.interval_mode:
+ if self.debug_mode:
+ spark_client.create_temporary_view(
+ dataframe=dataframe,
+ name=f"historical_feature_store__{feature_set.name}",
+ )
+ return
+
+ self._incremental_mode(feature_set, dataframe, spark_client)
+ return
if self.debug_mode:
spark_client.create_temporary_view(
@@ -132,6 +168,7 @@ def write(
return
s3_key = os.path.join("historical", feature_set.entity, feature_set.name)
+
spark_client.write_table(
dataframe=dataframe,
database=self.database,
@@ -140,6 +177,34 @@ def write(
**self.db_config.get_options(s3_key),
)
+ def _incremental_mode(
+ self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient
+ ) -> None:
+
+ partition_overwrite_mode = spark_client.conn.conf.get(
+ "spark.sql.sources.partitionOverwriteMode"
+ ).lower()
+
+ if partition_overwrite_mode != "dynamic":
+ raise RuntimeError(
+ "m=load_incremental_table, "
+ "spark.sql.sources.partitionOverwriteMode={}, "
+ "msg=partitionOverwriteMode have to be configured to 'dynamic'".format(
+ partition_overwrite_mode
+ )
+ )
+
+ s3_key = os.path.join("historical", feature_set.entity, feature_set.name)
+ options = {"path": self.db_config.get_options(s3_key).get("path")}
+
+ spark_client.write_dataframe(
+ dataframe=dataframe,
+ format_=self.db_config.format_,
+ mode=self.db_config.mode,
+ **options,
+ partitionBy=self.PARTITION_BY,
+ )
+
def _assert_validation_count(
self, table_name: str, written_count: int, dataframe_count: int
) -> None:
@@ -169,12 +234,26 @@ def validate(
"""
table_name = (
- f"{self.database}.{feature_set.name}"
- if not self.debug_mode
- else f"historical_feature_store__{feature_set.name}"
+ f"{feature_set.name}"
+ if self.interval_mode and not self.debug_mode
+ else (
+ f"{self.database}.{feature_set.name}"
+ if not self.debug_mode
+ else f"historical_feature_store__{feature_set.name}"
+ )
+ )
+
+ written_count = (
+ spark_client.read(
+ self.db_config.format_,
+ path=self.db_config.get_path_with_partitions(table_name, dataframe),
+ ).count()
+ if self.interval_mode and not self.debug_mode
+ else spark_client.read_table(table_name).count()
)
- written_count = spark_client.read_table(table_name).count()
+
dataframe_count = dataframe.count()
+
self._assert_validation_count(table_name, written_count, dataframe_count)
def _create_partitions(self, dataframe: DataFrame) -> DataFrame:
@@ -191,3 +270,21 @@ def _create_partitions(self, dataframe: DataFrame) -> DataFrame:
columns.PARTITION_DAY, dayofmonth(dataframe[columns.TIMESTAMP_COLUMN])
)
return repartition_df(dataframe, self.PARTITION_BY, self.num_partitions)
+
+ def check_schema(
+ self, client: Any, dataframe: DataFrame, table_name: str, database: str = None
+ ) -> DataFrame:
+ """Instantiate the schema check hook to check schema between dataframe and database.
+
+ Args:
+ client: client for Spark or Cassandra connections with external services.
+ dataframe: Spark dataframe containing data from a feature set.
+ table_name: table name where the dataframe will be saved.
+ database: database name where the dataframe will be saved.
+ """
+ if not self.check_schema_hook:
+ self.check_schema_hook = SparkTableSchemaCompatibilityHook(
+ client, table_name, database
+ )
+
+ return self.check_schema_hook.run(dataframe)
diff --git a/butterfree/load/writers/online_feature_store_writer.py b/butterfree/load/writers/online_feature_store_writer.py
index a81a1040..fade3789 100644
--- a/butterfree/load/writers/online_feature_store_writer.py
+++ b/butterfree/load/writers/online_feature_store_writer.py
@@ -7,9 +7,11 @@
from pyspark.sql.functions import col, row_number
from pyspark.sql.streaming import StreamingQuery
-from butterfree.clients import SparkClient
+from butterfree.clients import CassandraClient, SparkClient
from butterfree.configs.db import AbstractWriteConfig, CassandraConfig
from butterfree.constants.columns import TIMESTAMP_COLUMN
+from butterfree.hooks import Hook
+from butterfree.hooks.schema_compatibility import CassandraTableSchemaCompatibilityHook
from butterfree.load.writers.writer import Writer
from butterfree.transform import FeatureSet
@@ -66,6 +68,12 @@ class OnlineFeatureStoreWriter(Writer):
Both methods (writer and validate) will need the Spark Client,
Feature Set and DataFrame, to write or to validate,
according to OnlineFeatureStoreWriter class arguments.
+
+ There's an important aspect to be highlighted here: if you're using
+ the incremental mode, we do not check if your data is the newest before
+ writing to the online feature store.
+
+ This behavior is known and will be fixed soon.
"""
__name__ = "Online Feature Store Writer"
@@ -75,11 +83,13 @@ def __init__(
db_config: Union[AbstractWriteConfig, CassandraConfig] = None,
debug_mode: bool = False,
write_to_entity: bool = False,
+ interval_mode: bool = False,
+ check_schema_hook: Hook = None,
):
- super(OnlineFeatureStoreWriter, self).__init__()
+ super(OnlineFeatureStoreWriter, self).__init__(debug_mode, interval_mode)
self.db_config = db_config or CassandraConfig()
- self.debug_mode = debug_mode
self.write_to_entity = write_to_entity
+ self.check_schema_hook = check_schema_hook
@staticmethod
def filter_latest(dataframe: DataFrame, id_columns: List[Any]) -> DataFrame:
@@ -170,6 +180,22 @@ def write(
"""
table_name = feature_set.entity if self.write_to_entity else feature_set.name
+ if not self.debug_mode:
+ config = (
+ self.db_config
+ if self.db_config == CassandraConfig
+ else CassandraConfig()
+ )
+
+ cassandra_client = CassandraClient(
+ host=[config.host],
+ keyspace=config.keyspace,
+ user=config.username,
+ password=config.password,
+ )
+
+ dataframe = self.check_schema(cassandra_client, dataframe, table_name)
+
if dataframe.isStreaming:
dataframe = self._apply_transformations(dataframe)
if self.debug_mode:
@@ -236,3 +262,21 @@ def get_db_schema(self, feature_set: FeatureSet) -> List[Dict[Any, Any]]:
"""
db_schema = self.db_config.translate(feature_set.get_schema())
return db_schema
+
+ def check_schema(
+ self, client: Any, dataframe: DataFrame, table_name: str, database: str = None
+ ) -> DataFrame:
+ """Instantiate the schema check hook to check schema between dataframe and database.
+
+ Args:
+ client: client for Spark or Cassandra connections with external services.
+ dataframe: Spark dataframe containing data from a feature set.
+ table_name: table name where the dataframe will be saved.
+ database: database name where the dataframe will be saved.
+ """
+ if not self.check_schema_hook:
+ self.check_schema_hook = CassandraTableSchemaCompatibilityHook(
+ client, table_name
+ )
+
+ return self.check_schema_hook.run(dataframe)
diff --git a/butterfree/load/writers/writer.py b/butterfree/load/writers/writer.py
index f76b4c25..7e0f9018 100644
--- a/butterfree/load/writers/writer.py
+++ b/butterfree/load/writers/writer.py
@@ -7,10 +7,11 @@
from pyspark.sql.dataframe import DataFrame
from butterfree.clients import SparkClient
+from butterfree.hooks import HookableComponent
from butterfree.transform import FeatureSet
-class Writer(ABC):
+class Writer(ABC, HookableComponent):
"""Abstract base class for Writers.
Args:
@@ -18,8 +19,11 @@ class Writer(ABC):
"""
- def __init__(self) -> None:
+ def __init__(self, debug_mode: bool = False, interval_mode: bool = False) -> None:
+ super().__init__()
self.transformations: List[Dict[str, Any]] = []
+ self.debug_mode = debug_mode
+ self.interval_mode = interval_mode
def with_(
self, transformer: Callable[..., DataFrame], *args: Any, **kwargs: Any
@@ -70,6 +74,19 @@ def write(
"""
+ @abstractmethod
+ def check_schema(
+ self, client: Any, dataframe: DataFrame, table_name: str, database: str = None
+ ) -> DataFrame:
+ """Instantiate the schema check hook to check schema between dataframe and database.
+
+ Args:
+ client: client for Spark or Cassandra connections with external services.
+ dataframe: Spark dataframe containing data from a feature set.
+ table_name: table name where the dataframe will be saved.
+ database: database name where the dataframe will be saved.
+ """
+
@abstractmethod
def validate(
self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient
diff --git a/butterfree/pipelines/feature_set_pipeline.py b/butterfree/pipelines/feature_set_pipeline.py
index ce1b7ba4..8aec54ec 100644
--- a/butterfree/pipelines/feature_set_pipeline.py
+++ b/butterfree/pipelines/feature_set_pipeline.py
@@ -40,11 +40,12 @@ class FeatureSetPipeline:
... )
>>> from butterfree.load import Sink
>>> from butterfree.load.writers import HistoricalFeatureStoreWriter
- >>> import pyspark.sql.functions as F
+ >>> from pyspark.sql import functions
>>> def divide(df, fs, column1, column2):
... name = fs.get_output_columns()[0]
- ... df = df.withColumn(name, F.col(column1) / F.col(column2))
+ ... df = df.withColumn(name,
+ ... functions.col(column1) / functions.col(column2))
... return df
>>> pipeline = FeatureSetPipeline(
@@ -67,7 +68,8 @@ class FeatureSetPipeline:
... name="feature1",
... description="test",
... transformation=SparkFunctionTransform(
- ... functions=[F.avg, F.stddev_pop]
+ ... functions=[Function(functions.avg, DataType.DOUBLE),
+ ... Function(functions.stddev_pop, DataType.DOUBLE)],
... ).with_window(
... partition_by="id",
... order_by=TIMESTAMP_COLUMN,
@@ -113,6 +115,19 @@ class FeatureSetPipeline:
the defined sources, compute all the transformations and save the data
to the specified locations.
+ We can run the pipeline over a range of dates by passing an end-date
+ and a start-date, where it will only bring data within this date range.
+
+ >>> pipeline.run(end_date="2020-08-04", start_date="2020-07-04")
+
+ Or run up to a date, where it will only bring data up to the specific date.
+
+ >>> pipeline.run(end_date="2020-08-04")
+
+ Or just a specific date, where you will only bring data for that day.
+
+ >>> pipeline.run_for_date(execution_date="2020-08-04")
+
"""
def __init__(
@@ -179,6 +194,7 @@ def run(
partition_by: List[str] = None,
order_by: List[str] = None,
num_processors: int = None,
+ start_date: str = None,
) -> None:
"""Runs the defined feature set pipeline.
@@ -192,7 +208,11 @@ def run(
soon. Use only if strictly necessary.
"""
- dataframe = self.source.construct(client=self.spark_client)
+ dataframe = self.source.construct(
+ client=self.spark_client,
+ start_date=self.feature_set.define_start_date(start_date),
+ end_date=end_date,
+ )
if partition_by:
order_by = order_by or partition_by
@@ -203,6 +223,7 @@ def run(
dataframe = self.feature_set.construct(
dataframe=dataframe,
client=self.spark_client,
+ start_date=start_date,
end_date=end_date,
num_processors=num_processors,
)
@@ -219,3 +240,30 @@ def run(
feature_set=self.feature_set,
spark_client=self.spark_client,
)
+
+ def run_for_date(
+ self,
+ execution_date: str = None,
+ partition_by: List[str] = None,
+ order_by: List[str] = None,
+ num_processors: int = None,
+ ) -> None:
+ """Runs the defined feature set pipeline for a specific date.
+
+ The pipeline consists in the following steps:
+
+ - Constructs the input dataframe from the data source.
+ - Construct the feature set dataframe using the defined Features.
+ - Load the data to the configured sink locations.
+
+ It's important to notice, however, that both parameters partition_by
+ and num_processors are WIP, we intend to enhance their functionality
+ soon. Use only if strictly necessary.
+ """
+ self.run(
+ start_date=execution_date,
+ end_date=execution_date,
+ partition_by=partition_by,
+ order_by=order_by,
+ num_processors=num_processors,
+ )
diff --git a/butterfree/transform/aggregated_feature_set.py b/butterfree/transform/aggregated_feature_set.py
index f43c12d5..a19efb35 100644
--- a/butterfree/transform/aggregated_feature_set.py
+++ b/butterfree/transform/aggregated_feature_set.py
@@ -1,6 +1,6 @@
"""AggregatedFeatureSet entity."""
import itertools
-from datetime import timedelta
+from datetime import datetime, timedelta
from functools import reduce
from typing import Any, Dict, List, Optional, Union
@@ -8,6 +8,7 @@
from pyspark.sql import DataFrame, functions
from butterfree.clients import SparkClient
+from butterfree.constants.window_definitions import ALLOWED_WINDOWS
from butterfree.dataframe_service import repartition_df
from butterfree.transform import FeatureSet
from butterfree.transform.features import Feature, KeyFeature, TimestampFeature
@@ -488,12 +489,45 @@ def get_schema(self) -> List[Dict[str, Any]]:
return schema
+ @staticmethod
+ def _get_biggest_window_in_days(definitions: List[str]) -> float:
+ windows_list = []
+ for window in definitions:
+ windows_list.append(
+ int(window.split()[0]) * ALLOWED_WINDOWS[window.split()[1]]
+ )
+ return max(windows_list) / (60 * 60 * 24)
+
+ def define_start_date(self, start_date: str = None) -> Optional[str]:
+ """Get aggregated feature set start date.
+
+ Args:
+ start_date: start date regarding source dataframe.
+
+ Returns:
+ start date.
+ """
+ if self._windows and start_date:
+ window_definition = [
+ definition.frame_boundaries.window_definition
+ for definition in self._windows
+ ]
+ biggest_window = self._get_biggest_window_in_days(window_definition)
+
+ return (
+ datetime.strptime(start_date, "%Y-%m-%d")
+ - timedelta(days=int(biggest_window) + 1)
+ ).strftime("%Y-%m-%d")
+
+ return start_date
+
def construct(
self,
dataframe: DataFrame,
client: SparkClient,
end_date: str = None,
num_processors: int = None,
+ start_date: str = None,
) -> DataFrame:
"""Use all the features to build the feature set dataframe.
@@ -506,6 +540,7 @@ def construct(
client: client responsible for connecting to Spark session.
end_date: user defined max date for having aggregated data (exclusive).
num_processors: cluster total number of processors for repartitioning.
+ start_date: user defined min date for having aggregated data.
Returns:
Spark dataframe with all the feature columns.
@@ -519,10 +554,12 @@ def construct(
if not isinstance(dataframe, DataFrame):
raise ValueError("source_df must be a dataframe")
+ pre_hook_df = self.run_pre_hooks(dataframe)
+
output_df = reduce(
lambda df, feature: feature.transform(df),
self.keys + [self.timestamp],
- dataframe,
+ pre_hook_df,
)
if self._windows and end_date is not None:
@@ -558,6 +595,10 @@ def construct(
else:
output_df = self._aggregate(output_df, features=self.features)
+ output_df = self.incremental_strategy.filter_with_incremental_strategy(
+ dataframe=output_df, start_date=start_date, end_date=end_date
+ )
+
output_df = output_df.select(*self.columns).replace( # type: ignore
float("nan"), None
)
@@ -565,4 +606,6 @@ def construct(
output_df = self._filter_duplicated_rows(output_df)
output_df.cache().count()
- return output_df
+ post_hook_df = self.run_post_hooks(output_df)
+
+ return post_hook_df
diff --git a/butterfree/transform/feature_set.py b/butterfree/transform/feature_set.py
index c35e90fa..c2e40a49 100644
--- a/butterfree/transform/feature_set.py
+++ b/butterfree/transform/feature_set.py
@@ -1,7 +1,7 @@
"""FeatureSet entity."""
import itertools
from functools import reduce
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Optional
import pyspark.sql.functions as F
from pyspark.sql import Window
@@ -9,6 +9,8 @@
from butterfree.clients import SparkClient
from butterfree.constants.columns import TIMESTAMP_COLUMN
+from butterfree.dataframe_service import IncrementalStrategy
+from butterfree.hooks import HookableComponent
from butterfree.transform.features import Feature, KeyFeature, TimestampFeature
from butterfree.transform.transformations import (
AggregatedTransform,
@@ -16,7 +18,7 @@
)
-class FeatureSet:
+class FeatureSet(HookableComponent):
"""Holds metadata about the feature set and constructs the final dataframe.
Attributes:
@@ -106,12 +108,14 @@ def __init__(
timestamp: TimestampFeature,
features: List[Feature],
) -> None:
+ super().__init__()
self.name = name
self.entity = entity
self.description = description
self.keys = keys
self.timestamp = timestamp
self.features = features
+ self.incremental_strategy = IncrementalStrategy(column=TIMESTAMP_COLUMN)
@property
def name(self) -> str:
@@ -243,9 +247,6 @@ def columns(self) -> List[str]:
def get_schema(self) -> List[Dict[str, Any]]:
"""Get feature set schema.
- Args:
- feature_set: object processed with feature set metadata.
-
Returns:
List of dicts regarding cassandra feature set schema.
@@ -378,12 +379,24 @@ def _filter_duplicated_rows(self, df: DataFrame) -> DataFrame:
return df.select([column for column in self.columns])
+ def define_start_date(self, start_date: str = None) -> Optional[str]:
+ """Get feature set start date.
+
+ Args:
+ start_date: start date regarding source dataframe.
+
+ Returns:
+ start date.
+ """
+ return start_date
+
def construct(
self,
dataframe: DataFrame,
client: SparkClient,
end_date: str = None,
num_processors: int = None,
+ start_date: str = None,
) -> DataFrame:
"""Use all the features to build the feature set dataframe.
@@ -393,7 +406,8 @@ def construct(
Args:
dataframe: input dataframe to be transformed by the features.
client: client responsible for connecting to Spark session.
- end_date: user defined base date.
+ start_date: user defined start date.
+ end_date: user defined end date.
num_processors: cluster total number of processors for repartitioning.
Returns:
@@ -403,14 +417,22 @@ def construct(
if not isinstance(dataframe, DataFrame):
raise ValueError("source_df must be a dataframe")
+ pre_hook_df = self.run_pre_hooks(dataframe)
+
output_df = reduce(
lambda df, feature: feature.transform(df),
self.keys + [self.timestamp] + self.features,
- dataframe,
+ pre_hook_df,
).select(*self.columns)
if not output_df.isStreaming:
output_df = self._filter_duplicated_rows(output_df)
output_df.cache().count()
- return output_df
+ output_df = self.incremental_strategy.filter_with_incremental_strategy(
+ dataframe=output_df, start_date=start_date, end_date=end_date
+ )
+
+ post_hook_df = self.run_post_hooks(output_df)
+
+ return post_hook_df
diff --git a/butterfree/transform/utils/window_spec.py b/butterfree/transform/utils/window_spec.py
index f3a392f6..a270fec0 100644
--- a/butterfree/transform/utils/window_spec.py
+++ b/butterfree/transform/utils/window_spec.py
@@ -5,6 +5,7 @@
from pyspark.sql import Column, WindowSpec, functions
from butterfree.constants.columns import TIMESTAMP_COLUMN
+from butterfree.constants.window_definitions import ALLOWED_WINDOWS
class FrameBoundaries:
@@ -16,21 +17,6 @@ class FrameBoundaries:
it can be second(s), minute(s), hour(s), day(s), week(s) and year(s),
"""
- __ALLOWED_WINDOWS = {
- "second": 1,
- "seconds": 1,
- "minute": 60,
- "minutes": 60,
- "hour": 3600,
- "hours": 3600,
- "day": 86400,
- "days": 86400,
- "week": 604800,
- "weeks": 604800,
- "year": 29030400,
- "years": 29030400,
- }
-
def __init__(self, mode: Optional[str], window_definition: str):
self.mode = mode
self.window_definition = window_definition
@@ -46,7 +32,7 @@ def window_size(self) -> int:
def window_unit(self) -> str:
"""Returns window unit."""
unit = self.window_definition.split()[1]
- if unit not in self.__ALLOWED_WINDOWS and self.mode != "row_windows":
+ if unit not in ALLOWED_WINDOWS and self.mode != "row_windows":
raise ValueError("Not allowed")
return unit
@@ -59,7 +45,7 @@ def get(self, window: WindowSpec) -> Any:
span = self.window_size - 1
return window.rowsBetween(-span, 0)
if self.mode == "fixed_windows":
- span = self.__ALLOWED_WINDOWS[self.window_unit] * self.window_size
+ span = ALLOWED_WINDOWS[self.window_unit] * self.window_size
return window.rangeBetween(-span, 0)
diff --git a/examples/interval_runs/interval_runs.ipynb b/examples/interval_runs/interval_runs.ipynb
new file mode 100644
index 00000000..e234da8a
--- /dev/null
+++ b/examples/interval_runs/interval_runs.ipynb
@@ -0,0 +1,2152 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# #5 Discovering Butterfree - Interval Runs\n",
+ "\n",
+ "Welcome to Discovering Butterfree tutorial series!\n",
+ "\n",
+ "This is the fifth tutorial of this series: its goal is to cover interval runs.\n",
+ "\n",
+ "Before diving into the tutorial make sure you have a basic understanding of these main data concepts: features, feature sets and the \"Feature Store Architecture\", you can read more about this [here]."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Example:\n",
+ "\n",
+ "Simulating the following scenario (the same from previous tutorials):\n",
+ "\n",
+ "- We want to create a feature set with features about houses for rent (listings).\n",
+ "\n",
+ "\n",
+ "We have an input dataset:\n",
+ "\n",
+ "- Table: `listing_events`. Table with data about events of house listings.\n",
+ "\n",
+ "\n",
+ "Our desire is to have three resulting datasets with the following schema:\n",
+ "\n",
+ "* id: **int**;\n",
+ "* timestamp: **timestamp**;\n",
+ "* rent__avg_over_1_day_rolling_windows: **double**;\n",
+ "* rent__stddev_pop_over_1_day_rolling_windows: **double**.\n",
+ " \n",
+ "The first dataset will be computed with just an end date time limit. The second one, on the other hand, uses both start and end date in order to filter data. Finally, the third one will be the result of a daily run. You can understand more about these definitions in our documentation.\n",
+ "\n",
+ "The following code blocks will show how to generate this feature set using Butterfree library:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# setup spark\n",
+ "from pyspark import SparkContext, SparkConf\n",
+ "from pyspark.sql import session\n",
+ "\n",
+ "conf = SparkConf().setAll([('spark.driver.host','127.0.0.1'), ('spark.sql.session.timeZone', 'UTC')])\n",
+ "sc = SparkContext(conf=conf)\n",
+ "spark = session.SparkSession(sc)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# fix working dir\n",
+ "import pathlib\n",
+ "import os\n",
+ "path = os.path.join(pathlib.Path().absolute(), '../..')\n",
+ "os.chdir(path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Showing test data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "listing_events_df = spark.read.json(f\"{path}/examples/data/listing_events.json\")\n",
+ "listing_events_df.createOrReplaceTempView(\"listing_events\") # creating listing_events view\n",
+ "\n",
+ "region = spark.read.json(f\"{path}/examples/data/region.json\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Listing events table:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " area | \n",
+ " bathrooms | \n",
+ " bedrooms | \n",
+ " id | \n",
+ " region_id | \n",
+ " rent | \n",
+ " timestamp | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 50 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 1300 | \n",
+ " 1588302000000 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 50 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 2000 | \n",
+ " 1588647600000 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 100 | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 1500 | \n",
+ " 1588734000000 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 100 | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 2500 | \n",
+ " 1589252400000 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 150 | \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 3 | \n",
+ " 3000 | \n",
+ " 1589943600000 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 175 | \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 4 | \n",
+ " 4 | \n",
+ " 3200 | \n",
+ " 1589943600000 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 250 | \n",
+ " 3 | \n",
+ " 3 | \n",
+ " 5 | \n",
+ " 5 | \n",
+ " 3200 | \n",
+ " 1590030000000 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 225 | \n",
+ " 3 | \n",
+ " 2 | \n",
+ " 6 | \n",
+ " 6 | \n",
+ " 3200 | \n",
+ " 1590116400000 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " area bathrooms bedrooms id region_id rent timestamp\n",
+ "0 50 1 1 1 1 1300 1588302000000\n",
+ "1 50 1 1 1 1 2000 1588647600000\n",
+ "2 100 1 2 2 2 1500 1588734000000\n",
+ "3 100 1 2 2 2 2500 1589252400000\n",
+ "4 150 2 2 3 3 3000 1589943600000\n",
+ "5 175 2 2 4 4 3200 1589943600000\n",
+ "6 250 3 3 5 5 3200 1590030000000\n",
+ "7 225 3 2 6 6 3200 1590116400000"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "listing_events_df.toPandas()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Region table:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " city | \n",
+ " id | \n",
+ " lat | \n",
+ " lng | \n",
+ " region | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " Cerulean | \n",
+ " 1 | \n",
+ " 73.44489 | \n",
+ " 31.75030 | \n",
+ " Kanto | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " Veridian | \n",
+ " 2 | \n",
+ " -9.43510 | \n",
+ " -167.11772 | \n",
+ " Kanto | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " Cinnabar | \n",
+ " 3 | \n",
+ " 29.73043 | \n",
+ " 117.66164 | \n",
+ " Kanto | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Pallet | \n",
+ " 4 | \n",
+ " -52.95717 | \n",
+ " -81.15251 | \n",
+ " Kanto | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " Violet | \n",
+ " 5 | \n",
+ " -47.35798 | \n",
+ " -178.77255 | \n",
+ " Johto | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " Olivine | \n",
+ " 6 | \n",
+ " 51.72820 | \n",
+ " 46.21958 | \n",
+ " Johto | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " city id lat lng region\n",
+ "0 Cerulean 1 73.44489 31.75030 Kanto\n",
+ "1 Veridian 2 -9.43510 -167.11772 Kanto\n",
+ "2 Cinnabar 3 29.73043 117.66164 Kanto\n",
+ "3 Pallet 4 -52.95717 -81.15251 Kanto\n",
+ "4 Violet 5 -47.35798 -178.77255 Johto\n",
+ "5 Olivine 6 51.72820 46.21958 Johto"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "region.toPandas()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Extract\n",
+ "\n",
+ "- For the extract part, we need the `Source` entity and the `FileReader` for the data we have;\n",
+ "- We need to declare a query in order to bring the results from our lonely reader (it's as simples as a select all statement)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from butterfree.clients import SparkClient\n",
+ "from butterfree.extract import Source\n",
+ "from butterfree.extract.readers import FileReader, TableReader\n",
+ "from butterfree.extract.pre_processing import filter\n",
+ "\n",
+ "readers = [\n",
+ " TableReader(id=\"listing_events\", table=\"listing_events\",),\n",
+ " FileReader(id=\"region\", path=f\"{path}/examples/data/region.json\", format=\"json\",)\n",
+ "]\n",
+ "\n",
+ "query = \"\"\"\n",
+ "select\n",
+ " listing_events.*,\n",
+ " region.city,\n",
+ " region.region,\n",
+ " region.lat,\n",
+ " region.lng,\n",
+ " region.region as region_name\n",
+ "from\n",
+ " listing_events\n",
+ " join region\n",
+ " on listing_events.region_id = region.id\n",
+ "\"\"\"\n",
+ "\n",
+ "source = Source(readers=readers, query=query)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "spark_client = SparkClient()\n",
+ "source_df = source.construct(spark_client)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "And, finally, it's possible to see the results from building our souce dataset:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " area | \n",
+ " bathrooms | \n",
+ " bedrooms | \n",
+ " id | \n",
+ " region_id | \n",
+ " rent | \n",
+ " timestamp | \n",
+ " city | \n",
+ " region | \n",
+ " lat | \n",
+ " lng | \n",
+ " region_name | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 50 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 1300 | \n",
+ " 1588302000000 | \n",
+ " Cerulean | \n",
+ " Kanto | \n",
+ " 73.44489 | \n",
+ " 31.75030 | \n",
+ " Kanto | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 50 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 2000 | \n",
+ " 1588647600000 | \n",
+ " Cerulean | \n",
+ " Kanto | \n",
+ " 73.44489 | \n",
+ " 31.75030 | \n",
+ " Kanto | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 100 | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 1500 | \n",
+ " 1588734000000 | \n",
+ " Veridian | \n",
+ " Kanto | \n",
+ " -9.43510 | \n",
+ " -167.11772 | \n",
+ " Kanto | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 100 | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 2500 | \n",
+ " 1589252400000 | \n",
+ " Veridian | \n",
+ " Kanto | \n",
+ " -9.43510 | \n",
+ " -167.11772 | \n",
+ " Kanto | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 150 | \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 3 | \n",
+ " 3000 | \n",
+ " 1589943600000 | \n",
+ " Cinnabar | \n",
+ " Kanto | \n",
+ " 29.73043 | \n",
+ " 117.66164 | \n",
+ " Kanto | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 175 | \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 4 | \n",
+ " 4 | \n",
+ " 3200 | \n",
+ " 1589943600000 | \n",
+ " Pallet | \n",
+ " Kanto | \n",
+ " -52.95717 | \n",
+ " -81.15251 | \n",
+ " Kanto | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 250 | \n",
+ " 3 | \n",
+ " 3 | \n",
+ " 5 | \n",
+ " 5 | \n",
+ " 3200 | \n",
+ " 1590030000000 | \n",
+ " Violet | \n",
+ " Johto | \n",
+ " -47.35798 | \n",
+ " -178.77255 | \n",
+ " Johto | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 225 | \n",
+ " 3 | \n",
+ " 2 | \n",
+ " 6 | \n",
+ " 6 | \n",
+ " 3200 | \n",
+ " 1590116400000 | \n",
+ " Olivine | \n",
+ " Johto | \n",
+ " 51.72820 | \n",
+ " 46.21958 | \n",
+ " Johto | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " area bathrooms bedrooms id region_id rent timestamp city \\\n",
+ "0 50 1 1 1 1 1300 1588302000000 Cerulean \n",
+ "1 50 1 1 1 1 2000 1588647600000 Cerulean \n",
+ "2 100 1 2 2 2 1500 1588734000000 Veridian \n",
+ "3 100 1 2 2 2 2500 1589252400000 Veridian \n",
+ "4 150 2 2 3 3 3000 1589943600000 Cinnabar \n",
+ "5 175 2 2 4 4 3200 1589943600000 Pallet \n",
+ "6 250 3 3 5 5 3200 1590030000000 Violet \n",
+ "7 225 3 2 6 6 3200 1590116400000 Olivine \n",
+ "\n",
+ " region lat lng region_name \n",
+ "0 Kanto 73.44489 31.75030 Kanto \n",
+ "1 Kanto 73.44489 31.75030 Kanto \n",
+ "2 Kanto -9.43510 -167.11772 Kanto \n",
+ "3 Kanto -9.43510 -167.11772 Kanto \n",
+ "4 Kanto 29.73043 117.66164 Kanto \n",
+ "5 Kanto -52.95717 -81.15251 Kanto \n",
+ "6 Johto -47.35798 -178.77255 Johto \n",
+ "7 Johto 51.72820 46.21958 Johto "
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "source_df.toPandas()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Transform\n",
+ "- At the transform part, a set of `Feature` objects is declared;\n",
+ "- An Instance of `AggregatedFeatureSet` is used to hold the features;\n",
+ "- An `AggregatedFeatureSet` can only be created when it is possible to define a unique tuple formed by key columns and a time reference. This is an **architectural requirement** for the data. So least one `KeyFeature` and one `TimestampFeature` is needed;\n",
+ "- Every `Feature` needs a unique name, a description, and a data-type definition. Besides, in the case of the `AggregatedFeatureSet`, it's also mandatory to have an `AggregatedTransform` operator;\n",
+ "- An `AggregatedTransform` operator is used, as the name suggests, to define aggregation functions."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pyspark.sql import functions as F\n",
+ "\n",
+ "from butterfree.transform.aggregated_feature_set import AggregatedFeatureSet\n",
+ "from butterfree.transform.features import Feature, KeyFeature, TimestampFeature\n",
+ "from butterfree.transform.transformations import AggregatedTransform\n",
+ "from butterfree.constants import DataType\n",
+ "from butterfree.transform.utils import Function\n",
+ "\n",
+ "keys = [\n",
+ " KeyFeature(\n",
+ " name=\"id\",\n",
+ " description=\"Unique identificator code for houses.\",\n",
+ " dtype=DataType.BIGINT,\n",
+ " )\n",
+ "]\n",
+ "\n",
+ "# from_ms = True because the data originally is not in a Timestamp format.\n",
+ "ts_feature = TimestampFeature(from_ms=True)\n",
+ "\n",
+ "features = [\n",
+ " Feature(\n",
+ " name=\"rent\",\n",
+ " description=\"Rent value by month described in the listing.\",\n",
+ " transformation=AggregatedTransform(\n",
+ " functions=[\n",
+ " Function(F.avg, DataType.DOUBLE),\n",
+ " Function(F.stddev_pop, DataType.DOUBLE),\n",
+ " ],\n",
+ " filter_expression=\"region_name = 'Kanto'\",\n",
+ " ),\n",
+ " )\n",
+ "]\n",
+ "\n",
+ "aggregated_feature_set = AggregatedFeatureSet(\n",
+ " name=\"house_listings\",\n",
+ " entity=\"house\", # entity: to which \"business context\" this feature set belongs\n",
+ " description=\"Features describring a house listing.\",\n",
+ " keys=keys,\n",
+ " timestamp=ts_feature,\n",
+ " features=features,\n",
+ ").with_windows(definitions=[\"1 day\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here, we'll define out first aggregated feature set, with just an `end date` parameter:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "aggregated_feature_set_windows_df = aggregated_feature_set.construct(\n",
+ " source_df, \n",
+ " spark_client, \n",
+ " end_date=\"2020-05-30\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The resulting dataset is:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " id | \n",
+ " timestamp | \n",
+ " rent__avg_over_1_day_rolling_windows | \n",
+ " rent__stddev_pop_over_1_day_rolling_windows | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2020-05-01 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 2020-05-02 | \n",
+ " 1300.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1 | \n",
+ " 2020-05-03 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 2020-05-06 | \n",
+ " 2000.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 1 | \n",
+ " 2020-05-07 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 2 | \n",
+ " 2020-05-01 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 2 | \n",
+ " 2020-05-07 | \n",
+ " 1500.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 2 | \n",
+ " 2020-05-08 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 2 | \n",
+ " 2020-05-13 | \n",
+ " 2500.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 2 | \n",
+ " 2020-05-14 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 3 | \n",
+ " 2020-05-01 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 3 | \n",
+ " 2020-05-21 | \n",
+ " 3000.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 3 | \n",
+ " 2020-05-22 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 4 | \n",
+ " 2020-05-01 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 4 | \n",
+ " 2020-05-21 | \n",
+ " 3200.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 4 | \n",
+ " 2020-05-22 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 5 | \n",
+ " 2020-05-01 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 6 | \n",
+ " 2020-05-01 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " id timestamp rent__avg_over_1_day_rolling_windows \\\n",
+ "0 1 2020-05-01 NaN \n",
+ "1 1 2020-05-02 1300.0 \n",
+ "2 1 2020-05-03 NaN \n",
+ "3 1 2020-05-06 2000.0 \n",
+ "4 1 2020-05-07 NaN \n",
+ "5 2 2020-05-01 NaN \n",
+ "6 2 2020-05-07 1500.0 \n",
+ "7 2 2020-05-08 NaN \n",
+ "8 2 2020-05-13 2500.0 \n",
+ "9 2 2020-05-14 NaN \n",
+ "10 3 2020-05-01 NaN \n",
+ "11 3 2020-05-21 3000.0 \n",
+ "12 3 2020-05-22 NaN \n",
+ "13 4 2020-05-01 NaN \n",
+ "14 4 2020-05-21 3200.0 \n",
+ "15 4 2020-05-22 NaN \n",
+ "16 5 2020-05-01 NaN \n",
+ "17 6 2020-05-01 NaN \n",
+ "\n",
+ " rent__stddev_pop_over_1_day_rolling_windows \n",
+ "0 NaN \n",
+ "1 0.0 \n",
+ "2 NaN \n",
+ "3 0.0 \n",
+ "4 NaN \n",
+ "5 NaN \n",
+ "6 0.0 \n",
+ "7 NaN \n",
+ "8 0.0 \n",
+ "9 NaN \n",
+ "10 NaN \n",
+ "11 0.0 \n",
+ "12 NaN \n",
+ "13 NaN \n",
+ "14 0.0 \n",
+ "15 NaN \n",
+ "16 NaN \n",
+ "17 NaN "
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "aggregated_feature_set_windows_df.orderBy(\"id\", \"timestamp\").toPandas()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "It's possible to see that if we use both a `start date` and `end_date` values. Then we'll achieve a time slice of the last dataframe, as it's possible to see:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " id | \n",
+ " timestamp | \n",
+ " rent__avg_over_1_day_rolling_windows | \n",
+ " rent__stddev_pop_over_1_day_rolling_windows | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2020-05-06 | \n",
+ " 2000.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 2020-05-07 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 2020-05-06 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 2 | \n",
+ " 2020-05-07 | \n",
+ " 1500.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 2 | \n",
+ " 2020-05-08 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 2 | \n",
+ " 2020-05-13 | \n",
+ " 2500.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 2 | \n",
+ " 2020-05-14 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 3 | \n",
+ " 2020-05-06 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 3 | \n",
+ " 2020-05-21 | \n",
+ " 3000.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 4 | \n",
+ " 2020-05-06 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 4 | \n",
+ " 2020-05-21 | \n",
+ " 3200.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 5 | \n",
+ " 2020-05-06 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 6 | \n",
+ " 2020-05-06 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " id timestamp rent__avg_over_1_day_rolling_windows \\\n",
+ "0 1 2020-05-06 2000.0 \n",
+ "1 1 2020-05-07 NaN \n",
+ "2 2 2020-05-06 NaN \n",
+ "3 2 2020-05-07 1500.0 \n",
+ "4 2 2020-05-08 NaN \n",
+ "5 2 2020-05-13 2500.0 \n",
+ "6 2 2020-05-14 NaN \n",
+ "7 3 2020-05-06 NaN \n",
+ "8 3 2020-05-21 3000.0 \n",
+ "9 4 2020-05-06 NaN \n",
+ "10 4 2020-05-21 3200.0 \n",
+ "11 5 2020-05-06 NaN \n",
+ "12 6 2020-05-06 NaN \n",
+ "\n",
+ " rent__stddev_pop_over_1_day_rolling_windows \n",
+ "0 0.0 \n",
+ "1 NaN \n",
+ "2 NaN \n",
+ "3 0.0 \n",
+ "4 NaN \n",
+ "5 0.0 \n",
+ "6 NaN \n",
+ "7 NaN \n",
+ "8 0.0 \n",
+ "9 NaN \n",
+ "10 0.0 \n",
+ "11 NaN \n",
+ "12 NaN "
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "aggregated_feature_set.construct(\n",
+ " source_df, \n",
+ " spark_client, \n",
+ " end_date=\"2020-05-21\",\n",
+ " start_date=\"2020-05-06\",\n",
+ ").orderBy(\"id\", \"timestamp\").toPandas()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load\n",
+ "\n",
+ "- For the load part we need `Writer` instances and a `Sink`;\n",
+ "- `writers` define where to load the data;\n",
+ "- The `Sink` gets the transformed data (feature set) and trigger the load to all the defined `writers`;\n",
+ "- `debug_mode` will create a temporary view instead of trying to write in a real data store."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from butterfree.load.writers import (\n",
+ " HistoricalFeatureStoreWriter,\n",
+ " OnlineFeatureStoreWriter,\n",
+ ")\n",
+ "from butterfree.load import Sink\n",
+ "\n",
+ "writers = [HistoricalFeatureStoreWriter(debug_mode=True, interval_mode=True), \n",
+ " OnlineFeatureStoreWriter(debug_mode=True, interval_mode=True)]\n",
+ "sink = Sink(writers=writers)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Pipeline\n",
+ "\n",
+ "- The `Pipeline` entity wraps all the other defined elements.\n",
+ "- `run` command will trigger the execution of the pipeline, end-to-end."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from butterfree.pipelines import FeatureSetPipeline\n",
+ "\n",
+ "pipeline = FeatureSetPipeline(source=source, feature_set=aggregated_feature_set, sink=sink)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The first run will use just an `end_date` as parameter:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "result_df = pipeline.run(end_date=\"2020-05-30\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " id | \n",
+ " timestamp | \n",
+ " rent__avg_over_1_day_rolling_windows | \n",
+ " rent__stddev_pop_over_1_day_rolling_windows | \n",
+ " year | \n",
+ " month | \n",
+ " day | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2020-05-01 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 2020-05-02 | \n",
+ " 1300.0 | \n",
+ " 0.0 | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 2 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1 | \n",
+ " 2020-05-03 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 3 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 1 | \n",
+ " 2020-05-06 | \n",
+ " 2000.0 | \n",
+ " 0.0 | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 6 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 1 | \n",
+ " 2020-05-07 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 7 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 2 | \n",
+ " 2020-05-01 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 2 | \n",
+ " 2020-05-07 | \n",
+ " 1500.0 | \n",
+ " 0.0 | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 7 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 2 | \n",
+ " 2020-05-08 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 8 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 2 | \n",
+ " 2020-05-13 | \n",
+ " 2500.0 | \n",
+ " 0.0 | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 13 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 2 | \n",
+ " 2020-05-14 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 14 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 3 | \n",
+ " 2020-05-01 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 3 | \n",
+ " 2020-05-21 | \n",
+ " 3000.0 | \n",
+ " 0.0 | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 21 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 3 | \n",
+ " 2020-05-22 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 22 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 4 | \n",
+ " 2020-05-01 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 4 | \n",
+ " 2020-05-21 | \n",
+ " 3200.0 | \n",
+ " 0.0 | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 21 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 4 | \n",
+ " 2020-05-22 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 22 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 5 | \n",
+ " 2020-05-01 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 6 | \n",
+ " 2020-05-01 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " id timestamp rent__avg_over_1_day_rolling_windows \\\n",
+ "0 1 2020-05-01 NaN \n",
+ "1 1 2020-05-02 1300.0 \n",
+ "2 1 2020-05-03 NaN \n",
+ "3 1 2020-05-06 2000.0 \n",
+ "4 1 2020-05-07 NaN \n",
+ "5 2 2020-05-01 NaN \n",
+ "6 2 2020-05-07 1500.0 \n",
+ "7 2 2020-05-08 NaN \n",
+ "8 2 2020-05-13 2500.0 \n",
+ "9 2 2020-05-14 NaN \n",
+ "10 3 2020-05-01 NaN \n",
+ "11 3 2020-05-21 3000.0 \n",
+ "12 3 2020-05-22 NaN \n",
+ "13 4 2020-05-01 NaN \n",
+ "14 4 2020-05-21 3200.0 \n",
+ "15 4 2020-05-22 NaN \n",
+ "16 5 2020-05-01 NaN \n",
+ "17 6 2020-05-01 NaN \n",
+ "\n",
+ " rent__stddev_pop_over_1_day_rolling_windows year month day \n",
+ "0 NaN 2020 5 1 \n",
+ "1 0.0 2020 5 2 \n",
+ "2 NaN 2020 5 3 \n",
+ "3 0.0 2020 5 6 \n",
+ "4 NaN 2020 5 7 \n",
+ "5 NaN 2020 5 1 \n",
+ "6 0.0 2020 5 7 \n",
+ "7 NaN 2020 5 8 \n",
+ "8 0.0 2020 5 13 \n",
+ "9 NaN 2020 5 14 \n",
+ "10 NaN 2020 5 1 \n",
+ "11 0.0 2020 5 21 \n",
+ "12 NaN 2020 5 22 \n",
+ "13 NaN 2020 5 1 \n",
+ "14 0.0 2020 5 21 \n",
+ "15 NaN 2020 5 22 \n",
+ "16 NaN 2020 5 1 \n",
+ "17 NaN 2020 5 1 "
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "spark.table(\"historical_feature_store__house_listings\").orderBy(\n",
+ " \"id\", \"timestamp\"\n",
+ ").orderBy(\"id\", \"timestamp\").toPandas()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " id | \n",
+ " timestamp | \n",
+ " rent__avg_over_1_day_rolling_windows | \n",
+ " rent__stddev_pop_over_1_day_rolling_windows | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2020-05-07 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 2020-05-14 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 2020-05-22 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 2020-05-22 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 5 | \n",
+ " 2020-05-01 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 6 | \n",
+ " 2020-05-01 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " id timestamp rent__avg_over_1_day_rolling_windows \\\n",
+ "0 1 2020-05-07 NaN \n",
+ "1 2 2020-05-14 NaN \n",
+ "2 3 2020-05-22 NaN \n",
+ "3 4 2020-05-22 NaN \n",
+ "4 5 2020-05-01 NaN \n",
+ "5 6 2020-05-01 NaN \n",
+ "\n",
+ " rent__stddev_pop_over_1_day_rolling_windows \n",
+ "0 NaN \n",
+ "1 NaN \n",
+ "2 NaN \n",
+ "3 NaN \n",
+ "4 NaN \n",
+ "5 NaN "
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "spark.table(\"online_feature_store__house_listings\").orderBy(\"id\", \"timestamp\").toPandas()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- We can see that we were able to create all the desired features in an easy way\n",
+ "- The **historical feature set** holds all the data, and we can see that it is partitioned by year, month and day (columns added in the `HistoricalFeatureStoreWriter`)\n",
+ "- In the **online feature set** there is only the latest data for each id"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The second run, on the other hand, will use both a `start_date` and `end_date` as parameters."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "result_df = pipeline.run(end_date=\"2020-05-21\", start_date=\"2020-05-06\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " id | \n",
+ " timestamp | \n",
+ " rent__avg_over_1_day_rolling_windows | \n",
+ " rent__stddev_pop_over_1_day_rolling_windows | \n",
+ " year | \n",
+ " month | \n",
+ " day | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2020-05-06 | \n",
+ " 2000.0 | \n",
+ " 0.0 | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 6 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 2020-05-07 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 7 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 2020-05-06 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 6 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 2 | \n",
+ " 2020-05-07 | \n",
+ " 1500.0 | \n",
+ " 0.0 | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 7 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 2 | \n",
+ " 2020-05-08 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 8 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 2 | \n",
+ " 2020-05-13 | \n",
+ " 2500.0 | \n",
+ " 0.0 | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 13 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 2 | \n",
+ " 2020-05-14 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 14 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 3 | \n",
+ " 2020-05-06 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 6 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 3 | \n",
+ " 2020-05-21 | \n",
+ " 3000.0 | \n",
+ " 0.0 | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 21 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 4 | \n",
+ " 2020-05-06 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 6 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 4 | \n",
+ " 2020-05-21 | \n",
+ " 3200.0 | \n",
+ " 0.0 | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 21 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 5 | \n",
+ " 2020-05-06 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 6 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 6 | \n",
+ " 2020-05-06 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 6 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " id timestamp rent__avg_over_1_day_rolling_windows \\\n",
+ "0 1 2020-05-06 2000.0 \n",
+ "1 1 2020-05-07 NaN \n",
+ "2 2 2020-05-06 NaN \n",
+ "3 2 2020-05-07 1500.0 \n",
+ "4 2 2020-05-08 NaN \n",
+ "5 2 2020-05-13 2500.0 \n",
+ "6 2 2020-05-14 NaN \n",
+ "7 3 2020-05-06 NaN \n",
+ "8 3 2020-05-21 3000.0 \n",
+ "9 4 2020-05-06 NaN \n",
+ "10 4 2020-05-21 3200.0 \n",
+ "11 5 2020-05-06 NaN \n",
+ "12 6 2020-05-06 NaN \n",
+ "\n",
+ " rent__stddev_pop_over_1_day_rolling_windows year month day \n",
+ "0 0.0 2020 5 6 \n",
+ "1 NaN 2020 5 7 \n",
+ "2 NaN 2020 5 6 \n",
+ "3 0.0 2020 5 7 \n",
+ "4 NaN 2020 5 8 \n",
+ "5 0.0 2020 5 13 \n",
+ "6 NaN 2020 5 14 \n",
+ "7 NaN 2020 5 6 \n",
+ "8 0.0 2020 5 21 \n",
+ "9 NaN 2020 5 6 \n",
+ "10 0.0 2020 5 21 \n",
+ "11 NaN 2020 5 6 \n",
+ "12 NaN 2020 5 6 "
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "spark.table(\"historical_feature_store__house_listings\").orderBy(\n",
+ " \"id\", \"timestamp\"\n",
+ ").orderBy(\"id\", \"timestamp\").toPandas()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " id | \n",
+ " timestamp | \n",
+ " rent__avg_over_1_day_rolling_windows | \n",
+ " rent__stddev_pop_over_1_day_rolling_windows | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2020-05-07 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 2020-05-14 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 2020-05-21 | \n",
+ " 3000.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 2020-05-21 | \n",
+ " 3200.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 5 | \n",
+ " 2020-05-06 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 6 | \n",
+ " 2020-05-06 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " id timestamp rent__avg_over_1_day_rolling_windows \\\n",
+ "0 1 2020-05-07 NaN \n",
+ "1 2 2020-05-14 NaN \n",
+ "2 3 2020-05-21 3000.0 \n",
+ "3 4 2020-05-21 3200.0 \n",
+ "4 5 2020-05-06 NaN \n",
+ "5 6 2020-05-06 NaN \n",
+ "\n",
+ " rent__stddev_pop_over_1_day_rolling_windows \n",
+ "0 NaN \n",
+ "1 NaN \n",
+ "2 0.0 \n",
+ "3 0.0 \n",
+ "4 NaN \n",
+ "5 NaN "
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "spark.table(\"online_feature_store__house_listings\").orderBy(\"id\", \"timestamp\").toPandas()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Finally, the third run, will use only an `execution_date` as a parameter."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "result_df = pipeline.run_for_date(execution_date=\"2020-05-21\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " id | \n",
+ " timestamp | \n",
+ " rent__avg_over_1_day_rolling_windows | \n",
+ " rent__stddev_pop_over_1_day_rolling_windows | \n",
+ " year | \n",
+ " month | \n",
+ " day | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2020-05-21 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 21 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 2020-05-21 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 21 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 2020-05-21 | \n",
+ " 3000.0 | \n",
+ " 0.0 | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 21 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 2020-05-21 | \n",
+ " 3200.0 | \n",
+ " 0.0 | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 21 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 5 | \n",
+ " 2020-05-21 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 21 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 6 | \n",
+ " 2020-05-21 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 2020 | \n",
+ " 5 | \n",
+ " 21 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " id timestamp rent__avg_over_1_day_rolling_windows \\\n",
+ "0 1 2020-05-21 NaN \n",
+ "1 2 2020-05-21 NaN \n",
+ "2 3 2020-05-21 3000.0 \n",
+ "3 4 2020-05-21 3200.0 \n",
+ "4 5 2020-05-21 NaN \n",
+ "5 6 2020-05-21 NaN \n",
+ "\n",
+ " rent__stddev_pop_over_1_day_rolling_windows year month day \n",
+ "0 NaN 2020 5 21 \n",
+ "1 NaN 2020 5 21 \n",
+ "2 0.0 2020 5 21 \n",
+ "3 0.0 2020 5 21 \n",
+ "4 NaN 2020 5 21 \n",
+ "5 NaN 2020 5 21 "
+ ]
+ },
+ "execution_count": 22,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "spark.table(\"historical_feature_store__house_listings\").orderBy(\n",
+ " \"id\", \"timestamp\"\n",
+ ").orderBy(\"id\", \"timestamp\").toPandas()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " id | \n",
+ " timestamp | \n",
+ " rent__avg_over_1_day_rolling_windows | \n",
+ " rent__stddev_pop_over_1_day_rolling_windows | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2020-05-21 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 2020-05-21 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 2020-05-21 | \n",
+ " 3000.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 2020-05-21 | \n",
+ " 3200.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 5 | \n",
+ " 2020-05-21 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 6 | \n",
+ " 2020-05-21 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " id timestamp rent__avg_over_1_day_rolling_windows \\\n",
+ "0 1 2020-05-21 NaN \n",
+ "1 2 2020-05-21 NaN \n",
+ "2 3 2020-05-21 3000.0 \n",
+ "3 4 2020-05-21 3200.0 \n",
+ "4 5 2020-05-21 NaN \n",
+ "5 6 2020-05-21 NaN \n",
+ "\n",
+ " rent__stddev_pop_over_1_day_rolling_windows \n",
+ "0 NaN \n",
+ "1 NaN \n",
+ "2 0.0 \n",
+ "3 0.0 \n",
+ "4 NaN \n",
+ "5 NaN "
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "spark.table(\"online_feature_store__house_listings\").orderBy(\"id\", \"timestamp\").toPandas()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/setup.py b/setup.py
index bf471fec..4adcbce9 100644
--- a/setup.py
+++ b/setup.py
@@ -1,7 +1,7 @@
from setuptools import find_packages, setup
__package_name__ = "butterfree"
-__version__ = "1.1.3.dev0"
+__version__ = "1.2.0.dev0"
__repository_url__ = "https://github.com/quintoandar/butterfree"
with open("requirements.txt") as f:
diff --git a/tests/integration/butterfree/load/test_sink.py b/tests/integration/butterfree/load/test_sink.py
index d00f4806..f507a335 100644
--- a/tests/integration/butterfree/load/test_sink.py
+++ b/tests/integration/butterfree/load/test_sink.py
@@ -9,9 +9,10 @@
)
-def test_sink(input_dataframe, feature_set):
+def test_sink(input_dataframe, feature_set, mocker):
# arrange
client = SparkClient()
+ client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
feature_set_df = feature_set.construct(input_dataframe, client)
target_latest_df = OnlineFeatureStoreWriter.filter_latest(
feature_set_df, id_columns=[key.name for key in feature_set.keys]
@@ -20,14 +21,23 @@ def test_sink(input_dataframe, feature_set):
# setup historical writer
s3config = Mock()
+ s3config.mode = "overwrite"
+ s3config.format_ = "parquet"
s3config.get_options = Mock(
- return_value={
- "mode": "overwrite",
- "format_": "parquet",
- "path": "test_folder/historical/entity/feature_set",
- }
+ return_value={"path": "test_folder/historical/entity/feature_set"}
+ )
+ s3config.get_path_with_partitions = Mock(
+ return_value="test_folder/historical/entity/feature_set"
+ )
+
+ historical_writer = HistoricalFeatureStoreWriter(
+ db_config=s3config, interval_mode=True
)
- historical_writer = HistoricalFeatureStoreWriter(db_config=s3config)
+
+ schema_dataframe = historical_writer._create_partitions(feature_set_df)
+ historical_writer.check_schema_hook = mocker.stub("check_schema_hook")
+ historical_writer.check_schema_hook.run = mocker.stub("run")
+ historical_writer.check_schema_hook.run.return_value = schema_dataframe
# setup online writer
# TODO: Change for CassandraConfig when Cassandra for test is ready
@@ -39,6 +49,10 @@ def test_sink(input_dataframe, feature_set):
)
online_writer = OnlineFeatureStoreWriter(db_config=online_config)
+ online_writer.check_schema_hook = mocker.stub("check_schema_hook")
+ online_writer.check_schema_hook.run = mocker.stub("run")
+ online_writer.check_schema_hook.run.return_value = feature_set_df
+
writers = [historical_writer, online_writer]
sink = Sink(writers)
@@ -47,13 +61,14 @@ def test_sink(input_dataframe, feature_set):
sink.flush(feature_set, feature_set_df, client)
# get historical results
- historical_result_df = client.read_table(
- feature_set.name, historical_writer.database
+ historical_result_df = client.read(
+ s3config.format_,
+ path=s3config.get_path_with_partitions(feature_set.name, feature_set_df),
)
# get online results
online_result_df = client.read(
- online_config.format_, options=online_config.get_options(feature_set.name)
+ online_config.format_, **online_config.get_options(feature_set.name)
)
# assert
diff --git a/tests/integration/butterfree/pipelines/conftest.py b/tests/integration/butterfree/pipelines/conftest.py
index 79894176..73da163e 100644
--- a/tests/integration/butterfree/pipelines/conftest.py
+++ b/tests/integration/butterfree/pipelines/conftest.py
@@ -1,7 +1,19 @@
import pytest
+from pyspark.sql import DataFrame
+from pyspark.sql import functions as F
from butterfree.constants import DataType
from butterfree.constants.columns import TIMESTAMP_COLUMN
+from butterfree.dataframe_service.incremental_strategy import IncrementalStrategy
+from butterfree.extract import Source
+from butterfree.extract.readers import TableReader
+from butterfree.load import Sink
+from butterfree.load.writers import HistoricalFeatureStoreWriter
+from butterfree.pipelines.feature_set_pipeline import FeatureSetPipeline
+from butterfree.transform import FeatureSet
+from butterfree.transform.features import Feature, KeyFeature, TimestampFeature
+from butterfree.transform.transformations import SparkFunctionTransform
+from butterfree.transform.utils import Function
@pytest.fixture()
@@ -74,3 +86,193 @@ def fixed_windows_output_feature_set_dataframe(spark_context, spark_session):
df = df.withColumn(TIMESTAMP_COLUMN, df.timestamp.cast(DataType.TIMESTAMP.spark))
return df
+
+
+@pytest.fixture()
+def mocked_date_df(spark_context, spark_session):
+ data = [
+ {"id": 1, "ts": "2016-04-11 11:31:11", "feature": 200},
+ {"id": 1, "ts": "2016-04-12 11:44:12", "feature": 300},
+ {"id": 1, "ts": "2016-04-13 11:46:24", "feature": 400},
+ {"id": 1, "ts": "2016-04-14 12:03:21", "feature": 500},
+ ]
+ df = spark_session.read.json(spark_context.parallelize(data, 1))
+ df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark))
+
+ return df
+
+
+@pytest.fixture()
+def fixed_windows_output_feature_set_date_dataframe(spark_context, spark_session):
+ data = [
+ {
+ "id": 1,
+ "timestamp": "2016-04-12 11:44:12",
+ "feature__avg_over_1_day_fixed_windows": 300,
+ "feature__stddev_pop_over_1_day_fixed_windows": 0,
+ "year": 2016,
+ "month": 4,
+ "day": 12,
+ },
+ {
+ "id": 1,
+ "timestamp": "2016-04-13 11:46:24",
+ "feature__avg_over_1_day_fixed_windows": 400,
+ "feature__stddev_pop_over_1_day_fixed_windows": 0,
+ "year": 2016,
+ "month": 4,
+ "day": 13,
+ },
+ ]
+ df = spark_session.read.json(spark_context.parallelize(data, 1))
+ df = df.withColumn(TIMESTAMP_COLUMN, df.timestamp.cast(DataType.TIMESTAMP.spark))
+
+ return df
+
+
+@pytest.fixture()
+def feature_set_pipeline(
+ spark_context, spark_session,
+):
+
+ feature_set_pipeline = FeatureSetPipeline(
+ source=Source(
+ readers=[
+ TableReader(id="b_source", table="b_table",).with_incremental_strategy(
+ incremental_strategy=IncrementalStrategy(column="timestamp")
+ ),
+ ],
+ query=f"select * from b_source ", # noqa
+ ),
+ feature_set=FeatureSet(
+ name="feature_set",
+ entity="entity",
+ description="description",
+ features=[
+ Feature(
+ name="feature",
+ description="test",
+ transformation=SparkFunctionTransform(
+ functions=[
+ Function(F.avg, DataType.FLOAT),
+ Function(F.stddev_pop, DataType.FLOAT),
+ ],
+ ).with_window(
+ partition_by="id",
+ order_by=TIMESTAMP_COLUMN,
+ mode="fixed_windows",
+ window_definition=["1 day"],
+ ),
+ ),
+ ],
+ keys=[
+ KeyFeature(
+ name="id",
+ description="The user's Main ID or device ID",
+ dtype=DataType.INTEGER,
+ )
+ ],
+ timestamp=TimestampFeature(),
+ ),
+ sink=Sink(writers=[HistoricalFeatureStoreWriter(debug_mode=True)]),
+ )
+
+ return feature_set_pipeline
+
+
+@pytest.fixture()
+def pipeline_interval_run_target_dfs(
+ spark_session, spark_context
+) -> (DataFrame, DataFrame, DataFrame):
+ first_data = [
+ {
+ "id": 1,
+ "timestamp": "2016-04-11 11:31:11",
+ "feature": 200,
+ "run_id": 1,
+ "year": 2016,
+ "month": 4,
+ "day": 11,
+ },
+ {
+ "id": 1,
+ "timestamp": "2016-04-12 11:44:12",
+ "feature": 300,
+ "run_id": 1,
+ "year": 2016,
+ "month": 4,
+ "day": 12,
+ },
+ {
+ "id": 1,
+ "timestamp": "2016-04-13 11:46:24",
+ "feature": 400,
+ "run_id": 1,
+ "year": 2016,
+ "month": 4,
+ "day": 13,
+ },
+ ]
+
+ second_data = first_data + [
+ {
+ "id": 1,
+ "timestamp": "2016-04-14 12:03:21",
+ "feature": 500,
+ "run_id": 2,
+ "year": 2016,
+ "month": 4,
+ "day": 14,
+ },
+ ]
+
+ third_data = [
+ {
+ "id": 1,
+ "timestamp": "2016-04-11 11:31:11",
+ "feature": 200,
+ "run_id": 3,
+ "year": 2016,
+ "month": 4,
+ "day": 11,
+ },
+ {
+ "id": 1,
+ "timestamp": "2016-04-12 11:44:12",
+ "feature": 300,
+ "run_id": 1,
+ "year": 2016,
+ "month": 4,
+ "day": 12,
+ },
+ {
+ "id": 1,
+ "timestamp": "2016-04-13 11:46:24",
+ "feature": 400,
+ "run_id": 1,
+ "year": 2016,
+ "month": 4,
+ "day": 13,
+ },
+ {
+ "id": 1,
+ "timestamp": "2016-04-14 12:03:21",
+ "feature": 500,
+ "run_id": 2,
+ "year": 2016,
+ "month": 4,
+ "day": 14,
+ },
+ ]
+
+ first_run_df = spark_session.read.json(
+ spark_context.parallelize(first_data, 1)
+ ).withColumn("timestamp", F.col("timestamp").cast(DataType.TIMESTAMP.spark))
+ second_run_df = spark_session.read.json(
+ spark_context.parallelize(second_data, 1)
+ ).withColumn("timestamp", F.col("timestamp").cast(DataType.TIMESTAMP.spark))
+ third_run_df = spark_session.read.json(
+ spark_context.parallelize(third_data, 1)
+ ).withColumn("timestamp", F.col("timestamp").cast(DataType.TIMESTAMP.spark))
+
+ return first_run_df, second_run_df, third_run_df
diff --git a/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py b/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py
index 23d200c1..a302dc9e 100644
--- a/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py
+++ b/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py
@@ -4,21 +4,48 @@
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
+from butterfree.clients import SparkClient
from butterfree.configs import environment
+from butterfree.configs.db import MetastoreConfig
from butterfree.constants import DataType
from butterfree.constants.columns import TIMESTAMP_COLUMN
+from butterfree.dataframe_service.incremental_strategy import IncrementalStrategy
from butterfree.extract import Source
from butterfree.extract.readers import TableReader
+from butterfree.hooks import Hook
from butterfree.load import Sink
from butterfree.load.writers import HistoricalFeatureStoreWriter
from butterfree.pipelines.feature_set_pipeline import FeatureSetPipeline
from butterfree.testing.dataframe import assert_dataframe_equality
from butterfree.transform import FeatureSet
from butterfree.transform.features import Feature, KeyFeature, TimestampFeature
-from butterfree.transform.transformations import CustomTransform, SparkFunctionTransform
+from butterfree.transform.transformations import (
+ CustomTransform,
+ SparkFunctionTransform,
+ SQLExpressionTransform,
+)
from butterfree.transform.utils import Function
+class AddHook(Hook):
+ def __init__(self, value):
+ self.value = value
+
+ def run(self, dataframe):
+ return dataframe.withColumn("feature", F.expr(f"feature + {self.value}"))
+
+
+class RunHook(Hook):
+ def __init__(self, id):
+ self.id = id
+
+ def run(self, dataframe):
+ return dataframe.withColumn(
+ "run_id",
+ F.when(F.lit(self.id).isNotNull(), F.lit(self.id)).otherwise(F.lit(None)),
+ )
+
+
def create_temp_view(dataframe: DataFrame, name):
dataframe.createOrReplaceTempView(name)
@@ -38,9 +65,21 @@ def divide(df, fs, column1, column2):
return df
+def create_ymd(dataframe):
+ return (
+ dataframe.withColumn("year", F.year(F.col("timestamp")))
+ .withColumn("month", F.month(F.col("timestamp")))
+ .withColumn("day", F.dayofmonth(F.col("timestamp")))
+ )
+
+
class TestFeatureSetPipeline:
def test_feature_set_pipeline(
- self, mocked_df, spark_session, fixed_windows_output_feature_set_dataframe
+ self,
+ mocked_df,
+ spark_session,
+ fixed_windows_output_feature_set_dataframe,
+ mocker,
):
# arrange
table_reader_id = "a_source"
@@ -53,13 +92,25 @@ def test_feature_set_pipeline(
table_reader_db=table_reader_db,
table_reader_table=table_reader_table,
)
+
+ spark_client = SparkClient()
+ spark_client.conn.conf.set(
+ "spark.sql.sources.partitionOverwriteMode", "dynamic"
+ )
+
dbconfig = Mock()
+ dbconfig.mode = "overwrite"
+ dbconfig.format_ = "parquet"
dbconfig.get_options = Mock(
- return_value={
- "mode": "overwrite",
- "format_": "parquet",
- "path": "test_folder/historical/entity/feature_set",
- }
+ return_value={"path": "test_folder/historical/entity/feature_set"}
+ )
+
+ historical_writer = HistoricalFeatureStoreWriter(db_config=dbconfig)
+
+ historical_writer.check_schema_hook = mocker.stub("check_schema_hook")
+ historical_writer.check_schema_hook.run = mocker.stub("run")
+ historical_writer.check_schema_hook.run.return_value = (
+ fixed_windows_output_feature_set_dataframe
)
# act
@@ -112,7 +163,7 @@ def test_feature_set_pipeline(
],
timestamp=TimestampFeature(),
),
- sink=Sink(writers=[HistoricalFeatureStoreWriter(db_config=dbconfig)],),
+ sink=Sink(writers=[historical_writer]),
)
test_pipeline.run()
@@ -129,3 +180,247 @@ def test_feature_set_pipeline(
# tear down
shutil.rmtree("test_folder")
+
+ def test_feature_set_pipeline_with_dates(
+ self,
+ mocked_date_df,
+ spark_session,
+ fixed_windows_output_feature_set_date_dataframe,
+ feature_set_pipeline,
+ mocker,
+ ):
+ # arrange
+ table_reader_table = "b_table"
+ create_temp_view(dataframe=mocked_date_df, name=table_reader_table)
+
+ historical_writer = HistoricalFeatureStoreWriter(debug_mode=True)
+
+ feature_set_pipeline.sink.writers = [historical_writer]
+
+ # act
+ feature_set_pipeline.run(start_date="2016-04-12", end_date="2016-04-13")
+
+ df = spark_session.sql("select * from historical_feature_store__feature_set")
+
+ # assert
+ assert_dataframe_equality(df, fixed_windows_output_feature_set_date_dataframe)
+
+ def test_feature_set_pipeline_with_execution_date(
+ self,
+ mocked_date_df,
+ spark_session,
+ fixed_windows_output_feature_set_date_dataframe,
+ feature_set_pipeline,
+ mocker,
+ ):
+ # arrange
+ table_reader_table = "b_table"
+ create_temp_view(dataframe=mocked_date_df, name=table_reader_table)
+
+ target_df = fixed_windows_output_feature_set_date_dataframe.filter(
+ "timestamp < '2016-04-13'"
+ )
+
+ historical_writer = HistoricalFeatureStoreWriter(debug_mode=True)
+
+ feature_set_pipeline.sink.writers = [historical_writer]
+
+ # act
+ feature_set_pipeline.run_for_date(execution_date="2016-04-12")
+
+ df = spark_session.sql("select * from historical_feature_store__feature_set")
+
+ # assert
+ assert_dataframe_equality(df, target_df)
+
+ def test_pipeline_with_hooks(self, spark_session, mocker):
+ # arrange
+ hook1 = AddHook(value=1)
+
+ spark_session.sql(
+ "select 1 as id, timestamp('2020-01-01') as timestamp, 0 as feature"
+ ).createOrReplaceTempView("test")
+
+ target_df = spark_session.sql(
+ "select 1 as id, timestamp('2020-01-01') as timestamp, 6 as feature, 2020 "
+ "as year, 1 as month, 1 as day"
+ )
+
+ historical_writer = HistoricalFeatureStoreWriter(debug_mode=True)
+
+ test_pipeline = FeatureSetPipeline(
+ source=Source(
+ readers=[TableReader(id="reader", table="test",).add_post_hook(hook1)],
+ query="select * from reader",
+ ).add_post_hook(hook1),
+ feature_set=FeatureSet(
+ name="feature_set",
+ entity="entity",
+ description="description",
+ features=[
+ Feature(
+ name="feature",
+ description="test",
+ transformation=SQLExpressionTransform(expression="feature + 1"),
+ dtype=DataType.INTEGER,
+ ),
+ ],
+ keys=[
+ KeyFeature(
+ name="id",
+ description="The user's Main ID or device ID",
+ dtype=DataType.INTEGER,
+ )
+ ],
+ timestamp=TimestampFeature(),
+ )
+ .add_pre_hook(hook1)
+ .add_post_hook(hook1),
+ sink=Sink(writers=[historical_writer],).add_pre_hook(hook1),
+ )
+
+ # act
+ test_pipeline.run()
+ output_df = spark_session.table("historical_feature_store__feature_set")
+
+ # assert
+ output_df.show()
+ assert_dataframe_equality(output_df, target_df)
+
+ def test_pipeline_interval_run(
+ self, mocked_date_df, pipeline_interval_run_target_dfs, spark_session
+ ):
+ """Testing pipeline's idempotent interval run feature.
+ Source data:
+ +-------+---+-------------------+-------------------+
+ |feature| id| ts| timestamp|
+ +-------+---+-------------------+-------------------+
+ | 200| 1|2016-04-11 11:31:11|2016-04-11 11:31:11|
+ | 300| 1|2016-04-12 11:44:12|2016-04-12 11:44:12|
+ | 400| 1|2016-04-13 11:46:24|2016-04-13 11:46:24|
+ | 500| 1|2016-04-14 12:03:21|2016-04-14 12:03:21|
+ +-------+---+-------------------+-------------------+
+ The test executes 3 runs for different time intervals. The input data has 4 data
+ points: 2016-04-11, 2016-04-12, 2016-04-13 and 2016-04-14. The following run
+ specifications are:
+ 1) Interval: from 2016-04-11 to 2016-04-13
+ Target table result:
+ +---+-------+---+-----+------+-------------------+----+
+ |day|feature| id|month|run_id| timestamp|year|
+ +---+-------+---+-----+------+-------------------+----+
+ | 11| 200| 1| 4| 1|2016-04-11 11:31:11|2016|
+ | 12| 300| 1| 4| 1|2016-04-12 11:44:12|2016|
+ | 13| 400| 1| 4| 1|2016-04-13 11:46:24|2016|
+ +---+-------+---+-----+------+-------------------+----+
+ 2) Interval: only 2016-04-14.
+ Target table result:
+ +---+-------+---+-----+------+-------------------+----+
+ |day|feature| id|month|run_id| timestamp|year|
+ +---+-------+---+-----+------+-------------------+----+
+ | 11| 200| 1| 4| 1|2016-04-11 11:31:11|2016|
+ | 12| 300| 1| 4| 1|2016-04-12 11:44:12|2016|
+ | 13| 400| 1| 4| 1|2016-04-13 11:46:24|2016|
+ | 14| 500| 1| 4| 2|2016-04-14 12:03:21|2016|
+ +---+-------+---+-----+------+-------------------+----+
+ 3) Interval: only 2016-04-11.
+ Target table result:
+ +---+-------+---+-----+------+-------------------+----+
+ |day|feature| id|month|run_id| timestamp|year|
+ +---+-------+---+-----+------+-------------------+----+
+ | 11| 200| 1| 4| 3|2016-04-11 11:31:11|2016|
+ | 12| 300| 1| 4| 1|2016-04-12 11:44:12|2016|
+ | 13| 400| 1| 4| 1|2016-04-13 11:46:24|2016|
+ | 14| 500| 1| 4| 2|2016-04-14 12:03:21|2016|
+ +---+-------+---+-----+------+-------------------+----+
+ """
+ # arrange
+ create_temp_view(dataframe=mocked_date_df, name="input_data")
+
+ db = environment.get_variable("FEATURE_STORE_HISTORICAL_DATABASE")
+ path = "test_folder/historical/entity/feature_set"
+
+ spark_session.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
+ spark_session.sql(f"create database if not exists {db}")
+ spark_session.sql(
+ f"create table if not exists {db}.feature_set_interval "
+ f"(id int, timestamp timestamp, feature int, "
+ f"run_id int, year int, month int, day int);"
+ )
+
+ dbconfig = MetastoreConfig()
+ dbconfig.get_options = Mock(
+ return_value={"mode": "overwrite", "format_": "parquet", "path": path}
+ )
+
+ historical_writer = HistoricalFeatureStoreWriter(
+ db_config=dbconfig, interval_mode=True
+ )
+
+ first_run_hook = RunHook(id=1)
+ second_run_hook = RunHook(id=2)
+ third_run_hook = RunHook(id=3)
+
+ (
+ first_run_target_df,
+ second_run_target_df,
+ third_run_target_df,
+ ) = pipeline_interval_run_target_dfs
+
+ test_pipeline = FeatureSetPipeline(
+ source=Source(
+ readers=[
+ TableReader(id="id", table="input_data",).with_incremental_strategy(
+ IncrementalStrategy("ts")
+ ),
+ ],
+ query="select * from id ",
+ ),
+ feature_set=FeatureSet(
+ name="feature_set_interval",
+ entity="entity",
+ description="",
+ keys=[KeyFeature(name="id", description="", dtype=DataType.INTEGER,)],
+ timestamp=TimestampFeature(from_column="ts"),
+ features=[
+ Feature(name="feature", description="", dtype=DataType.INTEGER),
+ Feature(name="run_id", description="", dtype=DataType.INTEGER),
+ ],
+ ),
+ sink=Sink([historical_writer],),
+ )
+
+ # act and assert
+ dbconfig.get_path_with_partitions = Mock(
+ return_value=[
+ "test_folder/historical/entity/feature_set/year=2016/month=4/day=11",
+ "test_folder/historical/entity/feature_set/year=2016/month=4/day=12",
+ "test_folder/historical/entity/feature_set/year=2016/month=4/day=13",
+ ]
+ )
+ test_pipeline.feature_set.add_pre_hook(first_run_hook)
+ test_pipeline.run(end_date="2016-04-13", start_date="2016-04-11")
+ first_run_output_df = spark_session.read.parquet(path)
+ assert_dataframe_equality(first_run_output_df, first_run_target_df)
+
+ dbconfig.get_path_with_partitions = Mock(
+ return_value=[
+ "test_folder/historical/entity/feature_set/year=2016/month=4/day=14",
+ ]
+ )
+ test_pipeline.feature_set.add_pre_hook(second_run_hook)
+ test_pipeline.run_for_date("2016-04-14")
+ second_run_output_df = spark_session.read.parquet(path)
+ assert_dataframe_equality(second_run_output_df, second_run_target_df)
+
+ dbconfig.get_path_with_partitions = Mock(
+ return_value=[
+ "test_folder/historical/entity/feature_set/year=2016/month=4/day=11",
+ ]
+ )
+ test_pipeline.feature_set.add_pre_hook(third_run_hook)
+ test_pipeline.run_for_date("2016-04-11")
+ third_run_output_df = spark_session.read.parquet(path)
+ assert_dataframe_equality(third_run_output_df, third_run_target_df)
+
+ # tear down
+ shutil.rmtree("test_folder")
diff --git a/tests/integration/butterfree/transform/conftest.py b/tests/integration/butterfree/transform/conftest.py
index 6621c9a3..fe0cc572 100644
--- a/tests/integration/butterfree/transform/conftest.py
+++ b/tests/integration/butterfree/transform/conftest.py
@@ -395,3 +395,58 @@ def rolling_windows_output_feature_set_dataframe_base_date(
df = df.withColumn(TIMESTAMP_COLUMN, df.origin_ts.cast(DataType.TIMESTAMP.spark))
return df
+
+
+@fixture
+def feature_set_dates_dataframe(spark_context, spark_session):
+ data = [
+ {"id": 1, "ts": "2016-04-11 11:31:11", "feature": 200},
+ {"id": 1, "ts": "2016-04-12 11:44:12", "feature": 300},
+ {"id": 1, "ts": "2016-04-13 11:46:24", "feature": 400},
+ {"id": 1, "ts": "2016-04-14 12:03:21", "feature": 500},
+ ]
+ df = spark_session.read.json(spark_context.parallelize(data, 1))
+ df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark))
+ df = df.withColumn("ts", df.ts.cast(DataType.TIMESTAMP.spark))
+
+ return df
+
+
+@fixture
+def feature_set_dates_output_dataframe(spark_context, spark_session):
+ data = [
+ {"id": 1, "timestamp": "2016-04-11 11:31:11", "feature": 200},
+ {"id": 1, "timestamp": "2016-04-12 11:44:12", "feature": 300},
+ ]
+ df = spark_session.read.json(spark_context.parallelize(data, 1))
+ df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark))
+
+ return df
+
+
+@fixture
+def rolling_windows_output_date_boundaries(spark_context, spark_session):
+ data = [
+ {
+ "id": 1,
+ "ts": "2016-04-11 00:00:00",
+ "feature__avg_over_1_day_rolling_windows": None,
+ "feature__avg_over_1_week_rolling_windows": None,
+ "feature__stddev_pop_over_1_day_rolling_windows": None,
+ "feature__stddev_pop_over_1_week_rolling_windows": None,
+ },
+ {
+ "id": 1,
+ "ts": "2016-04-12 00:00:00",
+ "feature__avg_over_1_day_rolling_windows": 200.0,
+ "feature__avg_over_1_week_rolling_windows": 200.0,
+ "feature__stddev_pop_over_1_day_rolling_windows": 0.0,
+ "feature__stddev_pop_over_1_week_rolling_windows": 0.0,
+ },
+ ]
+ df = spark_session.read.json(
+ spark_context.parallelize(data).map(lambda x: json.dumps(x))
+ )
+ df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark))
+
+ return df
diff --git a/tests/integration/butterfree/transform/test_aggregated_feature_set.py b/tests/integration/butterfree/transform/test_aggregated_feature_set.py
index 559dbcb8..bc3ebb6c 100644
--- a/tests/integration/butterfree/transform/test_aggregated_feature_set.py
+++ b/tests/integration/butterfree/transform/test_aggregated_feature_set.py
@@ -241,3 +241,53 @@ def test_construct_with_pivot(
# assert
assert_dataframe_equality(output_df, target_df_pivot_agg)
+
+ def test_construct_rolling_windows_with_date_boundaries(
+ self, feature_set_dates_dataframe, rolling_windows_output_date_boundaries,
+ ):
+ # given
+
+ spark_client = SparkClient()
+
+ # arrange
+
+ feature_set = AggregatedFeatureSet(
+ name="feature_set",
+ entity="entity",
+ description="description",
+ features=[
+ Feature(
+ name="feature",
+ description="test",
+ transformation=AggregatedTransform(
+ functions=[
+ Function(F.avg, DataType.DOUBLE),
+ Function(F.stddev_pop, DataType.DOUBLE),
+ ],
+ ),
+ ),
+ ],
+ keys=[
+ KeyFeature(
+ name="id",
+ description="The user's Main ID or device ID",
+ dtype=DataType.INTEGER,
+ )
+ ],
+ timestamp=TimestampFeature(),
+ ).with_windows(definitions=["1 day", "1 week"])
+
+ # act
+ output_df = feature_set.construct(
+ feature_set_dates_dataframe,
+ client=spark_client,
+ start_date="2016-04-11",
+ end_date="2016-04-12",
+ ).orderBy("timestamp")
+
+ target_df = rolling_windows_output_date_boundaries.orderBy(
+ feature_set.timestamp_column
+ ).select(feature_set.columns)
+
+ # assert
+ assert_dataframe_equality(output_df, target_df)
diff --git a/tests/integration/butterfree/transform/test_feature_set.py b/tests/integration/butterfree/transform/test_feature_set.py
index 4872ded2..25f70b6e 100644
--- a/tests/integration/butterfree/transform/test_feature_set.py
+++ b/tests/integration/butterfree/transform/test_feature_set.py
@@ -77,3 +77,47 @@ def test_construct(
# assert
assert_dataframe_equality(output_df, target_df)
+
+ def test_construct_with_date_boundaries(
+ self, feature_set_dates_dataframe, feature_set_dates_output_dataframe
+ ):
+ # given
+
+ spark_client = SparkClient()
+
+ # arrange
+
+ feature_set = FeatureSet(
+ name="feature_set",
+ entity="entity",
+ description="description",
+ features=[
+ Feature(name="feature", description="test", dtype=DataType.FLOAT,),
+ ],
+ keys=[
+ KeyFeature(
+ name="id",
+ description="The user's Main ID or device ID",
+ dtype=DataType.INTEGER,
+ )
+ ],
+ timestamp=TimestampFeature(),
+ )
+
+ output_df = (
+ feature_set.construct(
+ feature_set_dates_dataframe,
+ client=spark_client,
+ start_date="2016-04-11",
+ end_date="2016-04-12",
+ )
+ .orderBy(feature_set.timestamp_column)
+ .select(feature_set.columns)
+ )
+
+ target_df = feature_set_dates_output_dataframe.orderBy(
+ feature_set.timestamp_column
+ ).select(feature_set.columns)
+
+ # assert
+ assert_dataframe_equality(output_df, target_df)
diff --git a/tests/unit/butterfree/clients/conftest.py b/tests/unit/butterfree/clients/conftest.py
index fda11f8e..ffb2db88 100644
--- a/tests/unit/butterfree/clients/conftest.py
+++ b/tests/unit/butterfree/clients/conftest.py
@@ -46,11 +46,16 @@ def mocked_stream_df() -> Mock:
return mock
+@pytest.fixture()
+def mock_spark_sql() -> Mock:
+ mock = Mock()
+ mock.sql = mock
+ return mock
+
+
@pytest.fixture
def cassandra_client() -> CassandraClient:
- return CassandraClient(
- cassandra_host=["mock"], cassandra_key_space="dummy_keyspace"
- )
+ return CassandraClient(host=["mock"], keyspace="dummy_keyspace")
@pytest.fixture
diff --git a/tests/unit/butterfree/clients/test_cassandra_client.py b/tests/unit/butterfree/clients/test_cassandra_client.py
index 8785485b..aa52e6f8 100644
--- a/tests/unit/butterfree/clients/test_cassandra_client.py
+++ b/tests/unit/butterfree/clients/test_cassandra_client.py
@@ -15,9 +15,7 @@ def sanitize_string(query: str) -> str:
class TestCassandraClient:
def test_conn(self, cassandra_client: CassandraClient) -> None:
# arrange
- cassandra_client = CassandraClient(
- cassandra_host=["mock"], cassandra_key_space="dummy_keyspace"
- )
+ cassandra_client = CassandraClient(host=["mock"], keyspace="dummy_keyspace")
# act
start_conn = cassandra_client._session
diff --git a/tests/unit/butterfree/clients/test_spark_client.py b/tests/unit/butterfree/clients/test_spark_client.py
index 58d53a40..9f641506 100644
--- a/tests/unit/butterfree/clients/test_spark_client.py
+++ b/tests/unit/butterfree/clients/test_spark_client.py
@@ -1,4 +1,5 @@
-from typing import Any, Dict, Optional, Union
+from datetime import datetime
+from typing import Any, Optional, Union
from unittest.mock import Mock
import pytest
@@ -26,19 +27,20 @@ def test_conn(self) -> None:
assert start_conn is None
@pytest.mark.parametrize(
- "format, options, stream, schema",
+ "format, path, stream, schema, options",
[
- ("parquet", {"path": "path/to/file"}, False, None),
- ("csv", {"path": "path/to/file", "header": True}, False, None),
- ("json", {"path": "path/to/file"}, True, None),
+ ("parquet", ["path/to/file"], False, None, {}),
+ ("csv", "path/to/file", False, None, {"header": True}),
+ ("json", "path/to/file", True, None, {}),
],
)
def test_read(
self,
format: str,
- options: Dict[str, Any],
stream: bool,
schema: Optional[StructType],
+ path: Any,
+ options: Any,
target_df: DataFrame,
mocked_spark_read: Mock,
) -> None:
@@ -48,26 +50,25 @@ def test_read(
spark_client._session = mocked_spark_read
# act
- result_df = spark_client.read(format, options, schema, stream)
+ result_df = spark_client.read(
+ format=format, schema=schema, stream=stream, path=path, **options
+ )
# assert
mocked_spark_read.format.assert_called_once_with(format)
- mocked_spark_read.options.assert_called_once_with(**options)
+ mocked_spark_read.load.assert_called_once_with(path, **options)
assert target_df.collect() == result_df.collect()
@pytest.mark.parametrize(
- "format, options",
- [(None, {"path": "path/to/file"}), ("csv", "not a valid options")],
+ "format, path", [(None, "path/to/file"), ("csv", 123)],
)
- def test_read_invalid_params(
- self, format: Optional[str], options: Union[Dict[str, Any], str]
- ) -> None:
+ def test_read_invalid_params(self, format: Optional[str], path: Any) -> None:
# arrange
spark_client = SparkClient()
# act and assert
with pytest.raises(ValueError):
- spark_client.read(format, options) # type: ignore
+ spark_client.read(format=format, path=path) # type: ignore
def test_sql(self, target_df: DataFrame) -> None:
# arrange
@@ -252,3 +253,43 @@ def test_create_temporary_view(
# assert
assert_dataframe_equality(target_df, result_df)
+
+ def test_add_table_partitions(self, mock_spark_sql: Mock):
+ # arrange
+ target_command = (
+ f"ALTER TABLE `db`.`table` ADD IF NOT EXISTS "
+ f"PARTITION ( year = 2020, month = 8, day = 14 ) "
+ f"PARTITION ( year = 2020, month = 8, day = 15 ) "
+ f"PARTITION ( year = 2020, month = 8, day = 16 )"
+ )
+
+ spark_client = SparkClient()
+ spark_client._session = mock_spark_sql
+ partitions = [
+ {"year": 2020, "month": 8, "day": 14},
+ {"year": 2020, "month": 8, "day": 15},
+ {"year": 2020, "month": 8, "day": 16},
+ ]
+
+ # act
+ spark_client.add_table_partitions(partitions, "table", "db")
+
+ # assert
+ mock_spark_sql.assert_called_once_with(target_command)
+
+ @pytest.mark.parametrize(
+ "partition",
+ [
+ [{"float_partition": 2.72}],
+ [{123: 2020}],
+ [{"date": datetime(year=2020, month=8, day=18)}],
+ ],
+ )
+ def test_add_invalid_partitions(self, mock_spark_sql: Mock, partition):
+ # arrange
+ spark_client = SparkClient()
+ spark_client._session = mock_spark_sql
+
+ # act and assert
+ with pytest.raises(ValueError):
+ spark_client.add_table_partitions(partition, "table", "db")
diff --git a/tests/unit/butterfree/dataframe_service/conftest.py b/tests/unit/butterfree/dataframe_service/conftest.py
index 867bc80a..09470c9a 100644
--- a/tests/unit/butterfree/dataframe_service/conftest.py
+++ b/tests/unit/butterfree/dataframe_service/conftest.py
@@ -25,3 +25,17 @@ def input_df(spark_context, spark_session):
return spark_session.read.json(
spark_context.parallelize(data, 1), schema="timestamp timestamp"
)
+
+
+@pytest.fixture()
+def test_partitioning_input_df(spark_context, spark_session):
+ data = [
+ {"feature": 1, "year": 2009, "month": 8, "day": 20},
+ {"feature": 2, "year": 2009, "month": 8, "day": 20},
+ {"feature": 3, "year": 2020, "month": 8, "day": 20},
+ {"feature": 4, "year": 2020, "month": 9, "day": 20},
+ {"feature": 5, "year": 2020, "month": 9, "day": 20},
+ {"feature": 6, "year": 2020, "month": 8, "day": 20},
+ {"feature": 7, "year": 2020, "month": 8, "day": 21},
+ ]
+ return spark_session.read.json(spark_context.parallelize(data, 1))
diff --git a/tests/unit/butterfree/dataframe_service/test_incremental_srategy.py b/tests/unit/butterfree/dataframe_service/test_incremental_srategy.py
new file mode 100644
index 00000000..a140ceb3
--- /dev/null
+++ b/tests/unit/butterfree/dataframe_service/test_incremental_srategy.py
@@ -0,0 +1,70 @@
+from butterfree.dataframe_service import IncrementalStrategy
+
+
+class TestIncrementalStrategy:
+ def test_from_milliseconds(self):
+ # arrange
+ incremental_strategy = IncrementalStrategy().from_milliseconds("ts")
+ target_expression = "date(from_unixtime(ts/ 1000.0)) >= date('2020-01-01')"
+
+ # act
+ result_expression = incremental_strategy.get_expression(start_date="2020-01-01")
+
+ # assert
+ assert target_expression.split() == result_expression.split()
+
+ def test_from_string(self):
+ # arrange
+ incremental_strategy = IncrementalStrategy().from_string(
+ "dt", mask="dd/MM/yyyy"
+ )
+ target_expression = "date(to_date(dt, 'dd/MM/yyyy')) >= date('2020-01-01')"
+
+ # act
+ result_expression = incremental_strategy.get_expression(start_date="2020-01-01")
+
+ # assert
+ assert target_expression.split() == result_expression.split()
+
+ def test_from_year_month_day_partitions(self):
+ # arrange
+ incremental_strategy = IncrementalStrategy().from_year_month_day_partitions(
+ year_column="y", month_column="m", day_column="d"
+ )
+ target_expression = (
+ "date(concat(string(y), "
+ "'-', string(m), "
+ "'-', string(d))) >= date('2020-01-01')"
+ )
+
+ # act
+ result_expression = incremental_strategy.get_expression(start_date="2020-01-01")
+
+ # assert
+ assert target_expression.split() == result_expression.split()
+
+ def test_get_expression_with_just_end_date(self):
+ # arrange
+ incremental_strategy = IncrementalStrategy(column="dt")
+ target_expression = "date(dt) <= date('2020-01-01')"
+
+ # act
+ result_expression = incremental_strategy.get_expression(end_date="2020-01-01")
+
+ # assert
+ assert target_expression.split() == result_expression.split()
+
+ def test_get_expression_with_start_and_end_date(self):
+ # arrange
+ incremental_strategy = IncrementalStrategy(column="dt")
+ target_expression = (
+ "date(dt) >= date('2019-12-30') and date(dt) <= date('2020-01-01')"
+ )
+
+ # act
+ result_expression = incremental_strategy.get_expression(
+ start_date="2019-12-30", end_date="2020-01-01"
+ )
+
+ # assert
+ assert target_expression.split() == result_expression.split()
diff --git a/tests/unit/butterfree/dataframe_service/test_partitioning.py b/tests/unit/butterfree/dataframe_service/test_partitioning.py
new file mode 100644
index 00000000..3a6b5b40
--- /dev/null
+++ b/tests/unit/butterfree/dataframe_service/test_partitioning.py
@@ -0,0 +1,20 @@
+from butterfree.dataframe_service import extract_partition_values
+
+
+class TestPartitioning:
+ def test_extract_partition_values(self, test_partitioning_input_df):
+ # arrange
+ target_values = [
+ {"year": 2009, "month": 8, "day": 20},
+ {"year": 2020, "month": 8, "day": 20},
+ {"year": 2020, "month": 9, "day": 20},
+ {"year": 2020, "month": 8, "day": 21},
+ ]
+
+ # act
+ result_values = extract_partition_values(
+ test_partitioning_input_df, partition_columns=["year", "month", "day"]
+ )
+
+ # assert
+ assert result_values == target_values
diff --git a/tests/unit/butterfree/extract/conftest.py b/tests/unit/butterfree/extract/conftest.py
index ab6f525c..3d0e763d 100644
--- a/tests/unit/butterfree/extract/conftest.py
+++ b/tests/unit/butterfree/extract/conftest.py
@@ -1,6 +1,7 @@
from unittest.mock import Mock
import pytest
+from pyspark.sql.functions import col, to_date
from butterfree.constants.columns import TIMESTAMP_COLUMN
@@ -17,6 +18,60 @@ def target_df(spark_context, spark_session):
return spark_session.read.json(spark_context.parallelize(data, 1))
+@pytest.fixture()
+def incremental_source_df(spark_context, spark_session):
+ data = [
+ {
+ "id": 1,
+ "feature": 100,
+ "date_str": "28/07/2020",
+ "milliseconds": 1595894400000,
+ "year": 2020,
+ "month": 7,
+ "day": 28,
+ },
+ {
+ "id": 1,
+ "feature": 110,
+ "date_str": "29/07/2020",
+ "milliseconds": 1595980800000,
+ "year": 2020,
+ "month": 7,
+ "day": 29,
+ },
+ {
+ "id": 1,
+ "feature": 120,
+ "date_str": "30/07/2020",
+ "milliseconds": 1596067200000,
+ "year": 2020,
+ "month": 7,
+ "day": 30,
+ },
+ {
+ "id": 2,
+ "feature": 150,
+ "date_str": "31/07/2020",
+ "milliseconds": 1596153600000,
+ "year": 2020,
+ "month": 7,
+ "day": 31,
+ },
+ {
+ "id": 2,
+ "feature": 200,
+ "date_str": "01/08/2020",
+ "milliseconds": 1596240000000,
+ "year": 2020,
+ "month": 8,
+ "day": 1,
+ },
+ ]
+ return spark_session.read.json(spark_context.parallelize(data, 1)).withColumn(
+ "date", to_date(col("date_str"), "dd/MM/yyyy")
+ )
+
+
@pytest.fixture()
def spark_client():
return Mock()
diff --git a/tests/unit/butterfree/extract/readers/test_file_reader.py b/tests/unit/butterfree/extract/readers/test_file_reader.py
index d337d4fe..9e1c42bc 100644
--- a/tests/unit/butterfree/extract/readers/test_file_reader.py
+++ b/tests/unit/butterfree/extract/readers/test_file_reader.py
@@ -36,11 +36,11 @@ def test_consume(
# act
output_df = file_reader.consume(spark_client)
- options = dict({"path": path}, **format_options if format_options else {})
+ options = dict(format_options if format_options else {})
# assert
spark_client.read.assert_called_once_with(
- format=format, options=options, schema=schema, stream=False
+ format=format, schema=schema, stream=False, path=path, **options
)
assert target_df.collect() == output_df.collect()
@@ -51,7 +51,7 @@ def test_consume_with_stream_without_schema(self, spark_client, target_df):
schema = None
format_options = None
stream = True
- options = dict({"path": path})
+ options = dict({})
spark_client.read.return_value = target_df
file_reader = FileReader(
@@ -64,11 +64,11 @@ def test_consume_with_stream_without_schema(self, spark_client, target_df):
# assert
# assert call for schema infer
- spark_client.read.assert_any_call(format=format, options=options)
+ spark_client.read.assert_any_call(format=format, path=path, **options)
# assert call for stream read
# stream
spark_client.read.assert_called_with(
- format=format, options=options, schema=output_df.schema, stream=stream
+ format=format, schema=output_df.schema, stream=stream, path=path, **options
)
assert target_df.collect() == output_df.collect()
diff --git a/tests/unit/butterfree/extract/readers/test_reader.py b/tests/unit/butterfree/extract/readers/test_reader.py
index c210a756..78160553 100644
--- a/tests/unit/butterfree/extract/readers/test_reader.py
+++ b/tests/unit/butterfree/extract/readers/test_reader.py
@@ -1,7 +1,9 @@
import pytest
from pyspark.sql.functions import expr
+from butterfree.dataframe_service import IncrementalStrategy
from butterfree.extract.readers import FileReader
+from butterfree.testing.dataframe import assert_dataframe_equality
def add_value_transformer(df, column, value):
@@ -152,3 +154,59 @@ def test_build_with_columns(
# assert
assert column_target_df.collect() == result_df.collect()
+
+ def test_build_with_incremental_strategy(
+ self, incremental_source_df, spark_client, spark_session
+ ):
+ # arrange
+ readers = [
+ # directly from column
+ FileReader(
+ id="test_1", path="path/to/file", format="format"
+ ).with_incremental_strategy(
+ incremental_strategy=IncrementalStrategy(column="date")
+ ),
+ # from milliseconds
+ FileReader(
+ id="test_2", path="path/to/file", format="format"
+ ).with_incremental_strategy(
+ incremental_strategy=IncrementalStrategy().from_milliseconds(
+ column_name="milliseconds"
+ )
+ ),
+ # from str
+ FileReader(
+ id="test_3", path="path/to/file", format="format"
+ ).with_incremental_strategy(
+ incremental_strategy=IncrementalStrategy().from_string(
+ column_name="date_str", mask="dd/MM/yyyy"
+ )
+ ),
+ # from year, month, day partitions
+ FileReader(
+ id="test_4", path="path/to/file", format="format"
+ ).with_incremental_strategy(
+ incremental_strategy=(
+ IncrementalStrategy().from_year_month_day_partitions()
+ )
+ ),
+ ]
+
+ spark_client.read.return_value = incremental_source_df
+ target_df = incremental_source_df.where(
+ "date >= date('2020-07-29') and date <= date('2020-07-31')"
+ )
+
+ # act
+ for reader in readers:
+ reader.build(
+ client=spark_client, start_date="2020-07-29", end_date="2020-07-31"
+ )
+
+ output_dfs = [
+ spark_session.table(f"test_{i + 1}") for i, _ in enumerate(readers)
+ ]
+
+ # assert
+ for output_df in output_dfs:
+ assert_dataframe_equality(output_df=output_df, target_df=target_df)
diff --git a/tests/unit/butterfree/hooks/__init__.py b/tests/unit/butterfree/hooks/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/unit/butterfree/hooks/schema_compatibility/__init__.py b/tests/unit/butterfree/hooks/schema_compatibility/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/unit/butterfree/hooks/schema_compatibility/test_cassandra_table_schema_compatibility_hook.py b/tests/unit/butterfree/hooks/schema_compatibility/test_cassandra_table_schema_compatibility_hook.py
new file mode 100644
index 00000000..eccb8d8c
--- /dev/null
+++ b/tests/unit/butterfree/hooks/schema_compatibility/test_cassandra_table_schema_compatibility_hook.py
@@ -0,0 +1,49 @@
+from unittest.mock import MagicMock
+
+import pytest
+
+from butterfree.clients import CassandraClient
+from butterfree.hooks.schema_compatibility import CassandraTableSchemaCompatibilityHook
+
+
+class TestCassandraTableSchemaCompatibilityHook:
+ def test_run_compatible_schema(self, spark_session):
+ cassandra_client = CassandraClient(host=["mock"], keyspace="dummy_keyspace")
+
+ cassandra_client.sql = MagicMock( # type: ignore
+ return_value=[
+ {"column_name": "feature1", "type": "text"},
+ {"column_name": "feature2", "type": "int"},
+ ]
+ )
+
+ table = "table"
+
+ input_dataframe = spark_session.sql("select 'abc' as feature1, 1 as feature2")
+
+ hook = CassandraTableSchemaCompatibilityHook(cassandra_client, table)
+
+ # act and assert
+ assert hook.run(input_dataframe) == input_dataframe
+
+ def test_run_incompatible_schema(self, spark_session):
+ cassandra_client = CassandraClient(host=["mock"], keyspace="dummy_keyspace")
+
+ cassandra_client.sql = MagicMock( # type: ignore
+ return_value=[
+ {"column_name": "feature1", "type": "text"},
+ {"column_name": "feature2", "type": "bigint"},
+ ]
+ )
+
+ table = "table"
+
+ input_dataframe = spark_session.sql("select 'abc' as feature1, 1 as feature2")
+
+ hook = CassandraTableSchemaCompatibilityHook(cassandra_client, table)
+
+ # act and assert
+ with pytest.raises(
+ ValueError, match="There's a schema incompatibility between"
+ ):
+ hook.run(input_dataframe)
diff --git a/tests/unit/butterfree/hooks/schema_compatibility/test_spark_table_schema_compatibility_hook.py b/tests/unit/butterfree/hooks/schema_compatibility/test_spark_table_schema_compatibility_hook.py
new file mode 100644
index 00000000..3a31b600
--- /dev/null
+++ b/tests/unit/butterfree/hooks/schema_compatibility/test_spark_table_schema_compatibility_hook.py
@@ -0,0 +1,53 @@
+import pytest
+
+from butterfree.clients import SparkClient
+from butterfree.hooks.schema_compatibility import SparkTableSchemaCompatibilityHook
+
+
+class TestSparkTableSchemaCompatibilityHook:
+ @pytest.mark.parametrize(
+ "table, database, target_table_expression",
+ [("table", "database", "`database`.`table`"), ("table", None, "`table`")],
+ )
+ def test_build_table_expression(self, table, database, target_table_expression):
+ # arrange
+ spark_client = SparkClient()
+
+ # act
+ result_table_expression = SparkTableSchemaCompatibilityHook(
+ spark_client, table, database
+ ).table_expression
+
+ # assert
+ assert target_table_expression == result_table_expression
+
+ def test_run_compatible_schema(self, spark_session):
+ # arrange
+ spark_client = SparkClient()
+ target_table = spark_session.sql(
+ "select 1 as feature_a, 'abc' as feature_b, true as other_feature"
+ )
+ input_dataframe = spark_session.sql("select 1 as feature_a, 'abc' as feature_b")
+ target_table.registerTempTable("test")
+
+ hook = SparkTableSchemaCompatibilityHook(spark_client, "test")
+
+ # act and assert
+ assert hook.run(input_dataframe) == input_dataframe
+
+ def test_run_incompatible_schema(self, spark_session):
+ # arrange
+ spark_client = SparkClient()
+ target_table = spark_session.sql(
+ "select 1 as feature_a, 'abc' as feature_b, true as other_feature"
+ )
+ input_dataframe = spark_session.sql(
+ "select 1 as feature_a, 'abc' as feature_b, true as unregisted_column"
+ )
+ target_table.registerTempTable("test")
+
+ hook = SparkTableSchemaCompatibilityHook(spark_client, "test")
+
+ # act and assert
+ with pytest.raises(ValueError, match="The dataframe has a schema incompatible"):
+ hook.run(input_dataframe)
diff --git a/tests/unit/butterfree/hooks/test_hookable_component.py b/tests/unit/butterfree/hooks/test_hookable_component.py
new file mode 100644
index 00000000..37e34e69
--- /dev/null
+++ b/tests/unit/butterfree/hooks/test_hookable_component.py
@@ -0,0 +1,107 @@
+import pytest
+from pyspark.sql.functions import expr
+
+from butterfree.hooks import Hook, HookableComponent
+from butterfree.testing.dataframe import assert_dataframe_equality
+
+
+class TestComponent(HookableComponent):
+ def construct(self, dataframe):
+ pre_hook_df = self.run_pre_hooks(dataframe)
+ construct_df = pre_hook_df.withColumn("feature", expr("feature * feature"))
+ return self.run_post_hooks(construct_df)
+
+
+class AddHook(Hook):
+ def __init__(self, value):
+ self.value = value
+
+ def run(self, dataframe):
+ return dataframe.withColumn("feature", expr(f"feature + {self.value}"))
+
+
+class TestHookableComponent:
+ def test_add_hooks(self):
+ # arrange
+ hook1 = AddHook(value=1)
+ hook2 = AddHook(value=2)
+ hook3 = AddHook(value=3)
+ hook4 = AddHook(value=4)
+ hookable_component = HookableComponent()
+
+ # act
+ hookable_component.add_pre_hook(hook1, hook2)
+ hookable_component.add_post_hook(hook3, hook4)
+
+ # assert
+ assert hookable_component.pre_hooks == [hook1, hook2]
+ assert hookable_component.post_hooks == [hook3, hook4]
+
+ @pytest.mark.parametrize(
+ "enable_pre_hooks, enable_post_hooks",
+ [("not boolean", False), (False, "not boolean")],
+ )
+ def test_invalid_enable_hook(self, enable_pre_hooks, enable_post_hooks):
+ # arrange
+ hookable_component = HookableComponent()
+
+ # act and assert
+ with pytest.raises(ValueError):
+ hookable_component.enable_pre_hooks = enable_pre_hooks
+ hookable_component.enable_post_hooks = enable_post_hooks
+
+ @pytest.mark.parametrize(
+ "pre_hooks, post_hooks",
+ [
+ ([AddHook(1)], "not a list of hooks"),
+ ([AddHook(1)], [AddHook(1), 2, 3]),
+ ("not a list of hooks", [AddHook(1)]),
+ ([AddHook(1), 2, 3], [AddHook(1)]),
+ ],
+ )
+ def test_invalid_hooks(self, pre_hooks, post_hooks):
+ # arrange
+ hookable_component = HookableComponent()
+
+ # act and assert
+ with pytest.raises(ValueError):
+ hookable_component.pre_hooks = pre_hooks
+ hookable_component.post_hooks = post_hooks
+
+ @pytest.mark.parametrize(
+ "pre_hook, enable_pre_hooks, post_hook, enable_post_hooks",
+ [
+ (AddHook(value=1), False, AddHook(value=1), True),
+ (AddHook(value=1), True, AddHook(value=1), False),
+ ("not a pre-hook", True, AddHook(value=1), True),
+ (AddHook(value=1), True, "not a pre-hook", True),
+ ],
+ )
+ def test_add_invalid_hooks(
+ self, pre_hook, enable_pre_hooks, post_hook, enable_post_hooks
+ ):
+ # arrange
+ hookable_component = HookableComponent()
+ hookable_component.enable_pre_hooks = enable_pre_hooks
+ hookable_component.enable_post_hooks = enable_post_hooks
+
+ # act and assert
+ with pytest.raises(ValueError):
+ hookable_component.add_pre_hook(pre_hook)
+ hookable_component.add_post_hook(post_hook)
+
+ def test_run_hooks(self, spark_session):
+ # arrange
+ input_dataframe = spark_session.sql("select 2 as feature")
+ test_component = (
+ TestComponent()
+ .add_pre_hook(AddHook(value=1))
+ .add_post_hook(AddHook(value=1))
+ )
+ target_table = spark_session.sql("select 10 as feature")
+
+ # act
+ output_df = test_component.construct(input_dataframe)
+
+ # assert
+ assert_dataframe_equality(output_df, target_table)
diff --git a/tests/unit/butterfree/load/conftest.py b/tests/unit/butterfree/load/conftest.py
index 7c2549c5..4dcf25c9 100644
--- a/tests/unit/butterfree/load/conftest.py
+++ b/tests/unit/butterfree/load/conftest.py
@@ -32,6 +32,31 @@ def feature_set():
)
+@fixture
+def feature_set_incremental():
+ key_features = [
+ KeyFeature(name="id", description="Description", dtype=DataType.INTEGER)
+ ]
+ ts_feature = TimestampFeature(from_column=TIMESTAMP_COLUMN)
+ features = [
+ Feature(
+ name="feature",
+ description="test",
+ transformation=AggregatedTransform(
+ functions=[Function(functions.sum, DataType.INTEGER)]
+ ),
+ ),
+ ]
+ return AggregatedFeatureSet(
+ "feature_set",
+ "entity",
+ "description",
+ keys=key_features,
+ timestamp=ts_feature,
+ features=features,
+ )
+
+
@fixture
def feature_set_dataframe(spark_context, spark_session):
data = [
diff --git a/tests/unit/butterfree/load/test_sink.py b/tests/unit/butterfree/load/test_sink.py
index 93b5e279..ef377f67 100644
--- a/tests/unit/butterfree/load/test_sink.py
+++ b/tests/unit/butterfree/load/test_sink.py
@@ -120,7 +120,7 @@ def test_flush_with_writers_list_empty(self):
with pytest.raises(ValueError):
Sink(writers=writer)
- def test_flush_streaming_df(self, feature_set):
+ def test_flush_streaming_df(self, feature_set, mocker):
"""Testing the return of the streaming handlers by the sink."""
# arrange
spark_client = SparkClient()
@@ -136,10 +136,25 @@ def test_flush_streaming_df(self, feature_set):
mocked_stream_df.start.return_value = Mock(spec=StreamingQuery)
online_feature_store_writer = OnlineFeatureStoreWriter()
+
+ online_feature_store_writer.check_schema_hook = mocker.stub("check_schema_hook")
+ online_feature_store_writer.check_schema_hook.run = mocker.stub("run")
+ online_feature_store_writer.check_schema_hook.run.return_value = (
+ mocked_stream_df
+ )
+
online_feature_store_writer_on_entity = OnlineFeatureStoreWriter(
write_to_entity=True
)
+ online_feature_store_writer_on_entity.check_schema_hook = mocker.stub(
+ "check_schema_hook"
+ )
+ online_feature_store_writer_on_entity.check_schema_hook.run = mocker.stub("run")
+ online_feature_store_writer_on_entity.check_schema_hook.run.return_value = (
+ mocked_stream_df
+ )
+
sink = Sink(
writers=[
online_feature_store_writer,
@@ -162,7 +177,7 @@ def test_flush_streaming_df(self, feature_set):
assert isinstance(handler, StreamingQuery)
def test_flush_with_multiple_online_writers(
- self, feature_set, feature_set_dataframe
+ self, feature_set, feature_set_dataframe, mocker
):
"""Testing the flow of writing to a feature-set table and to an entity table."""
# arrange
@@ -173,10 +188,25 @@ def test_flush_with_multiple_online_writers(
feature_set.name = "my_feature_set"
online_feature_store_writer = OnlineFeatureStoreWriter()
+
+ online_feature_store_writer.check_schema_hook = mocker.stub("check_schema_hook")
+ online_feature_store_writer.check_schema_hook.run = mocker.stub("run")
+ online_feature_store_writer.check_schema_hook.run.return_value = (
+ feature_set_dataframe
+ )
+
online_feature_store_writer_on_entity = OnlineFeatureStoreWriter(
write_to_entity=True
)
+ online_feature_store_writer_on_entity.check_schema_hook = mocker.stub(
+ "check_schema_hook"
+ )
+ online_feature_store_writer_on_entity.check_schema_hook.run = mocker.stub("run")
+ online_feature_store_writer_on_entity.check_schema_hook.run.return_value = (
+ feature_set_dataframe
+ )
+
sink = Sink(
writers=[online_feature_store_writer, online_feature_store_writer_on_entity]
)
diff --git a/tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py b/tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py
index 14c067f9..aac806f7 100644
--- a/tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py
+++ b/tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py
@@ -19,10 +19,15 @@ def test_write(
feature_set,
):
# given
- spark_client = mocker.stub("spark_client")
+ spark_client = SparkClient()
spark_client.write_table = mocker.stub("write_table")
writer = HistoricalFeatureStoreWriter()
+ schema_dataframe = writer._create_partitions(feature_set_dataframe)
+ writer.check_schema_hook = mocker.stub("check_schema_hook")
+ writer.check_schema_hook.run = mocker.stub("run")
+ writer.check_schema_hook.run.return_value = schema_dataframe
+
# when
writer.write(
feature_set=feature_set,
@@ -41,7 +46,76 @@ def test_write(
assert (
writer.PARTITION_BY == spark_client.write_table.call_args[1]["partition_by"]
)
- assert feature_set.name == spark_client.write_table.call_args[1]["table_name"]
+
+ def test_write_interval_mode(
+ self,
+ feature_set_dataframe,
+ historical_feature_set_dataframe,
+ mocker,
+ feature_set,
+ ):
+ # given
+ spark_client = SparkClient()
+ spark_client.write_dataframe = mocker.stub("write_dataframe")
+ spark_client.conn.conf.set(
+ "spark.sql.sources.partitionOverwriteMode", "dynamic"
+ )
+ writer = HistoricalFeatureStoreWriter(interval_mode=True)
+
+ schema_dataframe = writer._create_partitions(feature_set_dataframe)
+ writer.check_schema_hook = mocker.stub("check_schema_hook")
+ writer.check_schema_hook.run = mocker.stub("run")
+ writer.check_schema_hook.run.return_value = schema_dataframe
+
+ # when
+ writer.write(
+ feature_set=feature_set,
+ dataframe=feature_set_dataframe,
+ spark_client=spark_client,
+ )
+ result_df = spark_client.write_dataframe.call_args[1]["dataframe"]
+
+ # then
+ assert_dataframe_equality(historical_feature_set_dataframe, result_df)
+
+ assert (
+ writer.db_config.format_
+ == spark_client.write_dataframe.call_args[1]["format_"]
+ )
+ assert (
+ writer.db_config.mode == spark_client.write_dataframe.call_args[1]["mode"]
+ )
+ assert (
+ writer.PARTITION_BY
+ == spark_client.write_dataframe.call_args[1]["partitionBy"]
+ )
+
+ def test_write_interval_mode_invalid_partition_mode(
+ self,
+ feature_set_dataframe,
+ historical_feature_set_dataframe,
+ mocker,
+ feature_set,
+ ):
+ # given
+ spark_client = SparkClient()
+ spark_client.write_dataframe = mocker.stub("write_dataframe")
+ spark_client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "static")
+
+ writer = HistoricalFeatureStoreWriter(interval_mode=True)
+
+ schema_dataframe = writer._create_partitions(feature_set_dataframe)
+ writer.check_schema_hook = mocker.stub("check_schema_hook")
+ writer.check_schema_hook.run = mocker.stub("run")
+ writer.check_schema_hook.run.return_value = schema_dataframe
+
+ # when
+ with pytest.raises(RuntimeError):
+ _ = writer.write(
+ feature_set=feature_set,
+ dataframe=feature_set_dataframe,
+ spark_client=spark_client,
+ )
def test_write_in_debug_mode(
self,
@@ -49,6 +123,7 @@ def test_write_in_debug_mode(
historical_feature_set_dataframe,
feature_set,
spark_session,
+ mocker,
):
# given
spark_client = SparkClient()
@@ -65,33 +140,75 @@ def test_write_in_debug_mode(
# then
assert_dataframe_equality(historical_feature_set_dataframe, result_df)
- def test_validate(self, feature_set_dataframe, mocker, feature_set):
+ def test_write_in_debug_mode_with_interval_mode(
+ self,
+ feature_set_dataframe,
+ historical_feature_set_dataframe,
+ feature_set,
+ spark_session,
+ ):
+ # given
+ spark_client = SparkClient()
+ writer = HistoricalFeatureStoreWriter(debug_mode=True, interval_mode=True)
+
+ # when
+ writer.write(
+ feature_set=feature_set,
+ dataframe=feature_set_dataframe,
+ spark_client=spark_client,
+ )
+ result_df = spark_session.table(f"historical_feature_store__{feature_set.name}")
+
+ # then
+ assert_dataframe_equality(historical_feature_set_dataframe, result_df)
+
+ def test_validate(self, historical_feature_set_dataframe, mocker, feature_set):
# given
spark_client = mocker.stub("spark_client")
spark_client.read_table = mocker.stub("read_table")
- spark_client.read_table.return_value = feature_set_dataframe
+ spark_client.read_table.return_value = historical_feature_set_dataframe
writer = HistoricalFeatureStoreWriter()
# when
- writer.validate(feature_set, feature_set_dataframe, spark_client)
+ writer.validate(feature_set, historical_feature_set_dataframe, spark_client)
# then
spark_client.read_table.assert_called_once()
- def test_validate_false(self, feature_set_dataframe, mocker, feature_set):
+ def test_validate_interval_mode(
+ self, historical_feature_set_dataframe, mocker, feature_set
+ ):
# given
spark_client = mocker.stub("spark_client")
- spark_client.read_table = mocker.stub("read_table")
+ spark_client.read = mocker.stub("read")
+ spark_client.read.return_value = historical_feature_set_dataframe
+
+ writer = HistoricalFeatureStoreWriter(interval_mode=True)
+
+ # when
+ writer.validate(feature_set, historical_feature_set_dataframe, spark_client)
+
+ # then
+ spark_client.read.assert_called_once()
+
+ def test_validate_false(
+ self, historical_feature_set_dataframe, mocker, feature_set
+ ):
+ # given
+ spark_client = mocker.stub("spark_client")
+ spark_client.read = mocker.stub("read")
# limiting df to 1 row, now the counts should'n be the same
- spark_client.read_table.return_value = feature_set_dataframe.limit(1)
+ spark_client.read.return_value = historical_feature_set_dataframe.limit(1)
- writer = HistoricalFeatureStoreWriter()
+ writer = HistoricalFeatureStoreWriter(interval_mode=True)
# when
with pytest.raises(AssertionError):
- _ = writer.validate(feature_set, feature_set_dataframe, spark_client)
+ _ = writer.validate(
+ feature_set, historical_feature_set_dataframe, spark_client
+ )
def test__create_partitions(self, spark_session, spark_context):
# arrange
@@ -201,8 +318,15 @@ def test_write_with_transform(
# given
spark_client = mocker.stub("spark_client")
spark_client.write_table = mocker.stub("write_table")
+
writer = HistoricalFeatureStoreWriter().with_(json_transform)
+ schema_dataframe = writer._create_partitions(feature_set_dataframe)
+ json_dataframe = writer._apply_transformations(schema_dataframe)
+ writer.check_schema_hook = mocker.stub("check_schema_hook")
+ writer.check_schema_hook.run = mocker.stub("run")
+ writer.check_schema_hook.run.return_value = json_dataframe
+
# when
writer.write(
feature_set=feature_set,
diff --git a/tests/unit/butterfree/load/writers/test_online_feature_store_writer.py b/tests/unit/butterfree/load/writers/test_online_feature_store_writer.py
index 87823c55..384ec152 100644
--- a/tests/unit/butterfree/load/writers/test_online_feature_store_writer.py
+++ b/tests/unit/butterfree/load/writers/test_online_feature_store_writer.py
@@ -68,6 +68,10 @@ def test_write(
spark_client.write_dataframe = mocker.stub("write_dataframe")
writer = OnlineFeatureStoreWriter(cassandra_config)
+ writer.check_schema_hook = mocker.stub("check_schema_hook")
+ writer.check_schema_hook.run = mocker.stub("run")
+ writer.check_schema_hook.run.return_value = feature_set_dataframe
+
# when
writer.write(feature_set, feature_set_dataframe, spark_client)
@@ -94,11 +98,16 @@ def test_write_in_debug_mode(
latest_feature_set_dataframe,
feature_set,
spark_session,
+ mocker,
):
# given
spark_client = SparkClient()
writer = OnlineFeatureStoreWriter(debug_mode=True)
+ writer.check_schema_hook = mocker.stub("check_schema_hook")
+ writer.check_schema_hook.run = mocker.stub("run")
+ writer.check_schema_hook.run.return_value = feature_set_dataframe
+
# when
writer.write(
feature_set=feature_set,
@@ -110,9 +119,7 @@ def test_write_in_debug_mode(
# then
assert_dataframe_equality(latest_feature_set_dataframe, result_df)
- def test_write_in_debug_and_stream_mode(
- self, feature_set, spark_session,
- ):
+ def test_write_in_debug_and_stream_mode(self, feature_set, spark_session, mocker):
# arrange
spark_client = SparkClient()
@@ -125,6 +132,10 @@ def test_write_in_debug_and_stream_mode(
writer = OnlineFeatureStoreWriter(debug_mode=True)
+ writer.check_schema_hook = mocker.stub("check_schema_hook")
+ writer.check_schema_hook.run = mocker.stub("run")
+ writer.check_schema_hook.run.return_value = mocked_stream_df
+
# act
handler = writer.write(
feature_set=feature_set,
@@ -140,7 +151,7 @@ def test_write_in_debug_and_stream_mode(
assert isinstance(handler, StreamingQuery)
@pytest.mark.parametrize("has_checkpoint", [True, False])
- def test_write_stream(self, feature_set, has_checkpoint, monkeypatch):
+ def test_write_stream(self, feature_set, has_checkpoint, monkeypatch, mocker):
# arrange
spark_client = SparkClient()
spark_client.write_stream = Mock()
@@ -163,6 +174,10 @@ def test_write_stream(self, feature_set, has_checkpoint, monkeypatch):
writer = OnlineFeatureStoreWriter(cassandra_config)
writer.filter_latest = Mock()
+ writer.check_schema_hook = mocker.stub("check_schema_hook")
+ writer.check_schema_hook.run = mocker.stub("run")
+ writer.check_schema_hook.run.return_value = dataframe
+
# act
stream_handler = writer.write(feature_set, dataframe, spark_client)
@@ -186,7 +201,7 @@ def test_get_db_schema(self, cassandra_config, test_feature_set, expected_schema
assert schema == expected_schema
- def test_write_stream_on_entity(self, feature_set, monkeypatch):
+ def test_write_stream_on_entity(self, feature_set, monkeypatch, mocker):
"""Test write method with stream dataframe and write_to_entity enabled.
The main purpose of this test is assert the correct setup of stream checkpoint
@@ -209,6 +224,10 @@ def test_write_stream_on_entity(self, feature_set, monkeypatch):
writer = OnlineFeatureStoreWriter(write_to_entity=True)
+ writer.check_schema_hook = mocker.stub("check_schema_hook")
+ writer.check_schema_hook.run = mocker.stub("run")
+ writer.check_schema_hook.run.return_value = dataframe
+
# act
stream_handler = writer.write(feature_set, dataframe, spark_client)
@@ -237,6 +256,10 @@ def test_write_with_transform(
spark_client.write_dataframe = mocker.stub("write_dataframe")
writer = OnlineFeatureStoreWriter(cassandra_config).with_(json_transform)
+ writer.check_schema_hook = mocker.stub("check_schema_hook")
+ writer.check_schema_hook.run = mocker.stub("run")
+ writer.check_schema_hook.run.return_value = feature_set_dataframe
+
# when
writer.write(feature_set, feature_set_dataframe, spark_client)
@@ -270,6 +293,10 @@ def test_write_with_kafka_config(
kafka_config = KafkaConfig()
writer = OnlineFeatureStoreWriter(kafka_config).with_(json_transform)
+ writer.check_schema_hook = mocker.stub("check_schema_hook")
+ writer.check_schema_hook.run = mocker.stub("run")
+ writer.check_schema_hook.run.return_value = feature_set_dataframe
+
# when
writer.write(feature_set, feature_set_dataframe, spark_client)
@@ -293,6 +320,10 @@ def test_write_with_custom_kafka_config(
json_transform
)
+ custom_writer.check_schema_hook = mocker.stub("check_schema_hook")
+ custom_writer.check_schema_hook.run = mocker.stub("run")
+ custom_writer.check_schema_hook.run.return_value = feature_set_dataframe
+
# when
custom_writer.write(feature_set, feature_set_dataframe, spark_client)
diff --git a/tests/unit/butterfree/pipelines/conftest.py b/tests/unit/butterfree/pipelines/conftest.py
new file mode 100644
index 00000000..47e65efb
--- /dev/null
+++ b/tests/unit/butterfree/pipelines/conftest.py
@@ -0,0 +1,63 @@
+from unittest.mock import Mock
+
+from pyspark.sql import functions
+from pytest import fixture
+
+from butterfree.clients import SparkClient
+from butterfree.constants import DataType
+from butterfree.constants.columns import TIMESTAMP_COLUMN
+from butterfree.extract import Source
+from butterfree.extract.readers import TableReader
+from butterfree.load import Sink
+from butterfree.load.writers import HistoricalFeatureStoreWriter
+from butterfree.pipelines import FeatureSetPipeline
+from butterfree.transform import FeatureSet
+from butterfree.transform.features import Feature, KeyFeature, TimestampFeature
+from butterfree.transform.transformations import SparkFunctionTransform
+from butterfree.transform.utils import Function
+
+
+@fixture()
+def feature_set_pipeline():
+ test_pipeline = FeatureSetPipeline(
+ spark_client=SparkClient(),
+ source=Mock(
+ spec=Source,
+ readers=[TableReader(id="source_a", database="db", table="table",)],
+ query="select * from source_a",
+ ),
+ feature_set=Mock(
+ spec=FeatureSet,
+ name="feature_set",
+ entity="entity",
+ description="description",
+ keys=[
+ KeyFeature(
+ name="user_id",
+ description="The user's Main ID or device ID",
+ dtype=DataType.INTEGER,
+ )
+ ],
+ timestamp=TimestampFeature(from_column="ts"),
+ features=[
+ Feature(
+ name="listing_page_viewed__rent_per_month",
+ description="Average of something.",
+ transformation=SparkFunctionTransform(
+ functions=[
+ Function(functions.avg, DataType.FLOAT),
+ Function(functions.stddev_pop, DataType.FLOAT),
+ ],
+ ).with_window(
+ partition_by="user_id",
+ order_by=TIMESTAMP_COLUMN,
+ window_definition=["7 days", "2 weeks"],
+ mode="fixed_windows",
+ ),
+ ),
+ ],
+ ),
+ sink=Mock(spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)],),
+ )
+
+ return test_pipeline
diff --git a/tests/unit/butterfree/pipelines/test_feature_set_pipeline.py b/tests/unit/butterfree/pipelines/test_feature_set_pipeline.py
index 1bc3c707..7bae6606 100644
--- a/tests/unit/butterfree/pipelines/test_feature_set_pipeline.py
+++ b/tests/unit/butterfree/pipelines/test_feature_set_pipeline.py
@@ -17,12 +17,8 @@
from butterfree.load.writers.writer import Writer
from butterfree.pipelines.feature_set_pipeline import FeatureSetPipeline
from butterfree.transform import FeatureSet
-from butterfree.transform.aggregated_feature_set import AggregatedFeatureSet
from butterfree.transform.features import Feature, KeyFeature, TimestampFeature
-from butterfree.transform.transformations import (
- AggregatedTransform,
- SparkFunctionTransform,
-)
+from butterfree.transform.transformations import SparkFunctionTransform
from butterfree.transform.utils import Function
@@ -104,115 +100,29 @@ def test_feature_set_args(self):
assert len(pipeline.sink.writers) == 2
assert all(isinstance(writer, Writer) for writer in pipeline.sink.writers)
- def test_run(self, spark_session):
- test_pipeline = FeatureSetPipeline(
- spark_client=SparkClient(),
- source=Mock(
- spec=Source,
- readers=[TableReader(id="source_a", database="db", table="table",)],
- query="select * from source_a",
- ),
- feature_set=Mock(
- spec=FeatureSet,
- name="feature_set",
- entity="entity",
- description="description",
- keys=[
- KeyFeature(
- name="user_id",
- description="The user's Main ID or device ID",
- dtype=DataType.INTEGER,
- )
- ],
- timestamp=TimestampFeature(from_column="ts"),
- features=[
- Feature(
- name="listing_page_viewed__rent_per_month",
- description="Average of something.",
- transformation=SparkFunctionTransform(
- functions=[
- Function(functions.avg, DataType.FLOAT),
- Function(functions.stddev_pop, DataType.FLOAT),
- ],
- ).with_window(
- partition_by="user_id",
- order_by=TIMESTAMP_COLUMN,
- window_definition=["7 days", "2 weeks"],
- mode="fixed_windows",
- ),
- ),
- ],
- ),
- sink=Mock(
- spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)],
- ),
- )
-
+ def test_run(self, spark_session, feature_set_pipeline):
# feature_set need to return a real df for streaming validation
sample_df = spark_session.createDataFrame([{"a": "x", "b": "y", "c": "3"}])
- test_pipeline.feature_set.construct.return_value = sample_df
+ feature_set_pipeline.feature_set.construct.return_value = sample_df
- test_pipeline.run()
+ feature_set_pipeline.run()
- test_pipeline.source.construct.assert_called_once()
- test_pipeline.feature_set.construct.assert_called_once()
- test_pipeline.sink.flush.assert_called_once()
- test_pipeline.sink.validate.assert_called_once()
-
- def test_run_with_repartition(self, spark_session):
- test_pipeline = FeatureSetPipeline(
- spark_client=SparkClient(),
- source=Mock(
- spec=Source,
- readers=[TableReader(id="source_a", database="db", table="table",)],
- query="select * from source_a",
- ),
- feature_set=Mock(
- spec=FeatureSet,
- name="feature_set",
- entity="entity",
- description="description",
- keys=[
- KeyFeature(
- name="user_id",
- description="The user's Main ID or device ID",
- dtype=DataType.INTEGER,
- )
- ],
- timestamp=TimestampFeature(from_column="ts"),
- features=[
- Feature(
- name="listing_page_viewed__rent_per_month",
- description="Average of something.",
- transformation=SparkFunctionTransform(
- functions=[
- Function(functions.avg, DataType.FLOAT),
- Function(functions.stddev_pop, DataType.FLOAT),
- ],
- ).with_window(
- partition_by="user_id",
- order_by=TIMESTAMP_COLUMN,
- window_definition=["7 days", "2 weeks"],
- mode="fixed_windows",
- ),
- ),
- ],
- ),
- sink=Mock(
- spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)],
- ),
- )
+ feature_set_pipeline.source.construct.assert_called_once()
+ feature_set_pipeline.feature_set.construct.assert_called_once()
+ feature_set_pipeline.sink.flush.assert_called_once()
+ feature_set_pipeline.sink.validate.assert_called_once()
+ def test_run_with_repartition(self, spark_session, feature_set_pipeline):
# feature_set need to return a real df for streaming validation
sample_df = spark_session.createDataFrame([{"a": "x", "b": "y", "c": "3"}])
- test_pipeline.feature_set.construct.return_value = sample_df
+ feature_set_pipeline.feature_set.construct.return_value = sample_df
- test_pipeline.run(partition_by=["id"])
+ feature_set_pipeline.run(partition_by=["id"])
- test_pipeline.source.construct.assert_called_once()
- test_pipeline.feature_set.construct.assert_called_once()
- test_pipeline.sink.flush.assert_called_once()
- test_pipeline.sink.validate.assert_called_once()
+ feature_set_pipeline.source.construct.assert_called_once()
+ feature_set_pipeline.feature_set.construct.assert_called_once()
+ feature_set_pipeline.sink.flush.assert_called_once()
+ feature_set_pipeline.sink.validate.assert_called_once()
def test_source_raise(self):
with pytest.raises(ValueError, match="source must be a Source instance"):
@@ -343,52 +253,26 @@ def test_sink_raise(self):
sink=Mock(writers=[HistoricalFeatureStoreWriter(db_config=None)],),
)
- def test_run_agg_with_end_date(self, spark_session):
- test_pipeline = FeatureSetPipeline(
- spark_client=SparkClient(),
- source=Mock(
- spec=Source,
- readers=[TableReader(id="source_a", database="db", table="table",)],
- query="select * from source_a",
- ),
- feature_set=Mock(
- spec=AggregatedFeatureSet,
- name="feature_set",
- entity="entity",
- description="description",
- keys=[
- KeyFeature(
- name="user_id",
- description="The user's Main ID or device ID",
- dtype=DataType.INTEGER,
- )
- ],
- timestamp=TimestampFeature(from_column="ts"),
- features=[
- Feature(
- name="listing_page_viewed__rent_per_month",
- description="Average of something.",
- transformation=AggregatedTransform(
- functions=[
- Function(functions.avg, DataType.FLOAT),
- Function(functions.stddev_pop, DataType.FLOAT),
- ],
- ),
- ),
- ],
- ),
- sink=Mock(
- spec=Sink, writers=[HistoricalFeatureStoreWriter(db_config=None)],
- ),
- )
+ def test_run_agg_with_end_date(self, spark_session, feature_set_pipeline):
+ # feature_set need to return a real df for streaming validation
+ sample_df = spark_session.createDataFrame([{"a": "x", "b": "y", "c": "3"}])
+ feature_set_pipeline.feature_set.construct.return_value = sample_df
+
+ feature_set_pipeline.run(end_date="2016-04-18")
+
+ feature_set_pipeline.source.construct.assert_called_once()
+ feature_set_pipeline.feature_set.construct.assert_called_once()
+ feature_set_pipeline.sink.flush.assert_called_once()
+ feature_set_pipeline.sink.validate.assert_called_once()
+ def test_run_agg_with_start_date(self, spark_session, feature_set_pipeline):
# feature_set need to return a real df for streaming validation
sample_df = spark_session.createDataFrame([{"a": "x", "b": "y", "c": "3"}])
- test_pipeline.feature_set.construct.return_value = sample_df
+ feature_set_pipeline.feature_set.construct.return_value = sample_df
- test_pipeline.run(end_date="2016-04-18")
+ feature_set_pipeline.run(start_date="2020-08-04")
- test_pipeline.source.construct.assert_called_once()
- test_pipeline.feature_set.construct.assert_called_once()
- test_pipeline.sink.flush.assert_called_once()
- test_pipeline.sink.validate.assert_called_once()
+ feature_set_pipeline.source.construct.assert_called_once()
+ feature_set_pipeline.feature_set.construct.assert_called_once()
+ feature_set_pipeline.sink.flush.assert_called_once()
+ feature_set_pipeline.sink.validate.assert_called_once()
diff --git a/tests/unit/butterfree/transform/conftest.py b/tests/unit/butterfree/transform/conftest.py
index 2d7d3e50..febc8bbc 100644
--- a/tests/unit/butterfree/transform/conftest.py
+++ b/tests/unit/butterfree/transform/conftest.py
@@ -1,11 +1,19 @@
import json
from unittest.mock import Mock
+from pyspark.sql import functions
from pytest import fixture
from butterfree.constants import DataType
from butterfree.constants.columns import TIMESTAMP_COLUMN
+from butterfree.transform import FeatureSet
+from butterfree.transform.aggregated_feature_set import AggregatedFeatureSet
from butterfree.transform.features import Feature, KeyFeature, TimestampFeature
+from butterfree.transform.transformations import (
+ AggregatedTransform,
+ SparkFunctionTransform,
+)
+from butterfree.transform.utils import Function
def make_dataframe(spark_context, spark_session):
@@ -297,3 +305,77 @@ def key_id():
@fixture
def timestamp_c():
return TimestampFeature()
+
+
+@fixture
+def feature_set():
+ feature_set = FeatureSet(
+ name="feature_set",
+ entity="entity",
+ description="description",
+ features=[
+ Feature(
+ name="feature1",
+ description="test",
+ transformation=SparkFunctionTransform(
+ functions=[
+ Function(functions.avg, DataType.FLOAT),
+ Function(functions.stddev_pop, DataType.DOUBLE),
+ ]
+ ).with_window(
+ partition_by="id",
+ order_by=TIMESTAMP_COLUMN,
+ mode="fixed_windows",
+ window_definition=["2 minutes", "15 minutes"],
+ ),
+ ),
+ ],
+ keys=[
+ KeyFeature(
+ name="id",
+ description="The user's Main ID or device ID",
+ dtype=DataType.BIGINT,
+ )
+ ],
+ timestamp=TimestampFeature(),
+ )
+
+ return feature_set
+
+
+@fixture
+def agg_feature_set():
+ feature_set = AggregatedFeatureSet(
+ name="feature_set",
+ entity="entity",
+ description="description",
+ features=[
+ Feature(
+ name="feature1",
+ description="test",
+ transformation=AggregatedTransform(
+ functions=[
+ Function(functions.avg, DataType.DOUBLE),
+ Function(functions.stddev_pop, DataType.FLOAT),
+ ],
+ ),
+ ),
+ Feature(
+ name="feature2",
+ description="test",
+ transformation=AggregatedTransform(
+ functions=[Function(functions.count, DataType.ARRAY_STRING)]
+ ),
+ ),
+ ],
+ keys=[
+ KeyFeature(
+ name="id",
+ description="The user's Main ID or device ID",
+ dtype=DataType.BIGINT,
+ )
+ ],
+ timestamp=TimestampFeature(),
+ ).with_windows(definitions=["1 week", "2 days"])
+
+ return feature_set
diff --git a/tests/unit/butterfree/transform/test_aggregated_feature_set.py b/tests/unit/butterfree/transform/test_aggregated_feature_set.py
index 2c404fea..8025d6f8 100644
--- a/tests/unit/butterfree/transform/test_aggregated_feature_set.py
+++ b/tests/unit/butterfree/transform/test_aggregated_feature_set.py
@@ -89,7 +89,7 @@ def test_agg_feature_set_with_window(
output_df = fs.construct(dataframe, spark_client, end_date="2016-05-01")
assert_dataframe_equality(output_df, rolling_windows_agg_dataframe)
- def test_get_schema(self):
+ def test_get_schema(self, agg_feature_set):
expected_schema = [
{"column_name": "id", "type": LongType(), "primary_key": True},
{"column_name": "timestamp", "type": TimestampType(), "primary_key": False},
@@ -125,40 +125,7 @@ def test_get_schema(self):
},
]
- feature_set = AggregatedFeatureSet(
- name="feature_set",
- entity="entity",
- description="description",
- features=[
- Feature(
- name="feature1",
- description="test",
- transformation=AggregatedTransform(
- functions=[
- Function(functions.avg, DataType.DOUBLE),
- Function(functions.stddev_pop, DataType.FLOAT),
- ],
- ),
- ),
- Feature(
- name="feature2",
- description="test",
- transformation=AggregatedTransform(
- functions=[Function(functions.count, DataType.ARRAY_STRING)]
- ),
- ),
- ],
- keys=[
- KeyFeature(
- name="id",
- description="The user's Main ID or device ID",
- dtype=DataType.BIGINT,
- )
- ],
- timestamp=TimestampFeature(),
- ).with_windows(definitions=["1 week", "2 days"])
-
- schema = feature_set.get_schema()
+ schema = agg_feature_set.get_schema()
assert schema == expected_schema
@@ -389,3 +356,34 @@ def test_feature_transform_with_data_type_array(self, spark_context, spark_sessi
# assert
assert_dataframe_equality(target_df, output_df)
+
+ def test_define_start_date(self, agg_feature_set):
+ start_date = agg_feature_set.define_start_date("2020-08-04")
+
+ assert isinstance(start_date, str)
+ assert start_date == "2020-07-27"
+
+ def test_feature_set_start_date(
+ self, timestamp_c, feature_set_with_distinct_dataframe,
+ ):
+ fs = AggregatedFeatureSet(
+ name="name",
+ entity="entity",
+ description="description",
+ features=[
+ Feature(
+ name="feature",
+ description="test",
+ transformation=AggregatedTransform(
+ functions=[Function(functions.sum, DataType.INTEGER)]
+ ),
+ ),
+ ],
+ keys=[KeyFeature(name="h3", description="test", dtype=DataType.STRING)],
+ timestamp=timestamp_c,
+ ).with_windows(["10 days", "3 weeks", "90 days"])
+
+ # assert
+ start_date = fs.define_start_date("2016-04-14")
+
+ assert start_date == "2016-01-14"
diff --git a/tests/unit/butterfree/transform/test_feature_set.py b/tests/unit/butterfree/transform/test_feature_set.py
index bdb1ff7d..43d937be 100644
--- a/tests/unit/butterfree/transform/test_feature_set.py
+++ b/tests/unit/butterfree/transform/test_feature_set.py
@@ -12,13 +12,11 @@
from butterfree.clients import SparkClient
from butterfree.constants import DataType
-from butterfree.constants.columns import TIMESTAMP_COLUMN
from butterfree.testing.dataframe import assert_dataframe_equality
from butterfree.transform import FeatureSet
-from butterfree.transform.features import Feature, KeyFeature, TimestampFeature
+from butterfree.transform.features import Feature
from butterfree.transform.transformations import (
AggregatedTransform,
- SparkFunctionTransform,
SQLExpressionTransform,
)
from butterfree.transform.utils import Function
@@ -341,7 +339,7 @@ def test_feature_set_with_invalid_feature(self, key_id, timestamp_c, dataframe):
timestamp=timestamp_c,
).construct(dataframe, spark_client)
- def test_get_schema(self):
+ def test_get_schema(self, feature_set):
expected_schema = [
{"column_name": "id", "type": LongType(), "primary_key": True},
{"column_name": "timestamp", "type": TimestampType(), "primary_key": False},
@@ -367,37 +365,6 @@ def test_get_schema(self):
},
]
- feature_set = FeatureSet(
- name="feature_set",
- entity="entity",
- description="description",
- features=[
- Feature(
- name="feature1",
- description="test",
- transformation=SparkFunctionTransform(
- functions=[
- Function(F.avg, DataType.FLOAT),
- Function(F.stddev_pop, DataType.DOUBLE),
- ]
- ).with_window(
- partition_by="id",
- order_by=TIMESTAMP_COLUMN,
- mode="fixed_windows",
- window_definition=["2 minutes", "15 minutes"],
- ),
- ),
- ],
- keys=[
- KeyFeature(
- name="id",
- description="The user's Main ID or device ID",
- dtype=DataType.BIGINT,
- )
- ],
- timestamp=TimestampFeature(),
- )
-
schema = feature_set.get_schema()
assert schema == expected_schema
@@ -421,3 +388,9 @@ def test_feature_without_datatype(self, key_id, timestamp_c, dataframe):
keys=[key_id],
timestamp=timestamp_c,
).construct(dataframe, spark_client)
+
+ def test_define_start_date(self, feature_set):
+ start_date = feature_set.define_start_date("2020-08-04")
+
+ assert isinstance(start_date, str)
+ assert start_date == "2020-08-04"