Skip to content

Commit

Permalink
Handle unescaped UTF-8 characters in Presto url_extract_* UDFs (faceb…
Browse files Browse the repository at this point in the history
…ookincubator#11535)

Summary:

Presto Java supports UTF-8 characters that are not control or whitespace characters appearing
anywhere in a URL where a % escaped character can appear.  This change modifies Velox's 
URIParser to do the same.

Velox's URIParser would produce incorrect results when any non-ASCII character appeared
anywhere in the URL and this has been fixed as well.

In order to facilitate this I modified the tryGetCharLength helper function in UTF8Utils to take in a
int32_t reference which it populates with the code point if the UTF-8 character is valid. It was
already calculating this value and throwing it away, returning it allows me to avoid an additional call
to repeat those steps and is consistent with the Airlift function on which it's based.

Differential Revision: D65927918
  • Loading branch information
Kevin Wilfong authored and facebook-github-bot committed Nov 14, 2024
1 parent 4a0d2c5 commit 18ee41a
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 51 deletions.
9 changes: 5 additions & 4 deletions velox/functions/lib/Utf8Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ int firstByteCharLength(const char* u_input) {

} // namespace

int32_t tryGetCharLength(const char* input, int64_t size) {
int32_t tryGetCharLength(const char* input, int64_t size, int32_t& codePoint) {
VELOX_DCHECK_NOT_NULL(input);
VELOX_DCHECK_GT(size, 0);

Expand All @@ -72,6 +72,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) {

if (charLength == 1) {
// Normal ASCII: 0xxx_xxxx.
codePoint = input[0];
return 1;
}

Expand All @@ -89,7 +90,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) {

if (charLength == 2) {
// 110x_xxxx 10xx_xxxx
int codePoint = ((firstByte & 0b00011111) << 6) | (secondByte & 0b00111111);
codePoint = ((firstByte & 0b00011111) << 6) | (secondByte & 0b00111111);
// Fail if overlong encoding.
return codePoint < 0x80 ? -2 : 2;
}
Expand All @@ -106,7 +107,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) {

if (charLength == 3) {
// 1110_xxxx 10xx_xxxx 10xx_xxxx
int codePoint = ((firstByte & 0b00001111) << 12) |
codePoint = ((firstByte & 0b00001111) << 12) |
((secondByte & 0b00111111) << 6) | (thirdByte & 0b00111111);

// Surrogates are invalid.
Expand All @@ -132,7 +133,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) {

if (charLength == 4) {
// 1111_0xxx 10xx_xxxx 10xx_xxxx 10xx_xxxx
int codePoint = ((firstByte & 0b00000111) << 18) |
codePoint = ((firstByte & 0b00000111) << 18) |
((secondByte & 0b00111111) << 12) | ((thirdByte & 0b00111111) << 6) |
(forthByte & 0b00111111);
// Fail if overlong encoding or above upper bound of Unicode.
Expand Down
4 changes: 3 additions & 1 deletion velox/functions/lib/Utf8Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ namespace facebook::velox::functions {
///
/// @param input Pointer to the first byte of the code point. Must not be null.
/// @param size Number of available bytes. Must be greater than zero.
/// @param codePoint Populated with the code point it refers to. This is only
/// valid if the return value is positive.
/// @return the length of the code point or negative the number of bytes in the
/// invalid UTF-8 sequence.
///
/// Adapted from tryGetCodePointAt in
/// https://github.com/airlift/slice/blob/master/src/main/java/io/airlift/slice/SliceUtf8.java
int32_t tryGetCharLength(const char* input, int64_t size);
int32_t tryGetCharLength(const char* input, int64_t size, int32_t& codePoint);

/// Return the length in byte of the next UTF-8 encoded character at the
/// beginning of `string`. If the beginning of `string` is not valid UTF-8
Expand Down
51 changes: 30 additions & 21 deletions velox/functions/lib/tests/Utf8Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,53 +21,62 @@ namespace facebook::velox::functions {
namespace {

TEST(Utf8Test, tryCharLength) {
int32_t codepoint;
// Single-byte ASCII character.
ASSERT_EQ(1, tryGetCharLength("Hello", 5));
ASSERT_EQ(1, tryGetCharLength("Hello", 5, codepoint));
ASSERT_EQ('H', codepoint);

// 2-byte character. British pound sign.
static const char* kPound = "\u00A3tail";
ASSERT_EQ(2, tryGetCharLength(kPound, 5));
ASSERT_EQ(2, tryGetCharLength(kPound, 5, codepoint));
ASSERT_EQ(0xA3, codepoint);
// First byte alone is not a valid character.
ASSERT_EQ(-1, tryGetCharLength(kPound, 1));
ASSERT_EQ(-1, tryGetCharLength(kPound, 1, codepoint));
// Second byte alone is not a valid character.
ASSERT_EQ(-1, tryGetCharLength(kPound + 1, 5));
ASSERT_EQ(-1, tryGetCharLength(kPound + 1, 5, codepoint));
// ASCII character 't' after the pound sign is valid.
ASSERT_EQ(1, tryGetCharLength(kPound + 2, 5));
ASSERT_EQ(1, tryGetCharLength(kPound + 2, 5, codepoint));

// 3-byte character. Euro sign.
static const char* kEuro = "\u20ACtail";
ASSERT_EQ(3, tryGetCharLength(kEuro, 5));
ASSERT_EQ(3, tryGetCharLength(kEuro, 5, codepoint));
ASSERT_EQ(0x20AC, codepoint);
// First byte or first 2 bytes alone are not a valid character.
ASSERT_EQ(-1, tryGetCharLength(kEuro, 1));
ASSERT_EQ(-2, tryGetCharLength(kEuro, 2));
ASSERT_EQ(-1, tryGetCharLength(kEuro, 1, codepoint));
ASSERT_EQ(-2, tryGetCharLength(kEuro, 2, codepoint));
// Byte sequence starting from 2nd or 3rd byte is not a valid character.
ASSERT_EQ(-1, tryGetCharLength(kEuro + 1, 5));
ASSERT_EQ(-1, tryGetCharLength(kEuro + 2, 5));
ASSERT_EQ(1, tryGetCharLength(kEuro + 3, 5));
ASSERT_EQ(-1, tryGetCharLength(kEuro + 1, 5, codepoint));
ASSERT_EQ(-1, tryGetCharLength(kEuro + 2, 5, codepoint));
ASSERT_EQ(1, tryGetCharLength(kEuro + 3, 5, codepoint));
ASSERT_EQ('t', codepoint);
// ASCII character 't' after the euro sign is valid.
ASSERT_EQ(1, tryGetCharLength(kPound + 4, 5));
ASSERT_EQ(1, tryGetCharLength(kPound + 4, 5, codepoint));
ASSERT_EQ('a', codepoint);

// 4-byte character. Musical symbol F CLEF.
static const char* kClef = "\U0001D122tail";
ASSERT_EQ(4, tryGetCharLength(kClef, 5));
ASSERT_EQ(4, tryGetCharLength(kClef, 5, codepoint));
ASSERT_EQ(0x1D122, codepoint);
// First byte, first 2 bytes, or first 3 bytes alone are not a valid
// character.
ASSERT_EQ(-1, tryGetCharLength(kClef, 1));
ASSERT_EQ(-2, tryGetCharLength(kClef, 2));
ASSERT_EQ(-3, tryGetCharLength(kClef, 3));
ASSERT_EQ(-1, tryGetCharLength(kClef, 1, codepoint));
ASSERT_EQ(-2, tryGetCharLength(kClef, 2, codepoint));
ASSERT_EQ(-3, tryGetCharLength(kClef, 3, codepoint));
// Byte sequence starting from 2nd, 3rd or 4th byte is not a valid character.
ASSERT_EQ(-1, tryGetCharLength(kClef + 1, 3));
ASSERT_EQ(-1, tryGetCharLength(kClef + 2, 3));
ASSERT_EQ(-1, tryGetCharLength(kClef + 3, 3));
ASSERT_EQ(-1, tryGetCharLength(kClef + 1, 3, codepoint));
ASSERT_EQ(-1, tryGetCharLength(kClef + 2, 3, codepoint));
ASSERT_EQ(-1, tryGetCharLength(kClef + 3, 3, codepoint));

// ASCII character 't' after the clef sign is valid.
ASSERT_EQ(1, tryGetCharLength(kClef + 4, 5));
ASSERT_EQ(1, tryGetCharLength(kClef + 4, 5, codepoint));
ASSERT_EQ('t', codepoint);

// Test overlong encoding.

auto tryCharLength = [](const std::vector<unsigned char>& bytes) {
int32_t codepoint;
return tryGetCharLength(
reinterpret_cast<const char*>(bytes.data()), bytes.size());
reinterpret_cast<const char*>(bytes.data()), bytes.size(), codepoint);
};

// 2-byte encoding of 0x2F.
Expand Down
9 changes: 6 additions & 3 deletions velox/functions/prestosql/FromUtf8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,9 @@ class FromUtf8Function : public exec::VectorFunction {

auto replacement = decoded.valueAt<StringView>(row);
if (!replacement.empty()) {
int32_t codePoint;
auto charLength =
tryGetCharLength(replacement.data(), replacement.size());
tryGetCharLength(replacement.data(), replacement.size(), codePoint);
VELOX_USER_CHECK_GT(
charLength, 0, "Replacement is not a valid UTF-8 character");
VELOX_USER_CHECK_EQ(
Expand All @@ -188,8 +189,9 @@ class FromUtf8Function : public exec::VectorFunction {

int32_t pos = 0;
while (pos < value.size()) {
int32_t codePoint;
auto charLength =
tryGetCharLength(value.data() + pos, value.size() - pos);
tryGetCharLength(value.data() + pos, value.size() - pos, codePoint);
if (charLength < 0) {
firstInvalidRow = row;
return false;
Expand Down Expand Up @@ -267,8 +269,9 @@ class FromUtf8Function : public exec::VectorFunction {

int32_t pos = 0;
while (pos < input.size()) {
int32_t codePoint;
auto charLength =
tryGetCharLength(input.data() + pos, input.size() - pos);
tryGetCharLength(input.data() + pos, input.size() - pos, codePoint);
if (charLength > 0) {
fixedWriter.append(std::string_view(input.data() + pos, charLength));
pos += charLength;
Expand Down
54 changes: 43 additions & 11 deletions velox/functions/prestosql/URIParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/

#include "velox/functions/prestosql/URIParser.h"
#include "velox/external/utf8proc/utf8procImpl.h"
#include "velox/functions/lib/Utf8Utils.h"

namespace facebook::velox::functions {

Expand All @@ -40,6 +42,11 @@ Mask createMask(const std::vector<size_t>& values) {

return mask;
}

bool test(const Mask& mask, char value) {
return value < mask.size() && mask.test(value);
}

// a-z or A-Z.
const Mask kAlpha = createMask('a', 'z') | createMask('A', 'Z');
// 0-9.
Expand Down Expand Up @@ -128,7 +135,8 @@ bool tryConsumePercentEncoded(const char* str, const size_t len, int32_t& pos) {
return false;
}

if (str[pos] != '%' || !kHex.test(str[pos + 1]) || !kHex.test(str[pos + 2])) {
if (str[pos] != '%' || !test(kHex, str[pos + 1]) ||
!test(kHex, str[pos + 2])) {
return false;
}

Expand All @@ -138,7 +146,8 @@ bool tryConsumePercentEncoded(const char* str, const size_t len, int32_t& pos) {
}

// Helper function that consumes as much of `str` from `pos` as possible where a
// character passes mask or is part of a percent encoded character.
// character passes mask, is part of a percent encoded character, or is an
// allowed UTF-8 character.
//
// `pos` is updated to the first character in `str` that was not consumed and
// `hasEncoded` is set to true if any percent encoded characters were
Expand All @@ -150,7 +159,7 @@ void consume(
int32_t& pos,
bool& hasEncoded) {
while (pos < len) {
if (mask.test(str[pos])) {
if (test(mask, str[pos])) {
pos++;
continue;
}
Expand All @@ -160,6 +169,29 @@ void consume(
continue;
}

// Masks cover all ASCII characters, check if this is an allowed UTF-8
// character.
// The range after ASCII characters up to 159 covers control characters
// which are not allowed.
if ((unsigned char)str[pos] > 159) {
// Get the UTF-8 code point.
int32_t codePoint;
auto valid = tryGetCharLength(str + pos, len - pos, codePoint);

// Check if it's a valid UTF-8 character.
if (valid > 0) {
const auto category = utf8proc_get_property(codePoint)->category;
// White space characters are also not allowed. The range of categories
// excluded here are categories of white space.
if (category < UTF8PROC_CATEGORY_ZS ||
category > UTF8PROC_CATEGORY_ZP) {
// Increment over the whole (potentially multi-byte) character.
pos += valid;
continue;
}
}
}

break;
}
}
Expand Down Expand Up @@ -297,7 +329,7 @@ bool tryConsumeIPV6Address(const char* str, const size_t len, int32_t& pos) {
while (posInAddress < len && numBytes < 16) {
int32_t posInHex = posInAddress;
for (int i = 0; i < 4; i++) {
if (posInHex == len || !kHex.test(str[posInHex])) {
if (posInHex == len || !test(kHex, str[posInHex])) {
break;
}

Expand Down Expand Up @@ -333,7 +365,7 @@ bool tryConsumeIPV6Address(const char* str, const size_t len, int32_t& pos) {
posInAddress = posInHex + 2;
}
} else {
if (posInHex == len || !kHex.test(str[posInHex + 1])) {
if (posInHex == len || !test(kHex, str[posInHex + 1])) {
// Peak ahead, we can't end on a single ':'.
return false;
}
Expand Down Expand Up @@ -375,7 +407,7 @@ bool tryConsumeIPVFuture(const char* str, const size_t len, int32_t& pos) {
// Consume a string of hex digits.
int32_t posInHex = posInAddress;
while (posInHex < len) {
if (kHex.test(str[posInHex])) {
if (test(kHex, str[posInHex])) {
posInHex++;
} else {
break;
Expand All @@ -399,7 +431,7 @@ bool tryConsumeIPVFuture(const char* str, const size_t len, int32_t& pos) {

int32_t posInSuffix = posInAddress;
while (posInSuffix < len) {
if (kIPVFutureSuffixOrUserInfo.test(str[posInSuffix])) {
if (test(kIPVFutureSuffixOrUserInfo, str[posInSuffix])) {
posInSuffix++;
} else {
break;
Expand Down Expand Up @@ -450,7 +482,7 @@ void consumePort(const char* str, const size_t len, int32_t& pos, URI& uri) {
int32_t posInPort = pos;

while (posInPort < len) {
if (kNum.test(str[posInPort])) {
if (test(kNum, str[posInPort])) {
posInPort++;
continue;
}
Expand All @@ -471,7 +503,7 @@ void consumeHost(const char* str, const size_t len, int32_t& pos, URI& uri) {
int32_t posInIPV4Address = posInHost;
if (tryConsumeIPV4Address(str, len, posInIPV4Address) &&
(posInIPV4Address == len ||
kFollowingHost.test(str[posInIPV4Address]))) {
test(kFollowingHost, str[posInIPV4Address]))) {
// reg-name and IPv4 addresses are hard to distinguish, a reg-name could
// have a valid IPv4 address as a prefix, but treating that prefix as an
// IPv4 address would make this URI invalid. We make sure that if we
Expand Down Expand Up @@ -534,14 +566,14 @@ bool tryConsumeScheme(
int32_t posInScheme = pos;

// The scheme must start with a letter.
if (posInScheme == len || !kAlpha.test(str[posInScheme])) {
if (posInScheme == len || !test(kAlpha, str[posInScheme])) {
return false;
}

// Consume the first letter.
posInScheme++;

while (posInScheme < len && kScheme.test(str[posInScheme])) {
while (posInScheme < len && test(kScheme, str[posInScheme])) {
posInScheme++;
}

Expand Down
5 changes: 3 additions & 2 deletions velox/functions/prestosql/URLFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ FOLLY_ALWAYS_INLINE void urlEscape(TOutString& output, const TInString& input) {
outputBuffer[outIndex++] = '+';
inputIndex++;
} else {
const auto charLength =
tryGetCharLength(inputBuffer + inputIndex, inputSize - inputIndex);
int32_t codePoint;
const auto charLength = tryGetCharLength(
inputBuffer + inputIndex, inputSize - inputIndex, codePoint);
if (charLength > 0) {
for (int i = 0; i < charLength; ++i) {
charEscape(inputBuffer[inputIndex + i], outputBuffer + outIndex);
Expand Down
Loading

0 comments on commit 18ee41a

Please sign in to comment.