Skip to content

Commit

Permalink
refactored test code data format to allow for better separation of DF…
Browse files Browse the repository at this point in the history
… creation from TSDF constructor args
  • Loading branch information
tnixon committed May 15, 2024
1 parent 8a7eb5e commit 7469a50
Show file tree
Hide file tree
Showing 3 changed files with 497 additions and 297 deletions.
52 changes: 27 additions & 25 deletions python/tests/as_of_join_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ def test_asof_join(self):
"""AS-OF Join with out a time-partition test"""

# Construct dataframes
tsdf_left = self.get_data_as_tsdf("left")
tsdf_right = self.get_data_as_tsdf("right")
dfExpected = self.get_data_as_sdf("expected")
noRightPrefixdfExpected = self.get_data_as_sdf("expected_no_right_prefix")
tsdf_left = self.get_test_df_builder("left").as_tsdf()
tsdf_right = self.get_test_df_builder("right").as_tsdf()
dfExpected = self.get_test_df_builder("expected").as_sdf()
noRightPrefixdfExpected = self.get_test_df_builder("expected_no_right_prefix").as_sdf()

# perform the join
joined_df = tsdf_left.asofJoin(
Expand All @@ -35,12 +35,12 @@ def test_asof_join_skip_nulls_disabled(self):
"""AS-OF Join with skip nulls disabled"""

# fetch test data
tsdf_left = self.get_data_as_tsdf("left")
tsdf_right = self.get_data_as_tsdf("right")
dfExpectedSkipNulls = self.get_data_as_sdf("expected_skip_nulls")
dfExpectedSkipNullsDisabled = self.get_data_as_sdf(
tsdf_left = self.get_test_df_builder("left").as_tsdf()
tsdf_right = self.get_test_df_builder("right").as_tsdf()
dfExpectedSkipNulls = self.get_test_df_builder("expected_skip_nulls").as_sdf()
dfExpectedSkipNullsDisabled = self.get_test_df_builder(
"expected_skip_nulls_disabled"
)
).as_sdf()

# perform the join with skip nulls enabled (default)
joined_df = tsdf_left.asofJoin(
Expand All @@ -62,9 +62,9 @@ def test_sequence_number_sort(self):
"""Skew AS-OF Join with Partition Window Test"""

# fetch test data
tsdf_left = self.get_data_as_tsdf("left")
tsdf_right = self.get_data_as_tsdf("right")
dfExpected = self.get_data_as_sdf("expected")
tsdf_left = self.get_test_df_builder("left").as_tsdf()
tsdf_right = self.get_test_df_builder("right").as_tsdf()
dfExpected = self.get_test_df_builder("expected").as_sdf()

# perform the join
joined_df = tsdf_left.asofJoin(tsdf_right, right_prefix="right").df
Expand All @@ -76,9 +76,9 @@ def test_partitioned_asof_join(self):
"""AS-OF Join with a time-partition"""
with self.assertLogs(level="WARNING") as warning_captured:
# fetch test data
tsdf_left = self.get_data_as_tsdf("left")
tsdf_right = self.get_data_as_tsdf("right")
dfExpected = self.get_data_as_sdf("expected")
tsdf_left = self.get_test_df_builder("left").as_tsdf()
tsdf_right = self.get_test_df_builder("right").as_tsdf()
dfExpected = self.get_test_df_builder("expected").as_sdf()

joined_df = tsdf_left.asofJoin(
tsdf_right,
Expand All @@ -103,24 +103,26 @@ def test_asof_join_nanos(self):
"""As of join with nanosecond timestamps"""

# fetch test data
tsdf_left = self.get_data_as_tsdf("left")
tsdf_right = self.get_data_as_tsdf("right")
dfExpected = self.get_data_as_sdf("expected")
tsdf_left = self.get_test_df_builder("left").as_tsdf()
tsdf_right = self.get_test_df_builder("right").as_tsdf()
dfExpected = self.get_test_df_builder("expected").as_sdf()

# perform join
joined_df = tsdf_left.asofJoin(
tsdf_right, left_prefix="left", right_prefix="right"
).df

joined_df.show()

# compare
self.assertDataFrameEquality(joined_df, dfExpected)

def test_asof_join_tolerance(self):
"""As of join with tolerance band"""

# fetch test data
tsdf_left = self.get_data_as_tsdf("left")
tsdf_right = self.get_data_as_tsdf("right")
tsdf_left = self.get_test_df_builder("left").as_tsdf()
tsdf_right = self.get_test_df_builder("right").as_tsdf()

tolerance_test_values = [None, 0, 5.5, 7, 10]
for tolerance in tolerance_test_values:
Expand All @@ -133,17 +135,17 @@ def test_asof_join_tolerance(self):
).df

# compare
expected_tolerance = self.get_data_as_sdf(f"expected_tolerance_{tolerance}")
expected_tolerance = self.get_test_df_builder(f"expected_tolerance_{tolerance}").as_sdf()
self.assertDataFrameEquality(joined_df, expected_tolerance)

def test_asof_join_sql_join_opt_and_bytes_threshold(self):
"""AS-OF Join with out a time-partition test"""
with patch("tempo.tsdf.TSDF._TSDF__getBytesFromPlan", return_value=1000):
# Construct dataframes
tsdf_left = self.get_data_as_tsdf("left")
tsdf_right = self.get_data_as_tsdf("right")
dfExpected = self.get_data_as_sdf("expected")
noRightPrefixdfExpected = self.get_data_as_sdf("expected_no_right_prefix")
tsdf_left = self.get_test_df_builder("left").as_tsdf()
tsdf_right = self.get_test_df_builder("right").as_tsdf()
dfExpected = self.get_test_df_builder("expected").as_sdf()
noRightPrefixdfExpected = self.get_test_df_builder("expected_no_right_prefix").as_sdf()

# perform the join
joined_df = tsdf_left.asofJoin(
Expand Down
219 changes: 161 additions & 58 deletions python/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
import unittest
import warnings
from typing import Union
from typing import Union, Optional

import jsonref
import pyspark.sql.functions as sfn
Expand All @@ -14,6 +14,132 @@
from tempo.tsdf import TSDF


class TestDataFrameBuilder:
"""
A class to hold metadata about a Spark DataFrame
"""

def __init__(self, spark: SparkSession, test_data: dict):
"""
:param spark: the SparkSession to use
:param test_data: a dictionary containing the test data & metadata
"""
self.spark = spark
self.__test_data = test_data

# Spark DataFrame metadata

@property
def df(self) -> dict:
"""
:return: the DataFrame component of the test data
"""
return self.__test_data["df"]

@property
def df_schema(self) -> str:
"""
:return: the schema component of the test data
"""
return self.df["schema"]

def df_data(self) -> list:
"""
:return: the data component of the test data
"""
return self.df["data"]

# TSDF metadata

@property
def tsdf_constructor(self) -> Optional[str]:
"""
:return: the name of the TSDF constructor to use
"""
return self.__test_data.get("tsdf_constructor", None)

@property
def tsdf(self) -> dict:
"""
:return: the timestamp index metadata component of the test data
"""
return self.__test_data["tsdf"]

@property
def ts_schema(self) -> Optional[dict]:
"""
:return: the timestamp index schema component of the test data
"""
return self.tsdf.get("ts_schema", None)

@property
def ts_idx_class(self) -> str:
"""
:return: the timestamp index class component of the test data
"""
return self.ts_schema["ts_idx_class"]

@property
def ts_col(self) -> str:
"""
:return: the timestamp column component of the test data
"""
return self.ts_schema["ts_col"]

@property
def ts_idx(self) -> dict:
"""
:return: the timestamp index data component of the test data
"""
return self.ts_schema["ts_idx"]

# Builder functions

def as_sdf(self) -> DataFrame:
"""
Constructs a Spark Dataframe from the test data
"""
# build dataframe
df = self.spark.createDataFrame(self.df_data(), self.df_schema)

# convert timestamp columns
if "ts_convert" in self.df:
for ts_col in self.df["ts_convert"]:
# handle nested columns
if "." in ts_col:
col, field = ts_col.split(".")
convert_field_expr = sfn.to_timestamp(sfn.col(col).getField(field))
df = df.withColumn(
col, sfn.col(col).withField(field, convert_field_expr)
)
else:
df = df.withColumn(ts_col, sfn.to_timestamp(ts_col))
# convert date columns
if "date_convert" in self.df:
for date_col in self.df["date_convert"]:
# handle nested columns
if "." in date_col:
col, field = date_col.split(".")
convert_field_expr = sfn.to_timestamp(sfn.col(col).getField(field))
df = df.withColumn(
col, sfn.col(col).withField(field, convert_field_expr)
)
else:
df = df.withColumn(date_col, sfn.to_date(date_col))

return df

def as_tsdf(self) -> TSDF:
"""
Constructs a TSDF from the test data
"""
sdf = self.as_sdf()
if self.tsdf_constructor is not None:
return getattr(TSDF, self.tsdf_constructor)(sdf, **self.tsdf)
else:
return TSDF(sdf, **self.tsdf)


class SparkTest(unittest.TestCase):
#
# Fixtures
Expand Down Expand Up @@ -68,24 +194,24 @@ def tearDown(self) -> None:
# Utility Functions
#

def get_data_as_sdf(self, name: str, convert_ts_col=True):
td = self.test_data[name]
ts_cols = []
if convert_ts_col and (td.get("ts_col", None) or td.get("other_ts_cols", [])):
ts_cols = [td["ts_col"]] if "ts_col" in td else []
ts_cols.extend(td.get("other_ts_cols", []))
return self.buildTestDF(td["schema"], td["data"], ts_cols)

def get_data_as_tsdf(self, name: str, convert_ts_col=True):
df = self.get_data_as_sdf(name, convert_ts_col)
td = self.test_data[name]
tsdf = TSDF(
df,
ts_col=td["ts_col"],
partition_cols=td.get("partition_cols", None),
sequence_col=td.get("sequence_col", None),
)
return tsdf
# def get_data_as_sdf(self, name: str, convert_ts_col=True):
# td = self.test_data[name]
# ts_cols = []
# if convert_ts_col and (td.get("ts_col", None) or td.get("other_ts_cols", [])):
# ts_cols = [td["ts_col"]] if "ts_col" in td else []
# ts_cols.extend(td.get("other_ts_cols", []))
# return self.buildTestDF(td["schema"], td["data"], ts_cols)
#
# def get_data_as_tsdf(self, name: str, convert_ts_col=True):
# df = self.get_data_as_sdf(name, convert_ts_col)
# td = self.test_data[name]
# tsdf = TSDF(
# df,
# ts_col=td["ts_col"],
# partition_cols=td.get("partition_cols", None),
# sequence_col=td.get("sequence_col", None),
# )
# return tsdf

def get_data_as_idf(self, name: str, convert_ts_col=True):
df = self.get_data_as_sdf(name, convert_ts_col)
Expand All @@ -112,7 +238,8 @@ def __getTestDataFilePath(self, test_file_name: str) -> str:
dir_path = "./tests"
elif cwd != "tests":
raise RuntimeError(
f"Cannot locate test data file {test_file_name}, running from dir {os.getcwd()}"
f"Cannot locate test data file {test_file_name}, running from dir"
f" {os.getcwd()}"
)

# return appropriate path
Expand All @@ -136,40 +263,11 @@ def __loadTestData(self, test_case_path: str) -> dict:
# proces the data file
with open(test_data_file, "r") as f:
data_metadata_from_json = jsonref.load(f)
# warn if data not present
if class_name not in data_metadata_from_json:
warnings.warn(f"Could not load test data for {file_name}.{class_name}")
return {}
if func_name not in data_metadata_from_json[class_name]:
warnings.warn(
f"Could not load test data for {file_name}.{class_name}.{func_name}"
)
return {}
# return the data
return data_metadata_from_json[class_name][func_name]

def buildTestDF(self, schema, data, ts_cols=["event_ts"]):
"""
Constructs a Spark Dataframe from the given components
:param schema: the schema to use for the Dataframe
:param data: values to use for the Dataframe
:param ts_cols: list of column names to be converted to Timestamp values
:return: a Spark Dataframe, constructed from the given schema and values
"""
# build dataframe
df = self.spark.createDataFrame(data, schema)

# check if ts_col follows standard timestamp format, then check if timestamp has micro/nanoseconds
for tsc in ts_cols:
ts_value = str(df.select(ts_cols).limit(1).collect()[0][0])
ts_pattern = r"^\d{4}-\d{2}-\d{2}| \d{2}:\d{2}:\d{2}\.\d*$"
decimal_pattern = r"[.]\d+"
if re.match(ts_pattern, str(ts_value)) is not None:
if (
re.search(decimal_pattern, ts_value) is None
or len(re.search(decimal_pattern, ts_value)[0]) <= 4
):
df = df.withColumn(tsc, sfn.to_timestamp(sfn.col(tsc)))
return df
def get_test_df_builder(self, name: str) -> TestDataFrameBuilder:
return TestDataFrameBuilder(self.spark, self.test_data[name])

#
# Assertion Functions
Expand Down Expand Up @@ -201,12 +299,10 @@ def assertSchemaContainsField(self, schema, field):
# the attributes of the fields must be equal
self.assertFieldsEqual(field, schema[field.name])

@staticmethod
def assertDataFrameEquality(
df1: Union[IntervalsDF, TSDF, DataFrame],
df2: Union[IntervalsDF, TSDF, DataFrame],
from_tsdf: bool = False,
from_idf: bool = False,
self,
df1: Union[TSDF, DataFrame],
df2: Union[TSDF, DataFrame],
ignore_row_order: bool = False,
ignore_column_order: bool = True,
ignore_nullable: bool = True,
Expand All @@ -216,10 +312,17 @@ def assertDataFrameEquality(
That is, they have equivalent schemas, and both contain the same values
"""

if from_tsdf or from_idf:
# handle TSDFs
if isinstance(df1, TSDF):
# df2 must also be a TSDF
self.assertIsInstance(df2, TSDF)
# should have the same schemas
self.assertEqual(df1.ts_schema, df2.ts_schema)
# get the underlying Spark DataFrames
df1 = df1.df
df2 = df2.df

# handle DataFrames
assert_df_equality(
df1,
df2,
Expand Down
Loading

0 comments on commit 7469a50

Please sign in to comment.