diff --git a/sqllineage/core/parser/sqlfluff/utils.py b/sqllineage/core/parser/sqlfluff/utils.py index 56c1689d..6d841c4c 100644 --- a/sqllineage/core/parser/sqlfluff/utils.py +++ b/sqllineage/core/parser/sqlfluff/utils.py @@ -138,6 +138,10 @@ def list_subqueries(segment: BaseSegment) -> List[SubQueryTuple]: else None ) subquery.append(SubQueryTuple(bracketed_segment, alias)) + elif function := select_clause_element.get_child("function"): + for bracketed in function.recursive_crawl("bracketed"): + if is_subquery(bracketed): + subquery.append(SubQueryTuple(bracketed, None)) elif segment.type == "from_expression_element": as_segment, target = extract_as_and_target_segment(segment) if is_subquery(target): diff --git a/sqllineage/core/parser/sqlparse/utils.py b/sqllineage/core/parser/sqlparse/utils.py index 2f973c9f..06dbf63b 100644 --- a/sqllineage/core/parser/sqlparse/utils.py +++ b/sqllineage/core/parser/sqlparse/utils.py @@ -6,6 +6,7 @@ Comparison, Function, Identifier, + IdentifierList, Over, Parenthesis, TokenList, @@ -114,6 +115,9 @@ def get_subquery_parentheses( if isinstance(token, Function): # CTE without AS: tbl (SELECT 1) target = token.tokens[-1] + # fallback to Function: LEAST((SELECT MIN(dt) FROM tab1), (SELECT MIN(dt) FROM tab2)) + if len([p for p in target.tokens if isinstance(p, Parenthesis)]) > 0: + target = token elif isinstance(token, (Values, Where)): # WHERE col1 IN (SELECT max(col1) FROM tab2) target = token @@ -130,6 +134,15 @@ def get_subquery_parentheses( subquery.append(SubQueryTuple(tk.right, tk.right.get_real_name())) elif is_subquery(tk): subquery.append(SubQueryTuple(tk, token.get_real_name())) + elif isinstance(target, Function): + # recursively check function parameter to possible scalar subquery + parameters = get_parameters(target) + while parameters: + parameter = parameters.pop(0) + if is_subquery(parameter): + subquery.append(SubQueryTuple(parameter, None)) + elif isinstance(parameter, Function): + parameters.extend(get_parameters(parameter)) elif isinstance(target, Values): for row in target.get_sublists(): for col in row: @@ -154,7 +167,11 @@ def get_parameters(token: Function): if isinstance(last_token, Over): # special handling for window function parameters = token.get_parameters() - parameters += [ - tk for tk in last_token.tokens if tk.is_group or tk.ttype == Wildcard - ] + for tk in last_token.tokens: + if isinstance(tk, IdentifierList): + # special handling when multiple parameters are grouped as IdentifierList incorrectly + for identifier in tk.get_sublists(): + parameters.append(identifier) + elif tk.is_group or tk.ttype == Wildcard: + parameters.append(tk) return parameters diff --git a/tests/sql/table/test_select.py b/tests/sql/table/test_select.py index fe143dcc..f659f729 100644 --- a/tests/sql/table/test_select.py +++ b/tests/sql/table/test_select.py @@ -193,6 +193,31 @@ def test_select_subquery_in_where_clause(): ) +def test_select_subquery_in_function(): + assert_table_lineage_equal( + "SELECT TO_DATE((SELECT MIN(dt) FROM tab1))", + {"tab1"}, + ) + + +def test_select_multiple_subquery_in_function(): + assert_table_lineage_equal( + "SELECT LEAST((SELECT MIN(dt) FROM tab1), (SELECT MIN(dt) FROM tab2))", + {"tab1", "tab2"}, + ) + + +def test_select_subquery_in_function_nested(): + assert_table_lineage_equal( + """SELECT EXPLODE(SEQUENCE( + TO_DATE((SELECT MIN(dt) FROM tab1)), + TO_DATE((SELECT MAX(dt) FROM tab2)), + INTERVAL 1 DAY +)) AS result""", + {"tab1", "tab2"}, + ) + + def test_select_inner_join(): assert_table_lineage_equal("SELECT * FROM tab1 INNER JOIN tab2", {"tab1", "tab2"})