Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(function): Handle unescaped UTF-8 characters in Presto url_extract_* UDFs #11535

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions velox/functions/lib/Utf8Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ int firstByteCharLength(const char* u_input) {

} // namespace

int32_t tryGetCharLength(const char* input, int64_t size) {
int32_t tryGetCharLength(const char* input, int64_t size, int32_t& codePoint) {
VELOX_DCHECK_NOT_NULL(input);
VELOX_DCHECK_GT(size, 0);

Expand All @@ -72,6 +72,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) {

if (charLength == 1) {
// Normal ASCII: 0xxx_xxxx.
codePoint = input[0];
return 1;
}

Expand All @@ -89,7 +90,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) {

if (charLength == 2) {
// 110x_xxxx 10xx_xxxx
int codePoint = ((firstByte & 0b00011111) << 6) | (secondByte & 0b00111111);
codePoint = ((firstByte & 0b00011111) << 6) | (secondByte & 0b00111111);
// Fail if overlong encoding.
return codePoint < 0x80 ? -2 : 2;
}
Expand All @@ -106,7 +107,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) {

if (charLength == 3) {
// 1110_xxxx 10xx_xxxx 10xx_xxxx
int codePoint = ((firstByte & 0b00001111) << 12) |
codePoint = ((firstByte & 0b00001111) << 12) |
((secondByte & 0b00111111) << 6) | (thirdByte & 0b00111111);

// Surrogates are invalid.
Expand All @@ -132,7 +133,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) {

if (charLength == 4) {
// 1111_0xxx 10xx_xxxx 10xx_xxxx 10xx_xxxx
int codePoint = ((firstByte & 0b00000111) << 18) |
codePoint = ((firstByte & 0b00000111) << 18) |
((secondByte & 0b00111111) << 12) | ((thirdByte & 0b00111111) << 6) |
(forthByte & 0b00111111);
// Fail if overlong encoding or above upper bound of Unicode.
Expand Down
4 changes: 3 additions & 1 deletion velox/functions/lib/Utf8Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ namespace facebook::velox::functions {
///
/// @param input Pointer to the first byte of the code point. Must not be null.
/// @param size Number of available bytes. Must be greater than zero.
/// @param codePoint Populated with the code point it refers to. This is only
/// valid if the return value is positive.
/// @return the length of the code point or negative the number of bytes in the
/// invalid UTF-8 sequence.
///
/// Adapted from tryGetCodePointAt in
/// https://github.com/airlift/slice/blob/master/src/main/java/io/airlift/slice/SliceUtf8.java
int32_t tryGetCharLength(const char* input, int64_t size);
int32_t tryGetCharLength(const char* input, int64_t size, int32_t& codePoint);

/// Return the length in byte of the next UTF-8 encoded character at the
/// beginning of `string`. If the beginning of `string` is not valid UTF-8
Expand Down
51 changes: 30 additions & 21 deletions velox/functions/lib/tests/Utf8Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,53 +21,62 @@ namespace facebook::velox::functions {
namespace {

TEST(Utf8Test, tryCharLength) {
int32_t codepoint;
// Single-byte ASCII character.
ASSERT_EQ(1, tryGetCharLength("Hello", 5));
ASSERT_EQ(1, tryGetCharLength("Hello", 5, codepoint));
ASSERT_EQ('H', codepoint);

// 2-byte character. British pound sign.
static const char* kPound = "\u00A3tail";
ASSERT_EQ(2, tryGetCharLength(kPound, 5));
ASSERT_EQ(2, tryGetCharLength(kPound, 5, codepoint));
ASSERT_EQ(0xA3, codepoint);
// First byte alone is not a valid character.
ASSERT_EQ(-1, tryGetCharLength(kPound, 1));
ASSERT_EQ(-1, tryGetCharLength(kPound, 1, codepoint));
// Second byte alone is not a valid character.
ASSERT_EQ(-1, tryGetCharLength(kPound + 1, 5));
ASSERT_EQ(-1, tryGetCharLength(kPound + 1, 5, codepoint));
// ASCII character 't' after the pound sign is valid.
ASSERT_EQ(1, tryGetCharLength(kPound + 2, 5));
ASSERT_EQ(1, tryGetCharLength(kPound + 2, 5, codepoint));

// 3-byte character. Euro sign.
static const char* kEuro = "\u20ACtail";
ASSERT_EQ(3, tryGetCharLength(kEuro, 5));
ASSERT_EQ(3, tryGetCharLength(kEuro, 5, codepoint));
ASSERT_EQ(0x20AC, codepoint);
// First byte or first 2 bytes alone are not a valid character.
ASSERT_EQ(-1, tryGetCharLength(kEuro, 1));
ASSERT_EQ(-2, tryGetCharLength(kEuro, 2));
ASSERT_EQ(-1, tryGetCharLength(kEuro, 1, codepoint));
ASSERT_EQ(-2, tryGetCharLength(kEuro, 2, codepoint));
// Byte sequence starting from 2nd or 3rd byte is not a valid character.
ASSERT_EQ(-1, tryGetCharLength(kEuro + 1, 5));
ASSERT_EQ(-1, tryGetCharLength(kEuro + 2, 5));
ASSERT_EQ(1, tryGetCharLength(kEuro + 3, 5));
ASSERT_EQ(-1, tryGetCharLength(kEuro + 1, 5, codepoint));
ASSERT_EQ(-1, tryGetCharLength(kEuro + 2, 5, codepoint));
// ASCII character 't' after the euro sign is valid.
ASSERT_EQ(1, tryGetCharLength(kPound + 4, 5));
ASSERT_EQ(1, tryGetCharLength(kEuro + 3, 5, codepoint));
ASSERT_EQ('t', codepoint);
ASSERT_EQ(1, tryGetCharLength(kEuro + 4, 5, codepoint));
ASSERT_EQ('a', codepoint);

// 4-byte character. Musical symbol F CLEF.
static const char* kClef = "\U0001D122tail";
ASSERT_EQ(4, tryGetCharLength(kClef, 5));
ASSERT_EQ(4, tryGetCharLength(kClef, 5, codepoint));
ASSERT_EQ(0x1D122, codepoint);
// First byte, first 2 bytes, or first 3 bytes alone are not a valid
// character.
ASSERT_EQ(-1, tryGetCharLength(kClef, 1));
ASSERT_EQ(-2, tryGetCharLength(kClef, 2));
ASSERT_EQ(-3, tryGetCharLength(kClef, 3));
ASSERT_EQ(-1, tryGetCharLength(kClef, 1, codepoint));
ASSERT_EQ(-2, tryGetCharLength(kClef, 2, codepoint));
ASSERT_EQ(-3, tryGetCharLength(kClef, 3, codepoint));
// Byte sequence starting from 2nd, 3rd or 4th byte is not a valid character.
ASSERT_EQ(-1, tryGetCharLength(kClef + 1, 3));
ASSERT_EQ(-1, tryGetCharLength(kClef + 2, 3));
ASSERT_EQ(-1, tryGetCharLength(kClef + 3, 3));
ASSERT_EQ(-1, tryGetCharLength(kClef + 1, 3, codepoint));
ASSERT_EQ(-1, tryGetCharLength(kClef + 2, 3, codepoint));
ASSERT_EQ(-1, tryGetCharLength(kClef + 3, 3, codepoint));

// ASCII character 't' after the clef sign is valid.
ASSERT_EQ(1, tryGetCharLength(kClef + 4, 5));
ASSERT_EQ(1, tryGetCharLength(kClef + 4, 5, codepoint));
ASSERT_EQ('t', codepoint);

// Test overlong encoding.

auto tryCharLength = [](const std::vector<unsigned char>& bytes) {
int32_t codepoint;
return tryGetCharLength(
reinterpret_cast<const char*>(bytes.data()), bytes.size());
reinterpret_cast<const char*>(bytes.data()), bytes.size(), codepoint);
};

// 2-byte encoding of 0x2F.
Expand Down
1 change: 1 addition & 0 deletions velox/functions/prestosql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ velox_add_library(
TransformKeys.cpp
TransformValues.cpp
TypeOf.cpp
URIParser.cpp
URLFunctions.cpp
VectorArithmetic.cpp
WidthBucketArray.cpp
Expand Down
9 changes: 6 additions & 3 deletions velox/functions/prestosql/FromUtf8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,9 @@ class FromUtf8Function : public exec::VectorFunction {

auto replacement = decoded.valueAt<StringView>(row);
if (!replacement.empty()) {
int32_t codePoint;
auto charLength =
tryGetCharLength(replacement.data(), replacement.size());
tryGetCharLength(replacement.data(), replacement.size(), codePoint);
VELOX_USER_CHECK_GT(
charLength, 0, "Replacement is not a valid UTF-8 character");
VELOX_USER_CHECK_EQ(
Expand All @@ -188,8 +189,9 @@ class FromUtf8Function : public exec::VectorFunction {

int32_t pos = 0;
while (pos < value.size()) {
int32_t codePoint;
auto charLength =
tryGetCharLength(value.data() + pos, value.size() - pos);
tryGetCharLength(value.data() + pos, value.size() - pos, codePoint);
if (charLength < 0) {
firstInvalidRow = row;
return false;
Expand Down Expand Up @@ -267,8 +269,9 @@ class FromUtf8Function : public exec::VectorFunction {

int32_t pos = 0;
while (pos < input.size()) {
int32_t codePoint;
auto charLength =
tryGetCharLength(input.data() + pos, input.size() - pos);
tryGetCharLength(input.data() + pos, input.size() - pos, codePoint);
if (charLength > 0) {
fixedWriter.append(std::string_view(input.data() + pos, charLength));
pos += charLength;
Expand Down
Loading
Loading