Skip to content

Commit

Permalink
Implement VARIANT functions TYPEOF, VARIANTNULL; add variant.iq
Browse files Browse the repository at this point in the history
Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
  • Loading branch information
mihaibudiu committed Dec 17, 2024
1 parent 2fefd69 commit 87485a9
Show file tree
Hide file tree
Showing 25 changed files with 1,099 additions and 274 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -520,10 +520,12 @@
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.TRIM;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.TRUNCATE;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.TUMBLE;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.TYPEOF;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.UNARY_MINUS;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.UNARY_PLUS;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.UPPER;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.USER;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.VARIANTNULL;
import static org.apache.calcite.util.ReflectUtil.isStatic;

import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -866,6 +868,8 @@ void populate1() {
defineMethod(TRUNC_BIG_QUERY, BuiltInMethod.STRUNCATE.method, NullPolicy.STRICT);
defineMethod(TRUNCATE, BuiltInMethod.STRUNCATE.method, NullPolicy.STRICT);
defineMethod(LOG1P, BuiltInMethod.LOG1P.method, NullPolicy.STRICT);
defineMethod(TYPEOF, BuiltInMethod.TYPEOF.method, NullPolicy.STRICT);
defineMethod(VARIANTNULL, BuiltInMethod.VARIANTNULL.method, NullPolicy.STRICT);

define(SAFE_ADD,
new SafeArithmeticImplementor(BuiltInMethod.SAFE_ADD.method));
Expand Down Expand Up @@ -3825,6 +3829,8 @@ private static class ItemImplementor extends AbstractRexCallImplementor {
// use the general MethodImplementor.
private AbstractRexCallImplementor getImplementor(SqlTypeName sqlTypeName) {
switch (sqlTypeName) {
case VARIANT:
return new MethodImplementor(BuiltInMethod.VARIANT_ITEM.method, nullPolicy, false);
case ARRAY:
return new ArrayItemImplementor();
case MAP:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.runtime.SpatialTypeFunctions;
import org.apache.calcite.runtime.rtti.RuntimeTypeInformation;
import org.apache.calcite.runtime.variant.VariantValue;
import org.apache.calcite.schema.FunctionContext;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlOperator;
Expand All @@ -66,8 +68,6 @@
import org.apache.calcite.util.ControlFlowException;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import org.apache.calcite.util.Variant;
import org.apache.calcite.util.rtti.RuntimeTypeInformation;

import com.google.common.base.CaseFormat;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -321,10 +321,12 @@ private Expression getConvertExpression(
return defaultExpression.get();
}
// Converting a VARIANT to any other type calls the Variant.cast method
// First cast operand to a VariantValue (it may be an Object)
Expression operandCast = Expressions.convert_(operand, VariantValue.class);
Expression cast =
Expressions.call(BuiltInMethod.VARIANT_CAST.method, operand,
Expressions.call(operandCast, BuiltInMethod.VARIANT_CAST.method,
RuntimeTypeInformation.createExpression(targetType));
// The cast returns an Object, so we need a convert too
// The cast returns an Object, so we need a convert to the expected Java type
RelDataType nullableTarget = typeFactory.createTypeWithNullability(targetType, true);
return Expressions.convert_(cast, typeFactory.getJavaClass(nullableTarget));
}
Expand All @@ -333,7 +335,8 @@ private Expression getConvertExpression(
case VARIANT:
// Converting any type to a VARIANT invokes the Variant constructor
Expression rtti = RuntimeTypeInformation.createExpression(sourceType);
return Expressions.new_(Variant.class, operand, rtti);
Expression roundingMode = Expressions.constant(typeFactory.getTypeSystem().roundingMode());
return Expressions.call(BuiltInMethod.VARIANT_CREATE.method, roundingMode, operand, rtti);
case ANY:
return operand;

Expand Down
3 changes: 1 addition & 2 deletions core/src/main/java/org/apache/calcite/rex/RexBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -859,8 +859,7 @@ boolean canRemoveCastFromLiteral(RelDataType toType,
return true;
}
final SqlTypeName sqlType = toType.getSqlTypeName();
if (sqlType == SqlTypeName.MEASURE
|| sqlType == SqlTypeName.VARIANT) {
if (sqlType == SqlTypeName.MEASURE || sqlType == SqlTypeName.VARIANT) {
return false;
}
if (!RexLiteral.valueMatchesType(value, sqlType, false)) {
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/java/org/apache/calcite/rex/RexLiteral.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.runtime.FlatLists;
import org.apache.calcite.runtime.SpatialTypeFunctions;
import org.apache.calcite.runtime.variant.VariantValue;
import org.apache.calcite.sql.SqlCollation;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
Expand All @@ -43,7 +44,6 @@
import org.apache.calcite.util.TimestampString;
import org.apache.calcite.util.TimestampWithTimeZoneString;
import org.apache.calcite.util.Util;
import org.apache.calcite.util.Variant;

import com.google.common.collect.ImmutableList;

Expand Down Expand Up @@ -317,7 +317,7 @@ public static boolean valueMatchesType(
}
switch (typeName) {
case VARIANT:
return value instanceof Variant;
return value instanceof VariantValue;
case BOOLEAN:
// Unlike SqlLiteral, we do not allow boolean null.
return value instanceof Boolean;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.calcite.rel.type.TimeFrame;
import org.apache.calcite.rel.type.TimeFrameSet;
import org.apache.calcite.runtime.FlatLists.ComparableList;
import org.apache.calcite.runtime.variant.VariantValue;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.fun.SqlLibraryOperators;
Expand All @@ -47,7 +48,6 @@
import org.apache.calcite.util.TryThreadLocal;
import org.apache.calcite.util.Unsafe;
import org.apache.calcite.util.Util;
import org.apache.calcite.util.Variant;
import org.apache.calcite.util.format.FormatElement;
import org.apache.calcite.util.format.FormatModel;
import org.apache.calcite.util.format.FormatModels;
Expand Down Expand Up @@ -5767,8 +5767,8 @@ public static String replace(String s, String search, String replacement) {
* known until runtime.
*/
public static @Nullable Object item(Object object, Object index) {
if (object instanceof Variant) {
return ((Variant) object).item(index);
if (object instanceof VariantValue) {
return ((VariantValue) object).item(index);
}
if (object instanceof Map) {
return mapItem((Map) object, index);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.util.rtti;
package org.apache.calcite.runtime.rtti;

import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.Objects;

/** Runtime type information about a base (primitive) SQL type. */
public class BasicSqlTypeRtti extends RuntimeTypeInformation {
private final int precision;
private final int scale;

public BasicSqlTypeRtti(RuntimeSqlTypeName typeName, int precision, int scale) {
public BasicSqlTypeRtti(RuntimeSqlTypeName typeName) {
super(typeName);
assert typeName.isScalar() : "Base SQL type must be a scalar type " + typeName;
this.precision = precision;
this.scale = scale;
}

@Override public boolean equals(@Nullable Object o) {
Expand All @@ -39,13 +36,11 @@ public BasicSqlTypeRtti(RuntimeSqlTypeName typeName, int precision, int scale) {
}

BasicSqlTypeRtti that = (BasicSqlTypeRtti) o;
return typeName == that.typeName && precision == that.precision && scale == that.scale;
return typeName == that.typeName;
}

@Override public int hashCode() {
int result = precision;
result = 31 * result + scale;
return result;
return Objects.hashCode(typeName);
}

@Override public String getTypeString() {
Expand All @@ -62,8 +57,6 @@ public BasicSqlTypeRtti(RuntimeSqlTypeName typeName, int precision, int scale) {
return "BIGINT";
case DECIMAL:
return "DECIMAL";
case FLOAT:
return "FLOAT";
case REAL:
return "REAL";
case DOUBLE:
Expand All @@ -85,12 +78,8 @@ public BasicSqlTypeRtti(RuntimeSqlTypeName typeName, int precision, int scale) {
case INTERVAL_LONG:
case INTERVAL_SHORT:
return "INTERVAL";
case CHAR:
return "CHAR";
case VARCHAR:
return "VARCHAR";
case BINARY:
return "BINARY";
case VARBINARY:
return "VARBINARY";
case NULL:
Expand All @@ -107,7 +96,6 @@ public BasicSqlTypeRtti(RuntimeSqlTypeName typeName, int precision, int scale) {
// This method is used to serialize the type in Java code implementations,
// so it should produce a computation that reconstructs the type at runtime
@Override public String toString() {
return "new BasicSqlTypeRtti("
+ this.getTypeString() + ", " + this.precision + ", " + this.scale + ")";
return "new BasicSqlTypeRtti(" + this.getTypeString() + ")";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.util.rtti;
package org.apache.calcite.runtime.rtti;

import org.checkerframework.checker.nullness.qual.Nullable;

Expand Down Expand Up @@ -72,4 +72,12 @@ public GenericSqlTypeRtti(RuntimeSqlTypeName typeName, RuntimeTypeInformation...
@Override public int hashCode() {
return Arrays.hashCode(typeArguments);
}

public RuntimeTypeInformation getTypeArgument(int index) {
return typeArguments[index];
}

public int getArgumentCount() {
return typeArguments.length;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.util.rtti;
package org.apache.calcite.runtime.rtti;

import org.checkerframework.checker.nullness.qual.Nullable;

Expand All @@ -23,12 +23,12 @@

/** Runtime type information for a ROW type. */
public class RowSqlTypeRtti extends RuntimeTypeInformation {
private final Map.Entry<String, RuntimeTypeInformation>[] fieldNames;
private final Map.Entry<String, RuntimeTypeInformation>[] fields;

@SafeVarargs
public RowSqlTypeRtti(Map.Entry<String, RuntimeTypeInformation>... fieldNames) {
public RowSqlTypeRtti(Map.Entry<String, RuntimeTypeInformation>... fields) {
super(RuntimeSqlTypeName.ROW);
this.fieldNames = fieldNames;
this.fields = fields;
}

@Override public String getTypeString() {
Expand All @@ -41,7 +41,7 @@ public RowSqlTypeRtti(Map.Entry<String, RuntimeTypeInformation>... fieldNames) {
StringBuilder builder = new StringBuilder();
builder.append("new RowSqlTypeRtti(");
boolean first = true;
for (Map.Entry<String, RuntimeTypeInformation> arg : this.fieldNames) {
for (Map.Entry<String, RuntimeTypeInformation> arg : this.fields) {
if (!first) {
builder.append(", ");
}
Expand All @@ -61,10 +61,42 @@ public RowSqlTypeRtti(Map.Entry<String, RuntimeTypeInformation>... fieldNames) {
}

RowSqlTypeRtti that = (RowSqlTypeRtti) o;
return Arrays.equals(fieldNames, that.fieldNames);
return Arrays.equals(fields, that.fields);
}

@Override public int hashCode() {
return Arrays.hashCode(fieldNames);
return Arrays.hashCode(fields);
}

/** Get the field with the specified index. */
public Map.Entry<String, RuntimeTypeInformation> getField(int index) {
return this.fields[index];
}

public int size() {
return this.fields.length;
}

/** Return the runtime type information of the associated field,
* or null if no such field exists.
*
* @param index Field index, starting from 0
*/
public @Nullable RuntimeTypeInformation getFieldType(Object index) {
if (index instanceof Integer) {
int intIndex = (Integer) index;
if (intIndex < 0 || intIndex >= this.fields.length) {
return null;
}
return this.fields[intIndex].getValue();
} else if (index instanceof String) {
String stringIndex = (String) index;
for (Map.Entry<String, RuntimeTypeInformation> field : this.fields) {
if (field.getKey().equalsIgnoreCase(stringIndex)) {
return field.getValue();
}
}
}
return null;
}
}
Loading

0 comments on commit 87485a9

Please sign in to comment.