Skip to content

Commit

Permalink
applying black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
tnixon committed Oct 31, 2024
1 parent e17ffce commit 11f51cd
Showing 1 changed file with 36 additions and 20 deletions.
56 changes: 36 additions & 20 deletions python/tempo/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,35 @@
TMP_SPLIT_COL = "__tmp_split_col"
TMP_GAP_COL = "__tmp_gap_row"


class TimeSeriesCrossValidator(CrossValidator):
# some additional parameters
timeSeriesCol: Param[str] = Param(
Params._dummy(),
"timeSeriesCol",
"The name of the time series column",
typeConverter=TypeConverters.toString
typeConverter=TypeConverters.toString,
)
seriesIdCols: Param[List[str]] = Param(
Params._dummy(),
"seriesIdCols",
"The name of the series id columns",
typeConverter=TypeConverters.toListString
typeConverter=TypeConverters.toListString,
)
gap: Param[int] = Param(
Params._dummy(),
"gap",
"The gap between training and test set",
typeConverter=TypeConverters.toInt
typeConverter=TypeConverters.toInt,
)

def __init__(self,
timeSeriesCol: str = "event_ts",
seriesIdCols: List[str] = [],
gap: int = 0,
**other_kwargs) -> None:
def __init__(
self,
timeSeriesCol: str = "event_ts",
seriesIdCols: List[str] = [],
gap: int = 0,
**other_kwargs
) -> None:
super(TimeSeriesCrossValidator, self).__init__(**other_kwargs)
self._setDefault(timeSeriesCol="event_ts", seriesIdCols=[], gap=0)
self._set(timeSeriesCol=timeSeriesCol, seriesIdCols=seriesIdCols, gap=gap)
Expand Down Expand Up @@ -72,19 +75,24 @@ def _get_split_win(self, desc: bool = False) -> WindowSpec:

def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]:
nFolds = self.getOrDefault(self.numFolds)
nSplits = nFolds+1
nSplits = nFolds + 1

# split the data into nSplits subsets by timeseries order
split_df = dataset.withColumn(TMP_SPLIT_COL,
sfn.ntile(nSplits).over(self._get_split_win()))
all_splits = [split_df.filter(sfn.col(TMP_SPLIT_COL) == i).drop(TMP_SPLIT_COL)
for i in range(1, nSplits+1)]
split_df = dataset.withColumn(
TMP_SPLIT_COL, sfn.ntile(nSplits).over(self._get_split_win())
)
all_splits = [
split_df.filter(sfn.col(TMP_SPLIT_COL) == i).drop(TMP_SPLIT_COL)
for i in range(1, nSplits + 1)
]
assert len(all_splits) == nSplits

# compose the k folds by including all previous splits in the training set,
# and the next split in the test set
kFolds = [(reduce(lambda a, b: a.union(b), all_splits[:i+1]), all_splits[i+1])
for i in range(nFolds)]
kFolds = [
(reduce(lambda a, b: a.union(b), all_splits[: i + 1]), all_splits[i + 1])
for i in range(nFolds)
]
assert len(kFolds) == nFolds
for tv in kFolds:
assert len(tv) == 2
Expand All @@ -94,13 +102,21 @@ def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]:
if gap > 0:
order_cols = self.getSeriesIdCols() + [self.getTimeSeriesCol()]
# trim each training dataset by the specified gap
kFolds = [((train_df.withColumn(TMP_GAP_COL,
sfn.row_number().over(self._get_split_win(desc=True)))
kFolds = [
(
(
train_df.withColumn(
TMP_GAP_COL,
sfn.row_number().over(self._get_split_win(desc=True)),
)
.where(sfn.col(TMP_GAP_COL) > gap)
.drop(TMP_GAP_COL)
.orderBy(*order_cols)),
test_df)
for (train_df, test_df) in kFolds]
.orderBy(*order_cols)
),
test_df,
)
for (train_df, test_df) in kFolds
]

# return the k folds (training, test) datasets
return kFolds

0 comments on commit 11f51cd

Please sign in to comment.