diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java index e38ac3210b1..d63c512e88a 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java @@ -563,9 +563,14 @@ private Expression getConvertExpression( case TINYINT: case SMALLINT: { if (SqlTypeName.NUMERIC_TYPES.contains(sourceType.getSqlTypeName())) { + Type javaClass = typeFactory.getJavaClass(targetType); + Primitive primitive = Primitive.of(javaClass); + if (primitive == null) { + primitive = Primitive.ofBox(javaClass); + } return Expressions.call( BuiltInMethod.INTEGER_CAST_ROUNDING_MODE.method, - Expressions.constant(Primitive.of(typeFactory.getJavaClass(targetType))), + Expressions.constant(primitive), operand, Expressions.constant(typeFactory.getTypeSystem().roundingMode())); } return defaultExpression.get(); diff --git a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java index 7d9c807cd03..ff0afa6037b 100644 --- a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java +++ b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java @@ -90,6 +90,7 @@ import java.math.BigDecimal; import java.math.RoundingMode; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.List; import java.util.Objects; @@ -100,6 +101,7 @@ import static com.google.common.base.Preconditions.checkArgument; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.CAST; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.QUANTIFY_OPERATORS; import static org.apache.calcite.sql.type.NonNullableAccessors.getComponentTypeOrThrow; import static org.apache.calcite.util.Util.first; @@ -143,7 +145,7 @@ private StandardConvertletTable() { addAlias(SqlLibraryOperators.BITOR_AGG, SqlStdOperatorTable.BIT_OR); // Register convertlets for specific objects. - registerOp(SqlStdOperatorTable.CAST, this::convertCast); + registerOp(CAST, this::convertCast); registerOp(SqlLibraryOperators.SAFE_CAST, this::convertCast); registerOp(SqlLibraryOperators.TRY_CAST, this::convertCast); registerOp(SqlLibraryOperators.INFIX_CAST, this::convertCast); @@ -1151,14 +1153,61 @@ public RexNode convertCall( // Expand 'ROW (x0, x1, ...) = ROW (y0, y1, ...)' // to 'x0 = y0 AND x1 = y1 AND ...' + // If there are casts to ROW, apply them fieldwise: + // 'CAST(ROW(x0, x1) AS (t0, t1)) = ROW(y0, y1)' becomes + // 'CAST(x0 as t0) = y0 AND CAST(x1 as t1) = y1' if (op.kind == SqlKind.EQUALS) { - final RexNode expr0 = RexUtil.removeCast(exprs.get(0)); - final RexNode expr1 = RexUtil.removeCast(exprs.get(1)); + // For every cast one list with the types of all fields + ArrayDeque> leftTypes = new ArrayDeque<>(); + ArrayDeque> rightTypes = new ArrayDeque<>(); + RexNode expr0 = exprs.get(0); + RexNode expr1 = exprs.get(1); + while (expr0.getKind() == SqlKind.CAST) { + RexCall cast = (RexCall) expr0; + if (cast.type.getSqlTypeName() == SqlTypeName.ROW) { + expr0 = ((RexCall) expr0).operands.get(0); + leftTypes.add(cast.type.getFieldList()); + } else { + break; + } + } + while (expr1.getKind() == SqlKind.CAST) { + RexCall cast = (RexCall) expr1; + if (cast.type.getSqlTypeName() == SqlTypeName.ROW) { + expr1 = ((RexCall) expr1).operands.get(0); + rightTypes.add(cast.type.getFieldList()); + } else { + break; + } + } if (expr0.getKind() == SqlKind.ROW && expr1.getKind() == SqlKind.ROW) { final RexCall call0 = (RexCall) expr0; final RexCall call1 = (RexCall) expr1; + + List expr0Operands = call0.getOperands(); + // Insert the casts in reverse order + while (!leftTypes.isEmpty()) { + List converted = new ArrayList<>(); + List types = leftTypes.removeLast(); + Pair.forEach(types, expr0Operands, (x, y) -> + converted.add( + rexBuilder.makeAbstractCast( + call.getParserPosition(), x.getType(), y, false))); + expr0Operands = converted; + } + List expr1Operands = call1.getOperands(); + while (!rightTypes.isEmpty()) { + List converted = new ArrayList<>(); + List types = rightTypes.removeLast(); + Pair.forEach(types, expr1Operands, (x, y) -> + converted.add( + rexBuilder.makeAbstractCast( + call.getParserPosition(), x.getType(), y, false))); + expr1Operands = converted; + } + final List eqList = new ArrayList<>(); - Pair.forEach(call0.getOperands(), call1.getOperands(), (x, y) -> + Pair.forEach(expr0Operands, expr1Operands, (x, y) -> eqList.add(rexBuilder.makeCall(call.getParserPosition(), op, x, y))); return RexUtil.composeConjunction(rexBuilder, eqList); } @@ -1680,7 +1729,7 @@ private static SqlNode getCastedSqlNode(SqlNode argInput, if (argRex == null || argRex.getType().equals(varType)) { return argInput; } - return SqlStdOperatorTable.CAST.createCall(pos, argInput, + return CAST.createCall(pos, argInput, SqlTypeUtil.convertTypeToSpec(varType)); } } @@ -1852,7 +1901,7 @@ private static SqlNode getCastedSqlNode(SqlNode argInput, if (argRex == null || argRex.getType().equals(varType)) { return argInput; } - return SqlStdOperatorTable.CAST.createCall(pos, argInput, + return CAST.createCall(pos, argInput, SqlTypeUtil.convertTypeToSpec(varType)); } } diff --git a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java index 90ed3b69eb7..d97a73f163d 100644 --- a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java +++ b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java @@ -852,6 +852,23 @@ public static void checkActualAndReferenceFiles() { sql(sql).ok(); } + /** Test case for + * [CALCITE-6742] + * StandardConvertletTable.convertCall loses casts from ROW comparisons. */ + @Test void testStructCast() { + final String sql = "select ROW(1, 'x') = ROW('y', 1)"; + sql(sql).ok(); + } + + /** Test case for + * [CALCITE-6742] + * StandardConvertletTable.convertCall loses casts from ROW comparisons. */ + @Test void testStructCast1() { + final String sql = "select CAST(CAST(ROW('x', 1) AS " + + "ROW(l INTEGER, r DOUBLE)) AS ROW(l BIGINT, r INTEGER)) = ROW(RAND(), RAND())"; + sql(sql).ok(); + } + /** As {@link #testSelectOverDistinct()} but for streaming queries. */ @Test void testSelectStreamPartitionDistinct() { final String sql = "select stream\n" diff --git a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml index 98ab6f3fe53..49a6200e7de 100644 --- a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml @@ -7691,6 +7691,28 @@ from orders]]> LogicalDelta LogicalProject(ROWTIME=[$0], PRODUCTID=[$1], ORDERID=[$2], C=[COUNT() OVER (PARTITION BY $1 ORDER BY $0 RANGE 1000:INTERVAL SECOND PRECEDING)]) LogicalTableScan(table=[[CATALOG, SALES, ORDERS]]) +]]> + + + + + + + + + + + + + + + +