Skip to content

Commit

Permalink
feat: add overlay and substr support (#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq authored Oct 5, 2024
1 parent 13dd1f9 commit df6e6da
Show file tree
Hide file tree
Showing 11 changed files with 25 additions and 23 deletions.
1 change: 0 additions & 1 deletion docs/bigquery.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ See something that you would like to see supported? [Open an issue](https://gith
* sql
* SQLFrame Specific: Get the SQL representation of a given column
* [startswith](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Column.startswith.html)
* [substr](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Column.substr.html)
* [when](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Column.when.html)

### DataFrame Class
Expand Down
1 change: 0 additions & 1 deletion docs/duckdb.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ See something that you would like to see supported? [Open an issue](https://gith
* sql
* SQLFrame Specific: Get the SQL representation of a given column
* [startswith](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Column.startswith.html)
* [substr](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Column.substr.html)
* [when](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Column.when.html)

### DataFrame Class
Expand Down
1 change: 0 additions & 1 deletion docs/snowflake.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ See something that you would like to see supported? [Open an issue](https://gith
* sql
* SQLFrame Specific: Get the SQL representation of a given column
* [startswith](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Column.startswith.html)
* [substr](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Column.substr.html)
* [when](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Column.when.html)

### DataFrame Class
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
python_requires=">=3.8",
install_requires=[
"prettytable<3.11.1",
"sqlglot>=24.0.0,<25.23",
"sqlglot>=24.0.0,<25.25",
"typing_extensions>=4.8,<5",
],
extras_require={
Expand Down
21 changes: 11 additions & 10 deletions sqlframe/base/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,11 +1129,15 @@ def overlay(
pos: t.Union[ColumnOrName, int],
len: t.Optional[t.Union[ColumnOrName, int]] = None,
) -> Column:
pos_value = lit(pos) if isinstance(pos, int) else pos
if len is not None:
len_value = lit(len) if isinstance(len, int) else len
return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos_value, len_value)
return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos_value)
return Column.invoke_expression_over_column(
src,
expression.Overlay,
**{
"expression": Column(replace).expression,
"from": lit(pos).expression,
"for": lit(len).expression if len is not None else None,
},
)


@meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
Expand Down Expand Up @@ -4834,7 +4838,7 @@ def str_to_map(
)


@meta(unsupported_engines="*")
@meta(unsupported_engines="postgres")
def substr(str: ColumnOrName, pos: ColumnOrName, len: t.Optional[ColumnOrName] = None) -> Column:
"""
Returns the substring of `str` that starts at `pos` and is of length `len`,
Expand Down Expand Up @@ -4873,10 +4877,7 @@ def substr(str: ColumnOrName, pos: ColumnOrName, len: t.Optional[ColumnOrName] =
| k SQL|
+------------------------+
"""
if len is not None:
return Column.invoke_anonymous_function(str, "substr", pos, len)
else:
return Column.invoke_anonymous_function(str, "substr", pos)
return Column.invoke_expression_over_column(str, expression.Substring, start=pos, length=len)


@meta(unsupported_engines="*")
Expand Down
1 change: 1 addition & 0 deletions sqlframe/bigquery/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ from sqlframe.base.functions import stddev as stddev
from sqlframe.base.functions import stddev_pop as stddev_pop
from sqlframe.base.functions import stddev_samp as stddev_samp
from sqlframe.base.functions import struct as struct
from sqlframe.base.functions import substr as substr
from sqlframe.base.functions import substring as substring
from sqlframe.base.functions import sum as sum
from sqlframe.base.functions import sum_distinct as sum_distinct
Expand Down
1 change: 1 addition & 0 deletions sqlframe/duckdb/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ from sqlframe.base.functions import (
stddev_samp as stddev_samp,
struct as struct,
substring as substring,
substr as substr,
sum as sum,
sumDistinct as sumDistinct,
sum_distinct as sum_distinct,
Expand Down
1 change: 1 addition & 0 deletions sqlframe/snowflake/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ from sqlframe.base.functions import (
stddev_pop as stddev_pop,
stddev_samp as stddev_samp,
substring as substring,
substr as substr,
sum as sum,
sumDistinct as sumDistinct,
sum_distinct as sum_distinct,
Expand Down
1 change: 1 addition & 0 deletions sqlframe/spark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
add_months_by_multiplication as add_months,
arrays_overlap_renamed as arrays_overlap,
_is_string_using_typeof_string_lcase as _is_string,
try_element_at_zero_based as try_element_at,
)
2 changes: 1 addition & 1 deletion sqlframe/spark/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ from sqlframe.base.function_alternatives import ( # noqa
percentile_without_disc as percentile,
add_months_by_multiplication as add_months,
arrays_overlap_renamed as arrays_overlap,
try_element_at_zero_based as try_element_at,
)
from sqlframe.base.functions import (
abs as abs,
Expand Down Expand Up @@ -372,7 +373,6 @@ from sqlframe.base.functions import (
try_aes_decrypt as try_aes_decrypt,
try_avg as try_avg,
try_divide as try_divide,
try_element_at as try_element_at,
try_multiply as try_multiply,
try_subtract as try_subtract,
try_sum as try_sum,
Expand Down
16 changes: 8 additions & 8 deletions tests/unit/standalone/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,12 +1823,12 @@ def test_instr(expression, expected):
@pytest.mark.parametrize(
"expression, expected",
[
(SF.overlay("cola", "colb", 3, 7), "OVERLAY(cola, colb, 3, 7)"),
(SF.overlay("cola", "colb", 3, 7), "OVERLAY(cola PLACING colb FROM 3 FOR 7)"),
(
SF.overlay(SF.col("cola"), SF.col("colb"), SF.lit(3), SF.lit(7)),
"OVERLAY(cola, colb, 3, 7)",
"OVERLAY(cola PLACING colb FROM 3 FOR 7)",
),
(SF.overlay("cola", "colb", 3), "OVERLAY(cola, colb, 3)"),
(SF.overlay("cola", "colb", 3), "OVERLAY(cola PLACING colb FROM 3)"),
],
)
def test_overlay(expression, expected):
Expand Down Expand Up @@ -2263,8 +2263,8 @@ def test_array_contains(expression, expected):
@pytest.mark.parametrize(
"expression, expected",
[
(SF.arrays_overlap("cola", "colb"), "ARRAY_OVERLAPS(cola, colb)"),
(SF.arrays_overlap(SF.col("cola"), SF.col("colb")), "ARRAY_OVERLAPS(cola, colb)"),
(SF.arrays_overlap("cola", "colb"), "cola && colb"),
(SF.arrays_overlap(SF.col("cola"), SF.col("colb")), "cola && colb"),
],
)
def test_arrays_overlap(expression, expected):
Expand Down Expand Up @@ -4495,9 +4495,9 @@ def test_str_to_map(expression, expected):
@pytest.mark.parametrize(
"expression, expected",
[
(SF.substr("cola", "colb"), "SUBSTR(cola, colb)"),
(SF.substr(SF.col("cola"), SF.col("colb")), "SUBSTR(cola, colb)"),
(SF.substr("cola", "colb", "colc"), "SUBSTR(cola, colb, colc)"),
(SF.substr("cola", "colb"), "SUBSTRING(cola, colb)"),
(SF.substr(SF.col("cola"), SF.col("colb")), "SUBSTRING(cola, colb)"),
(SF.substr("cola", "colb", "colc"), "SUBSTRING(cola, colb, colc)"),
],
)
def test_substr(expression, expected):
Expand Down

0 comments on commit df6e6da

Please sign in to comment.