Skip to content

Commit

Permalink
Add support for DECIMAL inputs to truncate Presto function (facebooki…
Browse files Browse the repository at this point in the history
…ncubator#10217)

Summary: Pull Request resolved: facebookincubator#10217

Reviewed By: amitkdutta

Differential Revision: D58644512

fbshipit-source-id: f6e593056e2977ea1f38c95df8a47a2e18898d79
  • Loading branch information
mbasmanova authored and facebook-github-bot committed Jun 17, 2024
1 parent 57dc7fe commit 8845ca2
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 0 deletions.
25 changes: 25 additions & 0 deletions velox/docs/functions/presto/decimal.rst
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,28 @@ Decimal Functions
SELECT round(123.45, -1); -- 120.00
SELECT round(123.45, -2); -- 100.00
SELECT round(123.45, -10); -- 0.00

.. function:: truncate(x: decimal(p, s)) -> r: decimal(rp, 0)

Returns 'x' rounded to integer by dropping digits after decimal point.
The scale of the result is 0. The precision is calculated as:
::

pr = max(p - s, 1)

.. function:: truncate(x: decimal(p, s), d: integer) -> r: decimal(rp, s)

Returns ``x`` truncated to ``d`` decimal places.
The precision and scale of the result are the same as the precision and scale of the input.
``d`` can be positive, zero or negative.
When ``d`` is negative truncates ``-d`` digits left of the decimal point.
Returns ``x`` unmodified if ``d`` exceeds the scale of the input.
::

SELECT truncate(999.45, 0); -- 999.00
SELECT truncate(999.45, 1); -- 999.40
SELECT truncate(999.45, 2); -- 999.45
SELECT truncate(999.45, 3); -- 999.45
SELECT truncate(999.45, -1); -- 990.00
SELECT truncate(999.45, -2); -- 900.00
SELECT truncate(999.45, -10); -- 0.00
91 changes: 91 additions & 0 deletions velox/functions/prestosql/DecimalFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,54 @@ struct DecimalFloorFunction {
uint8_t scale_;
};

template <typename TExec>
struct DecimalTruncateFunction {
VELOX_DEFINE_FUNCTION_TYPES(TExec);

template <typename A>
void initialize(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& /*config*/,
A* /*a*/) {
const auto [precision, scale] = getDecimalPrecisionScale(*inputTypes[0]);
precision_ = precision;
scale_ = scale;
}

template <typename A>
void initialize(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
A* a,
const int32_t* /*n*/) {
initialize(inputTypes, config, a);
}

template <typename R, typename A>
void call(R& out, const A& a) {
if UNLIKELY (scale_ == 0 || a == 0) {
out = a;
} else {
out = a / DecimalUtil::kPowersOfTen[scale_];
}
}

template <typename A>
void call(A& out, const A& a, int32_t n) {
if UNLIKELY (a == 0 || (n + precision_ - scale_) <= 0) {
out = 0;
} else if UNLIKELY (scale_ <= n) {
out = a;
} else {
out = a - (a % DecimalUtil::kPowersOfTen[scale_ - n]);
}
}

private:
uint8_t precision_;
uint8_t scale_;
};

template <template <class> typename Func>
void registerDecimalBinary(
const std::string& name,
Expand Down Expand Up @@ -573,4 +621,47 @@ void registerDecimalRound(const std::string& prefix) {
}
}

void registerDecimalTruncate(const std::string& prefix) {
// truncate(decimal) -> decimal
std::vector<exec::SignatureVariable> constraints = {
exec::SignatureVariable(
P2::name(),
fmt::format(
"max({p} - {s}, 1)",
fmt::arg("p", P1::name()),
fmt::arg("s", S1::name())),
exec::ParameterType::kIntegerParameter),
exec::SignatureVariable(
S2::name(), "0", exec::ParameterType::kIntegerParameter),
};

registerFunction<
DecimalTruncateFunction,
ShortDecimal<P2, S2>,
ShortDecimal<P1, S1>>({prefix + "truncate"}, constraints);

registerFunction<
DecimalTruncateFunction,
LongDecimal<P2, S2>,
LongDecimal<P1, S1>>({prefix + "truncate"}, constraints);

registerFunction<
DecimalTruncateFunction,
ShortDecimal<P2, S2>,
LongDecimal<P1, S1>>({prefix + "truncate"}, constraints);

// truncate(decimal, n) -> decimal
registerFunction<
DecimalTruncateFunction,
ShortDecimal<P1, S1>,
ShortDecimal<P1, S1>,
int32_t>({prefix + "truncate"});

registerFunction<
DecimalTruncateFunction,
LongDecimal<P1, S1>,
LongDecimal<P1, S1>,
int32_t>({prefix + "truncate"});
}

} // namespace facebook::velox::functions
5 changes: 5 additions & 0 deletions velox/functions/prestosql/DecimalFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <string>

namespace facebook::velox::functions {

Expand All @@ -30,4 +33,6 @@ void registerDecimalFloor(const std::string& prefix);

void registerDecimalRound(const std::string& prefix);

void registerDecimalTruncate(const std::string& prefix);

} // namespace facebook::velox::functions
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ void registerMathematicalFunctions(const std::string& prefix = "") {

registerDecimalFloor(prefix);
registerDecimalRound(prefix);
registerDecimalTruncate(prefix);
}

} // namespace facebook::velox::functions
101 changes: 101 additions & 0 deletions velox/functions/prestosql/tests/DecimalArithmeticTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,107 @@ TEST_F(DecimalArithmeticTest, floor) {
DECIMAL(19, 19))});
}

TEST_F(DecimalArithmeticTest, truncate) {
// Truncate short decimals.
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({0, 0, 0, 0}, DECIMAL(1, 0))},
"truncate(c0)",
{makeFlatVector<int64_t>({123, 542, -999, 0}, DECIMAL(3, 3))});
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({1111, 1111, -9998, 9999}, DECIMAL(4, 0))},
"truncate(c0)",
{makeFlatVector<int64_t>({11112, 11115, -99989, 99999}, DECIMAL(5, 1))});

// Truncate long decimals.
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({0, 0, 0, 0}, DECIMAL(1, 0))},
"truncate(c0)",
{makeFlatVector<int128_t>(
{1234567890123456789, 5000000000000000000, -9000000000000000000, 0},
DECIMAL(19, 19))});
testDecimalExpr<TypeKind::HUGEINT>(
{makeFlatVector<int128_t>(
{DecimalUtil::kPowersOfTen[37] - 1,
-DecimalUtil::kPowersOfTen[37] + 1},
DECIMAL(37, 0))},
"truncate(c0)",
{makeFlatVector<int128_t>(
{DecimalUtil::kLongDecimalMax, DecimalUtil::kLongDecimalMin},
DECIMAL(38, 1))});

// Min and max short decimals.
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>(
{DecimalUtil::kShortDecimalMax, DecimalUtil::kShortDecimalMin},
DECIMAL(15, 0))},
"truncate(c0)",
{makeFlatVector<int64_t>(
{DecimalUtil::kShortDecimalMax, DecimalUtil::kShortDecimalMin},
DECIMAL(15, 0))});

// Min and max long decimals.
testDecimalExpr<TypeKind::HUGEINT>(
{makeFlatVector<int128_t>(
{DecimalUtil::kLongDecimalMax, DecimalUtil::kLongDecimalMin},
DECIMAL(38, 0))},
"truncate(c0)",
{makeFlatVector<int128_t>(
{DecimalUtil::kLongDecimalMax, DecimalUtil::kLongDecimalMin},
DECIMAL(38, 0))});
}

TEST_F(DecimalArithmeticTest, truncateN) {
// Truncate to 'scale' decimal places.
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(3, 3))},
"truncate(c0, 3::integer)",
{makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(3, 3))});

// Truncate to 'scale' - 1 decimal places.
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({120, 550, -990, 0}, DECIMAL(3, 3))},
"truncate(c0, 2::integer)",
{makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(3, 3))});

// Truncate to 0 decimal places.
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({100, 500, -900, 0}, DECIMAL(3, 2))},
"truncate(c0, 0::integer)",
{makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(3, 2))});

// Truncate to -1 decimal places.
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({100, 500, -900, 0}, DECIMAL(3, 1))},
"truncate(c0, '-1'::integer)",
{makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(3, 1))});

// Truncate to -2 decimal places.
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({0, 0, 0, 0}, DECIMAL(3, 1))},
"truncate(c0, '-2'::integer)",
{makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(3, 1))});

// Truncate long decimals to 'scale' - 5 decimal places.
testDecimalExpr<TypeKind::HUGEINT>(
{makeFlatVector<int128_t>(
{1234567890123400000, 5000000000000000000, -999999999999900000, 0},
DECIMAL(19, 19))},
"truncate(c0, 14::integer)",
{makeFlatVector<int128_t>(
{1234567890123456789, 5000000000000000000, -999999999999999999, 0},
DECIMAL(19, 19))});

// Truncate long decimals to -9 decimal places.
testDecimalExpr<TypeKind::HUGEINT>(
{makeFlatVector<int128_t>(
{1234500000000000000, 5555500000000000000, -999900000000000000, 0},
DECIMAL(19, 5))},
"truncate(c0, '-9'::integer)",
{makeFlatVector<int128_t>(
{1234567890123456789, 5555555555555555555, -999999999999999999, 0},
DECIMAL(19, 5))});
}

TEST_F(DecimalArithmeticTest, abs) {
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({1111, 1112, 9999, 0}, DECIMAL(5, 1))},
Expand Down

0 comments on commit 8845ca2

Please sign in to comment.