Skip to content

Commit

Permalink
fix: handling scalar subquery in function
Browse files Browse the repository at this point in the history
  • Loading branch information
reata committed May 20, 2024
1 parent 4c30955 commit 233be5d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 3 deletions.
4 changes: 4 additions & 0 deletions sqllineage/core/parser/sqlfluff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def list_subqueries(segment: BaseSegment) -> List[SubQueryTuple]:
else None
)
subquery.append(SubQueryTuple(bracketed_segment, alias))
elif function := select_clause_element.get_child("function"):
for bracketed in function.recursive_crawl("bracketed"):
if is_subquery(bracketed):
subquery.append(SubQueryTuple(bracketed, None))
elif segment.type == "from_expression_element":
as_segment, target = extract_as_and_target_segment(segment)
if is_subquery(target):
Expand Down
23 changes: 20 additions & 3 deletions sqllineage/core/parser/sqlparse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Comparison,
Function,
Identifier,
IdentifierList,
Over,
Parenthesis,
TokenList,
Expand Down Expand Up @@ -114,6 +115,9 @@ def get_subquery_parentheses(
if isinstance(token, Function):
# CTE without AS: tbl (SELECT 1)
target = token.tokens[-1]
# fallback to Function: LEAST((SELECT MIN(dt) FROM tab1), (SELECT MIN(dt) FROM tab2))
if len([p for p in target.tokens if isinstance(p, Parenthesis)]) > 0:
target = token
elif isinstance(token, (Values, Where)):
# WHERE col1 IN (SELECT max(col1) FROM tab2)
target = token
Expand All @@ -130,6 +134,15 @@ def get_subquery_parentheses(
subquery.append(SubQueryTuple(tk.right, tk.right.get_real_name()))
elif is_subquery(tk):
subquery.append(SubQueryTuple(tk, token.get_real_name()))
elif isinstance(target, Function):
# recursively check function parameter to possible scalar subquery
parameters = get_parameters(target)
while parameters:
parameter = parameters.pop(0)
if is_subquery(parameter):
subquery.append(SubQueryTuple(parameter, None))
elif isinstance(parameter, Function):
parameters.extend(get_parameters(parameter))
elif isinstance(target, Values):
for row in target.get_sublists():
for col in row:
Expand All @@ -154,7 +167,11 @@ def get_parameters(token: Function):
if isinstance(last_token, Over):
# special handling for window function
parameters = token.get_parameters()
parameters += [
tk for tk in last_token.tokens if tk.is_group or tk.ttype == Wildcard
]
for tk in last_token.tokens:
if isinstance(tk, IdentifierList):
# special handling when multiple parameters are grouped as IdentifierList incorrectly
for identifier in tk.get_sublists():
parameters.append(identifier)
elif tk.is_group or tk.ttype == Wildcard:
parameters.append(tk)
return parameters
25 changes: 25 additions & 0 deletions tests/sql/table/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,31 @@ def test_select_subquery_in_where_clause():
)


def test_select_subquery_in_function():
assert_table_lineage_equal(
"SELECT TO_DATE((SELECT MIN(dt) FROM tab1))",
{"tab1"},
)


def test_select_multiple_subquery_in_function():
assert_table_lineage_equal(
"SELECT LEAST((SELECT MIN(dt) FROM tab1), (SELECT MIN(dt) FROM tab2))",
{"tab1", "tab2"},
)


def test_select_subquery_in_function_nested():
assert_table_lineage_equal(
"""SELECT EXPLODE(SEQUENCE(
TO_DATE((SELECT MIN(dt) FROM tab1)),
TO_DATE((SELECT MAX(dt) FROM tab2)),
INTERVAL 1 DAY
)) AS result""",
{"tab1", "tab2"},
)


def test_select_inner_join():
assert_table_lineage_equal("SELECT * FROM tab1 INNER JOIN tab2", {"tab1", "tab2"})

Expand Down

0 comments on commit 233be5d

Please sign in to comment.