From 9a9c1f08d7349c4529b82c95da3a3016bfb8e47d Mon Sep 17 00:00:00 2001 From: Kevin Wilfong Date: Mon, 18 Nov 2024 15:22:00 -0800 Subject: [PATCH 1/3] Fix Presto URL functions to more closely match Presto Java's behavior (#11488) Summary: Today the Presto URL functions rely on a regex to extract features of a URL. Due to the limitations of regexes this is not sufficient to validate that a URL is indeed valid. This leads to a number of cases where Presto Java will return NULL due to an invalid URL and Velox will return some substring. To address this I've implemented a parser for URIs based on the RFC 3986 spec which can both validate a URL and extract features of it. While testing this change I noticed a number of other discrepancies between Presto Java's and Velox's implementations of these UDFs (mostly related to unescaping/decoding URLs or portions thereof) that had been missed, likely due to the noise from the different handling of invalid URLs. Those are addressed in this diff as well so that I could effectively test it. Differential Revision: D65695961 --- velox/functions/prestosql/CMakeLists.txt | 1 + velox/functions/prestosql/URIParser.cpp | 675 ++++++++++++++++++ velox/functions/prestosql/URIParser.h | 36 + velox/functions/prestosql/URLFunctions.h | 229 ++---- .../prestosql/tests/URLFunctionsTest.cpp | 311 +++++++- 5 files changed, 1081 insertions(+), 171 deletions(-) create mode 100644 velox/functions/prestosql/URIParser.cpp create mode 100644 velox/functions/prestosql/URIParser.h diff --git a/velox/functions/prestosql/CMakeLists.txt b/velox/functions/prestosql/CMakeLists.txt index fb08d3217793..e5174cd8bd08 100644 --- a/velox/functions/prestosql/CMakeLists.txt +++ b/velox/functions/prestosql/CMakeLists.txt @@ -54,6 +54,7 @@ velox_add_library( TransformKeys.cpp TransformValues.cpp TypeOf.cpp + URIParser.cpp URLFunctions.cpp VectorArithmetic.cpp WidthBucketArray.cpp diff --git a/velox/functions/prestosql/URIParser.cpp b/velox/functions/prestosql/URIParser.cpp new file mode 100644 index 000000000000..3f56a7d41f5e --- /dev/null +++ b/velox/functions/prestosql/URIParser.cpp @@ -0,0 +1,675 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/prestosql/URIParser.h" + +namespace facebook::velox::functions { + +namespace detail { +using Mask = std::bitset<128>; + +Mask createMask(size_t low, size_t high) { + Mask mask = 0; + + for (size_t i = low; i <= high; i++) { + mask.set(i); + } + + return mask; +} + +Mask createMask(const std::vector& values) { + Mask mask = 0; + + for (const auto& value : values) { + mask.set(value); + } + + return mask; +} +// a-z or A-Z. +const Mask kAlpha = createMask('a', 'z') | createMask('A', 'Z'); +// 0-9. +const Mask kNum = createMask('0', '9'); +// 0-9, a-f, or A-F. +const Mask kHex = kNum | createMask('a', 'f') | createMask('A', 'F'); +// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" +// / "*" / "+" / "," / ";" / "=" +const Mask kSubDelims = + createMask({'!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '='}); +// gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@" +const Mask kGenDelims = createMask({':', '/', '?', '#', '[', ']', '@'}); +// unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" +const Mask kUnreserved = kAlpha | kNum | createMask({'-', '.', '_', '~'}); +// pchar = unreserved / pct-encoded / sub-delims / ":" / "@" +const Mask kPChar = kUnreserved | kSubDelims | createMask({':', '@'}); +// query = *( pchar / "/" / "?" ) +// fragment = *( pchar / "/" / "?" ) +const Mask kQueryOrFragment = kPChar | createMask({'/', '?'}); +// path = path-abempty ; begins with "/" or is empty +// / path-absolute ; begins with "/" but not "//" +// / path-noscheme ; begins with a non-colon segment +// / path-rootless ; begins with a segment +// / path-empty ; zero characters +// +// path-abempty = *( "/" segment ) +// path-absolute = "/" [ segment-nz *( "/" segment ) ] +// path-noscheme = segment-nz-nc *( "/" segment ) +// path-rootless = segment-nz *( "/" segment ) +// path-empty = 0 +// +// segment = *pchar +// segment-nz = 1*pchar +// segment-nz-nc = 1*( unreserved / pct-encoded / sub-delims / "@" ) +// ; non-zero-length segment without any colon ":" +const Mask kPath = kPChar | createMask({'/'}); +// segment-nz-nc = 1*( unreserved / pct-encoded / sub-delims / "@" ) +// ; non-zero-length segment without any colon ":" +const Mask kPathNoColonPrefix = kPChar & createMask({':'}).flip(); +// reg-name = *( unreserved / pct-encoded / sub-delims ) +const Mask kRegName = kUnreserved | kSubDelims; +// IPvFuture = "v" 1*HEXDIG "." 1*( unreserved / sub-delims / ":" ) +// userinfo = *( unreserved / pct-encoded / sub-delims / ":" ) +const Mask kIPVFutureSuffixOrUserInfo = + kUnreserved | kSubDelims | createMask({':'}); +// scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." ) +const Mask kScheme = kAlpha | kNum | createMask({'+', '-', '.'}); +// Not explicitly called out in the spec, but these are the only characters that +// can legally follow a host name in a URI (as part of a port, query, fragment, +// or path respectively). Used to differentiate an IPv4 address from a reg-name +// that has an IPv4 address as its prefix. +const Mask kFollowingHost = createMask({':', '?', '#', '/'}); + +// The functions below follow the general conventions: +// +// They are named consumeX or tryConsumeX where X is a part of the URI string +// that generally corresponds to a rule in the ABNF grammar for URIs defined in +// Appendix A of RFC 3986. +// +// Functions will take 3 or 4 arguments, `str` the URI string we are parsing, +// `len` the length of the URI string we are parsing, and `pos` our current +// position in the string at the time this function was called. Some functions +// all take a fourth `uri` argument, the URI struct that is to be populated with +// extracted sections of the URI string as part of parsing. If a function +// successfully parses the string, pos is updated to be the first character +// after the substring that was successfully parsed. In addition, if the +// function takes `uri` it will be populated with the relevant section of the +// URI. +// +// Functions named consumeX will always succeed because they can accept an empty +// string. +// +// Functions named tryConsumeX may not successfully parse a substring, they will +// return a bool, with `true` indicating that it was successful and `false` +// indicating it was not. If it was not successful `pos` and `uri` will not be +// modified. +// +// All rules are greedy and will consume as much of the string as possible +// starting from `pos`. +// +// Naturally these conventions do not apply to helper functions. + +// pct-encoded = "%" HEXDIG HEXDIG +bool tryConsumePercentEncoded(const char* str, const size_t len, int32_t& pos) { + if (len - pos < 3) { + return false; + } + + if (str[pos] != '%' || !kHex.test(str[pos + 1]) || !kHex.test(str[pos + 2])) { + return false; + } + + pos += 3; + + return true; +} + +// Helper function that consumes as much of `str` from `pos` as possible where a +// character passes mask or is part of a percent encoded character. +// +// `pos` is updated to the first character in `str` that was not consumed and +// `hasEncoded` is set to true if any percent encoded characters were +// encountered. +void consume( + const Mask& mask, + const char* str, + const size_t len, + int32_t& pos, + bool& hasEncoded) { + while (pos < len) { + if (mask.test(str[pos])) { + pos++; + continue; + } + + if (tryConsumePercentEncoded(str, len, pos)) { + hasEncoded = true; + continue; + } + + break; + } +} + +// path = path-abempty ; begins with "/" or is empty +// / path-absolute ; begins with "/" but not "//" +// / path-noscheme ; begins with a non-colon segment +// / path-rootless ; begins with a segment +// / path-empty ; zero characters +// +// path-abempty = *( "/" segment ) +// path-absolute = "/" [ segment-nz *( "/" segment ) ] +// path-noscheme = segment-nz-nc *( "/" segment ) +// path-rootless = segment-nz *( "/" segment ) +// path-empty = 0 +// +// segment = *pchar +// segment-nz = 1*pchar +// segment-nz-nc = 1*( unreserved / pct-encoded / sub-delims / "@" ) +// ; non-zero-length segment without any colon ":" +// +// For our purposes this is just a complicated way of saying a possibly empty +// string of characters that match the kPath Mask or are percent encoded. +// There aren't currently any use cases in Velox for distinguishing the +// different types of paths so we don't bother here. +template +void consumePath(const char* str, const size_t len, int32_t& pos, URI& uri) { + int32_t posInPath = pos; + + if constexpr (restrictColonInPrefix) { + // Consume a prefix without ':' or '/'. + consume(kPathNoColonPrefix, str, len, posInPath, uri.pathHasEncoded); + // The path continues only if the next character is a '/'. + if (posInPath != len && str[posInPath] == '/') { + consume(kPath, str, len, posInPath, uri.pathHasEncoded); + } + } else { + consume(kPath, str, len, posInPath, uri.pathHasEncoded); + } + + uri.path = StringView(str + pos, posInPath - pos); + pos = posInPath; +} + +// Returns whether `probe` is in the range[`low`, `high`]. +bool inRange(char probe, char low, char high) { + return probe >= low && probe <= high; +} + +// IPv4address = dec-octet "." dec-octet "." dec-octet "." dec-octet +// dec-octet = DIGIT ; 0-9 +// / %x31-39 DIGIT ; 10-99 +// / "1" 2DIGIT ; 100-199 +// / "2" %x30-34 DIGIT ; 200-249 +// / "25" %x30-35 ; 250-255 +bool tryConsumeIPV4Address(const char* str, const size_t len, int32_t& pos) { + int32_t posInAddress = pos; + + for (int i = 0; i < 4; i++) { + if (posInAddress == len) { + return false; + } + + if (str[posInAddress] == '2' && posInAddress < len - 2 && + str[posInAddress + 1] == '5' && + inRange(str[posInAddress + 2], '0', '5')) { + // 250-255 + posInAddress += 3; + } else if ( + str[posInAddress] == '2' && posInAddress < len - 2 && + inRange(str[posInAddress + 1], '0', '4') && + inRange(str[posInAddress + 2], '0', '9')) { + // 200-249 + posInAddress += 3; + } else if ( + str[posInAddress] == '1' && posInAddress < len - 2 && + inRange(str[posInAddress + 1], '0', '9') && + inRange(str[posInAddress + 2], '0', '9')) { + // 100-199 + posInAddress += 3; + } else if (inRange(str[posInAddress], '0', '9')) { + if (posInAddress < len - 1 && inRange(str[posInAddress + 1], '0', '9')) { + // 10-99 + posInAddress += 2; + } else { + // 0-9 + posInAddress += 1; + } + } else { + return false; + } + + if (i < 3) { + // An IPv4 address must have exactly 4 parts. + if (posInAddress == len || str[posInAddress] != '.') { + return false; + } + + // Consume '.'. + posInAddress++; + } + } + + pos = posInAddress; + return true; +} + +// Returns true if the substring starting from `pos` is '::'. +bool isAtCompression(const char* str, const size_t len, const int32_t pos) { + return pos < len - 1 && str[pos] == ':' && str[pos + 1] == ':'; +} + +// IPv6address = 6( h16 ":" ) ls32 +// / "::" 5( h16 ":" ) ls32 +// / [ h16 ] "::" 4( h16 ":" ) ls32 +// / [ *1( h16 ":" ) h16 ] "::" 3( h16 ":" ) ls32 +// / [ *2( h16 ":" ) h16 ] "::" 2( h16 ":" ) ls32 +// / [ *3( h16 ":" ) h16 ] "::" h16 ":" ls32 +// / [ *4( h16 ":" ) h16 ] "::" ls32 +// / [ *5( h16 ":" ) h16 ] "::" h16 +// / [ *6( h16 ":" ) h16 ] "::" +// h16 = 1*4HEXDIG +// ls32 = ( h16 ":" h16 ) / IPv4address +bool tryConsumeIPV6Address(const char* str, const size_t len, int32_t& pos) { + bool hasCompression = false; + uint8_t numBytes = 0; + int32_t posInAddress = pos; + + if (isAtCompression(str, len, posInAddress)) { + hasCompression = true; + // Consume the compression '::'. + posInAddress += 2; + } + + while (posInAddress < len && numBytes < 16) { + int32_t posInHex = posInAddress; + for (int i = 0; i < 4; i++) { + if (posInHex == len || !kHex.test(str[posInHex])) { + break; + } + + posInHex++; + } + + if (posInHex == posInAddress) { + // We need to be able to consume at least one hex digit. + break; + } + + if (posInHex < len) { + if (str[posInHex] == '.') { + // We may be in the IPV4 Address. + if (tryConsumeIPV4Address(str, len, posInAddress)) { + numBytes += 4; + break; + } else { + // A '.' can't appear anywhere except in a valid IPV4 address. + return false; + } + } + if (str[posInHex] == ':') { + if (isAtCompression(str, len, posInHex)) { + if (hasCompression) { + // We can't have two compressions. + return false; + } else { + // We found a 2 byte hex value followed by a compression. + numBytes += 2; + hasCompression = true; + // Consume the hex block and the compression '::'. + posInAddress = posInHex + 2; + } + } else { + if (posInHex == len || !kHex.test(str[posInHex + 1])) { + // Peak ahead, we can't end on a single ':'. + return false; + } + // We found a 2 byte hex value followed by a single ':'. + numBytes += 2; + // Consume the hex block and the ':'. + posInAddress = posInHex + 1; + } + } else { + // We found a 2 byte hex value at the end of the string. + numBytes += 2; + posInAddress = posInHex; + break; + } + } + } + + // A valid IPv6 address must have exactly 16 bytes, or a compression. + if ((numBytes == 16 && !hasCompression) || + (hasCompression && numBytes <= 14 && numBytes % 2 == 0)) { + pos = posInAddress; + return true; + } else { + return false; + } +} + +// IPvFuture = "v" 1*HEXDIG "." 1*( unreserved / sub-delims / ":" ) +bool tryConsumeIPVFuture(const char* str, const size_t len, int32_t& pos) { + int32_t posInAddress = pos; + + if (posInAddress == len || str[posInAddress] != 'v') { + return false; + } + + // Consume 'v'. + posInAddress++; + + // Consume a string of hex digits. + int32_t posInHex = posInAddress; + while (posInHex < len) { + if (kHex.test(str[posInHex])) { + posInHex++; + } else { + break; + } + } + + // The string of hex digits has to be non-empty. + if (posInHex == posInAddress) { + return false; + } + + posInAddress = posInHex; + + // The string of hex digits must be followed by a '.'. + if (posInAddress == len || str[posInAddress] != '.') { + return false; + } + + // Consume '.'. + posInAddress++; + + int32_t posInSuffix = posInAddress; + while (posInSuffix < len) { + if (kIPVFutureSuffixOrUserInfo.test(str[posInSuffix])) { + posInSuffix++; + } else { + break; + } + } + + // The suffix must be non-empty. + if (posInSuffix == posInAddress) { + return false; + } + + pos = posInSuffix; + return true; +} + +// IP-literal = "[" ( IPv6address / IPvFuture ) "]" +bool tryConsumeIPLiteral(const char* str, const size_t len, int32_t& pos) { + int32_t posInAddress = pos; + + // The IP Literal must start with '['. + if (posInAddress == len || str[posInAddress] != '[') { + return false; + } + + // Consume '['. + posInAddress++; + + // The contents must be an IPv6 address or an IPvFuture. + if (!tryConsumeIPV6Address(str, len, posInAddress) && + !tryConsumeIPVFuture(str, len, posInAddress)) { + return false; + } + + // The IP literal must end with ']'. + if (posInAddress == len || str[posInAddress] != ']') { + return false; + } + + // Consume ']'. + posInAddress++; + pos = posInAddress; + + return true; +} + +// port = *DIGIT +void consumePort(const char* str, const size_t len, int32_t& pos, URI& uri) { + int32_t posInPort = pos; + + while (posInPort < len) { + if (kNum.test(str[posInPort])) { + posInPort++; + continue; + } + + break; + } + + uri.port = StringView(str + pos, posInPort - pos); + pos = posInPort; +} + +// host = IP-literal / IPv4address / reg-name +// reg-name = *( unreserved / pct-encoded / sub-delims ) +void consumeHost(const char* str, const size_t len, int32_t& pos, URI& uri) { + int32_t posInHost = pos; + + if (!tryConsumeIPLiteral(str, len, posInHost)) { + int32_t posInIPV4Address = posInHost; + if (tryConsumeIPV4Address(str, len, posInIPV4Address) && + (posInIPV4Address == len || + kFollowingHost.test(str[posInIPV4Address]))) { + // reg-name and IPv4 addresses are hard to distinguish, a reg-name could + // have a valid IPv4 address as a prefix, but treating that prefix as an + // IPv4 address would make this URI invalid. We make sure that if we + // detect an IPv4 address it either goes to the end of the string, or is + // followed by one of the characters that can appear after a host name + // (and importantly can't appear in a reg-name). + posInHost = posInIPV4Address; + } else { + consume(kRegName, str, len, posInHost, uri.hostHasEncoded); + } + } + + uri.host = StringView(str + pos, posInHost - pos); + pos = posInHost; +} + +// authority = [ userinfo "@" ] host [ ":" port ] +void consumeAuthority( + const char* str, + const size_t len, + int32_t& pos, + URI& uri) { + int32_t posInAuthority = pos; + + // Dummy variable to pass in as reference. + bool authorityHasEncoded; + consume( + kIPVFutureSuffixOrUserInfo, + str, + len, + posInAuthority, + authorityHasEncoded); + + // The user info must be followed by '@'. + if (posInAuthority != len && str[posInAuthority] == '@') { + // Consume '@'. + posInAuthority++; + } else { + posInAuthority = pos; + } + + consumeHost(str, len, posInAuthority, uri); + + // The port must be preceded by a ':'. + if (posInAuthority < len && str[posInAuthority] == ':') { + // Consume ':'. + posInAuthority++; + consumePort(str, len, posInAuthority, uri); + } + + pos = posInAuthority; +} + +// scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." ) +bool tryConsumeScheme( + const char* str, + const size_t len, + int32_t& pos, + URI& uri) { + int32_t posInScheme = pos; + + // The scheme must start with a letter. + if (posInScheme == len || !kAlpha.test(str[posInScheme])) { + return false; + } + + // Consume the first letter. + posInScheme++; + + while (posInScheme < len && kScheme.test(str[posInScheme])) { + posInScheme++; + } + + uri.scheme = StringView(str + pos, posInScheme - pos); + pos = posInScheme; + return true; +} + +// relative-part = "//" authority path-abempty +// / path-absolute +// / path-noscheme +// / path-empty +// +// hier-part = "//" authority path-abempty +// / path-absolute +// / path-rootless +// / path-empty +// +// Since we don't distinguish between path types these are functionally the same +// thing. +template +void consumeRelativePartOrHierPart( + const char* str, + const size_t len, + int32_t& pos, + URI& uri) { + if (pos < len - 1 && str[pos] == '/' && str[pos + 1] == '/') { + // Consume '//'. + pos += 2; + consumeAuthority(str, len, pos, uri); + + if (pos < len && str[pos] == '/') { + consumePath(str, len, pos, uri); + } + + return; + } + + if constexpr (isRelativePart) { + // In the relative part, there's a restriction that there cannot be any ':' + // until the first '/'. + consumePath(str, len, pos, uri); + } else { + consumePath(str, len, pos, uri); + } +} + +// query = *( pchar / "/" / "?" ) +// fragment = *( pchar / "/" / "?" ) +void consumeQueryAndFragment( + const char* str, + const size_t len, + int32_t& pos, + URI& uri) { + if (pos < len && str[pos] == '?') { + int32_t posInQuery = pos; + // Consume '?'. + posInQuery++; + // Consume query. + consume(kQueryOrFragment, str, len, posInQuery, uri.queryHasEncoded); + + // Don't include the '?'. + uri.query = StringView(str + pos + 1, posInQuery - pos - 1); + pos = posInQuery; + } + + if (pos < len && str[pos] == '#') { + int32_t posInFragment = pos; + // Consume '#'. + posInFragment++; + // Consume fragment. + consume(kQueryOrFragment, str, len, posInFragment, uri.fragmentHasEncoded); + + // Don't include the '#'. + uri.fragment = StringView(str + pos + 1, posInFragment - pos - 1); + pos = posInFragment; + } +} + +// relative-ref = relative-part [ "?" query ] [ "#" fragment ] +void consumeRelativeRef( + const char* str, + const size_t len, + int32_t& pos, + URI& uri) { + consumeRelativePartOrHierPart(str, len, pos, uri); + + consumeQueryAndFragment(str, len, pos, uri); +} + +// URI = scheme ":" hier-part [ "?" query ] [ "#" fragment ] +bool tryConsumeUri(const char* str, const size_t len, int32_t& pos, URI& uri) { + URI result; + int32_t posInUri = pos; + + if (!tryConsumeScheme(str, len, posInUri, result)) { + return false; + } + + // Scheme is always followed by ':'. + if (posInUri == len || str[posInUri] != ':') { + return false; + } + + // Consume ':'. + posInUri++; + + consumeRelativePartOrHierPart(str, len, posInUri, result); + consumeQueryAndFragment(str, len, posInUri, result); + + pos = posInUri; + uri = result; + return true; +} + +} // namespace detail + +// URI-reference = URI / relative-ref +bool parseUri(const StringView& uriStr, URI& uri) { + int32_t pos = 0; + if (detail::tryConsumeUri(uriStr.data(), uriStr.size(), pos, uri) && + pos == uriStr.size()) { + return true; + } + + pos = 0; + detail::consumeRelativeRef(uriStr.data(), uriStr.size(), pos, uri); + + return pos == uriStr.size(); +} +} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/URIParser.h b/velox/functions/prestosql/URIParser.h new file mode 100644 index 000000000000..6e86a3678fdf --- /dev/null +++ b/velox/functions/prestosql/URIParser.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/type/StringView.h" + +namespace facebook::velox::functions { +struct URI { + StringView scheme; + StringView path; + bool pathHasEncoded = false; + StringView query; + bool queryHasEncoded = false; + StringView fragment; + bool fragmentHasEncoded = false; + StringView host; + bool hostHasEncoded = false; + StringView port; +}; + +// Parse a URI string into a URI struct according to RFC 3986. +bool parseUri(const StringView& uriStr, URI& uri); +} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/URLFunctions.h b/velox/functions/prestosql/URLFunctions.h index 97e16a582089..8dc64704d1d7 100644 --- a/velox/functions/prestosql/URLFunctions.h +++ b/velox/functions/prestosql/URLFunctions.h @@ -20,77 +20,16 @@ #include #include "velox/functions/Macros.h" #include "velox/functions/lib/string/StringImpl.h" +#include "velox/functions/prestosql/URIParser.h" namespace facebook::velox::functions { namespace detail { - -const auto kScheme = 2; -const auto kAuthority = 3; -const auto kPath = 5; -const auto kQuery = 7; -const auto kFragment = 9; -const auto kHost = 3; // From the authority and path regex. -const auto kPort = 4; // From the authority and path regex. - FOLLY_ALWAYS_INLINE StringView submatch(const boost::cmatch& match, int idx) { const auto& sub = match[idx]; return StringView(sub.first, sub.length()); } -FOLLY_ALWAYS_INLINE bool -parse(const char* rawUrlData, size_t rawUrlsize, boost::cmatch& match) { - /// This regex is taken from RFC - 3986. - /// See: https://www.rfc-editor.org/rfc/rfc3986#appendix-B - /// The basic groups are: - /// scheme = $2 - /// authority = $4 - /// path = $5 - /// query = $7 - /// fragment = $9 - /// For example a URI like below : - /// http://www.ics.uci.edu/pub/ietf/uri/#Related - /// - /// results in the following subexpression matches: - /// - /// $1 = http: - /// $2 = http - /// $3 = //www.ics.uci.edu - /// $4 = www.ics.uci.edu - /// $5 = /pub/ietf/uri/ - /// $6 = - /// $7 = - /// $8 = #Related - /// $9 = Related - static const boost::regex kUriRegex( - "^(([^:\\/?#]+):)?" // scheme: - "(\\/\\/([^\\/?#]*))?([^?#]*)" // authority and path - "(\\?([^#]*))?" // ?query - "(#(.*))?"); // #fragment - - return boost::regex_match( - rawUrlData, rawUrlData + rawUrlsize, match, kUriRegex); -} - -/// Parses the url and returns the matching subgroup if the particular sub group -/// is matched by the call to parse call above. -FOLLY_ALWAYS_INLINE std::optional parse( - StringView rawUrl, - int subGroup) { - boost::cmatch match; - if (!parse(rawUrl.data(), rawUrl.size(), match)) { - return std::nullopt; - } - - VELOX_CHECK_LT(subGroup, match.size()); - - if (match[subGroup].matched) { - return submatch(match, subGroup); - } - - return std::nullopt; -} - FOLLY_ALWAYS_INLINE unsigned char toHex(unsigned char c) { return c < 10 ? (c + '0') : (c + 'A' - 10); } @@ -136,38 +75,7 @@ FOLLY_ALWAYS_INLINE void urlEscape(TOutString& output, const TInString& input) { output.resize(outIndex); } -/// Performs initial validation of the URI. -/// Checks if the URI contains ascii whitespaces or -/// unescaped '%' chars. -FOLLY_ALWAYS_INLINE bool isValidURI(StringView input) { - const char* p = input.data(); - const char* end = p + input.size(); - char buf[3]; - buf[2] = '\0'; - char* endptr; - for (; p < end; ++p) { - if (stringImpl::isAsciiWhiteSpace(*p)) { - return false; - } - - if (*p == '%') { - if (p + 2 < end) { - buf[0] = p[1]; - buf[1] = p[2]; - strtol(buf, &endptr, 16); - p += 2; - if (endptr != buf + 2) { - return false; - } - } else { - return false; - } - } - } - return true; -} - -template +template FOLLY_ALWAYS_INLINE void urlUnescape( TOutString& output, const TInString& input) { @@ -181,9 +89,13 @@ FOLLY_ALWAYS_INLINE void urlUnescape( buf[2] = '\0'; char* endptr; for (; p < end; ++p) { - if (*p == '+') { - *outputBuffer++ = ' '; - } else if (*p == '%') { + if constexpr (unescapePlus) { + if (*p == '+') { + *outputBuffer++ = ' '; + continue; + } + } + if (*p == '%') { if (p + 2 < end) { buf[0] = p[1]; buf[1] = p[2]; @@ -204,15 +116,6 @@ FOLLY_ALWAYS_INLINE void urlUnescape( } output.resize(outputBuffer - output.data()); } - -/// Matches the authority (i.e host[:port], ipaddress), and path from a string -/// representing the authority and path. Returns true if the regex matches, and -/// sets the appropriate groups matching authority in authorityMatch. -std::optional matchAuthorityAndPath( - StringView authorityAndPath, - boost::cmatch& authorityMatch, - int subGroup); - } // namespace detail template @@ -228,15 +131,13 @@ struct UrlExtractProtocolFunction { FOLLY_ALWAYS_INLINE bool call( out_type& result, const arg_type& url) { - if (!detail::isValidURI(url)) { + URI uri; + if (!parseUri(url, uri)) { return false; } - if (auto protocol = detail::parse(url, detail::kScheme)) { - result.setNoCopy(protocol.value()); - } else { - result.setEmpty(); - } + result.setNoCopy(uri.scheme); + return true; } }; @@ -248,21 +149,22 @@ struct UrlExtractFragmentFunction { // Results refer to strings in the first argument. static constexpr int32_t reuse_strings_from_arg = 0; - // ASCII input always produces ASCII result. - static constexpr bool is_default_ascii_behavior = true; + // Input is always ASCII, but result may or may not be ASCII. FOLLY_ALWAYS_INLINE bool call( out_type& result, const arg_type& url) { - if (!detail::isValidURI(url)) { + URI uri; + if (!parseUri(url, uri)) { return false; } - if (auto fragment = detail::parse(url, detail::kFragment)) { - result.setNoCopy(fragment.value()); + if (uri.fragmentHasEncoded) { + detail::urlUnescape(result, uri.fragment); } else { - result.setEmpty(); + result.setNoCopy(uri.fragment); } + return true; } }; @@ -274,29 +176,22 @@ struct UrlExtractHostFunction { // Results refer to strings in the first argument. static constexpr int32_t reuse_strings_from_arg = 0; - // ASCII input always produces ASCII result. - static constexpr bool is_default_ascii_behavior = true; + // Input is always ASCII, but result may or may not be ASCII. FOLLY_ALWAYS_INLINE bool call( out_type& result, const arg_type& url) { - if (!detail::isValidURI(url)) { + URI uri; + if (!parseUri(url, uri)) { return false; } - auto authAndPath = detail::parse(url, detail::kAuthority); - if (!authAndPath) { - result.setEmpty(); - return true; - } - boost::cmatch authorityMatch; - - if (auto host = detail::matchAuthorityAndPath( - authAndPath.value(), authorityMatch, detail::kHost)) { - result.setNoCopy(host.value()); + if (uri.hostHasEncoded) { + detail::urlUnescape(result, uri.host); } else { - result.setEmpty(); + result.setNoCopy(uri.host); } + return true; } }; @@ -306,26 +201,19 @@ struct UrlExtractPortFunction { VELOX_DEFINE_FUNCTION_TYPES(T); FOLLY_ALWAYS_INLINE bool call(int64_t& result, const arg_type& url) { - if (!detail::isValidURI(url)) { - return false; - } - - auto authAndPath = detail::parse(url, detail::kAuthority); - if (!authAndPath) { + URI uri; + if (!parseUri(url, uri)) { return false; } - boost::cmatch authorityMatch; - if (auto port = detail::matchAuthorityAndPath( - authAndPath.value(), authorityMatch, detail::kPort)) { - if (!port.value().empty()) { - try { - result = to(port.value()); - return true; - } catch (folly::ConversionError const&) { - } + if (!uri.port.empty()) { + try { + result = to(uri.port); + return true; + } catch (folly::ConversionError const&) { } } + return false; } }; @@ -336,17 +224,22 @@ struct UrlExtractPathFunction { // Input is always ASCII, but result may or may not be ASCII. + // Results refer to strings in the first argument. + static constexpr int32_t reuse_strings_from_arg = 0; + FOLLY_ALWAYS_INLINE bool call( out_type& result, const arg_type& url) { - if (!detail::isValidURI(url)) { + URI uri; + if (!parseUri(url, uri)) { return false; } - auto path = detail::parse(url, detail::kPath); - VELOX_USER_CHECK( - path.has_value(), "Unable to determine path for URL: {}", url); - detail::urlUnescape(result, path.value()); + if (uri.pathHasEncoded) { + detail::urlUnescape(result, uri.path); + } else { + result.setNoCopy(uri.path); + } return true; } @@ -359,20 +252,20 @@ struct UrlExtractQueryFunction { // Results refer to strings in the first argument. static constexpr int32_t reuse_strings_from_arg = 0; - // ASCII input always produces ASCII result. - static constexpr bool is_default_ascii_behavior = true; + // Input is always ASCII, but result may or may not be ASCII. FOLLY_ALWAYS_INLINE bool call( out_type& result, const arg_type& url) { - if (!detail::isValidURI(url)) { + URI uri; + if (!parseUri(url, uri)) { return false; } - if (auto query = detail::parse(url, detail::kQuery)) { - result.setNoCopy(query.value()); + if (uri.queryHasEncoded) { + detail::urlUnescape(result, uri.query); } else { - result.setEmpty(); + result.setNoCopy(uri.query); } return true; @@ -386,23 +279,18 @@ struct UrlExtractParameterFunction { // Results refer to strings in the first argument. static constexpr int32_t reuse_strings_from_arg = 0; - // ASCII input always produces ASCII result. - static constexpr bool is_default_ascii_behavior = true; + // Input is always ASCII, but result may or may not be ASCII. FOLLY_ALWAYS_INLINE bool call( out_type& result, const arg_type& url, const arg_type& param) { - if (!detail::isValidURI(url)) { - return false; - } - - auto query = detail::parse(url, detail::kQuery); - if (!query) { + URI uri; + if (!parseUri(url, uri)) { return false; } - if (!query.value().empty()) { + if (!uri.query.empty()) { // Parse query string. static const boost::regex kQueryParamRegex( "(^|&)" // start of query or start of parameter "&" @@ -413,8 +301,8 @@ struct UrlExtractParameterFunction { ); const boost::cregex_iterator begin( - query.value().data(), - query.value().data() + query.value().size(), + uri.query.data(), + uri.query.data() + uri.query.size(), kQueryParamRegex); boost::cregex_iterator end; @@ -452,7 +340,8 @@ struct UrlDecodeFunction { FOLLY_ALWAYS_INLINE void call( out_type& result, const arg_type& input) { - detail::urlUnescape(result, input); + detail::urlUnescape, arg_type, true>( + result, input); } }; diff --git a/velox/functions/prestosql/tests/URLFunctionsTest.cpp b/velox/functions/prestosql/tests/URLFunctionsTest.cpp index 58252664ff5e..0ab32784e33e 100644 --- a/velox/functions/prestosql/tests/URLFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/URLFunctionsTest.cpp @@ -118,12 +118,250 @@ TEST_F(URLFunctionsTest, validateURL) { "IC6S!8hGVRpo+!,yTaJEy/$RUZpqcr", "", "", - "IC6S!8hGVRpo !,yTaJEy/$RUZpqcr", + "IC6S!8hGVRpo+!,yTaJEy/$RUZpqcr", + "", + "", + std::nullopt); + + // Some examples from Wikipedia. + // https://en.wikipedia.org/wiki/Uniform_Resource_Identifier + validate( + "https://john.doe@www.example.com:1234/forum/questions/?tag=networking&order=newest#top", + "https", + "www.example.com", + "/forum/questions/", + "top", + "tag=networking&order=newest", + 1234); + validate( + "https://john.doe@www.example.com:1234/forum/questions/?tag=networking&order=newest#:~:text=whatever", + "https", + "www.example.com", + "/forum/questions/", + ":~:text=whatever", + "tag=networking&order=newest", + 1234); + validate( + "ldap://[2001:db8::7]/c=GB?objectClass?one", + "ldap", + "[2001:db8::7]", + "/c=GB", + "", + "objectClass?one", + std::nullopt); + validate( + "mailto:John.Doe@example.com", + "mailto", + "", + "John.Doe@example.com", + "", + "", + std::nullopt); + validate( + "news:comp.infosystems.www.servers.unix", + "news", + "", + "comp.infosystems.www.servers.unix", + "", + "", + std::nullopt); + validate( + "tel:+1-816-555-1212", + "tel", + "", + "+1-816-555-1212", + "", + "", + std::nullopt); + validate("telnet://192.0.2.16:80/", "telnet", "192.0.2.16", "/", "", "", 80); + validate( + "urn:oasis:names:specification:docbook:dtd:xml:4.1.2", + "urn", + "", + "oasis:names:specification:docbook:dtd:xml:4.1.2", "", "", std::nullopt); } +TEST_F(URLFunctionsTest, extractProtocol) { + const auto extractProtocol = [&](const std::optional& url) { + return evaluateOnce("url_extract_protocol(c0)", url); + }; + + // Test minimal protocol. + EXPECT_EQ("a", extractProtocol("a://www.yahoo.com")); + // Test all valid characters + EXPECT_EQ( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-.", + extractProtocol( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-.://www.yahoo.com")); + + // Test empty protocol. + EXPECT_EQ(std::nullopt, extractProtocol("://www.yahoo.com/")); + // Test protocol starts with digit. + EXPECT_EQ(std::nullopt, extractProtocol("1abc://www.yahoo.com/")); +} + +TEST_F(URLFunctionsTest, extractHostIPv4) { + const auto extractHost = [&](const std::optional& url) { + return evaluateOnce("url_extract_host(c0)", url); + }; + + // Upper bounds of the ranges of the rules for IPv4 dec-octets. + EXPECT_EQ("255.249.199.99", extractHost("http://255.249.199.99")); + // Lower bounds of the ranges of the rules for IPv4 dec-octets. + EXPECT_EQ("250.200.100.10", extractHost("http://250.200.100.10")); + // All single digits. + EXPECT_EQ("9.8.1.0", extractHost("http://9.8.1.0")); + // All two digits. + EXPECT_EQ("99.98.11.10", extractHost("http://99.98.11.10")); + // All three digits. + EXPECT_EQ("254.237.150.100", extractHost("http://254.237.150.100")); + + // We don't test invalid cases here as they will match the reg-name rule, we + // test them under the IPv6 cases below as these are distinguishable from + // reg-name. +} + +TEST_F(URLFunctionsTest, extractHostIPv6) { + const auto extractHost = [&](const std::optional& url) { + return evaluateOnce("url_extract_host(c0)", url); + }; + + // 8 hex blocks. + EXPECT_EQ( + "[0123:4567:89ab:cdef:0123:4567:89ab:cdef]", + extractHost("http://[0123:4567:89ab:cdef:0123:4567:89ab:cdef]")); + // 6 hex blocks followed by an IPv4 address. + EXPECT_EQ( + "[0123:4567:89ab:cdef:0123:4567:0.1.8.9]", + extractHost("http://[0123:4567:89ab:cdef:0123:4567:0.1.8.9]")); + // compression followed by 7 hex blocks. + EXPECT_EQ( + "[::456:89a:cde:012:456:89a:cde]", + extractHost("http://[::456:89a:cde:012:456:89a:cde]")); + // compression followed by 5 hex blocks followed by an IPv4 address. + EXPECT_EQ( + "[::456:89a:cde:012:456:10.11.98.99]", + extractHost("http://[::456:89a:cde:012:456:10.11.98.99]")); + // 1 hex block flowed by a compression followed by 6 hex blocks. + EXPECT_EQ( + "[12::45:89:cd:01:45:89]", extractHost("http://[12::45:89:cd:01:45:89]")); + // 1 hex block flowed by a compression followed by 4 hex blocks followed by an + // IPv4 address. + EXPECT_EQ( + "[12::45:89:cd:01:254.237.150.100]", + extractHost("http://[12::45:89:cd:01:254.237.150.100]")); + // 7 hex blocks followed by a compression. + EXPECT_EQ("[0:4:8:c:0:4:8::]", extractHost("http://[0:4:8:c:0:4:8::]")); + // 5 hex blocks followed by a compression followed by an IPv4 address. + EXPECT_EQ( + "[0:4:8:c:0::255.249.199.99]", + extractHost("http://[0:4:8:c:0::255.249.199.99]")); + // Compression followed by an IPv4 address. + EXPECT_EQ("[::250.200.100.10]", extractHost("http://[::250.200.100.10]")); + // Just a compression. + EXPECT_EQ("[::]", extractHost("http://[::]")); + + // Too many hex blocks. + EXPECT_EQ( + std::nullopt, + extractHost("http://[0123:4567:89ab:cdef:0123:4567:89ab:cdef:0123]")); + // Too many hex blocks with a compression. + EXPECT_EQ( + std::nullopt, + extractHost("http://[0123:4567:89ab:cdef:0123:4567:89ab::cdef]")); + // Too many hex blocks with an IPv4 address. + EXPECT_EQ( + std::nullopt, + extractHost( + "http://[0123:4567:89ab:cdef:0123:4567:89ab:250.200.100.10]")); + // Too few hex blocks. + EXPECT_EQ( + std::nullopt, extractHost("http://[0123:4567:89ab:cdef:0123:4567:89ab]")); + // Too few hex blocks with an IPv4 address. + EXPECT_EQ( + std::nullopt, + extractHost("http://[0123:4567:89ab:cdef:0123:250.200.100.10]")); + // End on a colon. + EXPECT_EQ( + std::nullopt, + extractHost("http://[0123:4567:89ab:cdef:0123:4567:89ab:cdef:]")); + // Hex blocks after an IPv4 address. + EXPECT_EQ( + std::nullopt, + extractHost("http://[0123:4567:89ab:cdef:250.200.100.10:89ab:cdef]")); + // Compression after an IPv4 address. + EXPECT_EQ( + std::nullopt, + extractHost("http://[0123:4567:89ab:cdef:250.200.100.10::]")); + // Two compressions. + EXPECT_EQ(std::nullopt, extractHost("http://[0123::4567::]")); + EXPECT_EQ(std::nullopt, extractHost("http://[::0123::]")); + + // Invalid IPv4 addresses. + // Too many dec-octets. + EXPECT_EQ(std::nullopt, extractHost("http://[::255.249.199.99.10]")); + // Too few dec-octets. + EXPECT_EQ(std::nullopt, extractHost("http://[::250.200.100]")); + // Dec-octets outside of range + EXPECT_EQ(std::nullopt, extractHost("http://[::256.8.1.0]")); + // Negative dec-octet. + EXPECT_EQ(std::nullopt, extractHost("http://[::99.98.-11.10]")); + // Hex in dec-octet. + EXPECT_EQ(std::nullopt, extractHost("http://[::254.dae.150.100]")); +} + +TEST_F(URLFunctionsTest, extractHostIPvFuture) { + const auto extractHost = [&](const std::optional& url) { + return evaluateOnce("url_extract_host(c0)", url); + }; + + // Test minimal. + EXPECT_EQ("[v0.a]", extractHost("http://[v0.a]")); + // Test all valid characters. + EXPECT_EQ( + "[v0123456789abcdefABCDEF.abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:]", + extractHost( + "http://[v0123456789abcdefABCDEF.abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:]")); + + // Missing v. + EXPECT_EQ(std::nullopt, extractHost("http://[0.a]")); + // Empty hex string. + EXPECT_EQ(std::nullopt, extractHost("http://[v.a]")); + // Invalid hex character. + EXPECT_EQ(std::nullopt, extractHost("http://[v0g.a]")); + // Missing period. + EXPECT_EQ(std::nullopt, extractHost("http://[v0a]")); + // Empty suffix. + EXPECT_EQ(std::nullopt, extractHost("http://[v0.]")); + // Invalid character in suffix. + EXPECT_EQ(std::nullopt, extractHost("http://[v0.a/]")); +} + +TEST_F(URLFunctionsTest, extractHostRegName) { + const auto extractHost = [&](const std::optional& url) { + return evaluateOnce("url_extract_host(c0)", url); + }; + + // Test minimal. + EXPECT_EQ("a", extractHost("http://a")); + // Test all valid characters. + EXPECT_EQ( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=", + extractHost( + "http://abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=")); + // Test prefix is valid IPv4 address. + EXPECT_EQ( + "123.456.789.012.abcdefg", extractHost("http://123.456.789.012.abcdefg")); + // Test percent encoded. + EXPECT_EQ("a b", extractHost("http://a%20b")); + + // Invalid character. + EXPECT_EQ(std::nullopt, extractHost("http://a b")); +} + TEST_F(URLFunctionsTest, extractPath) { const auto extractPath = [&](const std::optional& url) { return evaluateOnce("url_extract_path(c0)", url); @@ -142,6 +380,77 @@ TEST_F(URLFunctionsTest, extractPath) { EXPECT_EQ("foo", extractPath("foo")); EXPECT_EQ(std::nullopt, extractPath("BAD URL!")); EXPECT_EQ("", extractPath("http://www.yahoo.com")); + // All valid characters. + EXPECT_EQ( + "/abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@", + extractPath( + "/abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@")); +} + +TEST_F(URLFunctionsTest, extractPort) { + const auto extractPort = [&](const std::optional& url) { + return evaluateOnce("url_extract_port(c0)", url); + }; + + // 0-4 valid. + EXPECT_EQ(43210, extractPort("http://a:43210")); + // 5-9 valid. + EXPECT_EQ(98765, extractPort("http://a:98765")); + + // Empty port. + EXPECT_EQ(std::nullopt, extractPort("http://a:")); + // Hex invalid. + EXPECT_EQ(std::nullopt, extractPort("http://a:deadbeef")); +} + +TEST_F(URLFunctionsTest, extractHostWithUserInfo) { + const auto extractHost = [&](const std::optional& url) { + return evaluateOnce("url_extract_host(c0)", url); + }; + + // Test extracting a host when user info is present. + + // Test empty user info. + EXPECT_EQ("a", extractHost("http://@a")); + // Test all valid characters. + EXPECT_EQ( + "a", + extractHost( + "http://abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:%20@a")); + // Test with user info and port present. + EXPECT_EQ("a", extractHost("http://xyz@a:123")); +} + +TEST_F(URLFunctionsTest, extractQuery) { + const auto extractQuery = [&](const std::optional& url) { + return evaluateOnce("url_extract_query(c0)", url); + }; + + // Test empty query. + EXPECT_EQ("", extractQuery("http://www.yahoo.com?")); + // Test non-empty query. + EXPECT_EQ("a", extractQuery("http://www.yahoo.com?a")); + // Test all valid characters. + EXPECT_EQ( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@ ", + extractQuery( + "http://www.yahoo.com?abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@%20")); +} + +TEST_F(URLFunctionsTest, extractFragment) { + const auto extractFragment = [&](const std::optional& url) { + return evaluateOnce("url_extract_fragment(c0)", url); + }; + + // Test empty query. + EXPECT_EQ("", extractFragment("http://www.yahoo.com#")); + // Test non-empty query. + EXPECT_EQ("a", extractFragment("http://www.yahoo.com#a")); + // Test all valid characters. + EXPECT_EQ( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@ ", + extractFragment( + "http://www.yahoo.com#abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@%20")); } TEST_F(URLFunctionsTest, extractParameter) { From 2e5b51cc5e17752314c6b498c3277b3aca71d339 Mon Sep 17 00:00:00 2001 From: Kevin Wilfong Date: Mon, 18 Nov 2024 15:22:00 -0800 Subject: [PATCH 2/3] Match Presto's behavior for invalid UTF-8 in url_encode (#11518) Summary: Presto Java converts the URL to a Java String before encoding it in url_encode. Java replaces bytes in an invalid UTF-8 character with 0xEF 0xBF 0xBD. Velox encodes invalid UTF-8 characters as is, which leads to differences in results from Java and C++. This diff adds a check when encoding URLs for invalid UTF-8 characters and does the same replacement as Java. Differential Revision: D65856104 --- velox/functions/prestosql/URLFunctions.h | 82 +++++++++++++++++-- .../prestosql/tests/URLFunctionsTest.cpp | 29 +++++++ 2 files changed, 103 insertions(+), 8 deletions(-) diff --git a/velox/functions/prestosql/URLFunctions.h b/velox/functions/prestosql/URLFunctions.h index 8dc64704d1d7..a8f0abfde386 100644 --- a/velox/functions/prestosql/URLFunctions.h +++ b/velox/functions/prestosql/URLFunctions.h @@ -16,15 +16,15 @@ #pragma once #include -#include -#include #include "velox/functions/Macros.h" -#include "velox/functions/lib/string/StringImpl.h" +#include "velox/functions/lib/Utf8Utils.h" #include "velox/functions/prestosql/URIParser.h" namespace facebook::velox::functions { namespace detail { +constexpr std::string_view kEncodedReplacementCharacter{"%EF%BF%BD"}; + FOLLY_ALWAYS_INLINE StringView submatch(const boost::cmatch& match, int idx) { const auto& sub = match[idx]; return StringView(sub.first, sub.length()); @@ -49,27 +49,93 @@ FOLLY_ALWAYS_INLINE void charEscape(unsigned char c, char* output) { /// * All other characters are converted to UTF-8 and the bytes are encoded /// as the string ``%XX`` where ``XX`` is the uppercase hexadecimal /// value of the UTF-8 byte. +/// * If the character is invalid UTF-8 all bytes of the character are +/// converted to %EF%BF%BD. template FOLLY_ALWAYS_INLINE void urlEscape(TOutString& output, const TInString& input) { auto inputSize = input.size(); - output.reserve(inputSize * 3); + // In the worst case every byte is an invalid UTF-8 character. + output.reserve(inputSize * kEncodedReplacementCharacter.size()); auto inputBuffer = input.data(); auto outputBuffer = output.data(); + size_t inputIndex = 0; size_t outIndex = 0; - for (auto i = 0; i < inputSize; ++i) { - unsigned char p = inputBuffer[i]; + while (inputIndex < inputSize) { + unsigned char p = inputBuffer[inputIndex]; if ((p >= 'a' && p <= 'z') || (p >= 'A' && p <= 'Z') || (p >= '0' && p <= '9') || p == '-' || p == '_' || p == '.' || p == '*') { outputBuffer[outIndex++] = p; + inputIndex++; } else if (p == ' ') { outputBuffer[outIndex++] = '+'; + inputIndex++; } else { - charEscape(p, outputBuffer + outIndex); - outIndex += 3; + const auto charLength = + tryGetCharLength(inputBuffer + inputIndex, inputSize - inputIndex); + if (charLength > 0) { + for (int i = 0; i < charLength; ++i) { + charEscape(inputBuffer[inputIndex + i], outputBuffer + outIndex); + outIndex += 3; + } + + inputIndex += charLength; + } else { + // According to the Unicode standard the "maximal subpart of an + // ill-formed subsequence" is the longest code unit subsequenece that is + // either well-formed or of length 1. A replacement character should be + // written for each of these. In practice tryGetCharLength breaks most + // cases into maximal subparts, the exceptions are overlong encodings or + // subsequences outside the range of valid 4 byte sequences. In both + // these cases we should just write out a replacement character for + // every byte in the sequence. + bool isMultipleInvalidSequences = false; + if (inputIndex < inputSize - 1) { + isMultipleInvalidSequences = + // 0xe0 followed by a value less than 0xe0 or 0xf0 followed by a + // value less than 0x90 is considered an overlong encoding. + (inputBuffer[inputIndex] == '\xe0' && + (inputBuffer[inputIndex + 1] & 0xe0) == 0x80) || + (inputBuffer[inputIndex] == '\xf0' && + (inputBuffer[inputIndex + 1] & 0xf0) == 0x80) || + // 0xf4 followed by a byte >= 0x90 looks valid to + // tryGetCharLength, but is actually outside the range of valid + // code points. + (inputBuffer[inputIndex] == '\xf4' && + (inputBuffer[inputIndex + 1] & 0xf0) != 0x80) || + // The bytes 0xf5-0xff, 0xc0, and 0xc1 look like the start of + // multi-byte code points to tryGetCharLength, but are not part of + // any valid code point. + (unsigned char)inputBuffer[inputIndex] > 0xf4 || + inputBuffer[inputIndex] == '\xc0' || + inputBuffer[inputIndex] == '\xc1'; + } + + if (isMultipleInvalidSequences) { + // For overlong encodings we write a replacement character for each + // byte. + for (int i = charLength; i < 0; ++i) { + std::memcpy( + outputBuffer + outIndex, + kEncodedReplacementCharacter.data(), + kEncodedReplacementCharacter.size()); + outIndex += kEncodedReplacementCharacter.size(); + } + } else { + // For other invalid encodings we write a single replacement character + // regardless of the number of bytes in the invalid character. + std::memcpy( + outputBuffer + outIndex, + kEncodedReplacementCharacter.data(), + kEncodedReplacementCharacter.size()); + outIndex += kEncodedReplacementCharacter.size(); + } + + inputIndex += -charLength; + } } } output.resize(outIndex); diff --git a/velox/functions/prestosql/tests/URLFunctionsTest.cpp b/velox/functions/prestosql/tests/URLFunctionsTest.cpp index 0ab32784e33e..5c4455e22859 100644 --- a/velox/functions/prestosql/tests/URLFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/URLFunctionsTest.cpp @@ -496,6 +496,35 @@ TEST_F(URLFunctionsTest, urlEncode) { urlEncode("http://\u30c6\u30b9\u30c8")); EXPECT_EQ("%7E%40%3A.-*_%2B+%E2%98%83", urlEncode("~@:.-*_+ \u2603")); EXPECT_EQ("test", urlEncode("test")); + // Test a single byte invalid UTF-8 character. + EXPECT_EQ("te%EF%BF%BDst", urlEncode("te\x88st")); + // Test a multi-byte invalid UTF-8 character. (If the first byte is between + // 0xe0 and 0xef, it should be a 3 byte character, but we only have 2 bytes + // here.) + EXPECT_EQ("te%EF%BF%BDst", urlEncode("te\xe0\xb8st")); + // Test an overlong 3 byte UTF-8 character + EXPECT_EQ("%EF%BF%BD%EF%BF%BD", urlEncode("\xe0\x94")); + // Test an overlong 3 byte UTF-8 character with a continuation byte. + EXPECT_EQ("%EF%BF%BD%EF%BF%BD%EF%BF%BD", urlEncode("\xe0\x94\x83")); + // Test an overlong 4 byte UTF-8 character + EXPECT_EQ("%EF%BF%BD%EF%BF%BD", urlEncode("\xf0\x84")); + // Test an overlong 4 byte UTF-8 character with continuation bytes. + EXPECT_EQ( + "%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD", urlEncode("\xf0\x84\x90\x90")); + // Test a 4 byte UTF-8 character outside the range of valid values. + EXPECT_EQ( + "%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD", urlEncode("\xfa\x80\x80\x80")); + // Test the beginning of a 4 byte UTF-8 character followed by a + // non-continuation byte. + EXPECT_EQ("%EF%BF%BD%EF%BF%BD", urlEncode("\xf0\xe0")); + // Test the invalid byte 0xc0. + EXPECT_EQ("%EF%BF%BD%EF%BF%BD", urlEncode("\xc0\x83")); + // Test the invalid byte 0xc1. + EXPECT_EQ("%EF%BF%BD%EF%BF%BD", urlEncode("\xc1\x83")); + // Test a 4 byte UTF-8 character that looks valid, but is actually outside the + // range of valid values. + EXPECT_EQ( + "%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD", urlEncode("\xf4\x92\x83\x83")); } TEST_F(URLFunctionsTest, urlDecode) { From e0d21686beba51fb44031c8cd2741c36810d8e97 Mon Sep 17 00:00:00 2001 From: Kevin Wilfong Date: Mon, 18 Nov 2024 15:22:00 -0800 Subject: [PATCH 3/3] feat(function): Handle unescaped UTF-8 characters in Presto url_extract_* UDFs (#11535) Summary: Presto Java supports UTF-8 characters that are not control or whitespace characters appearing anywhere in a URL where a % escaped character can appear. This change modifies Velox's URIParser to do the same. Velox's URIParser would produce incorrect results when any non-ASCII character appeared anywhere in the URL and this has been fixed as well. In order to facilitate this I modified the tryGetCharLength helper function in UTF8Utils to take in a int32_t reference which it populates with the code point if the UTF-8 character is valid. It was already calculating this value and throwing it away, returning it allows me to avoid an additional call to repeat those steps and is consistent with the Airlift function on which it's based. Differential Revision: D65927918 --- velox/functions/lib/Utf8Utils.cpp | 9 +-- velox/functions/lib/Utf8Utils.h | 4 +- velox/functions/lib/tests/Utf8Test.cpp | 51 ++++++++------ velox/functions/prestosql/FromUtf8.cpp | 9 ++- velox/functions/prestosql/URIParser.cpp | 54 +++++++++++---- velox/functions/prestosql/URLFunctions.h | 5 +- .../prestosql/tests/URLFunctionsTest.cpp | 66 +++++++++++++++++-- velox/functions/sparksql/Split.h | 6 +- 8 files changed, 153 insertions(+), 51 deletions(-) diff --git a/velox/functions/lib/Utf8Utils.cpp b/velox/functions/lib/Utf8Utils.cpp index 17a26a633f5f..2aa1f31fd6e3 100644 --- a/velox/functions/lib/Utf8Utils.cpp +++ b/velox/functions/lib/Utf8Utils.cpp @@ -61,7 +61,7 @@ int firstByteCharLength(const char* u_input) { } // namespace -int32_t tryGetCharLength(const char* input, int64_t size) { +int32_t tryGetCharLength(const char* input, int64_t size, int32_t& codePoint) { VELOX_DCHECK_NOT_NULL(input); VELOX_DCHECK_GT(size, 0); @@ -72,6 +72,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) { if (charLength == 1) { // Normal ASCII: 0xxx_xxxx. + codePoint = input[0]; return 1; } @@ -89,7 +90,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) { if (charLength == 2) { // 110x_xxxx 10xx_xxxx - int codePoint = ((firstByte & 0b00011111) << 6) | (secondByte & 0b00111111); + codePoint = ((firstByte & 0b00011111) << 6) | (secondByte & 0b00111111); // Fail if overlong encoding. return codePoint < 0x80 ? -2 : 2; } @@ -106,7 +107,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) { if (charLength == 3) { // 1110_xxxx 10xx_xxxx 10xx_xxxx - int codePoint = ((firstByte & 0b00001111) << 12) | + codePoint = ((firstByte & 0b00001111) << 12) | ((secondByte & 0b00111111) << 6) | (thirdByte & 0b00111111); // Surrogates are invalid. @@ -132,7 +133,7 @@ int32_t tryGetCharLength(const char* input, int64_t size) { if (charLength == 4) { // 1111_0xxx 10xx_xxxx 10xx_xxxx 10xx_xxxx - int codePoint = ((firstByte & 0b00000111) << 18) | + codePoint = ((firstByte & 0b00000111) << 18) | ((secondByte & 0b00111111) << 12) | ((thirdByte & 0b00111111) << 6) | (forthByte & 0b00111111); // Fail if overlong encoding or above upper bound of Unicode. diff --git a/velox/functions/lib/Utf8Utils.h b/velox/functions/lib/Utf8Utils.h index 369e1151e93f..6373477887fd 100644 --- a/velox/functions/lib/Utf8Utils.h +++ b/velox/functions/lib/Utf8Utils.h @@ -45,12 +45,14 @@ namespace facebook::velox::functions { /// /// @param input Pointer to the first byte of the code point. Must not be null. /// @param size Number of available bytes. Must be greater than zero. +/// @param codePoint Populated with the code point it refers to. This is only +/// valid if the return value is positive. /// @return the length of the code point or negative the number of bytes in the /// invalid UTF-8 sequence. /// /// Adapted from tryGetCodePointAt in /// https://github.com/airlift/slice/blob/master/src/main/java/io/airlift/slice/SliceUtf8.java -int32_t tryGetCharLength(const char* input, int64_t size); +int32_t tryGetCharLength(const char* input, int64_t size, int32_t& codePoint); /// Return the length in byte of the next UTF-8 encoded character at the /// beginning of `string`. If the beginning of `string` is not valid UTF-8 diff --git a/velox/functions/lib/tests/Utf8Test.cpp b/velox/functions/lib/tests/Utf8Test.cpp index 4330d1d9bbd2..d2e697cb883f 100644 --- a/velox/functions/lib/tests/Utf8Test.cpp +++ b/velox/functions/lib/tests/Utf8Test.cpp @@ -21,53 +21,62 @@ namespace facebook::velox::functions { namespace { TEST(Utf8Test, tryCharLength) { + int32_t codepoint; // Single-byte ASCII character. - ASSERT_EQ(1, tryGetCharLength("Hello", 5)); + ASSERT_EQ(1, tryGetCharLength("Hello", 5, codepoint)); + ASSERT_EQ('H', codepoint); // 2-byte character. British pound sign. static const char* kPound = "\u00A3tail"; - ASSERT_EQ(2, tryGetCharLength(kPound, 5)); + ASSERT_EQ(2, tryGetCharLength(kPound, 5, codepoint)); + ASSERT_EQ(0xA3, codepoint); // First byte alone is not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kPound, 1)); + ASSERT_EQ(-1, tryGetCharLength(kPound, 1, codepoint)); // Second byte alone is not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kPound + 1, 5)); + ASSERT_EQ(-1, tryGetCharLength(kPound + 1, 5, codepoint)); // ASCII character 't' after the pound sign is valid. - ASSERT_EQ(1, tryGetCharLength(kPound + 2, 5)); + ASSERT_EQ(1, tryGetCharLength(kPound + 2, 5, codepoint)); // 3-byte character. Euro sign. static const char* kEuro = "\u20ACtail"; - ASSERT_EQ(3, tryGetCharLength(kEuro, 5)); + ASSERT_EQ(3, tryGetCharLength(kEuro, 5, codepoint)); + ASSERT_EQ(0x20AC, codepoint); // First byte or first 2 bytes alone are not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kEuro, 1)); - ASSERT_EQ(-2, tryGetCharLength(kEuro, 2)); + ASSERT_EQ(-1, tryGetCharLength(kEuro, 1, codepoint)); + ASSERT_EQ(-2, tryGetCharLength(kEuro, 2, codepoint)); // Byte sequence starting from 2nd or 3rd byte is not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kEuro + 1, 5)); - ASSERT_EQ(-1, tryGetCharLength(kEuro + 2, 5)); - ASSERT_EQ(1, tryGetCharLength(kEuro + 3, 5)); + ASSERT_EQ(-1, tryGetCharLength(kEuro + 1, 5, codepoint)); + ASSERT_EQ(-1, tryGetCharLength(kEuro + 2, 5, codepoint)); // ASCII character 't' after the euro sign is valid. - ASSERT_EQ(1, tryGetCharLength(kPound + 4, 5)); + ASSERT_EQ(1, tryGetCharLength(kEuro + 3, 5, codepoint)); + ASSERT_EQ('t', codepoint); + ASSERT_EQ(1, tryGetCharLength(kEuro + 4, 5, codepoint)); + ASSERT_EQ('a', codepoint); // 4-byte character. Musical symbol F CLEF. static const char* kClef = "\U0001D122tail"; - ASSERT_EQ(4, tryGetCharLength(kClef, 5)); + ASSERT_EQ(4, tryGetCharLength(kClef, 5, codepoint)); + ASSERT_EQ(0x1D122, codepoint); // First byte, first 2 bytes, or first 3 bytes alone are not a valid // character. - ASSERT_EQ(-1, tryGetCharLength(kClef, 1)); - ASSERT_EQ(-2, tryGetCharLength(kClef, 2)); - ASSERT_EQ(-3, tryGetCharLength(kClef, 3)); + ASSERT_EQ(-1, tryGetCharLength(kClef, 1, codepoint)); + ASSERT_EQ(-2, tryGetCharLength(kClef, 2, codepoint)); + ASSERT_EQ(-3, tryGetCharLength(kClef, 3, codepoint)); // Byte sequence starting from 2nd, 3rd or 4th byte is not a valid character. - ASSERT_EQ(-1, tryGetCharLength(kClef + 1, 3)); - ASSERT_EQ(-1, tryGetCharLength(kClef + 2, 3)); - ASSERT_EQ(-1, tryGetCharLength(kClef + 3, 3)); + ASSERT_EQ(-1, tryGetCharLength(kClef + 1, 3, codepoint)); + ASSERT_EQ(-1, tryGetCharLength(kClef + 2, 3, codepoint)); + ASSERT_EQ(-1, tryGetCharLength(kClef + 3, 3, codepoint)); // ASCII character 't' after the clef sign is valid. - ASSERT_EQ(1, tryGetCharLength(kClef + 4, 5)); + ASSERT_EQ(1, tryGetCharLength(kClef + 4, 5, codepoint)); + ASSERT_EQ('t', codepoint); // Test overlong encoding. auto tryCharLength = [](const std::vector& bytes) { + int32_t codepoint; return tryGetCharLength( - reinterpret_cast(bytes.data()), bytes.size()); + reinterpret_cast(bytes.data()), bytes.size(), codepoint); }; // 2-byte encoding of 0x2F. diff --git a/velox/functions/prestosql/FromUtf8.cpp b/velox/functions/prestosql/FromUtf8.cpp index c538db022961..0db0564bdbbe 100644 --- a/velox/functions/prestosql/FromUtf8.cpp +++ b/velox/functions/prestosql/FromUtf8.cpp @@ -165,8 +165,9 @@ class FromUtf8Function : public exec::VectorFunction { auto replacement = decoded.valueAt(row); if (!replacement.empty()) { + int32_t codePoint; auto charLength = - tryGetCharLength(replacement.data(), replacement.size()); + tryGetCharLength(replacement.data(), replacement.size(), codePoint); VELOX_USER_CHECK_GT( charLength, 0, "Replacement is not a valid UTF-8 character"); VELOX_USER_CHECK_EQ( @@ -188,8 +189,9 @@ class FromUtf8Function : public exec::VectorFunction { int32_t pos = 0; while (pos < value.size()) { + int32_t codePoint; auto charLength = - tryGetCharLength(value.data() + pos, value.size() - pos); + tryGetCharLength(value.data() + pos, value.size() - pos, codePoint); if (charLength < 0) { firstInvalidRow = row; return false; @@ -267,8 +269,9 @@ class FromUtf8Function : public exec::VectorFunction { int32_t pos = 0; while (pos < input.size()) { + int32_t codePoint; auto charLength = - tryGetCharLength(input.data() + pos, input.size() - pos); + tryGetCharLength(input.data() + pos, input.size() - pos, codePoint); if (charLength > 0) { fixedWriter.append(std::string_view(input.data() + pos, charLength)); pos += charLength; diff --git a/velox/functions/prestosql/URIParser.cpp b/velox/functions/prestosql/URIParser.cpp index 3f56a7d41f5e..178c88191fee 100644 --- a/velox/functions/prestosql/URIParser.cpp +++ b/velox/functions/prestosql/URIParser.cpp @@ -15,6 +15,8 @@ */ #include "velox/functions/prestosql/URIParser.h" +#include "velox/external/utf8proc/utf8procImpl.h" +#include "velox/functions/lib/Utf8Utils.h" namespace facebook::velox::functions { @@ -40,6 +42,11 @@ Mask createMask(const std::vector& values) { return mask; } + +bool test(const Mask& mask, char value) { + return value < mask.size() && mask.test(value); +} + // a-z or A-Z. const Mask kAlpha = createMask('a', 'z') | createMask('A', 'Z'); // 0-9. @@ -128,7 +135,8 @@ bool tryConsumePercentEncoded(const char* str, const size_t len, int32_t& pos) { return false; } - if (str[pos] != '%' || !kHex.test(str[pos + 1]) || !kHex.test(str[pos + 2])) { + if (str[pos] != '%' || !test(kHex, str[pos + 1]) || + !test(kHex, str[pos + 2])) { return false; } @@ -138,7 +146,8 @@ bool tryConsumePercentEncoded(const char* str, const size_t len, int32_t& pos) { } // Helper function that consumes as much of `str` from `pos` as possible where a -// character passes mask or is part of a percent encoded character. +// character passes mask, is part of a percent encoded character, or is an +// allowed UTF-8 character. // // `pos` is updated to the first character in `str` that was not consumed and // `hasEncoded` is set to true if any percent encoded characters were @@ -150,7 +159,7 @@ void consume( int32_t& pos, bool& hasEncoded) { while (pos < len) { - if (mask.test(str[pos])) { + if (test(mask, str[pos])) { pos++; continue; } @@ -160,6 +169,29 @@ void consume( continue; } + // Masks cover all ASCII characters, check if this is an allowed UTF-8 + // character. + if ((unsigned char)str[pos] > 127) { + // Get the UTF-8 code point. + int32_t codePoint; + auto valid = tryGetCharLength(str + pos, len - pos, codePoint); + + // Check if it's a valid UTF-8 character. + // The range after ASCII characters up to 159 covers control characters + // which are not allowed. + if (valid > 0 && codePoint > 159) { + const auto category = utf8proc_get_property(codePoint)->category; + // White space characters are also not allowed. The range of categories + // excluded here are categories of white space. + if (category < UTF8PROC_CATEGORY_ZS || + category > UTF8PROC_CATEGORY_ZP) { + // Increment over the whole (potentially multi-byte) character. + pos += valid; + continue; + } + } + } + break; } } @@ -297,7 +329,7 @@ bool tryConsumeIPV6Address(const char* str, const size_t len, int32_t& pos) { while (posInAddress < len && numBytes < 16) { int32_t posInHex = posInAddress; for (int i = 0; i < 4; i++) { - if (posInHex == len || !kHex.test(str[posInHex])) { + if (posInHex == len || !test(kHex, str[posInHex])) { break; } @@ -333,7 +365,7 @@ bool tryConsumeIPV6Address(const char* str, const size_t len, int32_t& pos) { posInAddress = posInHex + 2; } } else { - if (posInHex == len || !kHex.test(str[posInHex + 1])) { + if (posInHex == len || !test(kHex, str[posInHex + 1])) { // Peak ahead, we can't end on a single ':'. return false; } @@ -375,7 +407,7 @@ bool tryConsumeIPVFuture(const char* str, const size_t len, int32_t& pos) { // Consume a string of hex digits. int32_t posInHex = posInAddress; while (posInHex < len) { - if (kHex.test(str[posInHex])) { + if (test(kHex, str[posInHex])) { posInHex++; } else { break; @@ -399,7 +431,7 @@ bool tryConsumeIPVFuture(const char* str, const size_t len, int32_t& pos) { int32_t posInSuffix = posInAddress; while (posInSuffix < len) { - if (kIPVFutureSuffixOrUserInfo.test(str[posInSuffix])) { + if (test(kIPVFutureSuffixOrUserInfo, str[posInSuffix])) { posInSuffix++; } else { break; @@ -450,7 +482,7 @@ void consumePort(const char* str, const size_t len, int32_t& pos, URI& uri) { int32_t posInPort = pos; while (posInPort < len) { - if (kNum.test(str[posInPort])) { + if (test(kNum, str[posInPort])) { posInPort++; continue; } @@ -471,7 +503,7 @@ void consumeHost(const char* str, const size_t len, int32_t& pos, URI& uri) { int32_t posInIPV4Address = posInHost; if (tryConsumeIPV4Address(str, len, posInIPV4Address) && (posInIPV4Address == len || - kFollowingHost.test(str[posInIPV4Address]))) { + test(kFollowingHost, str[posInIPV4Address]))) { // reg-name and IPv4 addresses are hard to distinguish, a reg-name could // have a valid IPv4 address as a prefix, but treating that prefix as an // IPv4 address would make this URI invalid. We make sure that if we @@ -534,14 +566,14 @@ bool tryConsumeScheme( int32_t posInScheme = pos; // The scheme must start with a letter. - if (posInScheme == len || !kAlpha.test(str[posInScheme])) { + if (posInScheme == len || !test(kAlpha, str[posInScheme])) { return false; } // Consume the first letter. posInScheme++; - while (posInScheme < len && kScheme.test(str[posInScheme])) { + while (posInScheme < len && test(kScheme, str[posInScheme])) { posInScheme++; } diff --git a/velox/functions/prestosql/URLFunctions.h b/velox/functions/prestosql/URLFunctions.h index a8f0abfde386..6397bc268723 100644 --- a/velox/functions/prestosql/URLFunctions.h +++ b/velox/functions/prestosql/URLFunctions.h @@ -74,8 +74,9 @@ FOLLY_ALWAYS_INLINE void urlEscape(TOutString& output, const TInString& input) { outputBuffer[outIndex++] = '+'; inputIndex++; } else { - const auto charLength = - tryGetCharLength(inputBuffer + inputIndex, inputSize - inputIndex); + int32_t codePoint; + const auto charLength = tryGetCharLength( + inputBuffer + inputIndex, inputSize - inputIndex, codePoint); if (charLength > 0) { for (int i = 0; i < charLength; ++i) { charEscape(inputBuffer[inputIndex + i], outputBuffer + outIndex); diff --git a/velox/functions/prestosql/tests/URLFunctionsTest.cpp b/velox/functions/prestosql/tests/URLFunctionsTest.cpp index 5c4455e22859..06ed4df79b37 100644 --- a/velox/functions/prestosql/tests/URLFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/URLFunctionsTest.cpp @@ -347,7 +347,7 @@ TEST_F(URLFunctionsTest, extractHostRegName) { // Test minimal. EXPECT_EQ("a", extractHost("http://a")); - // Test all valid characters. + // Test all valid ASCII characters. EXPECT_EQ( "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=", extractHost( @@ -357,9 +357,31 @@ TEST_F(URLFunctionsTest, extractHostRegName) { "123.456.789.012.abcdefg", extractHost("http://123.456.789.012.abcdefg")); // Test percent encoded. EXPECT_EQ("a b", extractHost("http://a%20b")); + // Valid UTF-8 in host reg name. + EXPECT_EQ("你好", extractHost("https://你好")); + // Valid UTF-8 in userinfo. + EXPECT_EQ("foo", extractHost("https://你好@foo")); - // Invalid character. + // Invalid ASCII character. EXPECT_EQ(std::nullopt, extractHost("http://a b")); + // Inalid UTF-8 in host reg name (it should be a 3 byte character but there's + // only 2 bytes). + EXPECT_EQ(std::nullopt, extractHost("https://\xe0\xb8")); + // Inalid UTF-8 in userinfo (it should be a 3 byte character but there's only + // 2 bytes). + EXPECT_EQ(std::nullopt, extractHost("https://\xe0\xb8@foo")); + // Valid UTF-8 in host reg name but character is not allowed (it's a control + // character). + EXPECT_EQ(std::nullopt, extractHost("https://\x82")); + // Valid UTF-8 in userinfo but character is not allowed (it's a control + // character). + EXPECT_EQ(std::nullopt, extractHost("https://\x82@foo")); + // Valid UTF-8 in host reg name but character is not allowed (it's white + // space: THREE-PER-EM SPACE). + EXPECT_EQ(std::nullopt, extractHost("https://\xe2\x80\x84")); + // Valid UTF-8 in userinfo but character is not allowed (it's white space: + // THREE-PER-EM SPACE). + EXPECT_EQ(std::nullopt, extractHost("https://\xe2\x80\x84@foo")); } TEST_F(URLFunctionsTest, extractPath) { @@ -380,11 +402,21 @@ TEST_F(URLFunctionsTest, extractPath) { EXPECT_EQ("foo", extractPath("foo")); EXPECT_EQ(std::nullopt, extractPath("BAD URL!")); EXPECT_EQ("", extractPath("http://www.yahoo.com")); - // All valid characters. + // All valid ASCII characters. EXPECT_EQ( "/abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@", extractPath( "/abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@")); + // Valid UTF-8 in path. + EXPECT_EQ("/你好", extractPath("https://foo.com/你好")); + // Inalid UTF-8 in path (it should be a 3 byte character but there's only 2 + // bytes). + EXPECT_EQ(std::nullopt, extractPath("https://foo.com/\xe0\xb8")); + // Valid UTF-8 but character is not allowed (it's a control character). + EXPECT_EQ(std::nullopt, extractPath("https://foo.com/\xc2\x82")); + // Valid UTF-8 but character is not allowed (it's white space: THREE-PER-EM + // SPACE). + EXPECT_EQ(std::nullopt, extractPath("https://foo.com/\xe2\x80\x84")); } TEST_F(URLFunctionsTest, extractPort) { @@ -430,11 +462,21 @@ TEST_F(URLFunctionsTest, extractQuery) { EXPECT_EQ("", extractQuery("http://www.yahoo.com?")); // Test non-empty query. EXPECT_EQ("a", extractQuery("http://www.yahoo.com?a")); - // Test all valid characters. + // Test all valid ASCII characters. EXPECT_EQ( "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@ ", extractQuery( "http://www.yahoo.com?abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@%20")); + // Valid UTF-8 in query. + EXPECT_EQ("你好", extractQuery("https://foo.com?你好")); + // Inalid UTF-8 in query (it should be a 3 byte character but there's only 2 + // bytes). + EXPECT_EQ(std::nullopt, extractQuery("https://foo.com?\xe0\xb8")); + // Valid UTF-8 but character is not allowed (it's a control character). + EXPECT_EQ(std::nullopt, extractQuery("https://foo.com?\xc2\x82")); + // Valid UTF-8 but character is not allowed (it's white space: THREE-PER-EM + // SPACE). + EXPECT_EQ(std::nullopt, extractQuery("https://foo.com?\xe2\x80\x84")); } TEST_F(URLFunctionsTest, extractFragment) { @@ -442,15 +484,25 @@ TEST_F(URLFunctionsTest, extractFragment) { return evaluateOnce("url_extract_fragment(c0)", url); }; - // Test empty query. + // Test empty fragment. EXPECT_EQ("", extractFragment("http://www.yahoo.com#")); - // Test non-empty query. + // Test non-empty fragment. EXPECT_EQ("a", extractFragment("http://www.yahoo.com#a")); - // Test all valid characters. + // Test all valid ASCII characters. EXPECT_EQ( "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@ ", extractFragment( "http://www.yahoo.com#abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~!$&'()*+,;=:@%20")); + // Valid UTF-8 in fgrament. + EXPECT_EQ("你好", extractFragment("https://foo.com#你好")); + // Inalid UTF-8 in fragment (it should be a 3 byte character but there's only + // 2 bytes). + EXPECT_EQ(std::nullopt, extractFragment("https://foo.com#\xe0\xb8")); + // Valid UTF-8 but character is not allowed (it's a control character). + EXPECT_EQ(std::nullopt, extractFragment("https://foo.com#\xc2\x82")); + // Valid UTF-8 but character is not allowed (it's white space: THREE-PER-EM + // SPACE). + EXPECT_EQ(std::nullopt, extractFragment("https://foo.com#\xe2\x80\x84")); } TEST_F(URLFunctionsTest, extractParameter) { diff --git a/velox/functions/sparksql/Split.h b/velox/functions/sparksql/Split.h index 2cee345f77b2..854ae64008f2 100644 --- a/velox/functions/sparksql/Split.h +++ b/velox/functions/sparksql/Split.h @@ -81,7 +81,8 @@ struct Split { size_t pos = 0; int32_t count = 0; while (pos < end && count < limit) { - auto charLength = tryGetCharLength(start + pos, end - pos); + int32_t codePoint; + auto charLength = tryGetCharLength(start + pos, end - pos, codePoint); if (charLength <= 0) { // Invalid UTF-8 character, the length of the invalid // character is the absolute value of result of `tryGetCharLength`. @@ -142,7 +143,8 @@ struct Split { // empty tail string at last, e.g., the result array for split('abc','d|') // is ["a","b","c",""]. if (size == 0) { - auto charLength = tryGetCharLength(start + pos, end - pos); + int32_t codePoint; + auto charLength = tryGetCharLength(start + pos, end - pos, codePoint); if (charLength <= 0) { // Invalid UTF-8 character, the length of the invalid // character is the absolute value of result of `tryGetCharLength`.