Skip to content

Commit

Permalink
Optimize Aggregated Feature Sets with Repartition (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
Igor Gustavo Hoelscher committed May 6, 2020
1 parent 8ec83fe commit 77db644
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 37 deletions.
15 changes: 10 additions & 5 deletions butterfree/core/dataframe_service/repartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,24 @@ def _num_partitions_definition(num_processors, num_partitions):


def repartition_df(
dataframe: DataFrame, partition_by: List[str], num_partitions: int = None
dataframe: DataFrame,
partition_by: List[str],
num_partitions: int = None,
num_processors: int = None,
):
"""Partition the DataFrame.
Args:
dataframe: Spark DataFrame.
num_partitions: number of partitions.
partition_by: list of partitions.
num_processors: number of processors.
num_partitions: number of partitions.
Returns:
Partitioned dataframe.
"""
num_partitions = num_partitions or DEFAULT_NUM_PARTITIONS
num_partitions = _num_partitions_definition(num_processors, num_partitions)
return dataframe.repartition(num_partitions, *partition_by)


Expand All @@ -48,9 +52,10 @@ def repartition_sort_df(
Args:
dataframe: Spark DataFrame.
num_partitions: number of partitions.
partition_by: list of partitions.
partition_by: list of columns to partition by.
order_by: list of columns to order by.
num_processors: number of processors.
num_partitions: number of partitions.
Returns:
Partitioned and sorted dataframe.
Expand Down
51 changes: 19 additions & 32 deletions butterfree/core/transform/aggregated_feature_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pyspark.sql import DataFrame, functions

from butterfree.core.clients import SparkClient
from butterfree.core.dataframe_service import repartition_sort_df
from butterfree.core.dataframe_service import repartition_df
from butterfree.core.transform import FeatureSet
from butterfree.core.transform.features import Feature, KeyFeature, TimestampFeature
from butterfree.core.transform.transformations import AggregatedTransform
Expand Down Expand Up @@ -337,10 +337,13 @@ def _get_base_dataframe(self, client, dataframe, end_date):
return unique_keys.crossJoin(date_df)

@staticmethod
def _dataframe_join(left, right, on, how):
def _dataframe_join(left, right, on, how, num_processors=None):
# make both tables co-partitioned to improve join performance
left = repartition_df(left, partition_by=on, num_processors=num_processors)
right = repartition_df(right, partition_by=on, num_processors=num_processors)
return left.join(right, on=on, how=how)

def _aggregate(self, dataframe, features, window=None):
def _aggregate(self, dataframe, features, window=None, num_processors=None):
aggregations = list(
itertools.chain.from_iterable(
[f.transformation.aggregations for f in features]
Expand Down Expand Up @@ -371,6 +374,11 @@ def _aggregate(self, dataframe, features, window=None):
"keep_rn", functions.row_number().over(partition_window)
).filter("keep_rn = 1")

# repartition to have all rows for each group at the same partition
# by doing that, we won't have to shuffle data on grouping by id
dataframe = repartition_df(
dataframe, partition_by=groupby, num_processors=num_processors,
)
grouped_data = dataframe.groupby(*groupby)

if self._pivot_column:
Expand Down Expand Up @@ -482,33 +490,21 @@ def construct(
dataframe,
)

# repartition to have all ids at the same partition
# by doing that, we won't have to shuffle data on grouping by id
output_df = repartition_sort_df(
output_df,
partition_by=self.keys_columns,
order_by=[self.timestamp_column],
num_processors=num_processors,
)

if self._windows:
# prepare our left table, a cartesian product between distinct keys
# and dates in range for this feature set
# make the left table co-partitioned with the aggregations' result
# improving our upcoming joins
base_df = self._get_base_dataframe(
client=client, dataframe=output_df, end_date=end_date
)
base_df = repartition_sort_df(
base_df,
partition_by=self.keys_columns,
order_by=[self.timestamp_column],
num_processors=num_processors,
)

# run aggregations for each window
agg_list = [
self._aggregate(dataframe=output_df, features=self.features, window=w)
self._aggregate(
dataframe=output_df,
features=self.features,
window=w,
num_processors=num_processors,
)
for w in self._windows
]

Expand All @@ -519,24 +515,15 @@ def construct(
right,
on=self.keys_columns + [self.timestamp_column],
how="left",
num_processors=num_processors,
),
agg_list,
base_df,
)
else:
output_df = self._aggregate(output_df, features=self.features)

output_df = (
repartition_sort_df(
output_df,
partition_by=self.keys_columns,
order_by=[self.timestamp_column],
num_processors=num_processors,
)
.select(*self.columns)
.replace(float("nan"), None)
)

output_df = output_df.select(*self.columns).replace(float("nan"), None)
if not output_df.isStreaming:
output_df = self._filter_duplicated_rows(output_df)
output_df.cache().count()
Expand Down

0 comments on commit 77db644

Please sign in to comment.