Skip to content

Commit

Permalink
fix: unable to split material name for quoted input (#358)
Browse files Browse the repository at this point in the history
  • Loading branch information
shekhar-rudder authored Jul 5, 2024
1 parent 925ef43 commit 4a52fac
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 72 deletions.
14 changes: 7 additions & 7 deletions src/predictions/profiles_mlcorelib/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,14 +1010,14 @@ def plot_user_feature_importance(


def replace_seq_no_in_query(query: str, seq_no: int) -> str:
match = re.search(r"(_\d+)(`|$)", query)
if match:
replaced_query = (
query[: match.start(1)] + "_" + str(seq_no) + query[match.end(1) :]
)
return replaced_query
else:
matches = list(re.finditer(r"(_\d+)(`|'|\"|$)", query))
if len(matches) == 0:
raise Exception(f"Couldn't find an integer seq_no in the input query: {query}")
match = matches[-1]
if match is None:
raise Exception(f"Couldn't find an integer seq_no in the input query: {query}")
replaced_query = query[: match.start(1)] + "_" + str(seq_no) + query[match.end(1) :]
return replaced_query


def extract_seq_no_from_select_query(select_query: str) -> int:
Expand Down
1 change: 1 addition & 0 deletions src/predictions/profiles_mlcorelib/wht/pythonWHT.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def split_material_name(self, name: str) -> dict:
else:
table_name = mlower.split()[-1]
table_suffix = table_name.split(MATERIAL_PREFIX.lower())[-1]
table_suffix = table_suffix.strip("\"'")
split_parts = table_suffix.split("_")
try:
seq_no = int(split_parts[-1])
Expand Down
107 changes: 42 additions & 65 deletions tests/unit/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,32 @@

class TestReplaceSeqNoInQuery(unittest.TestCase):
def test_replaces_seq_no_correctly(self):
query = "SELECT * FROM material_user_var_table_123"
seq_no = 567
expected_result = "SELECT * FROM material_user_var_table_567"
actual_result = replace_seq_no_in_query(query, seq_no)
self.assertEqual(expected_result, actual_result)
test_cases = [
{
"input": "SELECT * FROM material_user_var_table_123",
"output": "SELECT * FROM material_user_var_table_567",
},
{
"input": "SELECT * FROM `schema`.`material_user_var_table_123`",
"output": "SELECT * FROM `schema`.`material_user_var_table_567`",
},
{
"input": '''SELECT days_since_last_seen FROM "rs_profiles_3"."material_user_var_table_54ddc22a_383"''',
"output": '''SELECT days_since_last_seen FROM "rs_profiles_3"."material_user_var_table_54ddc22a_567"''',
},
{
"input": "SELECT days_since_last_seen FROM 'rs_profiles_3'.'material_user_var_table_54ddc22a_383'",
"output": "SELECT days_since_last_seen FROM 'rs_profiles_3'.'material_user_var_table_54ddc22a_567'",
},
{
"input": 'SELECT days_since_last_seen FROM "rs_profiles_3"."material_user_var_table_54ddc22a_383"',
"output": 'SELECT days_since_last_seen FROM "rs_profiles_3"."material_user_var_table_54ddc22a_567"',
},
]
for case in test_cases:
with self.subTest(case=case["input"]):
result = replace_seq_no_in_query(case["input"], 567)
self.assertEqual(result, case["output"])

def test_handles_empty_query(self):
query = ""
Expand All @@ -42,34 +63,30 @@ def test_handles_missing_seq_no(self):
expected_result,
)

def test_replaces_seq_no_correctly_with_bigquery_input(self):
query = "SELECT * FROM `schema`.`material_user_var_table_123`"
seq_no = 567
expected_result = "SELECT * FROM `schema`.`material_user_var_table_567`"
actual_result = replace_seq_no_in_query(query, seq_no)
self.assertEqual(expected_result, actual_result)


class TestSplitMaterialTable(unittest.TestCase):
def test_valid_table_name(self):
table_name = "Material_user_var_table_54ddc22a_383"
expected_result = {
"model_name": "user_var_table",
"model_hash": "54ddc22a",
"seq_no": 383,
}
actual_result = PythonWHT().split_material_name(table_name)
self.assertEqual(actual_result, expected_result)

def test_missing_prefix(self):
table_name = "user_var_table_54ddc22a_383"
test_cases = [
"Material_user_var_table_54ddc22a_383",
"user_var_table_54ddc22a_383",
"SELECT * FROM SCHEMA.Material_user_var_table_54ddc22a_383",
"SELECT * FROM `SCHEMA.Material_user_var_table_54ddc22a_383`",
"SELECT * FROM Material_user_var_table_54ddc22a_383",
"SELECT * FROM material_user_var_table_54ddc22a_383", # redshift_input
"SELECT * FROM `SCHEMA`.`Material_user_var_table_54ddc22a_383`", # bigquery_input
'SELECT days_since_last_seen FROM "rs_profiles_3"."material_user_var_table_54ddc22a_383"'
'''SELECT days_since_last_seen FROM "rs_profiles_3"."material_user_var_table_54ddc22a_383"''',
"SELECT days_since_last_seen FROM 'rs_profiles_3'.'material_user_var_table_54ddc22a_383'",
]
expected_result = {
"model_name": "user_var_table",
"model_hash": "54ddc22a",
"seq_no": 383,
}
actual_result = PythonWHT().split_material_name(table_name)
self.assertEqual(actual_result, expected_result)
for case in test_cases:
with self.subTest(case=case):
result = PythonWHT().split_material_name(case)
self.assertEqual(result, expected_result)

def test_missing_seq_no(self):
table_name = "Material_user_var_table_54ddc22a"
Expand All @@ -86,46 +103,6 @@ def test_invalid_table_name(self):
with self.assertRaises(Exception):
PythonWHT().split_material_name(table_name)

def test_material_query(self):
table_name = "SELECT * FROM SCHEMA.Material_user_var_table_54ddc22a_383"
expected_result = {
"model_name": "user_var_table",
"model_hash": "54ddc22a",
"seq_no": 383,
}
actual_result = PythonWHT().split_material_name(table_name)
self.assertEqual(actual_result, expected_result)

def test_material_query_with_snowflake_input(self):
table_name = "SELECT * FROM Material_user_var_table_54ddc22a_383"
expected_result = {
"model_name": "user_var_table",
"model_hash": "54ddc22a",
"seq_no": 383,
}
actual_result = PythonWHT().split_material_name(table_name)
self.assertEqual(actual_result, expected_result)

def test_material_query_with_redshift_input(self):
table_name = "SELECT * FROM material_user_var_table_54ddc22a_383"
expected_result = {
"model_name": "user_var_table",
"model_hash": "54ddc22a",
"seq_no": 383,
}
actual_result = PythonWHT().split_material_name(table_name)
self.assertEqual(actual_result, expected_result)

def test_material_query_with_bigquery_input(self):
table_name = "SELECT * FROM `SCHEMA`.`Material_user_var_table_54ddc22a_383`"
expected_result = {
"model_name": "user_var_table",
"model_hash": "54ddc22a",
"seq_no": 383,
}
actual_result = PythonWHT().split_material_name(table_name)
self.assertEqual(actual_result, expected_result)


class TestDateDiff(unittest.TestCase):
def test_same_date(self):
Expand Down

0 comments on commit 4a52fac

Please sign in to comment.