From 9ffa06543be51613ea1f509e63f6e7405b7d9989 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Mon, 16 Dec 2024 10:32:18 -0800 Subject: [PATCH] Improvements to UTF-8 statistics truncation (#6870) * fix a few edge cases with utf-8 incrementing * add todo * simplify truncation * add another test * note case where string should render right to left * rework entirely, also avoid UTF8 processing if not required by the schema * more consistent naming * modify some tests to truncate in the middle of a multibyte char * add test and docstring * document truncate_min_value too --- parquet/src/column/writer/mod.rs | 293 +++++++++++++++++++++++++------ 1 file changed, 236 insertions(+), 57 deletions(-) diff --git a/parquet/src/column/writer/mod.rs b/parquet/src/column/writer/mod.rs index 16de0ba7898..8dc1d0db447 100644 --- a/parquet/src/column/writer/mod.rs +++ b/parquet/src/column/writer/mod.rs @@ -878,24 +878,67 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { } } + /// Returns `true` if this column's logical type is a UTF-8 string. + fn is_utf8(&self) -> bool { + self.get_descriptor().logical_type() == Some(LogicalType::String) + || self.get_descriptor().converted_type() == ConvertedType::UTF8 + } + + /// Truncates a binary statistic to at most `truncation_length` bytes. + /// + /// If truncation is not possible, returns `data`. + /// + /// The `bool` in the returned tuple indicates whether truncation occurred or not. + /// + /// UTF-8 Note: + /// If the column type indicates UTF-8, and `data` contains valid UTF-8, then the result will + /// also remain valid UTF-8, but may be less tnan `truncation_length` bytes to avoid splitting + /// on non-character boundaries. fn truncate_min_value(&self, truncation_length: Option, data: &[u8]) -> (Vec, bool) { truncation_length .filter(|l| data.len() > *l) - .and_then(|l| match str::from_utf8(data) { - Ok(str_data) => truncate_utf8(str_data, l), - Err(_) => Some(data[..l].to_vec()), - }) + .and_then(|l| + // don't do extra work if this column isn't UTF-8 + if self.is_utf8() { + match str::from_utf8(data) { + Ok(str_data) => truncate_utf8(str_data, l), + Err(_) => Some(data[..l].to_vec()), + } + } else { + Some(data[..l].to_vec()) + } + ) .map(|truncated| (truncated, true)) .unwrap_or_else(|| (data.to_vec(), false)) } + /// Truncates a binary statistic to at most `truncation_length` bytes, and then increment the + /// final byte(s) to yield a valid upper bound. This may result in a result of less than + /// `truncation_length` bytes if the last byte(s) overflows. + /// + /// If truncation is not possible, returns `data`. + /// + /// The `bool` in the returned tuple indicates whether truncation occurred or not. + /// + /// UTF-8 Note: + /// If the column type indicates UTF-8, and `data` contains valid UTF-8, then the result will + /// also remain valid UTF-8 (but again may be less than `truncation_length` bytes). If `data` + /// does not contain valid UTF-8, then truncation will occur as if the column is non-string + /// binary. fn truncate_max_value(&self, truncation_length: Option, data: &[u8]) -> (Vec, bool) { truncation_length .filter(|l| data.len() > *l) - .and_then(|l| match str::from_utf8(data) { - Ok(str_data) => truncate_utf8(str_data, l).and_then(increment_utf8), - Err(_) => increment(data[..l].to_vec()), - }) + .and_then(|l| + // don't do extra work if this column isn't UTF-8 + if self.is_utf8() { + match str::from_utf8(data) { + Ok(str_data) => truncate_and_increment_utf8(str_data, l), + Err(_) => increment(data[..l].to_vec()), + } + } else { + increment(data[..l].to_vec()) + } + ) .map(|truncated| (truncated, true)) .unwrap_or_else(|| (data.to_vec(), false)) } @@ -1418,13 +1461,50 @@ fn compare_greater_byte_array_decimals(a: &[u8], b: &[u8]) -> bool { (a[1..]) > (b[1..]) } -/// Truncate a UTF8 slice to the longest prefix that is still a valid UTF8 string, -/// while being less than `length` bytes and non-empty +/// Truncate a UTF-8 slice to the longest prefix that is still a valid UTF-8 string, +/// while being less than `length` bytes and non-empty. Returns `None` if truncation +/// is not possible within those constraints. +/// +/// The caller guarantees that data.len() > length. fn truncate_utf8(data: &str, length: usize) -> Option> { let split = (1..=length).rfind(|x| data.is_char_boundary(*x))?; Some(data.as_bytes()[..split].to_vec()) } +/// Truncate a UTF-8 slice and increment it's final character. The returned value is the +/// longest such slice that is still a valid UTF-8 string while being less than `length` +/// bytes and non-empty. Returns `None` if no such transformation is possible. +/// +/// The caller guarantees that data.len() > length. +fn truncate_and_increment_utf8(data: &str, length: usize) -> Option> { + // UTF-8 is max 4 bytes, so start search 3 back from desired length + let lower_bound = length.saturating_sub(3); + let split = (lower_bound..=length).rfind(|x| data.is_char_boundary(*x))?; + increment_utf8(data.get(..split)?) +} + +/// Increment the final character in a UTF-8 string in such a way that the returned result +/// is still a valid UTF-8 string. The returned string may be shorter than the input if the +/// last character(s) cannot be incremented (due to overflow or producing invalid code points). +/// Returns `None` if the string cannot be incremented. +/// +/// Note that this implementation will not promote an N-byte code point to (N+1) bytes. +fn increment_utf8(data: &str) -> Option> { + for (idx, original_char) in data.char_indices().rev() { + let original_len = original_char.len_utf8(); + if let Some(next_char) = char::from_u32(original_char as u32 + 1) { + // do not allow increasing byte width of incremented char + if next_char.len_utf8() == original_len { + let mut result = data.as_bytes()[..idx + original_len].to_vec(); + next_char.encode_utf8(&mut result[idx..]); + return Some(result); + } + } + } + + None +} + /// Try and increment the bytes from right to left. /// /// Returns `None` if all bytes are set to `u8::MAX`. @@ -1441,29 +1521,15 @@ fn increment(mut data: Vec) -> Option> { None } -/// Try and increment the the string's bytes from right to left, returning when the result -/// is a valid UTF8 string. Returns `None` when it can't increment any byte. -fn increment_utf8(mut data: Vec) -> Option> { - for idx in (0..data.len()).rev() { - let original = data[idx]; - let (byte, overflow) = original.overflowing_add(1); - if !overflow { - data[idx] = byte; - if str::from_utf8(&data).is_ok() { - return Some(data); - } - data[idx] = original; - } - } - - None -} - #[cfg(test)] mod tests { - use crate::file::properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH; + use crate::{ + file::{properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH, writer::SerializedFileWriter}, + schema::parser::parse_message_type, + }; + use core::str; use rand::distributions::uniform::SampleUniform; - use std::sync::Arc; + use std::{fs::File, sync::Arc}; use crate::column::{ page::PageReader, @@ -3140,39 +3206,69 @@ mod tests { #[test] fn test_increment_utf8() { + let test_inc = |o: &str, expected: &str| { + if let Ok(v) = String::from_utf8(increment_utf8(o).unwrap()) { + // Got the expected result... + assert_eq!(v, expected); + // and it's greater than the original string + assert!(*v > *o); + // Also show that BinaryArray level comparison works here + let mut greater = ByteArray::new(); + greater.set_data(Bytes::from(v)); + let mut original = ByteArray::new(); + original.set_data(Bytes::from(o.as_bytes().to_vec())); + assert!(greater > original); + } else { + panic!("Expected incremented UTF8 string to also be valid."); + } + }; + // Basic ASCII case - let v = increment_utf8("hello".as_bytes().to_vec()).unwrap(); - assert_eq!(&v, "hellp".as_bytes()); + test_inc("hello", "hellp"); + + // 1-byte ending in max 1-byte + test_inc("a\u{7f}", "b"); - // Also show that BinaryArray level comparison works here - let mut greater = ByteArray::new(); - greater.set_data(Bytes::from(v)); - let mut original = ByteArray::new(); - original.set_data(Bytes::from("hello".as_bytes().to_vec())); - assert!(greater > original); + // 1-byte max should not truncate as it would need 2-byte code points + assert!(increment_utf8("\u{7f}\u{7f}").is_none()); // UTF8 string - let s = "โค๏ธ๐Ÿงก๐Ÿ’›๐Ÿ’š๐Ÿ’™๐Ÿ’œ"; - let v = increment_utf8(s.as_bytes().to_vec()).unwrap(); + test_inc("โค๏ธ๐Ÿงก๐Ÿ’›๐Ÿ’š๐Ÿ’™๐Ÿ’œ", "โค๏ธ๐Ÿงก๐Ÿ’›๐Ÿ’š๐Ÿ’™๐Ÿ’"); - if let Ok(new) = String::from_utf8(v) { - assert_ne!(&new, s); - assert_eq!(new, "โค๏ธ๐Ÿงก๐Ÿ’›๐Ÿ’š๐Ÿ’™๐Ÿ’"); - assert!(new.as_bytes().last().unwrap() > s.as_bytes().last().unwrap()); - } else { - panic!("Expected incremented UTF8 string to also be valid.") - } + // 2-byte without overflow + test_inc("รฉรฉรฉรฉ", "รฉรฉรฉรช"); - // Max UTF8 character - should be a No-Op - let s = char::MAX.to_string(); - assert_eq!(s.len(), 4); - let v = increment_utf8(s.as_bytes().to_vec()); - assert!(v.is_none()); + // 2-byte that overflows lowest byte + test_inc("\u{ff}\u{ff}", "\u{ff}\u{100}"); + + // 2-byte ending in max 2-byte + test_inc("a\u{7ff}", "b"); + + // Max 2-byte should not truncate as it would need 3-byte code points + assert!(increment_utf8("\u{7ff}\u{7ff}").is_none()); + + // 3-byte without overflow [U+800, U+800] -> [U+800, U+801] (note that these + // characters should render right to left). + test_inc("เ €เ €", "เ €เ "); + + // 3-byte ending in max 3-byte + test_inc("a\u{ffff}", "b"); + + // Max 3-byte should not truncate as it would need 4-byte code points + assert!(increment_utf8("\u{ffff}\u{ffff}").is_none()); - // Handle multi-byte UTF8 characters - let s = "a\u{10ffff}"; - let v = increment_utf8(s.as_bytes().to_vec()); - assert_eq!(&v.unwrap(), "b\u{10ffff}".as_bytes()); + // 4-byte without overflow + test_inc("๐€€๐€€", "๐€€๐€"); + + // 4-byte ending in max unicode + test_inc("a\u{10ffff}", "b"); + + // Max 4-byte should not truncate + assert!(increment_utf8("\u{10ffff}\u{10ffff}").is_none()); + + // Skip over surrogate pair range (0xD800..=0xDFFF) + //test_inc("a\u{D7FF}", "a\u{e000}"); + test_inc("a\u{D7FF}", "b"); } #[test] @@ -3182,7 +3278,6 @@ mod tests { let r = truncate_utf8(data, data.as_bytes().len()).unwrap(); assert_eq!(r.len(), data.as_bytes().len()); assert_eq!(&r, data.as_bytes()); - println!("len is {}", data.len()); // We slice it away from the UTF8 boundary let r = truncate_utf8(data, 13).unwrap(); @@ -3192,6 +3287,90 @@ mod tests { // One multi-byte code point, and a length shorter than it, so we can't slice it let r = truncate_utf8("\u{0836}", 1); assert!(r.is_none()); + + // Test truncate and increment for max bounds on UTF-8 statistics + // 7-bit (i.e. ASCII) + let r = truncate_and_increment_utf8("yyyyyyyyy", 8).unwrap(); + assert_eq!(&r, "yyyyyyyz".as_bytes()); + + // 2-byte without overflow + let r = truncate_and_increment_utf8("รฉรฉรฉรฉรฉ", 7).unwrap(); + assert_eq!(&r, "รฉรฉรช".as_bytes()); + + // 2-byte that overflows lowest byte + let r = truncate_and_increment_utf8("\u{ff}\u{ff}\u{ff}\u{ff}\u{ff}", 8).unwrap(); + assert_eq!(&r, "\u{ff}\u{ff}\u{ff}\u{100}".as_bytes()); + + // max 2-byte should not truncate as it would need 3-byte code points + let r = truncate_and_increment_utf8("฿ฟ฿ฟ฿ฟ฿ฟ฿ฟ", 8); + assert!(r.is_none()); + + // 3-byte without overflow [U+800, U+800, U+800] -> [U+800, U+801] (note that these + // characters should render right to left). + let r = truncate_and_increment_utf8("เ €เ €เ €เ €", 8).unwrap(); + assert_eq!(&r, "เ €เ ".as_bytes()); + + // max 3-byte should not truncate as it would need 4-byte code points + let r = truncate_and_increment_utf8("\u{ffff}\u{ffff}\u{ffff}", 8); + assert!(r.is_none()); + + // 4-byte without overflow + let r = truncate_and_increment_utf8("๐€€๐€€๐€€๐€€", 9).unwrap(); + assert_eq!(&r, "๐€€๐€".as_bytes()); + + // max 4-byte should not truncate + let r = truncate_and_increment_utf8("\u{10ffff}\u{10ffff}", 8); + assert!(r.is_none()); + } + + #[test] + // Check fallback truncation of statistics that should be UTF-8, but aren't + // (see https://github.com/apache/arrow-rs/pull/6870). + fn test_byte_array_truncate_invalid_utf8_statistics() { + let message_type = " + message test_schema { + OPTIONAL BYTE_ARRAY a (UTF8); + } + "; + let schema = Arc::new(parse_message_type(message_type).unwrap()); + + // Create Vec containing non-UTF8 bytes + let data = vec![ByteArray::from(vec![128u8; 32]); 7]; + let def_levels = [1, 1, 1, 1, 0, 1, 0, 1, 0, 1]; + let file: File = tempfile::tempfile().unwrap(); + let props = Arc::new( + WriterProperties::builder() + .set_statistics_enabled(EnabledStatistics::Chunk) + .set_statistics_truncate_length(Some(8)) + .build(), + ); + + let mut writer = SerializedFileWriter::new(&file, schema, props).unwrap(); + let mut row_group_writer = writer.next_row_group().unwrap(); + + let mut col_writer = row_group_writer.next_column().unwrap().unwrap(); + col_writer + .typed::() + .write_batch(&data, Some(&def_levels), None) + .unwrap(); + col_writer.close().unwrap(); + row_group_writer.close().unwrap(); + let file_metadata = writer.close().unwrap(); + assert!(file_metadata.row_groups[0].columns[0].meta_data.is_some()); + let stats = file_metadata.row_groups[0].columns[0] + .meta_data + .as_ref() + .unwrap() + .statistics + .as_ref() + .unwrap(); + assert!(!stats.is_max_value_exact.unwrap()); + // Truncation of invalid UTF-8 should fall back to binary truncation, so last byte should + // be incremented by 1. + assert_eq!( + stats.max_value, + Some([128, 128, 128, 128, 128, 128, 128, 129].to_vec()) + ); } #[test]