Skip to content

Commit

Permalink
[multistage][hotfix] use UTF-8 as default CharSet, this is also true …
Browse files Browse the repository at this point in the history
…for v1 engine (apache#12213)

* [hotfix] use UTF-8 as default CharSet, this is also true for v1 engine

---------

Co-authored-by: Rong Rong <rongr@startree.ai>
  • Loading branch information
walterddr and Rong Rong authored Jan 3, 2024
1 parent add2236 commit d1cc17c
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
*/
package org.apache.pinot.query.type;

import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.function.Predicate;
import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
Expand All @@ -39,11 +41,17 @@
* upgrading Calcite versions.
*/
public class TypeFactory extends JavaTypeFactoryImpl {
private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8;

public TypeFactory(RelDataTypeSystem typeSystem) {
super(typeSystem);
}

@Override
public Charset getDefaultCharset() {
return DEFAULT_CHARSET;
}

public RelDataType createRelDataTypeFromSchema(Schema schema) {
Builder builder = new Builder(this);
Predicate<FieldSpec> isNullable;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
*/
package org.apache.pinot.query.type;

import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.calcite.adapter.java.JavaTypeFactory;
import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.sql.SqlCollation;
import org.apache.calcite.sql.type.ArraySqlType;
import org.apache.calcite.sql.type.BasicSqlType;
import org.apache.calcite.sql.type.SqlTypeName;
Expand All @@ -37,23 +40,22 @@

public class TypeFactoryTest {
private static final TypeSystem TYPE_SYSTEM = new TypeSystem();
private static final JavaTypeFactory TYPE_FACTORY = new TestJavaTypeFactoryImpl(TYPE_SYSTEM);

@DataProvider(name = "relDataTypeConversion")
public Iterator<Object[]> relDataTypeConversion() {
ArrayList<Object[]> cases = new ArrayList<>();

JavaTypeFactory javaTypeFactory = new JavaTypeFactoryImpl(TYPE_SYSTEM);

for (FieldSpec.DataType dataType : FieldSpec.DataType.values()) {
RelDataType basicType;
RelDataType arrayType = null;
switch (dataType) {
case INT: {
basicType = javaTypeFactory.createSqlType(SqlTypeName.INTEGER);
basicType = TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER);
break;
}
case LONG: {
basicType = javaTypeFactory.createSqlType(SqlTypeName.BIGINT);
basicType = TYPE_FACTORY.createSqlType(SqlTypeName.BIGINT);
break;
}
// Map float and double to the same RelDataType so that queries like
Expand All @@ -68,33 +70,33 @@ public Iterator<Object[]> relDataTypeConversion() {
// With float and double mapped to the same RelDataType, the behavior in multi-stage query engine will be the
// same as the query in v1 query engine.
case FLOAT: {
basicType = javaTypeFactory.createSqlType(SqlTypeName.DOUBLE);
arrayType = javaTypeFactory.createSqlType(SqlTypeName.REAL);
basicType = TYPE_FACTORY.createSqlType(SqlTypeName.DOUBLE);
arrayType = TYPE_FACTORY.createSqlType(SqlTypeName.REAL);
break;
}
case DOUBLE: {
basicType = javaTypeFactory.createSqlType(SqlTypeName.DOUBLE);
basicType = TYPE_FACTORY.createSqlType(SqlTypeName.DOUBLE);
break;
}
case BOOLEAN: {
basicType = javaTypeFactory.createSqlType(SqlTypeName.BOOLEAN);
basicType = TYPE_FACTORY.createSqlType(SqlTypeName.BOOLEAN);
break;
}
case TIMESTAMP: {
basicType = javaTypeFactory.createSqlType(SqlTypeName.TIMESTAMP);
basicType = TYPE_FACTORY.createSqlType(SqlTypeName.TIMESTAMP);
break;
}
case STRING:
case JSON: {
basicType = javaTypeFactory.createSqlType(SqlTypeName.VARCHAR);
basicType = TYPE_FACTORY.createSqlType(SqlTypeName.VARCHAR);
break;
}
case BYTES: {
basicType = javaTypeFactory.createSqlType(SqlTypeName.VARBINARY);
basicType = TYPE_FACTORY.createSqlType(SqlTypeName.VARBINARY);
break;
}
case BIG_DECIMAL: {
basicType = javaTypeFactory.createSqlType(SqlTypeName.DECIMAL);
basicType = TYPE_FACTORY.createSqlType(SqlTypeName.DECIMAL);
break;
}
case LIST:
Expand Down Expand Up @@ -268,7 +270,9 @@ public void testRelDataTypeConversion() {
break;
case "STRING_COL":
case "JSON_COL":
Assert.assertEquals(field.getType(), new BasicSqlType(TYPE_SYSTEM, SqlTypeName.VARCHAR));
Assert.assertEquals(field.getType(),
TYPE_FACTORY.createTypeWithCharsetAndCollation(new BasicSqlType(TYPE_SYSTEM, SqlTypeName.VARCHAR),
StandardCharsets.UTF_8, SqlCollation.IMPLICIT));
break;
case "BYTES_COL":
Assert.assertEquals(field.getType(), new BasicSqlType(TYPE_SYSTEM, SqlTypeName.VARBINARY));
Expand All @@ -290,8 +294,9 @@ public void testRelDataTypeConversion() {
new ArraySqlType(new BasicSqlType(TYPE_SYSTEM, SqlTypeName.DOUBLE), false));
break;
case "STRING_ARRAY_COL":
Assert.assertEquals(field.getType(),
new ArraySqlType(new BasicSqlType(TYPE_SYSTEM, SqlTypeName.VARCHAR), false));
Assert.assertEquals(field.getType(), new ArraySqlType(
TYPE_FACTORY.createTypeWithCharsetAndCollation(new BasicSqlType(TYPE_SYSTEM, SqlTypeName.VARCHAR),
StandardCharsets.UTF_8, SqlCollation.IMPLICIT), false));
break;
case "BYTES_ARRAY_COL":
Assert.assertEquals(field.getType(),
Expand All @@ -304,6 +309,17 @@ public void testRelDataTypeConversion() {
}
}

private static class TestJavaTypeFactoryImpl extends JavaTypeFactoryImpl {
public TestJavaTypeFactoryImpl(TypeSystem typeSystem) {
super(typeSystem);
}

@Override
public Charset getDefaultCharset() {
return StandardCharsets.UTF_8;
}
}

//tests precision and scale for numeric data type
private void checkPrecisionScale(RelDataTypeField field, BasicSqlType basicSqlType) {
Assert.assertEquals(field.getValue().getPrecision(), basicSqlType.getPrecision());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
"\n PinotLogicalExchange(distribution=[hash])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($3)], agg#1=[COUNT()])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'pink floyd'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -46,7 +46,7 @@
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], agg#2=[MAX($2)])",
"\n PinotLogicalExchange(distribution=[hash])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($3)], agg#1=[COUNT()], agg#2=[MAX($3)])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'pink floyd'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -60,7 +60,7 @@
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
"\n PinotLogicalExchange(distribution=[hash])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'pink floyd'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -87,7 +87,7 @@
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
"\n PinotLogicalExchange(distribution=[hash])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -101,7 +101,7 @@
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
"\n PinotLogicalExchange(distribution=[hash])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'pink floyd'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -115,7 +115,7 @@
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
"\n PinotLogicalExchange(distribution=[hash])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'pink floyd'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -129,7 +129,7 @@
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
"\n PinotLogicalExchange(distribution=[hash])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'pink floyd'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"output": [
"Execution Plan",
"\nLogicalProject(col1=[$0], EXPR$1=[+($2, $6)])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -51,7 +51,7 @@
"output": [
"Execution Plan",
"\nLogicalProject(col1=[$0], colsum=[+($2, $6)])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -64,7 +64,7 @@
"\nLogicalSort(offset=[0], fetch=[10])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[]], isSortOnSender=[false], isSortOnReceiver=[false])",
"\n LogicalSort(fetch=[10])",
"\n LogicalProject(EXPR$0=[DATETRUNC('DAY', $6)])",
"\n LogicalProject(EXPR$0=[DATETRUNC(_UTF-8'DAY', $6)])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -77,7 +77,7 @@
"\nLogicalSort(offset=[0], fetch=[10])",
"\n PinotLogicalSortExchange(distribution=[hash], collation=[[]], isSortOnSender=[false], isSortOnReceiver=[false])",
"\n LogicalSort(fetch=[10])",
"\n LogicalProject(day=[DATETRUNC('DAY', $6)])",
"\n LogicalProject(day=[DATETRUNC(_UTF-8'DAY', $6)])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -91,7 +91,7 @@
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
"\n PinotLogicalExchange(distribution=[hash])",
"\n LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT()])",
"\n LogicalProject($f0=[CAST(CASE(>($2, 10), '1':VARCHAR, >($2, 20), '2':VARCHAR, >($2, 30), '3':VARCHAR, >($2, 40), '4':VARCHAR, >($2, 50), '5':VARCHAR, '0':VARCHAR)):DECIMAL(1000, 500) NOT NULL])",
"\n LogicalProject($f0=[CAST(CASE(>($2, 10), _UTF-8'1':VARCHAR CHARACTER SET \"UTF-8\", >($2, 20), _UTF-8'2':VARCHAR CHARACTER SET \"UTF-8\", >($2, 30), _UTF-8'3':VARCHAR CHARACTER SET \"UTF-8\", >($2, 40), _UTF-8'4':VARCHAR CHARACTER SET \"UTF-8\", >($2, 50), _UTF-8'5':VARCHAR CHARACTER SET \"UTF-8\", _UTF-8'0':VARCHAR CHARACTER SET \"UTF-8\")):DECIMAL(1000, 500) NOT NULL])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand Down
24 changes: 12 additions & 12 deletions pinot-query-planner/src/test/resources/queries/GroupByPlans.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -48,7 +48,7 @@
"\nLogicalAggregate(group=[{0}], agg#0=[COUNT($1)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalAggregate(group=[{0}], agg#0=[COUNT()])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -62,7 +62,7 @@
"\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
"\n PinotLogicalExchange(distribution=[hash[0, 1]])",
"\n LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($0, 'a'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($0, _UTF-8'a'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -77,7 +77,7 @@
"\n LogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -92,7 +92,7 @@
"\n LogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand Down Expand Up @@ -143,7 +143,7 @@
"\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col1=[$0], col3=[$2])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -156,7 +156,7 @@
"\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], EXPR$2=[MAX($1)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col1=[$0], col3=[$2])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -170,7 +170,7 @@
"\nLogicalAggregate(group=[{0}], EXPR$1=[COUNT()])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col1=[$0])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -184,7 +184,7 @@
"\n LogicalAggregate(group=[{0, 1}], EXPR$2=[$SUM0($2)])",
"\n PinotLogicalExchange(distribution=[hash[0, 1]])",
"\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($0, 'a'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($0, _UTF-8'a'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -199,7 +199,7 @@
"\n LogicalAggregate(group=[{0}], EXPR$1=[COUNT()], EXPR$2=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col1=[$0], col3=[$2])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -214,7 +214,7 @@
"\n LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], agg#1=[MAX($1)], agg#2=[MIN($1)], agg#3=[COUNT()])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col1=[$0], col3=[$2])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand All @@ -229,7 +229,7 @@
"\n LogicalAggregate(group=[{0}], count=[COUNT()], SUM=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
"\n PinotLogicalExchange(distribution=[hash[0]])",
"\n LogicalProject(col1=[$0], col3=[$2])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
"\n LogicalFilter(condition=[AND(>=($2, 0), =($1, _UTF-8'a'))])",
"\n LogicalTableScan(table=[[a]])",
"\n"
]
Expand Down
Loading

0 comments on commit d1cc17c

Please sign in to comment.