From 4a52fac5693836b4d1673ad9bac8b8dded53dfa0 Mon Sep 17 00:00:00 2001 From: shekhar-rudder <85345786+shekhar-rudder@users.noreply.github.com> Date: Fri, 5 Jul 2024 11:16:38 +0530 Subject: [PATCH] fix: unable to split material name for quoted input (#358) --- .../profiles_mlcorelib/utils/utils.py | 14 +-- .../profiles_mlcorelib/wht/pythonWHT.py | 1 + tests/unit/utils_tests.py | 107 +++++++----------- 3 files changed, 50 insertions(+), 72 deletions(-) diff --git a/src/predictions/profiles_mlcorelib/utils/utils.py b/src/predictions/profiles_mlcorelib/utils/utils.py index 74878c6c6..2ee2c07c1 100644 --- a/src/predictions/profiles_mlcorelib/utils/utils.py +++ b/src/predictions/profiles_mlcorelib/utils/utils.py @@ -1010,14 +1010,14 @@ def plot_user_feature_importance( def replace_seq_no_in_query(query: str, seq_no: int) -> str: - match = re.search(r"(_\d+)(`|$)", query) - if match: - replaced_query = ( - query[: match.start(1)] + "_" + str(seq_no) + query[match.end(1) :] - ) - return replaced_query - else: + matches = list(re.finditer(r"(_\d+)(`|'|\"|$)", query)) + if len(matches) == 0: + raise Exception(f"Couldn't find an integer seq_no in the input query: {query}") + match = matches[-1] + if match is None: raise Exception(f"Couldn't find an integer seq_no in the input query: {query}") + replaced_query = query[: match.start(1)] + "_" + str(seq_no) + query[match.end(1) :] + return replaced_query def extract_seq_no_from_select_query(select_query: str) -> int: diff --git a/src/predictions/profiles_mlcorelib/wht/pythonWHT.py b/src/predictions/profiles_mlcorelib/wht/pythonWHT.py index b4090db75..5f86a906c 100644 --- a/src/predictions/profiles_mlcorelib/wht/pythonWHT.py +++ b/src/predictions/profiles_mlcorelib/wht/pythonWHT.py @@ -376,6 +376,7 @@ def split_material_name(self, name: str) -> dict: else: table_name = mlower.split()[-1] table_suffix = table_name.split(MATERIAL_PREFIX.lower())[-1] + table_suffix = table_suffix.strip("\"'") split_parts = table_suffix.split("_") try: seq_no = int(split_parts[-1]) diff --git a/tests/unit/utils_tests.py b/tests/unit/utils_tests.py index 5ce286d64..eaa66efad 100644 --- a/tests/unit/utils_tests.py +++ b/tests/unit/utils_tests.py @@ -14,11 +14,32 @@ class TestReplaceSeqNoInQuery(unittest.TestCase): def test_replaces_seq_no_correctly(self): - query = "SELECT * FROM material_user_var_table_123" - seq_no = 567 - expected_result = "SELECT * FROM material_user_var_table_567" - actual_result = replace_seq_no_in_query(query, seq_no) - self.assertEqual(expected_result, actual_result) + test_cases = [ + { + "input": "SELECT * FROM material_user_var_table_123", + "output": "SELECT * FROM material_user_var_table_567", + }, + { + "input": "SELECT * FROM `schema`.`material_user_var_table_123`", + "output": "SELECT * FROM `schema`.`material_user_var_table_567`", + }, + { + "input": '''SELECT days_since_last_seen FROM "rs_profiles_3"."material_user_var_table_54ddc22a_383"''', + "output": '''SELECT days_since_last_seen FROM "rs_profiles_3"."material_user_var_table_54ddc22a_567"''', + }, + { + "input": "SELECT days_since_last_seen FROM 'rs_profiles_3'.'material_user_var_table_54ddc22a_383'", + "output": "SELECT days_since_last_seen FROM 'rs_profiles_3'.'material_user_var_table_54ddc22a_567'", + }, + { + "input": 'SELECT days_since_last_seen FROM "rs_profiles_3"."material_user_var_table_54ddc22a_383"', + "output": 'SELECT days_since_last_seen FROM "rs_profiles_3"."material_user_var_table_54ddc22a_567"', + }, + ] + for case in test_cases: + with self.subTest(case=case["input"]): + result = replace_seq_no_in_query(case["input"], 567) + self.assertEqual(result, case["output"]) def test_handles_empty_query(self): query = "" @@ -42,34 +63,30 @@ def test_handles_missing_seq_no(self): expected_result, ) - def test_replaces_seq_no_correctly_with_bigquery_input(self): - query = "SELECT * FROM `schema`.`material_user_var_table_123`" - seq_no = 567 - expected_result = "SELECT * FROM `schema`.`material_user_var_table_567`" - actual_result = replace_seq_no_in_query(query, seq_no) - self.assertEqual(expected_result, actual_result) - class TestSplitMaterialTable(unittest.TestCase): def test_valid_table_name(self): - table_name = "Material_user_var_table_54ddc22a_383" - expected_result = { - "model_name": "user_var_table", - "model_hash": "54ddc22a", - "seq_no": 383, - } - actual_result = PythonWHT().split_material_name(table_name) - self.assertEqual(actual_result, expected_result) - - def test_missing_prefix(self): - table_name = "user_var_table_54ddc22a_383" + test_cases = [ + "Material_user_var_table_54ddc22a_383", + "user_var_table_54ddc22a_383", + "SELECT * FROM SCHEMA.Material_user_var_table_54ddc22a_383", + "SELECT * FROM `SCHEMA.Material_user_var_table_54ddc22a_383`", + "SELECT * FROM Material_user_var_table_54ddc22a_383", + "SELECT * FROM material_user_var_table_54ddc22a_383", # redshift_input + "SELECT * FROM `SCHEMA`.`Material_user_var_table_54ddc22a_383`", # bigquery_input + 'SELECT days_since_last_seen FROM "rs_profiles_3"."material_user_var_table_54ddc22a_383"' + '''SELECT days_since_last_seen FROM "rs_profiles_3"."material_user_var_table_54ddc22a_383"''', + "SELECT days_since_last_seen FROM 'rs_profiles_3'.'material_user_var_table_54ddc22a_383'", + ] expected_result = { "model_name": "user_var_table", "model_hash": "54ddc22a", "seq_no": 383, } - actual_result = PythonWHT().split_material_name(table_name) - self.assertEqual(actual_result, expected_result) + for case in test_cases: + with self.subTest(case=case): + result = PythonWHT().split_material_name(case) + self.assertEqual(result, expected_result) def test_missing_seq_no(self): table_name = "Material_user_var_table_54ddc22a" @@ -86,46 +103,6 @@ def test_invalid_table_name(self): with self.assertRaises(Exception): PythonWHT().split_material_name(table_name) - def test_material_query(self): - table_name = "SELECT * FROM SCHEMA.Material_user_var_table_54ddc22a_383" - expected_result = { - "model_name": "user_var_table", - "model_hash": "54ddc22a", - "seq_no": 383, - } - actual_result = PythonWHT().split_material_name(table_name) - self.assertEqual(actual_result, expected_result) - - def test_material_query_with_snowflake_input(self): - table_name = "SELECT * FROM Material_user_var_table_54ddc22a_383" - expected_result = { - "model_name": "user_var_table", - "model_hash": "54ddc22a", - "seq_no": 383, - } - actual_result = PythonWHT().split_material_name(table_name) - self.assertEqual(actual_result, expected_result) - - def test_material_query_with_redshift_input(self): - table_name = "SELECT * FROM material_user_var_table_54ddc22a_383" - expected_result = { - "model_name": "user_var_table", - "model_hash": "54ddc22a", - "seq_no": 383, - } - actual_result = PythonWHT().split_material_name(table_name) - self.assertEqual(actual_result, expected_result) - - def test_material_query_with_bigquery_input(self): - table_name = "SELECT * FROM `SCHEMA`.`Material_user_var_table_54ddc22a_383`" - expected_result = { - "model_name": "user_var_table", - "model_hash": "54ddc22a", - "seq_no": 383, - } - actual_result = PythonWHT().split_material_name(table_name) - self.assertEqual(actual_result, expected_result) - class TestDateDiff(unittest.TestCase): def test_same_date(self):