diff --git a/butterfree/_cli/migrate.py b/butterfree/_cli/migrate.py index a718832d..811796d3 100644 --- a/butterfree/_cli/migrate.py +++ b/butterfree/_cli/migrate.py @@ -5,7 +5,7 @@ import os import pkgutil import sys -from typing import Set, Type +from typing import Set import boto3 import setuptools @@ -90,18 +90,8 @@ def __fs_objects(path: str) -> Set[FeatureSetPipeline]: instances.add(value) - def create_instance(cls: Type[FeatureSetPipeline]) -> FeatureSetPipeline: - sig = inspect.signature(cls.__init__) - parameters = sig.parameters - - if "run_date" in parameters: - run_date = datetime.datetime.today().strftime("%Y-%m-%d") - return cls(run_date) - - return cls() - logger.info("Creating instances...") - return set(create_instance(value) for value in instances) # type: ignore + return set(value() for value in instances) # type: ignore PATH = typer.Argument( diff --git a/butterfree/extract/source.py b/butterfree/extract/source.py index 9d50e94c..bfc15271 100644 --- a/butterfree/extract/source.py +++ b/butterfree/extract/source.py @@ -3,7 +3,6 @@ from typing import List, Optional from pyspark.sql import DataFrame -from pyspark.storagelevel import StorageLevel from butterfree.clients import SparkClient from butterfree.extract.readers.reader import Reader @@ -96,21 +95,16 @@ def construct( DataFrame with the query result against all readers. """ - # Step 1: Build temporary views for each reader for reader in self.readers: - reader.build(client=client, start_date=start_date, end_date=end_date) + reader.build( + client=client, start_date=start_date, end_date=end_date + ) # create temporary views for each reader - # Step 2: Execute SQL query on the combined readers dataframe = client.sql(self.query) - # Step 3: Cache the dataframe if necessary, using memory and disk storage if not dataframe.isStreaming and self.eager_evaluation: - # Persist to ensure the DataFrame is stored in mem and disk (if necessary) - dataframe.persist(StorageLevel.MEMORY_AND_DISK) - # Trigger the cache/persist operation by performing an action - dataframe.count() + dataframe.cache().count() - # Step 4: Run post-processing hooks on the dataframe post_hook_df = self.run_post_hooks(dataframe) return post_hook_df diff --git a/butterfree/migrations/database_migration/cassandra_migration.py b/butterfree/migrations/database_migration/cassandra_migration.py index 4d50746c..5a4f755f 100644 --- a/butterfree/migrations/database_migration/cassandra_migration.py +++ b/butterfree/migrations/database_migration/cassandra_migration.py @@ -78,9 +78,6 @@ def _get_alter_table_add_query(self, columns: List[Diff], table_name: str) -> st def _get_alter_column_type_query(self, column: Diff, table_name: str) -> str: """Creates CQL statement to alter columns' types. - In Cassandra 3.4.x to 3.11.x alter type is not allowed. - This method creates a temp column to comply. - Args: columns: list of Diff objects with ALTER_TYPE kind. table_name: table name. @@ -89,23 +86,10 @@ def _get_alter_column_type_query(self, column: Diff, table_name: str) -> str: Alter column type query. """ - temp_column_name = f"{column.column}_temp" - - add_temp_column_query = ( - f"ALTER TABLE {table_name} ADD {temp_column_name} {column.value};" - ) - copy_data_to_temp_query = ( - f"UPDATE {table_name} SET {temp_column_name} = {column.column};" - ) - - drop_old_column_query = f"ALTER TABLE {table_name} DROP {column.column};" - rename_temp_column_query = ( - f"ALTER TABLE {table_name} RENAME {temp_column_name} TO {column.column};" - ) + parsed_columns = self._get_parsed_columns([column]) return ( - f"{add_temp_column_query} {copy_data_to_temp_query} " - f"{drop_old_column_query} {rename_temp_column_query};" + f"ALTER TABLE {table_name} ALTER {parsed_columns.replace(' ', ' TYPE ')};" ) @staticmethod diff --git a/butterfree/pipelines/feature_set_pipeline.py b/butterfree/pipelines/feature_set_pipeline.py index f1c94ec2..8ba1a636 100644 --- a/butterfree/pipelines/feature_set_pipeline.py +++ b/butterfree/pipelines/feature_set_pipeline.py @@ -2,8 +2,6 @@ from typing import List, Optional -from pyspark.storagelevel import StorageLevel - from butterfree.clients import SparkClient from butterfree.dataframe_service import repartition_sort_df from butterfree.extract import Source @@ -211,22 +209,19 @@ def run( soon. Use only if strictly necessary. """ - # Step 1: Construct input dataframe from the source. dataframe = self.source.construct( client=self.spark_client, start_date=self.feature_set.define_start_date(start_date), end_date=end_date, ) - # Step 2: Repartition and sort if required, avoid if not necessary. if partition_by: order_by = order_by or partition_by dataframe = repartition_sort_df( dataframe, partition_by, order_by, num_processors ) - # Step 3: Construct the feature set dataframe using defined transformations. - transformed_dataframe = self.feature_set.construct( + dataframe = self.feature_set.construct( dataframe=dataframe, client=self.spark_client, start_date=start_date, @@ -234,22 +229,15 @@ def run( num_processors=num_processors, ) - if transformed_dataframe.storageLevel != StorageLevel( - False, False, False, False, 1 - ): - dataframe.unpersist() # Clear the data from the cache (disk and memory) - - # Step 4: Load the data into the configured sink. self.sink.flush( - dataframe=transformed_dataframe, + dataframe=dataframe, feature_set=self.feature_set, spark_client=self.spark_client, ) - # Step 5: Validate the output if not streaming and data volume is reasonable. - if not transformed_dataframe.isStreaming: + if not dataframe.isStreaming: self.sink.validate( - dataframe=transformed_dataframe, + dataframe=dataframe, feature_set=self.feature_set, spark_client=self.spark_client, ) diff --git a/butterfree/transform/aggregated_feature_set.py b/butterfree/transform/aggregated_feature_set.py index 1230cd4d..6706bf8c 100644 --- a/butterfree/transform/aggregated_feature_set.py +++ b/butterfree/transform/aggregated_feature_set.py @@ -387,7 +387,6 @@ def _aggregate( ] groupby = self.keys_columns.copy() - if window is not None: dataframe = dataframe.withColumn("window", window.get()) groupby.append("window") @@ -411,23 +410,19 @@ def _aggregate( "keep_rn", functions.row_number().over(partition_window) ).filter("keep_rn = 1") - current_partitions = dataframe.rdd.getNumPartitions() - optimal_partitions = num_processors or current_partitions - - if current_partitions != optimal_partitions: - dataframe = repartition_df( - dataframe, - partition_by=groupby, - num_processors=optimal_partitions, - ) - + # repartition to have all rows for each group at the same partition + # by doing that, we won't have to shuffle data on grouping by id + dataframe = repartition_df( + dataframe, + partition_by=groupby, + num_processors=num_processors, + ) grouped_data = dataframe.groupby(*groupby) - if self._pivot_column and self._pivot_values: + if self._pivot_column: grouped_data = grouped_data.pivot(self._pivot_column, self._pivot_values) aggregated = grouped_data.agg(*aggregations) - return self._with_renamed_columns(aggregated, features, window) def _with_renamed_columns( @@ -576,12 +571,14 @@ def construct( pre_hook_df = self.run_pre_hooks(dataframe) - output_df = pre_hook_df - for feature in self.keys + [self.timestamp]: - output_df = feature.transform(output_df) + output_df = reduce( + lambda df, feature: feature.transform(df), + self.keys + [self.timestamp], + pre_hook_df, + ) if self._windows and end_date is not None: - # Run aggregations for each window + # run aggregations for each window agg_list = [ self._aggregate( dataframe=output_df, @@ -601,12 +598,13 @@ def construct( # keeping this logic to maintain the same behavior for already implemented # feature sets + if self._windows[0].slide == "1 day": base_df = self._get_base_dataframe( client=client, dataframe=output_df, end_date=end_date ) - # Left join each aggregation result to our base dataframe + # left join each aggregation result to our base dataframe output_df = reduce( lambda left, right: self._dataframe_join( left, @@ -639,18 +637,12 @@ def construct( output_df = output_df.select(*self.columns).replace( # type: ignore float("nan"), None ) - - if not output_df.isStreaming and self.deduplicate_rows: - output_df = self._filter_duplicated_rows(output_df) + if not output_df.isStreaming: + if self.deduplicate_rows: + output_df = self._filter_duplicated_rows(output_df) + if self.eager_evaluation: + output_df.cache().count() post_hook_df = self.run_post_hooks(output_df) - # Eager evaluation, only if needed and managable - if not output_df.isStreaming and self.eager_evaluation: - # Small dataframes only - if output_df.count() < 1_000_000: - post_hook_df.cache().count() - else: - post_hook_df.cache() # Cache without materialization for large volumes - return post_hook_df diff --git a/butterfree/transform/feature_set.py b/butterfree/transform/feature_set.py index 2c4b9b51..369eaf29 100644 --- a/butterfree/transform/feature_set.py +++ b/butterfree/transform/feature_set.py @@ -436,8 +436,11 @@ def construct( pre_hook_df, ).select(*self.columns) - if not output_df.isStreaming and self.deduplicate_rows: - output_df = self._filter_duplicated_rows(output_df) + if not output_df.isStreaming: + if self.deduplicate_rows: + output_df = self._filter_duplicated_rows(output_df) + if self.eager_evaluation: + output_df.cache().count() output_df = self.incremental_strategy.filter_with_incremental_strategy( dataframe=output_df, start_date=start_date, end_date=end_date diff --git a/tests/unit/butterfree/transform/test_feature_set.py b/tests/unit/butterfree/transform/test_feature_set.py index 37a69be2..e907dc0a 100644 --- a/tests/unit/butterfree/transform/test_feature_set.py +++ b/tests/unit/butterfree/transform/test_feature_set.py @@ -220,7 +220,7 @@ def test_construct( + feature_divide.get_output_columns() ) assert_dataframe_equality(result_df, feature_set_dataframe) - assert not result_df.is_cached + assert result_df.is_cached def test_construct_invalid_df( self, key_id, timestamp_c, feature_add, feature_divide