Skip to content

Commit

Permalink
feat(rust,python): expressify offset and length for str.slice
Browse files Browse the repository at this point in the history
  • Loading branch information
cmdlineluser committed Oct 27, 2023
1 parent 4b2bd83 commit d524bf8
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 90 deletions.
7 changes: 2 additions & 5 deletions crates/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,12 +524,9 @@ pub trait Utf8NameSpaceImpl: AsUtf8 {
///
/// Determines a substring starting from `start` and with optional length `length` of each of the elements in `array`.
/// `start` can be negative, in which case the start counts from the end of the string.
fn str_slice(&self, start: i64, length: Option<u64>) -> Utf8Chunked {
fn str_slice(&self, start: &Int64Chunked, length: &UInt64Chunked) -> Utf8Chunked {
let ca = self.as_utf8();
let iter = ca
.downcast_iter()
.map(|c| substring::utf8_substring(c, start, &length));
Utf8Chunked::from_chunk_iter_like(ca, iter)
super::substring::utf8_substring(ca, start, length).into()
}
}

Expand Down
95 changes: 53 additions & 42 deletions crates/polars-ops/src/chunked_array/strings/substring.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,62 @@
use arrow::array::Utf8Array;
use polars_core::prelude::arity::ternary_elementwise;

use crate::chunked_array::{Int64Chunked, UInt64Chunked, Utf8Chunked};

/// Returns a Utf8Array<O> with a substring starting from `start` and with optional length `length` of each of the elements in `array`.
/// `start` can be negative, in which case the start counts from the end of the string.
pub(super) fn utf8_substring(
array: &Utf8Array<i64>,
start: i64,
length: &Option<u64>,
) -> Utf8Array<i64> {
let length = length.map(|v| v as usize);
/// `offset` can be negative, in which case the offset counts from the end of the string.
fn utf8_substring_ternary<'a>(
opt_str_val: Option<&'a str>,
opt_offset: Option<i64>,
opt_length: Option<u64>,
) -> Option<&'a str> {
match (opt_str_val, opt_offset) {
(Some(str_val), Some(offset)) => {
// compute where we should offset slicing this entry.
let offset = if offset >= 0 {
offset as usize
} else {
let offset = (0i64 - offset) as usize;
str_val
.char_indices()
.rev()
.nth(offset)
.map(|(idx, _)| idx + 1)
.unwrap_or(0)
};

let iter = array.values_iter().map(|str_val| {
// compute where we should start slicing this entry.
let start = if start >= 0 {
start as usize
} else {
let start = (0i64 - start) as usize;
str_val
.char_indices()
.rev()
.nth(start)
.map(|(idx, _)| idx + 1)
.unwrap_or(0)
};
let mut iter_chars = str_val.char_indices();
if let Some((offset_idx, _)) = iter_chars.nth(offset) {
// length of the str
let len_end = str_val.len() - offset_idx;

let mut iter_chars = str_val.char_indices();
if let Some((start_idx, _)) = iter_chars.nth(start) {
// length of the str
let len_end = str_val.len() - start_idx;
// slice to end of str if no length given
let length = match opt_length {
Some(length) => length as usize,
_ => len_end,
};

// length to slice
let length = length.unwrap_or(len_end);
if length == 0 {
return Some("");
}
// compute
let end_idx = iter_chars
.nth(length.saturating_sub(1))
.map(|(idx, _)| idx)
.unwrap_or(str_val.len());

if length == 0 {
return "";
Some(&str_val[offset_idx..end_idx])
} else {
Some("")
}
// compute
let end_idx = iter_chars
.nth(length.saturating_sub(1))
.map(|(idx, _)| idx)
.unwrap_or(str_val.len());

&str_val[start_idx..end_idx]
} else {
""
}
});
},
_ => None,
}
}

let new = Utf8Array::<i64>::from_trusted_len_values_iter(iter);
new.with_validity(array.validity().cloned())
pub(super) fn utf8_substring(
ca: &Utf8Chunked,
offset: &Int64Chunked,
length: &UInt64Chunked,
) -> Utf8Chunked {
ternary_elementwise(ca, offset, length, utf8_substring_ternary)
}
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
StripSuffix => map_as_slice!(strings::strip_suffix),
#[cfg(feature = "string_from_radix")]
FromRadix(radix, strict) => map!(strings::from_radix, radix, strict),
Slice(start, length) => map!(strings::str_slice, start, length),
Slice => map_as_slice!(strings::str_slice),
Explode => map!(strings::explode),
#[cfg(feature = "dtype-decimal")]
ToDecimal(infer_len) => map!(strings::to_decimal, infer_len),
Expand Down
61 changes: 48 additions & 13 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub enum StringFunction {
length: usize,
fill_char: char,
},
Slice(i64, Option<u64>),
Slice,
StartsWith,
StripChars,
StripCharsStart,
Expand Down Expand Up @@ -125,14 +125,8 @@ impl StringFunction {
Titlecase => mapper.with_same_dtype(),
#[cfg(feature = "dtype-decimal")]
ToDecimal(_) => mapper.with_dtype(DataType::Decimal(None, None)),
Uppercase
| Lowercase
| StripChars
| StripCharsStart
| StripCharsEnd
| StripPrefix
| StripSuffix
| Slice(_, _) => mapper.with_same_dtype(),
Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix
| StripSuffix | Slice => mapper.with_same_dtype(),
#[cfg(feature = "string_pad")]
PadStart { .. } | PadEnd { .. } | ZFill { .. } => mapper.with_same_dtype(),
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -180,7 +174,7 @@ impl Display for StringFunction {
StringFunction::PadStart { .. } => "pad_start",
#[cfg(feature = "regex")]
StringFunction::Replace { .. } => "replace",
StringFunction::Slice(_, _) => "slice",
StringFunction::Slice => "slice",
StringFunction::StartsWith { .. } => "starts_with",
StringFunction::StripChars => "strip_chars",
StringFunction::StripCharsStart => "strip_chars_start",
Expand Down Expand Up @@ -732,9 +726,50 @@ pub(super) fn from_radix(s: &Series, radix: u32, strict: bool) -> PolarsResult<S
let ca = s.utf8()?;
ca.parse_int(radix, strict).map(|ok| ok.into_series())
}
pub(super) fn str_slice(s: &Series, start: i64, length: Option<u64>) -> PolarsResult<Series> {
let ca = s.utf8()?;
Ok(ca.str_slice(start, length).into_series())

pub(super) fn str_slice(s: &[Series]) -> PolarsResult<Series> {
let ca = s[0].utf8()?;

let s1 = &s[1];
let s2 = &s[2];

polars_ensure!(
s1.len() <= ca.len(),
ComputeError:
"too many `offset` values ({}) for column length ({})",
s1.len(), ca.len(),
);

polars_ensure!(
s2.len() <= ca.len(),
ComputeError:
"too many `length` values ({}) for column length ({})",
s2.len(), ca.len(),
);

let offset = match s1.len() {
1 => {
let offset = s1.get(0).unwrap();
s1.clear().extend_constant(offset, ca.len()).unwrap()
},
_ => s1.clone(),
};

let offset = offset.cast(&DataType::Int64)?;
let offset = offset.i64()?;

let length = match s2.len() {
1 => {
let length = s2.get(0).unwrap();
s2.clear().extend_constant(length, ca.len()).unwrap()
},
_ => s2.clone(),
};

let length = length.cast(&DataType::UInt64)?;
let length = length.u64()?;

Ok(ca.str_slice(offset, length).into_series())
}

pub(super) fn explode(s: &Series) -> PolarsResult<Series> {
Expand Down
12 changes: 7 additions & 5 deletions crates/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,11 +402,13 @@ impl StringNameSpace {
}

/// Slice the string values.
pub fn slice(self, start: i64, length: Option<u64>) -> Expr {
self.0
.map_private(FunctionExpr::StringExpr(StringFunction::Slice(
start, length,
)))
pub fn slice(self, offset: Expr, length: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::Slice),
&[offset, length],
false,
false,
)
}

pub fn explode(self) -> Expr {
Expand Down
39 changes: 19 additions & 20 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,12 +702,14 @@ impl SqlFunctionVisitor<'_> {
#[cfg(feature = "nightly")]
InitCap => self.visit_unary(|e| e.str().to_titlecase()),
Left => self.try_visit_binary(|e, length| {
Ok(e.str().slice(0, match length {
Expr::Literal(LiteralValue::Int64(n)) => Some(n as u64),
Ok(match length {
Expr::Literal(LiteralValue::Int64(_n)) => {
e.str().slice(lit(0), length)
},
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Left: {}", function.args[1]);
polars_bail!(InvalidOperation: "Invalid 'length' for Left: {}", function.args[1])
}
}))
})
}),
Length => self.visit_unary(|e| e.str().len_chars()),
Lower => self.visit_unary(|e| e.str().to_lowercase()),
Expand Down Expand Up @@ -756,26 +758,23 @@ impl SqlFunctionVisitor<'_> {
StartsWith => self.visit_binary(|e, s| e.str().starts_with(s)),
Substring => match function.args.len() {
2 => self.try_visit_binary(|e, start| {
Ok(e.str().slice(match start {
Expr::Literal(LiteralValue::Int64(n)) => n,
Ok(match start {
Expr::Literal(LiteralValue::Int64(_)) => {
e.str().slice(start, lit(Null))
},
_ => {
polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]);
polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1])
}
}, None))
})
}),
3 => self.try_visit_ternary(|e, start, length| {
Ok(e.str().slice(
match start {
Expr::Literal(LiteralValue::Int64(n)) => n,
_ => {
polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]);
}
}, match length {
Expr::Literal(LiteralValue::Int64(n)) => Some(n as u64),
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Substring: {}", function.args[2]);
}
}))
if !matches!(start, Expr::Literal(LiteralValue::Int64(_))) {
polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]);
}
if !matches!(length, Expr::Literal(LiteralValue::Int64(_))) {
polars_bail!(InvalidOperation: "Invalid 'length' for Substring: {}", function.args[2]);
}
Ok(e.str().slice(start, length))
}),
_ => polars_bail!(InvalidOperation:
"Invalid number of arguments for Substring: {}",
Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,7 +1852,7 @@ def replace_all(
value = parse_as_expression(value, str_as_lit=True)
return wrap_expr(self._pyexpr.str_replace_all(pattern, value, literal))

def slice(self, offset: int, length: int | None = None) -> Expr:
def slice(self, offset: IntoExpr, length: IntoExpr | None = None) -> Expr:
"""
Create subslices of the string values of a Utf8 Series.
Expand Down Expand Up @@ -1905,6 +1905,8 @@ def slice(self, offset: int, length: int | None = None) -> Expr:
└─────────────┴──────────┘
"""
offset = parse_as_expression(offset)
length = parse_as_expression(length)
return wrap_expr(self._pyexpr.str_slice(offset, length))

def explode(self) -> Expr:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ def to_titlecase(self) -> Series:
"""

def slice(self, offset: int, length: int | None = None) -> Series:
def slice(self, offset: IntoExpr, length: IntoExpr | None = None) -> Series:
"""
Create subslices of the string values of a Utf8 Series.
Expand Down
8 changes: 6 additions & 2 deletions py-polars/src/expr/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,12 @@ impl PyExpr {
self.inner.clone().str().strip_suffix(suffix.inner).into()
}

fn str_slice(&self, start: i64, length: Option<u64>) -> Self {
self.inner.clone().str().slice(start, length).into()
fn str_slice(&self, offset: Self, length: Self) -> Self {
self.inner
.clone()
.str()
.slice(offset.inner, length.inner)
.into()
}

fn str_explode(&self) -> Self {
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/namespaces/string/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@ def test_str_slice() -> None:
assert df.select([pl.col("a").str.slice(2, 4)])["a"].to_list() == ["obar", "rfoo"]


def test_str_slice_expressions() -> None:
df = pl.DataFrame({"a": ["foobar", "barfoo"], "offset": [1, 3], "length": [3, 4]})

out = df.select(pl.col("a").str.slice("offset", "length"))

expected = pl.DataFrame({"a": ["oob", "foo"]})
assert out.frame_equal(expected)

out = df.select(pl.col("a").str.slice(-3, "length"))

expected = pl.DataFrame({"a": ["bar", "foo"]})
assert out.frame_equal(expected)


def test_str_concat() -> None:
s = pl.Series(["1", None, "2"])
result = s.str.concat()
Expand Down

0 comments on commit d524bf8

Please sign in to comment.