Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to UTF-8 statistics truncation #6870

Merged
merged 11 commits into from
Dec 16, 2024
293 changes: 236 additions & 57 deletions parquet/src/column/writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -878,24 +878,67 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> {
}
}

/// Returns `true` if this column's logical type is a UTF-8 string.
fn is_utf8(&self) -> bool {
self.get_descriptor().logical_type() == Some(LogicalType::String)
|| self.get_descriptor().converted_type() == ConvertedType::UTF8
}
Comment on lines +881 to +885
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this works for dictionary encoded columns as well right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, regardless of the encoding, the statistics are for the data itself. You wouldn't see a dictionary key here.


/// Truncates a binary statistic to at most `truncation_length` bytes.
///
/// If truncation is not possible, returns `data`.
///
/// The `bool` in the returned tuple indicates whether truncation occurred or not.
///
/// UTF-8 Note:
/// If the column type indicates UTF-8, and `data` contains valid UTF-8, then the result will
/// also remain valid UTF-8, but may be less tnan `truncation_length` bytes to avoid splitting
/// on non-character boundaries.
fn truncate_min_value(&self, truncation_length: Option<usize>, data: &[u8]) -> (Vec<u8>, bool) {
truncation_length
.filter(|l| data.len() > *l)
.and_then(|l| match str::from_utf8(data) {
Ok(str_data) => truncate_utf8(str_data, l),
Err(_) => Some(data[..l].to_vec()),
})
.and_then(|l|
// don't do extra work if this column isn't UTF-8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯

if self.is_utf8() {
match str::from_utf8(data) {
Ok(str_data) => truncate_utf8(str_data, l),
Err(_) => Some(data[..l].to_vec()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a somewhat questionable move to truncate this on invalid data, but I see that is wht the code used to do so seems good to me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, good point. The old code simply tried utf first, and then fell back. Here we're actually expecting valid UTF8 so perhaps it's better to return an error. I'd hope some string validation was done before getting this far. I'll think on this some more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should leave it as is and maybe document that if non utf8 data is passed in it will be truncated with bytes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've managed a test that exercises this path via the SerializedFileWriter API. It truncates as expected (i.e. as binary, not UTF-8), but now I worry that it's possible to create invalid data. Oddly both parquet-java and pyarrow seem ok with non-utf8 string data.

To paraphrase a wise man I know: Every day I wake up. And then I remember Parquet exists. 🫤

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left the logic here as is, but added documentation and a test. We can revisit if this ever becomes an issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To paraphrase a wise man I know: Every day I wake up. And then I remember Parquet exists. 🫤

I solace myself with this quote from a former coworker:

"Legacy Code, n: code that is getting the job done, and pretty well at that"

Not that we can't / shouldn't improve it of course 🤣

thanks again for all the help here

}
} else {
Some(data[..l].to_vec())
}
)
.map(|truncated| (truncated, true))
.unwrap_or_else(|| (data.to_vec(), false))
}

/// Truncates a binary statistic to at most `truncation_length` bytes, and then increment the
/// final byte(s) to yield a valid upper bound. This may result in a result of less than
/// `truncation_length` bytes if the last byte(s) overflows.
///
/// If truncation is not possible, returns `data`.
///
/// The `bool` in the returned tuple indicates whether truncation occurred or not.
///
/// UTF-8 Note:
/// If the column type indicates UTF-8, and `data` contains valid UTF-8, then the result will
/// also remain valid UTF-8 (but again may be less than `truncation_length` bytes). If `data`
/// does not contain valid UTF-8, then truncation will occur as if the column is non-string
/// binary.
fn truncate_max_value(&self, truncation_length: Option<usize>, data: &[u8]) -> (Vec<u8>, bool) {
truncation_length
.filter(|l| data.len() > *l)
.and_then(|l| match str::from_utf8(data) {
Ok(str_data) => truncate_utf8(str_data, l).and_then(increment_utf8),
Err(_) => increment(data[..l].to_vec()),
})
.and_then(|l|
// don't do extra work if this column isn't UTF-8
if self.is_utf8() {
match str::from_utf8(data) {
Ok(str_data) => truncate_and_increment_utf8(str_data, l),
Err(_) => increment(data[..l].to_vec()),
}
} else {
increment(data[..l].to_vec())
}
)
.map(|truncated| (truncated, true))
.unwrap_or_else(|| (data.to_vec(), false))
}
Expand Down Expand Up @@ -1418,13 +1461,50 @@ fn compare_greater_byte_array_decimals(a: &[u8], b: &[u8]) -> bool {
(a[1..]) > (b[1..])
}

/// Truncate a UTF8 slice to the longest prefix that is still a valid UTF8 string,
/// while being less than `length` bytes and non-empty
/// Truncate a UTF-8 slice to the longest prefix that is still a valid UTF-8 string,
/// while being less than `length` bytes and non-empty. Returns `None` if truncation
/// is not possible within those constraints.
///
/// The caller guarantees that data.len() > length.
fn truncate_utf8(data: &str, length: usize) -> Option<Vec<u8>> {
let split = (1..=length).rfind(|x| data.is_char_boundary(*x))?;
Some(data.as_bytes()[..split].to_vec())
}

/// Truncate a UTF-8 slice and increment it's final character. The returned value is the
/// longest such slice that is still a valid UTF-8 string while being less than `length`
/// bytes and non-empty. Returns `None` if no such transformation is possible.
///
/// The caller guarantees that data.len() > length.
fn truncate_and_increment_utf8(data: &str, length: usize) -> Option<Vec<u8>> {
// UTF-8 is max 4 bytes, so start search 3 back from desired length
let lower_bound = length.saturating_sub(3);
let split = (lower_bound..=length).rfind(|x| data.is_char_boundary(*x))?;
increment_utf8(data.get(..split)?)
}

/// Increment the final character in a UTF-8 string in such a way that the returned result
/// is still a valid UTF-8 string. The returned string may be shorter than the input if the
/// last character(s) cannot be incremented (due to overflow or producing invalid code points).
/// Returns `None` if the string cannot be incremented.
///
/// Note that this implementation will not promote an N-byte code point to (N+1) bytes.
fn increment_utf8(data: &str) -> Option<Vec<u8>> {
for (idx, original_char) in data.char_indices().rev() {
let original_len = original_char.len_utf8();
if let Some(next_char) = char::from_u32(original_char as u32 + 1) {
// do not allow increasing byte width of incremented char
etseidl marked this conversation as resolved.
Show resolved Hide resolved
if next_char.len_utf8() == original_len {
let mut result = data.as_bytes()[..idx + original_len].to_vec();
next_char.encode_utf8(&mut result[idx..]);
return Some(result);
}
}
}

None
}

/// Try and increment the bytes from right to left.
///
/// Returns `None` if all bytes are set to `u8::MAX`.
Expand All @@ -1441,29 +1521,15 @@ fn increment(mut data: Vec<u8>) -> Option<Vec<u8>> {
None
}

/// Try and increment the the string's bytes from right to left, returning when the result
/// is a valid UTF8 string. Returns `None` when it can't increment any byte.
fn increment_utf8(mut data: Vec<u8>) -> Option<Vec<u8>> {
for idx in (0..data.len()).rev() {
let original = data[idx];
let (byte, overflow) = original.overflowing_add(1);
if !overflow {
data[idx] = byte;
if str::from_utf8(&data).is_ok() {
return Some(data);
}
data[idx] = original;
}
}

None
}

#[cfg(test)]
mod tests {
use crate::file::properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH;
use crate::{
file::{properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH, writer::SerializedFileWriter},
schema::parser::parse_message_type,
};
use core::str;
use rand::distributions::uniform::SampleUniform;
use std::sync::Arc;
use std::{fs::File, sync::Arc};

use crate::column::{
page::PageReader,
Expand Down Expand Up @@ -3140,39 +3206,69 @@ mod tests {

#[test]
fn test_increment_utf8() {
let test_inc = |o: &str, expected: &str| {
if let Ok(v) = String::from_utf8(increment_utf8(o).unwrap()) {
// Got the expected result...
assert_eq!(v, expected);
// and it's greater than the original string
assert!(*v > *o);
// Also show that BinaryArray level comparison works here
let mut greater = ByteArray::new();
greater.set_data(Bytes::from(v));
let mut original = ByteArray::new();
original.set_data(Bytes::from(o.as_bytes().to_vec()));
assert!(greater > original);
etseidl marked this conversation as resolved.
Show resolved Hide resolved
} else {
panic!("Expected incremented UTF8 string to also be valid.");
}
};

// Basic ASCII case
let v = increment_utf8("hello".as_bytes().to_vec()).unwrap();
assert_eq!(&v, "hellp".as_bytes());
test_inc("hello", "hellp");

// 1-byte ending in max 1-byte
test_inc("a\u{7f}", "b");

// Also show that BinaryArray level comparison works here
let mut greater = ByteArray::new();
greater.set_data(Bytes::from(v));
let mut original = ByteArray::new();
original.set_data(Bytes::from("hello".as_bytes().to_vec()));
assert!(greater > original);
// 1-byte max should not truncate as it would need 2-byte code points
assert!(increment_utf8("\u{7f}\u{7f}").is_none());

// UTF8 string
let s = "❤️🧡💛💚💙💜";
let v = increment_utf8(s.as_bytes().to_vec()).unwrap();
test_inc("❤️🧡💛💚💙💜", "❤️🧡💛💚💙💝");

if let Ok(new) = String::from_utf8(v) {
assert_ne!(&new, s);
assert_eq!(new, "❤️🧡💛💚💙💝");
assert!(new.as_bytes().last().unwrap() > s.as_bytes().last().unwrap());
} else {
panic!("Expected incremented UTF8 string to also be valid.")
}
// 2-byte without overflow
test_inc("éééé", "éééê");

// Max UTF8 character - should be a No-Op
let s = char::MAX.to_string();
assert_eq!(s.len(), 4);
let v = increment_utf8(s.as_bytes().to_vec());
assert!(v.is_none());
// 2-byte that overflows lowest byte
test_inc("\u{ff}\u{ff}", "\u{ff}\u{100}");
etseidl marked this conversation as resolved.
Show resolved Hide resolved

// 2-byte ending in max 2-byte
test_inc("a\u{7ff}", "b");

// Max 2-byte should not truncate as it would need 3-byte code points
assert!(increment_utf8("\u{7ff}\u{7ff}").is_none());

// 3-byte without overflow [U+800, U+800] -> [U+800, U+801] (note that these
// characters should render right to left).
test_inc("ࠀࠀ", "ࠀࠁ");
alamb marked this conversation as resolved.
Show resolved Hide resolved

// 3-byte ending in max 3-byte
test_inc("a\u{ffff}", "b");

// Max 3-byte should not truncate as it would need 4-byte code points
assert!(increment_utf8("\u{ffff}\u{ffff}").is_none());

// Handle multi-byte UTF8 characters
let s = "a\u{10ffff}";
let v = increment_utf8(s.as_bytes().to_vec());
assert_eq!(&v.unwrap(), "b\u{10ffff}".as_bytes());
// 4-byte without overflow
test_inc("𐀀𐀀", "𐀀𐀁");

// 4-byte ending in max unicode
test_inc("a\u{10ffff}", "b");

// Max 4-byte should not truncate
assert!(increment_utf8("\u{10ffff}\u{10ffff}").is_none());

// Skip over surrogate pair range (0xD800..=0xDFFF)
//test_inc("a\u{D7FF}", "a\u{e000}");
test_inc("a\u{D7FF}", "b");
}

#[test]
Expand All @@ -3182,7 +3278,6 @@ mod tests {
let r = truncate_utf8(data, data.as_bytes().len()).unwrap();
assert_eq!(r.len(), data.as_bytes().len());
assert_eq!(&r, data.as_bytes());
println!("len is {}", data.len());

// We slice it away from the UTF8 boundary
let r = truncate_utf8(data, 13).unwrap();
Expand All @@ -3192,6 +3287,90 @@ mod tests {
// One multi-byte code point, and a length shorter than it, so we can't slice it
let r = truncate_utf8("\u{0836}", 1);
assert!(r.is_none());

// Test truncate and increment for max bounds on UTF-8 statistics
// 7-bit (i.e. ASCII)
let r = truncate_and_increment_utf8("yyyyyyyyy", 8).unwrap();
assert_eq!(&r, "yyyyyyyz".as_bytes());

// 2-byte without overflow
let r = truncate_and_increment_utf8("ééééé", 7).unwrap();
assert_eq!(&r, "ééê".as_bytes());

// 2-byte that overflows lowest byte
let r = truncate_and_increment_utf8("\u{ff}\u{ff}\u{ff}\u{ff}\u{ff}", 8).unwrap();
assert_eq!(&r, "\u{ff}\u{ff}\u{ff}\u{100}".as_bytes());

// max 2-byte should not truncate as it would need 3-byte code points
let r = truncate_and_increment_utf8("߿߿߿߿߿", 8);
assert!(r.is_none());

// 3-byte without overflow [U+800, U+800, U+800] -> [U+800, U+801] (note that these
// characters should render right to left).
let r = truncate_and_increment_utf8("ࠀࠀࠀࠀ", 8).unwrap();
assert_eq!(&r, "ࠀࠁ".as_bytes());

// max 3-byte should not truncate as it would need 4-byte code points
let r = truncate_and_increment_utf8("\u{ffff}\u{ffff}\u{ffff}", 8);
assert!(r.is_none());

// 4-byte without overflow
let r = truncate_and_increment_utf8("𐀀𐀀𐀀𐀀", 9).unwrap();
assert_eq!(&r, "𐀀𐀁".as_bytes());

// max 4-byte should not truncate
let r = truncate_and_increment_utf8("\u{10ffff}\u{10ffff}", 8);
assert!(r.is_none());
}

#[test]
// Check fallback truncation of statistics that should be UTF-8, but aren't
// (see https://github.com/apache/arrow-rs/pull/6870).
fn test_byte_array_truncate_invalid_utf8_statistics() {
let message_type = "
message test_schema {
OPTIONAL BYTE_ARRAY a (UTF8);
}
";
let schema = Arc::new(parse_message_type(message_type).unwrap());

// Create Vec<ByteArray> containing non-UTF8 bytes
let data = vec![ByteArray::from(vec![128u8; 32]); 7];
let def_levels = [1, 1, 1, 1, 0, 1, 0, 1, 0, 1];
let file: File = tempfile::tempfile().unwrap();
let props = Arc::new(
WriterProperties::builder()
.set_statistics_enabled(EnabledStatistics::Chunk)
.set_statistics_truncate_length(Some(8))
.build(),
);

let mut writer = SerializedFileWriter::new(&file, schema, props).unwrap();
let mut row_group_writer = writer.next_row_group().unwrap();

let mut col_writer = row_group_writer.next_column().unwrap().unwrap();
col_writer
.typed::<ByteArrayType>()
.write_batch(&data, Some(&def_levels), None)
.unwrap();
col_writer.close().unwrap();
row_group_writer.close().unwrap();
let file_metadata = writer.close().unwrap();
assert!(file_metadata.row_groups[0].columns[0].meta_data.is_some());
let stats = file_metadata.row_groups[0].columns[0]
.meta_data
.as_ref()
.unwrap()
.statistics
.as_ref()
.unwrap();
assert!(!stats.is_max_value_exact.unwrap());
// Truncation of invalid UTF-8 should fall back to binary truncation, so last byte should
// be incremented by 1.
assert_eq!(
stats.max_value,
Some([128, 128, 128, 128, 128, 128, 128, 129].to_vec())
);
}

#[test]
Expand Down
Loading