diff --git a/velox/serializers/PrestoSerializer.cpp b/velox/serializers/PrestoSerializer.cpp index 80301dde0a7f..e5211b3346f4 100644 --- a/velox/serializers/PrestoSerializer.cpp +++ b/velox/serializers/PrestoSerializer.cpp @@ -220,7 +220,11 @@ struct PrestoHeader { int32_t compressedSize; int64_t checksum; - static PrestoHeader read(ByteInputStream* source) { + static Expected read(ByteInputStream* source) { + if (source->remainingSize() < kHeaderSize) { + return folly::makeUnexpected(Status::Invalid( + fmt::format("{} bytes for header", source->remainingSize()))); + } PrestoHeader header; header.numRows = source->read(); header.pageCodecMarker = source->read(); @@ -228,9 +232,18 @@ struct PrestoHeader { header.compressedSize = source->read(); header.checksum = source->read(); - VELOX_CHECK_GE(header.numRows, 0); - VELOX_CHECK_GE(header.uncompressedSize, 0); - VELOX_CHECK_GE(header.compressedSize, 0); + if (header.numRows < 0) { + return folly::makeUnexpected( + Status::Invalid(fmt::format("negative numRows: {}", header.numRows))); + } + if (header.uncompressedSize < 0) { + return folly::makeUnexpected(Status::Invalid(fmt::format( + "negative uncompressedSize: {}", header.uncompressedSize))); + } + if (header.compressedSize < 0) { + return folly::makeUnexpected(Status::Invalid( + fmt::format("negative compressedSize: {}", header.compressedSize))); + } return header; } @@ -4193,7 +4206,12 @@ void PrestoVectorSerde::deserialize( const auto prestoOptions = toPrestoOptions(options); const auto codec = common::compressionKindToCodec(prestoOptions.compressionKind); - auto const header = PrestoHeader::read(source); + auto maybeHeader = PrestoHeader::read(source); + VELOX_CHECK( + maybeHeader.hasValue(), + fmt::format( + "PrestoPage header is invalid: {}", maybeHeader.error().message())); + auto const header = std::move(maybeHeader.value()); int64_t actualCheckSum = 0; if (isChecksumBitSet(header.pageCodecMarker)) { diff --git a/velox/serializers/tests/PrestoSerializerTest.cpp b/velox/serializers/tests/PrestoSerializerTest.cpp index 472d6b05debf..f898f8ae289d 100644 --- a/velox/serializers/tests/PrestoSerializerTest.cpp +++ b/velox/serializers/tests/PrestoSerializerTest.cpp @@ -201,11 +201,13 @@ class PrestoSerializerTest RowVectorPtr deserialize( const RowTypePtr& rowType, const std::string& input, - const serializer::presto::PrestoVectorSerde::PrestoOptions* - serdeOptions) { + const serializer::presto::PrestoVectorSerde::PrestoOptions* serdeOptions, + bool skipLexer = false) { auto byteStream = toByteStream(input); auto paramOptions = getParamSerdeOptions(serdeOptions); - validateLexer(input, paramOptions); + if (!skipLexer) { + validateLexer(input, paramOptions); + } RowVectorPtr result; serde_->deserialize( byteStream.get(), pool_.get(), rowType, &result, 0, ¶mOptions); @@ -838,6 +840,19 @@ TEST_P(PrestoSerializerTest, emptyPage) { assertEqualVectors(deserialized, rowVector); } +TEST_P(PrestoSerializerTest, invalidPage) { + auto rowVector = makeEmptyTestVector(); + + std::ostringstream out; + serialize(rowVector, &out, nullptr); + + auto invalidPage = ""; // empty string + auto rowType = asRowType(rowVector->type()); + VELOX_ASSERT_THROW( + deserialize(rowType, invalidPage, nullptr, true /*skipLexer*/), + "PrestoPage header invalid: 0 bytes for header"); +} + TEST_P(PrestoSerializerTest, initMemory) { const auto numRows = 100; auto testFunc = [&](TypePtr type, int64_t expectedBytes) {