diff --git a/python/tempo/ml.py b/python/tempo/ml.py index 1c9f14ac..3db858af 100644 --- a/python/tempo/ml.py +++ b/python/tempo/ml.py @@ -12,32 +12,35 @@ TMP_SPLIT_COL = "__tmp_split_col" TMP_GAP_COL = "__tmp_gap_row" + class TimeSeriesCrossValidator(CrossValidator): # some additional parameters timeSeriesCol: Param[str] = Param( Params._dummy(), "timeSeriesCol", "The name of the time series column", - typeConverter=TypeConverters.toString + typeConverter=TypeConverters.toString, ) seriesIdCols: Param[List[str]] = Param( Params._dummy(), "seriesIdCols", "The name of the series id columns", - typeConverter=TypeConverters.toListString + typeConverter=TypeConverters.toListString, ) gap: Param[int] = Param( Params._dummy(), "gap", "The gap between training and test set", - typeConverter=TypeConverters.toInt + typeConverter=TypeConverters.toInt, ) - def __init__(self, - timeSeriesCol: str = "event_ts", - seriesIdCols: List[str] = [], - gap: int = 0, - **other_kwargs) -> None: + def __init__( + self, + timeSeriesCol: str = "event_ts", + seriesIdCols: List[str] = [], + gap: int = 0, + **other_kwargs + ) -> None: super(TimeSeriesCrossValidator, self).__init__(**other_kwargs) self._setDefault(timeSeriesCol="event_ts", seriesIdCols=[], gap=0) self._set(timeSeriesCol=timeSeriesCol, seriesIdCols=seriesIdCols, gap=gap) @@ -72,19 +75,24 @@ def _get_split_win(self, desc: bool = False) -> WindowSpec: def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]: nFolds = self.getOrDefault(self.numFolds) - nSplits = nFolds+1 + nSplits = nFolds + 1 # split the data into nSplits subsets by timeseries order - split_df = dataset.withColumn(TMP_SPLIT_COL, - sfn.ntile(nSplits).over(self._get_split_win())) - all_splits = [split_df.filter(sfn.col(TMP_SPLIT_COL) == i).drop(TMP_SPLIT_COL) - for i in range(1, nSplits+1)] + split_df = dataset.withColumn( + TMP_SPLIT_COL, sfn.ntile(nSplits).over(self._get_split_win()) + ) + all_splits = [ + split_df.filter(sfn.col(TMP_SPLIT_COL) == i).drop(TMP_SPLIT_COL) + for i in range(1, nSplits + 1) + ] assert len(all_splits) == nSplits # compose the k folds by including all previous splits in the training set, # and the next split in the test set - kFolds = [(reduce(lambda a, b: a.union(b), all_splits[:i+1]), all_splits[i+1]) - for i in range(nFolds)] + kFolds = [ + (reduce(lambda a, b: a.union(b), all_splits[: i + 1]), all_splits[i + 1]) + for i in range(nFolds) + ] assert len(kFolds) == nFolds for tv in kFolds: assert len(tv) == 2 @@ -94,13 +102,21 @@ def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]: if gap > 0: order_cols = self.getSeriesIdCols() + [self.getTimeSeriesCol()] # trim each training dataset by the specified gap - kFolds = [((train_df.withColumn(TMP_GAP_COL, - sfn.row_number().over(self._get_split_win(desc=True))) + kFolds = [ + ( + ( + train_df.withColumn( + TMP_GAP_COL, + sfn.row_number().over(self._get_split_win(desc=True)), + ) .where(sfn.col(TMP_GAP_COL) > gap) .drop(TMP_GAP_COL) - .orderBy(*order_cols)), - test_df) - for (train_df, test_df) in kFolds] + .orderBy(*order_cols) + ), + test_df, + ) + for (train_df, test_df) in kFolds + ] # return the k folds (training, test) datasets return kFolds