Skip to content

Commit

Permalink
Fix Validation Step (#302)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlvaroMarquesAndrade authored Mar 22, 2021
1 parent 5fe4c40 commit e8fc0da
Show file tree
Hide file tree
Showing 9 changed files with 12 additions and 144 deletions.
2 changes: 1 addition & 1 deletion butterfree/configs/db/metastore_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,4 @@ def get_path_with_partitions(self, key: str, dataframe: DataFrame) -> List:

def translate(self, schema: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Translate feature set spark schema to the corresponding database."""
pass
return schema
9 changes: 1 addition & 8 deletions butterfree/load/writers/historical_feature_store_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,7 @@ def write(
"""
dataframe = self._create_partitions(dataframe)

partition_df = self._apply_transformations(dataframe)

if self.debug_mode:
dataframe = partition_df
else:
dataframe = self.check_schema(
spark_client, partition_df, feature_set.name, self.database
)
dataframe = self._apply_transformations(dataframe)

if self.interval_mode:
if self.debug_mode:
Expand Down
18 changes: 1 addition & 17 deletions butterfree/load/writers/online_feature_store_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pyspark.sql.functions import col, row_number
from pyspark.sql.streaming import StreamingQuery

from butterfree.clients import CassandraClient, SparkClient
from butterfree.clients import SparkClient
from butterfree.configs.db import AbstractWriteConfig, CassandraConfig
from butterfree.constants.columns import TIMESTAMP_COLUMN
from butterfree.hooks import Hook
Expand Down Expand Up @@ -180,22 +180,6 @@ def write(
"""
table_name = feature_set.entity if self.write_to_entity else feature_set.name

if not self.debug_mode:
config = (
self.db_config
if self.db_config == CassandraConfig
else CassandraConfig()
)

cassandra_client = CassandraClient(
host=[config.host],
keyspace=config.keyspace,
user=config.username,
password=config.password,
)

dataframe = self.check_schema(cassandra_client, dataframe, table_name)

if dataframe.isStreaming:
dataframe = self._apply_transformations(dataframe)
if self.debug_mode:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from setuptools import find_packages, setup

__package_name__ = "butterfree"
__version__ = "1.2.0.dev2"
__version__ = "1.2.0.dev3"
__repository_url__ = "https://github.com/quintoandar/butterfree"

with open("requirements.txt") as f:
Expand Down
11 changes: 1 addition & 10 deletions tests/integration/butterfree/load/test_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)


def test_sink(input_dataframe, feature_set, mocker):
def test_sink(input_dataframe, feature_set):
# arrange
client = SparkClient()
client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
Expand All @@ -34,11 +34,6 @@ def test_sink(input_dataframe, feature_set, mocker):
db_config=s3config, interval_mode=True
)

schema_dataframe = historical_writer._create_partitions(feature_set_df)
historical_writer.check_schema_hook = mocker.stub("check_schema_hook")
historical_writer.check_schema_hook.run = mocker.stub("run")
historical_writer.check_schema_hook.run.return_value = schema_dataframe

# setup online writer
# TODO: Change for CassandraConfig when Cassandra for test is ready
online_config = Mock()
Expand All @@ -49,10 +44,6 @@ def test_sink(input_dataframe, feature_set, mocker):
)
online_writer = OnlineFeatureStoreWriter(db_config=online_config)

online_writer.check_schema_hook = mocker.stub("check_schema_hook")
online_writer.check_schema_hook.run = mocker.stub("run")
online_writer.check_schema_hook.run.return_value = feature_set_df

writers = [historical_writer, online_writer]
sink = Sink(writers)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pyspark.sql import DataFrame
from pyspark.sql import functions as F

from butterfree.clients import SparkClient
from butterfree.configs import environment
from butterfree.configs.db import MetastoreConfig
from butterfree.constants import DataType
Expand Down Expand Up @@ -75,11 +74,7 @@ def create_ymd(dataframe):

class TestFeatureSetPipeline:
def test_feature_set_pipeline(
self,
mocked_df,
spark_session,
fixed_windows_output_feature_set_dataframe,
mocker,
self, mocked_df, spark_session, fixed_windows_output_feature_set_dataframe,
):
# arrange
table_reader_id = "a_source"
Expand All @@ -93,11 +88,6 @@ def test_feature_set_pipeline(
table_reader_table=table_reader_table,
)

spark_client = SparkClient()
spark_client.conn.conf.set(
"spark.sql.sources.partitionOverwriteMode", "dynamic"
)

dbconfig = Mock()
dbconfig.mode = "overwrite"
dbconfig.format_ = "parquet"
Expand All @@ -107,12 +97,6 @@ def test_feature_set_pipeline(

historical_writer = HistoricalFeatureStoreWriter(db_config=dbconfig)

historical_writer.check_schema_hook = mocker.stub("check_schema_hook")
historical_writer.check_schema_hook.run = mocker.stub("run")
historical_writer.check_schema_hook.run.return_value = (
fixed_windows_output_feature_set_dataframe
)

# act
test_pipeline = FeatureSetPipeline(
source=Source(
Expand Down Expand Up @@ -187,7 +171,6 @@ def test_feature_set_pipeline_with_dates(
spark_session,
fixed_windows_output_feature_set_date_dataframe,
feature_set_pipeline,
mocker,
):
# arrange
table_reader_table = "b_table"
Expand All @@ -211,7 +194,6 @@ def test_feature_set_pipeline_with_execution_date(
spark_session,
fixed_windows_output_feature_set_date_dataframe,
feature_set_pipeline,
mocker,
):
# arrange
table_reader_table = "b_table"
Expand All @@ -233,7 +215,7 @@ def test_feature_set_pipeline_with_execution_date(
# assert
assert_dataframe_equality(df, target_df)

def test_pipeline_with_hooks(self, spark_session, mocker):
def test_pipeline_with_hooks(self, spark_session):
# arrange
hook1 = AddHook(value=1)

Expand Down
32 changes: 2 additions & 30 deletions tests/unit/butterfree/load/test_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_flush_with_writers_list_empty(self):
with pytest.raises(ValueError):
Sink(writers=writer)

def test_flush_streaming_df(self, feature_set, mocker):
def test_flush_streaming_df(self, feature_set):
"""Testing the return of the streaming handlers by the sink."""
# arrange
spark_client = SparkClient()
Expand All @@ -137,24 +137,10 @@ def test_flush_streaming_df(self, feature_set, mocker):

online_feature_store_writer = OnlineFeatureStoreWriter()

online_feature_store_writer.check_schema_hook = mocker.stub("check_schema_hook")
online_feature_store_writer.check_schema_hook.run = mocker.stub("run")
online_feature_store_writer.check_schema_hook.run.return_value = (
mocked_stream_df
)

online_feature_store_writer_on_entity = OnlineFeatureStoreWriter(
write_to_entity=True
)

online_feature_store_writer_on_entity.check_schema_hook = mocker.stub(
"check_schema_hook"
)
online_feature_store_writer_on_entity.check_schema_hook.run = mocker.stub("run")
online_feature_store_writer_on_entity.check_schema_hook.run.return_value = (
mocked_stream_df
)

sink = Sink(
writers=[
online_feature_store_writer,
Expand All @@ -177,7 +163,7 @@ def test_flush_streaming_df(self, feature_set, mocker):
assert isinstance(handler, StreamingQuery)

def test_flush_with_multiple_online_writers(
self, feature_set, feature_set_dataframe, mocker
self, feature_set, feature_set_dataframe
):
"""Testing the flow of writing to a feature-set table and to an entity table."""
# arrange
Expand All @@ -189,24 +175,10 @@ def test_flush_with_multiple_online_writers(

online_feature_store_writer = OnlineFeatureStoreWriter()

online_feature_store_writer.check_schema_hook = mocker.stub("check_schema_hook")
online_feature_store_writer.check_schema_hook.run = mocker.stub("run")
online_feature_store_writer.check_schema_hook.run.return_value = (
feature_set_dataframe
)

online_feature_store_writer_on_entity = OnlineFeatureStoreWriter(
write_to_entity=True
)

online_feature_store_writer_on_entity.check_schema_hook = mocker.stub(
"check_schema_hook"
)
online_feature_store_writer_on_entity.check_schema_hook.run = mocker.stub("run")
online_feature_store_writer_on_entity.check_schema_hook.run.return_value = (
feature_set_dataframe
)

sink = Sink(
writers=[online_feature_store_writer, online_feature_store_writer_on_entity]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@ def test_write(
spark_client.write_table = mocker.stub("write_table")
writer = HistoricalFeatureStoreWriter()

schema_dataframe = writer._create_partitions(feature_set_dataframe)
writer.check_schema_hook = mocker.stub("check_schema_hook")
writer.check_schema_hook.run = mocker.stub("run")
writer.check_schema_hook.run.return_value = schema_dataframe

# when
writer.write(
feature_set=feature_set,
Expand Down Expand Up @@ -62,11 +57,6 @@ def test_write_interval_mode(
)
writer = HistoricalFeatureStoreWriter(interval_mode=True)

schema_dataframe = writer._create_partitions(feature_set_dataframe)
writer.check_schema_hook = mocker.stub("check_schema_hook")
writer.check_schema_hook.run = mocker.stub("run")
writer.check_schema_hook.run.return_value = schema_dataframe

# when
writer.write(
feature_set=feature_set,
Expand Down Expand Up @@ -104,11 +94,6 @@ def test_write_interval_mode_invalid_partition_mode(

writer = HistoricalFeatureStoreWriter(interval_mode=True)

schema_dataframe = writer._create_partitions(feature_set_dataframe)
writer.check_schema_hook = mocker.stub("check_schema_hook")
writer.check_schema_hook.run = mocker.stub("run")
writer.check_schema_hook.run.return_value = schema_dataframe

# when
with pytest.raises(RuntimeError):
_ = writer.write(
Expand All @@ -123,7 +108,6 @@ def test_write_in_debug_mode(
historical_feature_set_dataframe,
feature_set,
spark_session,
mocker,
):
# given
spark_client = SparkClient()
Expand Down Expand Up @@ -321,12 +305,6 @@ def test_write_with_transform(

writer = HistoricalFeatureStoreWriter().with_(json_transform)

schema_dataframe = writer._create_partitions(feature_set_dataframe)
json_dataframe = writer._apply_transformations(schema_dataframe)
writer.check_schema_hook = mocker.stub("check_schema_hook")
writer.check_schema_hook.run = mocker.stub("run")
writer.check_schema_hook.run.return_value = json_dataframe

# when
writer.write(
feature_set=feature_set,
Expand Down
Loading

0 comments on commit e8fc0da

Please sign in to comment.