Skip to content

Commit

Permalink
[CALCITE-6690] Arrow adapter support DECIMAL with precision and scale
Browse files Browse the repository at this point in the history
  • Loading branch information
caicancai authored and mihaibudiu committed Dec 22, 2024
1 parent fae49bc commit e621c33
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,84 +17,69 @@
package org.apache.calcite.adapter.arrow;

import org.apache.calcite.adapter.java.JavaTypeFactory;
import org.apache.calcite.linq4j.tree.Primitive;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.type.SqlTypeName;

import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.pojo.ArrowType;

import java.math.BigDecimal;
import java.sql.Date;
import java.util.List;

import static java.util.Objects.requireNonNull;

/**
* Arrow field type.
*/
enum ArrowFieldType {
INT(Primitive.INT),
BOOLEAN(Primitive.BOOLEAN),
STRING(String.class),
FLOAT(Primitive.FLOAT),
DOUBLE(Primitive.DOUBLE),
DATE(Date.class),
LIST(List.class),
DECIMAL(BigDecimal.class),
LONG(Primitive.LONG),
BYTE(Primitive.BYTE),
SHORT(Primitive.SHORT);

private final Class<?> clazz;

ArrowFieldType(Primitive primitive) {
this(requireNonNull(primitive.boxClass, "boxClass"));
}
public class ArrowFieldTypeFactory {

ArrowFieldType(Class<?> clazz) {
this.clazz = clazz;
private ArrowFieldTypeFactory() {
throw new UnsupportedOperationException("Utility class");
}

public RelDataType toType(JavaTypeFactory typeFactory) {
RelDataType javaType = typeFactory.createJavaType(clazz);
RelDataType sqlType = typeFactory.createSqlType(javaType.getSqlTypeName());
public static RelDataType toType(ArrowType arrowType, JavaTypeFactory typeFactory) {
RelDataType sqlType = of(arrowType, typeFactory);
return typeFactory.createTypeWithNullability(sqlType, true);
}

public static ArrowFieldType of(ArrowType arrowType) {
/**
* Converts an Arrow type to a Calcite RelDataType.
*
* @param arrowType the Arrow type to convert
* @param typeFactory the factory to create the Calcite type
* @return the corresponding Calcite RelDataType
*/
private static RelDataType of(ArrowType arrowType, JavaTypeFactory typeFactory) {
switch (arrowType.getTypeID()) {
case Int:
int bitWidth = ((ArrowType.Int) arrowType).getBitWidth();
switch (bitWidth) {
case 64:
return LONG;
return typeFactory.createSqlType(SqlTypeName.BIGINT);
case 32:
return INT;
return typeFactory.createSqlType(SqlTypeName.INTEGER);
case 16:
return SHORT;
return typeFactory.createSqlType(SqlTypeName.SMALLINT);
case 8:
return BYTE;
return typeFactory.createSqlType(SqlTypeName.TINYINT);
default:
throw new IllegalArgumentException("Unsupported Int bit width: " + bitWidth);
}
case Bool:
return BOOLEAN;
return typeFactory.createSqlType(SqlTypeName.BOOLEAN);
case Utf8:
return STRING;
return typeFactory.createSqlType(SqlTypeName.VARCHAR);
case FloatingPoint:
FloatingPointPrecision precision = ((ArrowType.FloatingPoint) arrowType).getPrecision();
switch (precision) {
case SINGLE:
return FLOAT;
return typeFactory.createSqlType(SqlTypeName.REAL);
case DOUBLE:
return DOUBLE;
return typeFactory.createSqlType(SqlTypeName.DOUBLE);
default:
throw new IllegalArgumentException("Unsupported Floating point precision: " + precision);
}
case Date:
return DATE;
return typeFactory.createSqlType(SqlTypeName.DATE);
case Decimal:
return DECIMAL;
return typeFactory.createSqlType(SqlTypeName.DECIMAL,
((ArrowType.Decimal) arrowType).getPrecision(),
((ArrowType.Decimal) arrowType).getScale());
default:
throw new IllegalArgumentException("Unsupported type: " + arrowType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ private static RelDataType deduceRowType(Schema schema,
final RelDataTypeFactory.Builder builder = typeFactory.builder();
for (Field field : schema.getFields()) {
builder.add(field.getName(),
ArrowFieldType.of(field.getType()).toType(typeFactory));
ArrowFieldTypeFactory.toType(field.getType(), typeFactory));
}
return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
arrowDataDirectory = arrowFilesDirectory.toFile();

File dataLocationFile = arrowFilesDirectory.resolve("arrowdatatype.arrow").toFile();
ArrowData arrowDataGenerator = new ArrowData();
ArrowDataTest arrowDataGenerator = new ArrowDataTest();
arrowDataGenerator.writeArrowDataType(dataLocationFile);

arrow = ImmutableMap.of("model", modelFileTarget.toAbsolutePath().toString());
Expand All @@ -68,7 +68,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"tinyIntField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(tinyIntField=[$0])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "tinyIntField=0\ntinyIntField=1\n";
CalciteAssert.that()
.with(arrow)
Expand All @@ -82,7 +82,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"smallIntField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(smallIntField=[$1])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "smallIntField=0\nsmallIntField=1\n";
CalciteAssert.that()
.with(arrow)
Expand All @@ -96,7 +96,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"intField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(intField=[$2])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "intField=0\nintField=1\n";
CalciteAssert.that()
.with(arrow)
Expand All @@ -110,7 +110,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"longField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(longField=[$5])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "longField=0\nlongField=1\n";
CalciteAssert.that()
.with(arrow)
Expand All @@ -124,7 +124,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"floatField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(floatField=[$4])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "floatField=0.0\nfloatField=1.0\n";
CalciteAssert.that()
.with(arrow)
Expand All @@ -138,7 +138,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"doubleField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(doubleField=[$6])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "doubleField=0.0\ndoubleField=1.0\n";
CalciteAssert.that()
.with(arrow)
Expand All @@ -152,7 +152,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"decimalField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(decimalField=[$8])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "decimalField=0.00\ndecimalField=1.00\n";
CalciteAssert.that()
.with(arrow)
Expand All @@ -166,7 +166,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"dateField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(dateField=[$9])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "dateField=1970-01-01\n"
+ "dateField=1970-01-02\n";
CalciteAssert.that()
Expand All @@ -181,7 +181,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"booleanField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(booleanField=[$7])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "booleanField=null\nbooleanField=true\nbooleanField=false\n";
CalciteAssert.that()
.with(arrow)
Expand All @@ -191,4 +191,20 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
.explainContains(plan);
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6690">[CALCITE-6690]
* Arrow adapter support DECIMAL with precision and scale</a>. */
@Test void testDecimalProject2() {
String sql = "select \"decimalField2\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(decimalField2=[$10])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "decimalField2=20.000\ndecimalField2=21.000\n";
CalciteAssert.that()
.with(arrow)
.query(sql)
.limit(2)
.returns(result)
.explainContains(plan);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
arrowDataDirectory = arrowFilesDirectory.toFile();

File dataLocationFile = arrowFilesDirectory.resolve("arrowdata.arrow").toFile();
ArrowData arrowDataGenerator = new ArrowData();
ArrowDataTest arrowDataGenerator = new ArrowDataTest();
arrowDataGenerator.writeArrowData(dataLocationFile);
arrowDataGenerator.writeScottEmpData(arrowFilesDirectory);

File datatypeLocationFile = arrowFilesDirectory.resolve("arrowdatatype.arrow").toFile();
ArrowData arrowtypeDataGenerator = new ArrowData();
ArrowDataTest arrowtypeDataGenerator = new ArrowDataTest();
arrowtypeDataGenerator.writeArrowDataType(datatypeLocationFile);

arrow = ImmutableMap.of("model", modelFileTarget.toAbsolutePath().toString());
Expand Down Expand Up @@ -732,8 +732,8 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
@Test void testFilteredAgg() {
String sql = "select SUM(SAL) FILTER (WHERE COMM > 400) as SALESSUM from EMP";
String plan = "PLAN=EnumerableAggregate(group=[{}], SALESSUM=[SUM($0) FILTER $1])\n"
+ " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[400:DECIMAL(19, 0)], expr#9=[>($t6, $t8)], "
+ "expr#10=[IS TRUE($t9)], SAL=[$t5], $f1=[$t10])\n"
+ " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(12, 2)], expr#9=[400.00:DECIMAL(12, 2)], "
+ "expr#10=[>($t8, $t9)], expr#11=[IS TRUE($t10)], SAL=[$t5], $f1=[$t11])\n"
+ " ArrowToEnumerableConverter\n"
+ " ArrowTableScan(table=[[ARROW, EMP]], fields=[[0, 1, 2, 3, 4, 5, 6, 7]])\n\n";
String result = "SALESSUM=2500.00\n";
Expand All @@ -750,8 +750,8 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select SUM(SAL) FILTER (WHERE COMM > 400) as SALESSUM from EMP group by EMPNO";
String plan = "PLAN=EnumerableCalc(expr#0..1=[{inputs}], SALESSUM=[$t1])\n"
+ " EnumerableAggregate(group=[{0}], SALESSUM=[SUM($1) FILTER $2])\n"
+ " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[400:DECIMAL(19, 0)], expr#9=[>($t6, $t8)], "
+ "expr#10=[IS TRUE($t9)], EMPNO=[$t0], SAL=[$t5], $f2=[$t10])\n"
+ " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(12, 2)], expr#9=[400.00:DECIMAL(12, 2)], "
+ "expr#10=[>($t8, $t9)], expr#11=[IS TRUE($t10)], EMPNO=[$t0], SAL=[$t5], $f2=[$t11])\n"
+ " ArrowToEnumerableConverter\n"
+ " ArrowTableScan(table=[[ARROW, EMP]], fields=[[0, 1, 2, 3, 4, 5, 6, 7]])\n\n";
String result = "SALESSUM=1250.00\nSALESSUM=null\n";
Expand Down Expand Up @@ -860,7 +860,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(booleanField=[$7])\n"
+ " ArrowFilter(condition=[$7])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "booleanField=true\nbooleanField=true\n";

CalciteAssert.that()
Expand All @@ -878,7 +878,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(intField=[$2])\n"
+ " ArrowFilter(condition=[>($2, 10)])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "intField=11\nintField=12\n";

CalciteAssert.that()
Expand All @@ -896,7 +896,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(booleanField=[$7])\n"
+ " ArrowFilter(condition=[NOT($7)])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "booleanField=false\nbooleanField=false\n";

CalciteAssert.that()
Expand All @@ -915,7 +915,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(booleanField=[$7])\n"
+ " ArrowFilter(condition=[IS NOT TRUE($7)])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "booleanField=null\nbooleanField=false\n";

CalciteAssert.that()
Expand All @@ -933,7 +933,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(booleanField=[$7])\n"
+ " ArrowFilter(condition=[IS NOT FALSE($7)])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "booleanField=null\nbooleanField=true\n";

CalciteAssert.that()
Expand All @@ -951,7 +951,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(booleanField=[$7])\n"
+ " ArrowFilter(condition=[IS NULL($7)])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "booleanField=null\n";

CalciteAssert.that()
Expand All @@ -971,8 +971,8 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
+ "where \"decimalField\" = 1.00";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(decimalField=[$8])\n"
+ " ArrowFilter(condition=[=($8, 1)])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowFilter(condition=[=($8, 1.00)])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "decimalField=1.00\n";

CalciteAssert.that()
Expand All @@ -989,7 +989,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(doubleField=[$6])\n"
+ " ArrowFilter(condition=[=($6, 1.0E0)])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "doubleField=1.0\n";

CalciteAssert.that()
Expand All @@ -1006,7 +1006,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(stringField=[$3])\n"
+ " ArrowFilter(condition=[=($3, '1')])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "stringField=1\n";

CalciteAssert.that()
Expand Down
Loading

0 comments on commit e621c33

Please sign in to comment.