Skip to content

Commit

Permalink
[CALCITE-6736] Validator accepts comparisons between arrays, multiset…
Browse files Browse the repository at this point in the history
…s, maps without regard to element types

Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
  • Loading branch information
mihaibudiu committed Dec 19, 2024
1 parent c1ca226 commit 01cfeee
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 9 deletions.
51 changes: 45 additions & 6 deletions core/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -595,17 +595,34 @@ public static boolean sameNamedType(RelDataType t1, RelDataType t2) {
}
return true;
}
RelDataType comp1 = t1.getComponentType();
RelDataType comp2 = t2.getComponentType();
if ((comp1 != null) || (comp2 != null)) {
if ((comp1 == null) || (comp2 == null)) {
SqlTypeName t1Name = t1.getSqlTypeName();
SqlTypeName t2Name = t2.getSqlTypeName();
if (t1Name == SqlTypeName.ARRAY || t1Name == SqlTypeName.MULTISET) {
if (t1Name != t2Name) {
return false;
}
if (!sameNamedType(comp1, comp2)) {

RelDataType comp1 = requireNonNull(t1.getComponentType());
RelDataType comp2 = requireNonNull(t2.getComponentType());
return sameNamedType(comp1, comp2);
}

if (t1Name == SqlTypeName.MAP) {
if (t1Name != t2Name) {
return false;
}

RelDataType keyType1 = requireNonNull(t1.getKeyType());
RelDataType keyType2 = requireNonNull(t2.getKeyType());
if (!sameNamedType(keyType1, keyType2)) {
return false;
}
RelDataType valueType1 = requireNonNull(t1.getValueType());
RelDataType valueType2 = requireNonNull(t2.getValueType());
return sameNamedType(valueType1, valueType2);
}
return t1.getSqlTypeName() == t2.getSqlTypeName();

return t1Name == t2Name;
}

/**
Expand Down Expand Up @@ -1504,6 +1521,28 @@ public static boolean isComparable(RelDataType type1, RelDataType type2) {
return true;
}

SqlTypeName type1Name = type1.getSqlTypeName();
SqlTypeName type2Name = type2.getSqlTypeName();
if (type1Name == SqlTypeName.ARRAY || type1Name == SqlTypeName.MULTISET) {
if (type2Name != type1Name) {
return false;
}
RelDataType elementType1 = requireNonNull(type1.getComponentType());
RelDataType elementType2 = requireNonNull(type2.getComponentType());
return isComparable(elementType1, elementType2);
}

if (type1Name == SqlTypeName.MAP) {
if (type2Name != type1Name) {
return false;
}
RelDataType keyType1 = requireNonNull(type1.getKeyType());
RelDataType keyType2 = requireNonNull(type2.getKeyType());
RelDataType valueType1 = requireNonNull(type1.getValueType());
RelDataType valueType2 = requireNonNull(type2.getValueType());
return isComparable(keyType1, keyType2) && isComparable(valueType1, valueType2);
}

final RelDataTypeFamily family1 = family(type1);
final RelDataTypeFamily family2 = family(type2);
if (family1 == family2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ private RelDataType getTightestCommonTypeOrThrow(
SqlTypeName typeName1 = type1.getSqlTypeName();
SqlTypeName typeName2 = type2.getSqlTypeName();

// The following can never be true, but there is no harm leaving it here.
if (typeName1 == null || typeName2 == null) {
return null;
}
Expand Down Expand Up @@ -629,6 +630,43 @@ private RelDataType getTightestCommonTypeOrThrow(
}
}

if (typeName1 == SqlTypeName.ARRAY || typeName1 == SqlTypeName.MULTISET) {
if (typeName2 != typeName1) {
return null;
}
RelDataType elementType1 = type1.getComponentType();
RelDataType elementType2 = type2.getComponentType();
RelDataType type = commonTypeForBinaryComparison(elementType1, elementType2);
if (type == null) {
return null;
}
// The only maxCardinality that seems to be supported is -1, i.e., unlimited.
RelDataType resultType = factory.createArrayType(type, -1);
return factory.createTypeWithNullability(
resultType, type1.isNullable() || type2.isNullable());
}

if (typeName1 == SqlTypeName.MAP) {
if (typeName2 != SqlTypeName.MAP) {
return null;
}
RelDataType keyType1 = type1.getKeyType();
RelDataType keyType2 = type2.getKeyType();
RelDataType keyType = commonTypeForBinaryComparison(keyType1, keyType2);
if (keyType == null) {
return null;
}
RelDataType valueType1 = type1.getValueType();
RelDataType valueType2 = type2.getValueType();
RelDataType valueType = commonTypeForBinaryComparison(valueType1, valueType2);
if (valueType == null) {
return null;
}
RelDataType resultType = factory.createMapType(keyType, valueType);
return factory.createTypeWithNullability(
resultType, type1.isNullable() || type2.isNullable());
}

return null;
}

Expand Down
13 changes: 13 additions & 0 deletions core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,19 @@ static SqlOperatorTable operatorTableFor(SqlLibrary library) {
expr("^x'abcd'<>1^")
.fails("(?s).*Cannot apply '<>' to arguments of type "
+ "'<BINARY.2.> <> <INTEGER>'.*");
// Test cases for [CALCITE-6736] Validator accepts comparisons between arrays, multisets, maps
// without regard to element types
expr("^array[x'a4'] = array[1]^")
.fails("(?s).*Cannot apply '=' to arguments of type "
+ "'<BINARY.1. ARRAY> = <INTEGER ARRAY>'.*");
expr("^MAP[x'a4', 1] = MAP[1, 1]^")
.fails("(?s).*Cannot apply '=' to arguments of type "
+ "'<.BINARY.1., INTEGER. MAP> = <.INTEGER, INTEGER. MAP>'.*");
expr("^array[x'a4'] = 1^")
.fails("(?s).*Cannot apply '=' to arguments of type '<BINARY.1. ARRAY> = <INTEGER>'.*");
expr("^multiset[x'a4'] = multiset[1]^")
.fails("(?s).*Cannot apply '=' to arguments of type "
+ "'<BINARY.1. MULTISET> = <INTEGER MULTISET>'.*");
}

@Test void testBinaryString() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,7 @@ private Fluent pig(String script) {
final String plan = ""
+ "LogicalSort(sort0=[$6], dir0=[ASC])\n"
+ " LogicalProject(DEPTNO=[$0.DEPTNO], MGR=[$0.MGR], HIREDATE=[$0.HIREDATE], "
+ "$f3=[COUNT(PIG_BAG($1))], newCol=[1:BIGINT], comArray=[MULTISET_PROJECTION($1, 6)], "
+ "$f3=[COUNT(PIG_BAG($1))], newCol=[1:BIGINT], comArray=[CAST(MULTISET_PROJECTION($1, 6)):RecordType(DECIMAL(19, 0) COMM) MULTISET], "
+ "salSum=[BigDecimalSum(PIG_BAG(MULTISET_PROJECTION($1, 5)))])\n"
+ " LogicalProject(group=[ROW($0, $1, $2)], A=[$3])\n"
+ " LogicalAggregate(group=[{0, 1, 2}], A=[COLLECT($3)])\n"
Expand All @@ -1541,14 +1541,14 @@ private Fluent pig(String script) {
final String optimizedPlan = ""
+ "LogicalSort(sort0=[$6], dir0=[ASC])\n"
+ " LogicalProject(DEPTNO=[$0], MGR=[$1], HIREDATE=[$2], $f3=[CAST($3):BIGINT], "
+ "newCol=[1:BIGINT], comArray=[$4], salSum=[CAST($5):DECIMAL(19, 0)])\n"
+ "newCol=[1:BIGINT], comArray=[CAST($4):RecordType(DECIMAL(19, 0) COMM) MULTISET], salSum=[CAST($5):DECIMAL(19, 0)])\n"
+ " LogicalAggregate(group=[{0, 1, 2}], agg#0=[COUNT()], agg#1=[COLLECT($3)], "
+ "agg#2=[SUM($4)])\n"
+ " LogicalProject(DEPTNO=[$7], MGR=[$3], HIREDATE=[$4], COMM=[$6], SAL=[$5])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
final String sql = ""
+ "SELECT DEPTNO, MGR, HIREDATE, CAST(COUNT(*) AS BIGINT) AS $f3, 1 AS newCol, "
+ "COLLECT(COMM) AS comArray, CAST(SUM(SAL) AS DECIMAL(19, 0)) AS salSum\n"
+ "CAST(COLLECT(COMM) AS ROW(COMM DECIMAL(19, 0)) MULTISET) AS comArray, CAST(SUM(SAL) AS DECIMAL(19, 0)) AS salSum\n"
+ "FROM scott.EMP\n"
+ "GROUP BY DEPTNO, MGR, HIREDATE\n"
+ "ORDER BY 7";
Expand Down

0 comments on commit 01cfeee

Please sign in to comment.