Skip to content

Commit

Permalink
[CALCITE-6742] StandardConvertletTable.convertCall loses casts from R…
Browse files Browse the repository at this point in the history
…OW comparisons

Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
  • Loading branch information
mihaibudiu committed Dec 21, 2024
1 parent 0badcce commit 044095e
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,15 @@ private Expression getConvertExpression(
case TINYINT:
case SMALLINT: {
if (SqlTypeName.NUMERIC_TYPES.contains(sourceType.getSqlTypeName())) {
Type javaClass = typeFactory.getJavaClass(targetType);
Primitive primitive = Primitive.of(javaClass);
if (primitive == null) {
primitive = Primitive.ofBox(javaClass);
}
assert primitive != null;
return Expressions.call(
BuiltInMethod.INTEGER_CAST_ROUNDING_MODE.method,
Expressions.constant(Primitive.of(typeFactory.getJavaClass(targetType))),
Expressions.constant(primitive),
operand, Expressions.constant(typeFactory.getTypeSystem().roundingMode()));
}
return defaultExpression.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Stack;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;

import static com.google.common.base.Preconditions.checkArgument;

import static org.apache.calcite.sql.fun.SqlStdOperatorTable.CAST;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.QUANTIFY_OPERATORS;
import static org.apache.calcite.sql.type.NonNullableAccessors.getComponentTypeOrThrow;
import static org.apache.calcite.util.Util.first;
Expand Down Expand Up @@ -143,7 +145,7 @@ private StandardConvertletTable() {
addAlias(SqlLibraryOperators.BITOR_AGG, SqlStdOperatorTable.BIT_OR);

// Register convertlets for specific objects.
registerOp(SqlStdOperatorTable.CAST, this::convertCast);
registerOp(CAST, this::convertCast);
registerOp(SqlLibraryOperators.SAFE_CAST, this::convertCast);
registerOp(SqlLibraryOperators.TRY_CAST, this::convertCast);
registerOp(SqlLibraryOperators.INFIX_CAST, this::convertCast);
Expand Down Expand Up @@ -1151,14 +1153,57 @@ public RexNode convertCall(

// Expand 'ROW (x0, x1, ...) = ROW (y0, y1, ...)'
// to 'x0 = y0 AND x1 = y1 AND ...'
// If there are casts to ROW, apply them fieldwise:
// 'CAST(ROW(x0, x1) AS (t0, t1)) = ROW(y0, y1)' becomes
// 'CAST(x0 as t0) = y0 AND CAST(x1 as t1) = y1'
if (op.kind == SqlKind.EQUALS) {
final RexNode expr0 = RexUtil.removeCast(exprs.get(0));
final RexNode expr1 = RexUtil.removeCast(exprs.get(1));
// For every cast one list with the types of all fields
Stack<List<RelDataTypeField>> leftTypes = new Stack<>();
Stack<List<RelDataTypeField>> rightTypes = new Stack<>();
RexNode expr0 = exprs.get(0);
RexNode expr1 = exprs.get(1);
while (expr0.getKind() == SqlKind.CAST) {
RexCall cast = (RexCall) expr0;
if (cast.type.getSqlTypeName() == SqlTypeName.ROW) {
expr0 = ((RexCall) expr0).operands.get(0);
leftTypes.add(cast.type.getFieldList());
}
}
while (expr1.getKind() == SqlKind.CAST) {
RexCall cast = (RexCall) expr1;
if (cast.type.getSqlTypeName() == SqlTypeName.ROW) {
expr1 = ((RexCall) expr1).operands.get(0);
rightTypes.add(cast.type.getFieldList());
}
}
if (expr0.getKind() == SqlKind.ROW && expr1.getKind() == SqlKind.ROW) {
final RexCall call0 = (RexCall) expr0;
final RexCall call1 = (RexCall) expr1;

List<RexNode> expr0Operands = call0.getOperands();
// Insert the casts in reverse order
while (!leftTypes.isEmpty()) {
List<RexNode> converted = new ArrayList<>();
List<RelDataTypeField> types = leftTypes.pop();
Pair.forEach(types, expr0Operands, (x, y) ->
converted.add(
rexBuilder.makeAbstractCast(
call.getParserPosition(), x.getType(), y, false)));
expr0Operands = converted;
}
List<RexNode> expr1Operands = call1.getOperands();
while (!rightTypes.isEmpty()) {
List<RexNode> converted = new ArrayList<>();
List<RelDataTypeField> types = rightTypes.pop();
Pair.forEach(types, expr1Operands, (x, y) ->
converted.add(
rexBuilder.makeAbstractCast(
call.getParserPosition(), x.getType(), y, false)));
expr1Operands = converted;
}

final List<RexNode> eqList = new ArrayList<>();
Pair.forEach(call0.getOperands(), call1.getOperands(), (x, y) ->
Pair.forEach(expr0Operands, expr1Operands, (x, y) ->
eqList.add(rexBuilder.makeCall(call.getParserPosition(), op, x, y)));
return RexUtil.composeConjunction(rexBuilder, eqList);
}
Expand Down Expand Up @@ -1680,7 +1725,7 @@ private static SqlNode getCastedSqlNode(SqlNode argInput,
if (argRex == null || argRex.getType().equals(varType)) {
return argInput;
}
return SqlStdOperatorTable.CAST.createCall(pos, argInput,
return CAST.createCall(pos, argInput,
SqlTypeUtil.convertTypeToSpec(varType));
}
}
Expand Down Expand Up @@ -1852,7 +1897,7 @@ private static SqlNode getCastedSqlNode(SqlNode argInput,
if (argRex == null || argRex.getType().equals(varType)) {
return argInput;
}
return SqlStdOperatorTable.CAST.createCall(pos, argInput,
return CAST.createCall(pos, argInput,
SqlTypeUtil.convertTypeToSpec(varType));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,23 @@ public static void checkActualAndReferenceFiles() {
sql(sql).ok();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6742">[CALCITE-6742]
* StandardConvertletTable.convertCall loses casts from ROW comparisons</a>. */
@Test void testStructCast() {
final String sql = "select ROW(1, 'x') = ROW('x', 1)";
sql(sql).ok();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6742">[CALCITE-6742]
* StandardConvertletTable.convertCall loses casts from ROW comparisons</a>. */
@Test void testStructCast1() {
final String sql = "select CAST(CAST(ROW('x', 1) AS "
+ "ROW(l INTEGER, r DOUBLE)) AS ROW(l BIGINT, r INTEGER)) = ROW(RAND(), RAND())";
sql(sql).ok();
}

/** As {@link #testSelectOverDistinct()} but for streaming queries. */
@Test void testSelectStreamPartitionDistinct() {
final String sql = "select stream\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7654,6 +7654,28 @@ from orders]]>
LogicalDelta
LogicalProject(ROWTIME=[$0], PRODUCTID=[$1], ORDERID=[$2], C=[COUNT() OVER (PARTITION BY $1 ORDER BY $0 RANGE 1000:INTERVAL SECOND PRECEDING)])
LogicalTableScan(table=[[CATALOG, SALES, ORDERS]])
]]>
</Resource>
</TestCase>
<TestCase name="testStructCast">
<Resource name="sql">
<![CDATA[select ROW(1, 'x') = ROW('x', 1)]]>
</Resource>
<Resource name="plan">
<![CDATA[
LogicalProject(EXPR$0=[=(1, CAST('x'):INTEGER NOT NULL)])
LogicalValues(tuples=[[{ 0 }]])
]]>
</Resource>
</TestCase>
<TestCase name="testStructCast1">
<Resource name="sql">
<![CDATA[select CAST(CAST(ROW('x', 1) AS ROW(l INTEGER, r DOUBLE)) AS ROW(l BIGINT, r INTEGER)) = ROW(RAND(), RAND())]]>
</Resource>
<Resource name="plan">
<![CDATA[
LogicalProject(EXPR$0=[AND(=(CAST(CAST('x'):INTEGER NOT NULL):DOUBLE NOT NULL, RAND()), =(1.0E0, RAND()))])
LogicalValues(tuples=[[{ 0 }]])
]]>
</Resource>
</TestCase>
Expand Down

0 comments on commit 044095e

Please sign in to comment.