diff --git a/docs/bigquery.md b/docs/bigquery.md index 00fec59..72c2fc2 100644 --- a/docs/bigquery.md +++ b/docs/bigquery.md @@ -425,6 +425,7 @@ See something that you would like to see supported? [Open an issue](https://gith * [max_by](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.max_by.html) * [md5](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.md5.html) * [mean](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.mean.html) +* [median](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.median.html) * [min](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.min.html) * [min_by](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.min_by.html) * [minute](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.minute.html) diff --git a/docs/duckdb.md b/docs/duckdb.md index bd5d621..4b24354 100644 --- a/docs/duckdb.md +++ b/docs/duckdb.md @@ -308,6 +308,8 @@ See something that you would like to see supported? [Open an issue](https://gith * [concat](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.concat.html) * Only works on strings (does not work on arrays) * [concat_ws](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.concat_ws.html) +* [contains](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.contains.html) + * Only works on strings (does not support binary) * [convert_timezone](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.convert_timezone.html) * [corr](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.corr.html) * [cos](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.cos.html) @@ -395,6 +397,7 @@ See something that you would like to see supported? [Open an issue](https://gith * [max_by](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.max_by.html) * [md5](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.md5.html) * [mean](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.mean.html) +* [median](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.median.html) * [min](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.min.html) * [min_by](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.min_by.html) * [minute](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.minute.html) diff --git a/docs/snowflake.md b/docs/snowflake.md index 47609d8..ed856ef 100644 --- a/docs/snowflake.md +++ b/docs/snowflake.md @@ -344,6 +344,8 @@ See something that you would like to see supported? [Open an issue](https://gith * [concat](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.concat.html) * Can only concat strings not arrays * [concat_ws](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.concat_ws.html) +* [contains](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.contains.html) + * * Only works on strings (does not support binary) * [convert_timezone](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.convert_timezone.html) * [corr](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.corr.html) * [cos](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.cos.html) @@ -428,6 +430,7 @@ See something that you would like to see supported? [Open an issue](https://gith * [max_by](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.max_by.html) * [md5](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.md5.html) * [mean](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.mean.html) +* [median](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.median.html) * [min](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.min.html) * [min_by](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.min_by.html) * [minute](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.minute.html) @@ -510,6 +513,7 @@ See something that you would like to see supported? [Open an issue](https://gith * [ucase](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.ucase.html) * [unbase64](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unbase64.html) * [unhex](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unhex.html) +* [unix_seconds](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unix_seconds.html) * [unix_timestamp](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.unix_timestamp.html) * [upper](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.upper.html) * [var_pop](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.var_pop.html) diff --git a/setup.py b/setup.py index 2840b33..6974043 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ python_requires=">=3.8", install_requires=[ "prettytable<3.12.1", - "sqlglot>=24.0.0,<25.29", + "sqlglot>=24.0.0,<25.32", "typing_extensions>=4.8,<5", ], extras_require={ diff --git a/sqlframe/base/function_alternatives.py b/sqlframe/base/function_alternatives.py index b418584..3c7bdad 100644 --- a/sqlframe/base/function_alternatives.py +++ b/sqlframe/base/function_alternatives.py @@ -1593,7 +1593,7 @@ def try_to_timestamp_pgtemp(col: ColumnOrName, format: t.Optional[ColumnOrName] def typeof_pg_typeof(col: ColumnOrName) -> Column: return ( Column.invoke_anonymous_function(col, "pg_typeof") - .cast(expression.DataType.build("regtype", dialect="postgres")) + .cast(expression.DataType(this=expression.DataType.Type.USERDEFINED, kind="regtype")) .cast("text") ) diff --git a/sqlframe/base/functions.py b/sqlframe/base/functions.py index 7bcf29f..f6c52f2 100644 --- a/sqlframe/base/functions.py +++ b/sqlframe/base/functions.py @@ -2069,9 +2069,11 @@ def character_length(str: ColumnOrName) -> Column: return Column.invoke_anonymous_function(str, "character_length") -@meta() +@meta(unsupported_engines=["bigquery", "postgres"]) def contains(left: ColumnOrName, right: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(left, "contains", right) + return Column.invoke_expression_over_column( + left, expression.Contains, expression=Column.ensure_col(right).expression + ) @meta(unsupported_engines=["bigquery", "postgres"]) @@ -3484,7 +3486,7 @@ def mask( ) -@meta(unsupported_engines="*") +@meta(unsupported_engines=["bigquery"]) def median(col: ColumnOrName) -> Column: """ Returns the median of the values in a group. @@ -3520,7 +3522,7 @@ def median(col: ColumnOrName) -> Column: |dotNET| 10000.0| +------+----------------+ """ - return Column.invoke_anonymous_function(col, "median") + return Column.invoke_expression_over_column(col, expression.Median) @meta(unsupported_engines="*") @@ -4106,11 +4108,9 @@ def regexp_extract_all( >>> df.select(regexp_extract_all('str', col("regexp")).alias('d')).collect() [Row(d=['100', '300'])] """ - if idx is None: - return Column.invoke_anonymous_function(str, "regexp_extract_all", regexp) - else: - idx = lit(idx) if isinstance(idx, int) else idx - return Column.invoke_anonymous_function(str, "regexp_extract_all", regexp, idx) + return Column.invoke_expression_over_column( + str, expression.RegexpExtractAll, expression=regexp, group=idx + ) @meta(unsupported_engines="*") @@ -5426,7 +5426,7 @@ def unix_millis(col: ColumnOrName) -> Column: return Column.invoke_anonymous_function(col, "unix_millis") -@meta(unsupported_engines="*") +@meta(unsupported_engines=["bigquery", "duckdb", "postgres"]) def unix_seconds(col: ColumnOrName) -> Column: """Returns the number of seconds since 1970-01-01 00:00:00 UTC. Truncates higher levels of precision. @@ -5441,7 +5441,7 @@ def unix_seconds(col: ColumnOrName) -> Column: [Row(n=1437584400)] >>> spark.conf.unset("spark.sql.session.timeZone") """ - return Column.invoke_anonymous_function(col, "unix_seconds") + return Column.invoke_expression_over_column(col, expression.UnixSeconds) @meta(unsupported_engines="*") diff --git a/tests/unit/standalone/test_functions.py b/tests/unit/standalone/test_functions.py index 8e7d282..757d236 100644 --- a/tests/unit/standalone/test_functions.py +++ b/tests/unit/standalone/test_functions.py @@ -13,9 +13,11 @@ @pytest.mark.parametrize("name,func", inspect.getmembers(SF, inspect.isfunction)) def test_invoke_anonymous(name, func): # array_size - converts to `size` but `array_size` and `size` behave differently + # exists - the spark exists takes a lambda function and the exists in SQLGlot seems more basic + # make_interval - SQLGlot doesn't support week # to_char - convert to a cast that ignores the format provided # ltrim/rtrim - don't seem to convert correctly on some engines - ignore_funcs = {"array_size", "to_char", "ltrim", "rtrim"} + ignore_funcs = {"array_size", "exists", "make_interval", "to_char", "ltrim", "rtrim"} if "invoke_anonymous_function" in inspect.getsource(func) and name not in ignore_funcs: func = parse_one(f"{name}()", read="spark", error_level=ErrorLevel.IGNORE) assert isinstance(func, exp.Anonymous) @@ -4224,7 +4226,7 @@ def test_regexp_count(expression, expected): SF.regexp_extract_all("cola", "colb", SF.col("colc")), "REGEXP_EXTRACT_ALL(cola, colb, colc)", ), - (SF.regexp_extract_all("cola", "colb", 1), "REGEXP_EXTRACT_ALL(cola, colb, 1)"), + (SF.regexp_extract_all("cola", "colb", 2), "REGEXP_EXTRACT_ALL(cola, colb, 2)"), ], ) def test_regexp_extract_all(expression, expected):