Skip to content

Commit

Permalink
update to support latest sqlglot
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq committed Nov 19, 2024
1 parent 6c8e705 commit 920a004
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 14 deletions.
1 change: 1 addition & 0 deletions docs/bigquery.md
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ See something that you would like to see supported? [Open an issue](https://gith
* [max_by](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.max_by.html)
* [md5](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.md5.html)
* [mean](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.mean.html)
* [median](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.median.html)
* [min](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.min.html)
* [min_by](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.min_by.html)
* [minute](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.minute.html)
Expand Down
3 changes: 3 additions & 0 deletions docs/duckdb.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ See something that you would like to see supported? [Open an issue](https://gith
* [concat](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.concat.html)
* Only works on strings (does not work on arrays)
* [concat_ws](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.concat_ws.html)
* [contains](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.contains.html)
* Only works on strings (does not support binary)
* [convert_timezone](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.convert_timezone.html)
* [corr](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.corr.html)
* [cos](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.cos.html)
Expand Down Expand Up @@ -395,6 +397,7 @@ See something that you would like to see supported? [Open an issue](https://gith
* [max_by](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.max_by.html)
* [md5](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.md5.html)
* [mean](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.mean.html)
* [median](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.median.html)
* [min](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.min.html)
* [min_by](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.min_by.html)
* [minute](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.minute.html)
Expand Down
4 changes: 4 additions & 0 deletions docs/snowflake.md
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ See something that you would like to see supported? [Open an issue](https://gith
* [concat](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.concat.html)
* Can only concat strings not arrays
* [concat_ws](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.concat_ws.html)
* [contains](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.contains.html)
* * Only works on strings (does not support binary)
* [convert_timezone](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.convert_timezone.html)
* [corr](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.corr.html)
* [cos](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.cos.html)
Expand Down Expand Up @@ -428,6 +430,7 @@ See something that you would like to see supported? [Open an issue](https://gith
* [max_by](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.max_by.html)
* [md5](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.md5.html)
* [mean](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.mean.html)
* [median](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.median.html)
* [min](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.min.html)
* [min_by](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.min_by.html)
* [minute](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.minute.html)
Expand Down Expand Up @@ -510,6 +513,7 @@ See something that you would like to see supported? [Open an issue](https://gith
* [ucase](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.ucase.html)
* [unbase64](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unbase64.html)
* [unhex](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unhex.html)
* [unix_seconds](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unix_seconds.html)
* [unix_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unix_timestamp.html)
* [upper](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.upper.html)
* [var_pop](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.var_pop.html)
Expand Down
2 changes: 1 addition & 1 deletion sqlframe/base/function_alternatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,7 +1593,7 @@ def try_to_timestamp_pgtemp(col: ColumnOrName, format: t.Optional[ColumnOrName]
def typeof_pg_typeof(col: ColumnOrName) -> Column:
return (
Column.invoke_anonymous_function(col, "pg_typeof")
.cast(expression.DataType.build("regtype", dialect="postgres"))
.cast(expression.DataType(this=expression.DataType.Type.USERDEFINED, kind="regtype"))
.cast("text")
)

Expand Down
22 changes: 11 additions & 11 deletions sqlframe/base/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2069,9 +2069,11 @@ def character_length(str: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(str, "character_length")


@meta()
@meta(unsupported_engines=["bigquery", "postgres"])
def contains(left: ColumnOrName, right: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(left, "contains", right)
return Column.invoke_expression_over_column(
left, expression.Contains, expression=Column.ensure_col(right).expression
)


@meta(unsupported_engines=["bigquery", "postgres"])
Expand Down Expand Up @@ -3484,7 +3486,7 @@ def mask(
)


@meta(unsupported_engines="*")
@meta(unsupported_engines=["bigquery"])
def median(col: ColumnOrName) -> Column:
"""
Returns the median of the values in a group.
Expand Down Expand Up @@ -3520,7 +3522,7 @@ def median(col: ColumnOrName) -> Column:
|dotNET| 10000.0|
+------+----------------+
"""
return Column.invoke_anonymous_function(col, "median")
return Column.invoke_expression_over_column(col, expression.Median)


@meta(unsupported_engines="*")
Expand Down Expand Up @@ -4106,11 +4108,9 @@ def regexp_extract_all(
>>> df.select(regexp_extract_all('str', col("regexp")).alias('d')).collect()
[Row(d=['100', '300'])]
"""
if idx is None:
return Column.invoke_anonymous_function(str, "regexp_extract_all", regexp)
else:
idx = lit(idx) if isinstance(idx, int) else idx
return Column.invoke_anonymous_function(str, "regexp_extract_all", regexp, idx)
return Column.invoke_expression_over_column(
str, expression.RegexpExtractAll, expression=regexp, group=idx
)


@meta(unsupported_engines="*")
Expand Down Expand Up @@ -5426,7 +5426,7 @@ def unix_millis(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "unix_millis")


@meta(unsupported_engines="*")
@meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
def unix_seconds(col: ColumnOrName) -> Column:
"""Returns the number of seconds since 1970-01-01 00:00:00 UTC.
Truncates higher levels of precision.
Expand All @@ -5441,7 +5441,7 @@ def unix_seconds(col: ColumnOrName) -> Column:
[Row(n=1437584400)]
>>> spark.conf.unset("spark.sql.session.timeZone")
"""
return Column.invoke_anonymous_function(col, "unix_seconds")
return Column.invoke_expression_over_column(col, expression.UnixSeconds)


@meta(unsupported_engines="*")
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/standalone/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
@pytest.mark.parametrize("name,func", inspect.getmembers(SF, inspect.isfunction))
def test_invoke_anonymous(name, func):
# array_size - converts to `size` but `array_size` and `size` behave differently
# exists - the spark exists takes a lambda function and the exists in SQLGlot seems more basic
# make_interval - SQLGlot doesn't support week
# to_char - convert to a cast that ignores the format provided
# ltrim/rtrim - don't seem to convert correctly on some engines
ignore_funcs = {"array_size", "to_char", "ltrim", "rtrim"}
ignore_funcs = {"array_size", "exists", "make_interval", "to_char", "ltrim", "rtrim"}
if "invoke_anonymous_function" in inspect.getsource(func) and name not in ignore_funcs:
func = parse_one(f"{name}()", read="spark", error_level=ErrorLevel.IGNORE)
assert isinstance(func, exp.Anonymous)
Expand Down Expand Up @@ -4224,7 +4226,7 @@ def test_regexp_count(expression, expected):
SF.regexp_extract_all("cola", "colb", SF.col("colc")),
"REGEXP_EXTRACT_ALL(cola, colb, colc)",
),
(SF.regexp_extract_all("cola", "colb", 1), "REGEXP_EXTRACT_ALL(cola, colb, 1)"),
(SF.regexp_extract_all("cola", "colb", 2), "REGEXP_EXTRACT_ALL(cola, colb, 2)"),
],
)
def test_regexp_extract_all(expression, expected):
Expand Down

0 comments on commit 920a004

Please sign in to comment.