diff --git a/core/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java b/core/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java index 245756824e64..b3836d17d8d7 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java +++ b/core/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java @@ -595,17 +595,34 @@ public static boolean sameNamedType(RelDataType t1, RelDataType t2) { } return true; } - RelDataType comp1 = t1.getComponentType(); - RelDataType comp2 = t2.getComponentType(); - if ((comp1 != null) || (comp2 != null)) { - if ((comp1 == null) || (comp2 == null)) { + SqlTypeName t1Name = t1.getSqlTypeName(); + SqlTypeName t2Name = t2.getSqlTypeName(); + if (t1Name == SqlTypeName.ARRAY || t1Name == SqlTypeName.MULTISET) { + if (t1Name != t2Name) { return false; } - if (!sameNamedType(comp1, comp2)) { + + RelDataType comp1 = requireNonNull(t1.getComponentType()); + RelDataType comp2 = requireNonNull(t2.getComponentType()); + return sameNamedType(comp1, comp2); + } + + if (t1Name == SqlTypeName.MAP) { + if (t1Name != t2Name) { + return false; + } + + RelDataType keyType1 = requireNonNull(t1.getKeyType()); + RelDataType keyType2 = requireNonNull(t2.getKeyType()); + if (!sameNamedType(keyType1, keyType2)) { return false; } + RelDataType valueType1 = requireNonNull(t1.getValueType()); + RelDataType valueType2 = requireNonNull(t2.getValueType()); + return sameNamedType(valueType1, valueType2); } - return t1.getSqlTypeName() == t2.getSqlTypeName(); + + return t1Name == t2Name; } /** @@ -1504,6 +1521,28 @@ public static boolean isComparable(RelDataType type1, RelDataType type2) { return true; } + SqlTypeName type1Name = type1.getSqlTypeName(); + SqlTypeName type2Name = type2.getSqlTypeName(); + if (type1Name == SqlTypeName.ARRAY || type1Name == SqlTypeName.MULTISET) { + if (type2Name != type1Name) { + return false; + } + RelDataType elementType1 = requireNonNull(type1.getComponentType()); + RelDataType elementType2 = requireNonNull(type2.getComponentType()); + return isComparable(elementType1, elementType2); + } + + if (type1Name == SqlTypeName.MAP) { + if (type2Name != type1Name) { + return false; + } + RelDataType keyType1 = requireNonNull(type1.getKeyType()); + RelDataType keyType2 = requireNonNull(type2.getKeyType()); + RelDataType valueType1 = requireNonNull(type1.getValueType()); + RelDataType valueType2 = requireNonNull(type2.getValueType()); + return isComparable(keyType1, keyType2) && isComparable(valueType1, valueType2); + } + final RelDataTypeFamily family1 = family(type1); final RelDataTypeFamily family2 = family(type2); if (family1 == family2) { diff --git a/core/src/main/java/org/apache/calcite/sql/validate/implicit/AbstractTypeCoercion.java b/core/src/main/java/org/apache/calcite/sql/validate/implicit/AbstractTypeCoercion.java index 87106eebb8c8..2d5970440da5 100644 --- a/core/src/main/java/org/apache/calcite/sql/validate/implicit/AbstractTypeCoercion.java +++ b/core/src/main/java/org/apache/calcite/sql/validate/implicit/AbstractTypeCoercion.java @@ -504,6 +504,7 @@ private RelDataType getTightestCommonTypeOrThrow( SqlTypeName typeName1 = type1.getSqlTypeName(); SqlTypeName typeName2 = type2.getSqlTypeName(); + // The following can never be true, but there is no harm leaving it here. if (typeName1 == null || typeName2 == null) { return null; } @@ -629,6 +630,43 @@ private RelDataType getTightestCommonTypeOrThrow( } } + if (typeName1 == SqlTypeName.ARRAY || typeName1 == SqlTypeName.MULTISET) { + if (typeName2 != typeName1) { + return null; + } + RelDataType elementType1 = type1.getComponentType(); + RelDataType elementType2 = type2.getComponentType(); + RelDataType type = commonTypeForBinaryComparison(elementType1, elementType2); + if (type == null) { + return null; + } + // The only maxCardinality that seems to be supported is -1, i.e., unlimited. + RelDataType resultType = factory.createArrayType(type, -1); + return factory.createTypeWithNullability( + resultType, type1.isNullable() || type2.isNullable()); + } + + if (typeName1 == SqlTypeName.MAP) { + if (typeName2 != SqlTypeName.MAP) { + return null; + } + RelDataType keyType1 = type1.getKeyType(); + RelDataType keyType2 = type2.getKeyType(); + RelDataType keyType = commonTypeForBinaryComparison(keyType1, keyType2); + if (keyType == null) { + return null; + } + RelDataType valueType1 = type1.getValueType(); + RelDataType valueType2 = type2.getValueType(); + RelDataType valueType = commonTypeForBinaryComparison(valueType1, valueType2); + if (valueType == null) { + return null; + } + RelDataType resultType = factory.createMapType(keyType, valueType); + return factory.createTypeWithNullability( + resultType, type1.isNullable() || type2.isNullable()); + } + return null; } diff --git a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java index 1e9ae827a7ee..37a705070187 100644 --- a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java +++ b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java @@ -501,6 +501,19 @@ static SqlOperatorTable operatorTableFor(SqlLibrary library) { expr("^x'abcd'<>1^") .fails("(?s).*Cannot apply '<>' to arguments of type " + "' <> '.*"); + // Test cases for [CALCITE-6736] Validator accepts comparisons between arrays, multisets, maps + // without regard to element types + expr("^array[x'a4'] = array[1]^") + .fails("(?s).*Cannot apply '=' to arguments of type " + + "' = '.*"); + expr("^MAP[x'a4', 1] = MAP[1, 1]^") + .fails("(?s).*Cannot apply '=' to arguments of type " + + "'<.BINARY.1., INTEGER. MAP> = <.INTEGER, INTEGER. MAP>'.*"); + expr("^array[x'a4'] = 1^") + .fails("(?s).*Cannot apply '=' to arguments of type ' = '.*"); + expr("^multiset[x'a4'] = multiset[1]^") + .fails("(?s).*Cannot apply '=' to arguments of type " + + "' = '.*"); } @Test void testBinaryString() { diff --git a/piglet/src/test/java/org/apache/calcite/test/PigRelOpTest.java b/piglet/src/test/java/org/apache/calcite/test/PigRelOpTest.java index 95de883788d7..ed6301ee0c05 100644 --- a/piglet/src/test/java/org/apache/calcite/test/PigRelOpTest.java +++ b/piglet/src/test/java/org/apache/calcite/test/PigRelOpTest.java @@ -1531,7 +1531,7 @@ private Fluent pig(String script) { final String plan = "" + "LogicalSort(sort0=[$6], dir0=[ASC])\n" + " LogicalProject(DEPTNO=[$0.DEPTNO], MGR=[$0.MGR], HIREDATE=[$0.HIREDATE], " - + "$f3=[COUNT(PIG_BAG($1))], newCol=[1:BIGINT], comArray=[MULTISET_PROJECTION($1, 6)], " + + "$f3=[COUNT(PIG_BAG($1))], newCol=[1:BIGINT], comArray=[CAST(MULTISET_PROJECTION($1, 6)):RecordType(DECIMAL(19, 0) COMM) MULTISET], " + "salSum=[BigDecimalSum(PIG_BAG(MULTISET_PROJECTION($1, 5)))])\n" + " LogicalProject(group=[ROW($0, $1, $2)], A=[$3])\n" + " LogicalAggregate(group=[{0, 1, 2}], A=[COLLECT($3)])\n" @@ -1541,14 +1541,14 @@ private Fluent pig(String script) { final String optimizedPlan = "" + "LogicalSort(sort0=[$6], dir0=[ASC])\n" + " LogicalProject(DEPTNO=[$0], MGR=[$1], HIREDATE=[$2], $f3=[CAST($3):BIGINT], " - + "newCol=[1:BIGINT], comArray=[$4], salSum=[CAST($5):DECIMAL(19, 0)])\n" + + "newCol=[1:BIGINT], comArray=[CAST($4):RecordType(DECIMAL(19, 0) COMM) MULTISET], salSum=[CAST($5):DECIMAL(19, 0)])\n" + " LogicalAggregate(group=[{0, 1, 2}], agg#0=[COUNT()], agg#1=[COLLECT($3)], " + "agg#2=[SUM($4)])\n" + " LogicalProject(DEPTNO=[$7], MGR=[$3], HIREDATE=[$4], COMM=[$6], SAL=[$5])\n" + " LogicalTableScan(table=[[scott, EMP]])\n"; final String sql = "" + "SELECT DEPTNO, MGR, HIREDATE, CAST(COUNT(*) AS BIGINT) AS $f3, 1 AS newCol, " - + "COLLECT(COMM) AS comArray, CAST(SUM(SAL) AS DECIMAL(19, 0)) AS salSum\n" + + "CAST(COLLECT(COMM) AS ROW(COMM DECIMAL(19, 0)) MULTISET) AS comArray, CAST(SUM(SAL) AS DECIMAL(19, 0)) AS salSum\n" + "FROM scott.EMP\n" + "GROUP BY DEPTNO, MGR, HIREDATE\n" + "ORDER BY 7";