Skip to content

Commit

Permalink
Enhance json index to support regexp and range predicate evaluation (#…
Browse files Browse the repository at this point in the history
…12568)

* Enhance json index to support regexp and range predicate evaluation

* Move to TreeMap for mutable json index

* Review comments

* Simplify subMap call

* Lint

* Review comments

---------

Co-authored-by: Saurabh Dubey <saurabh.dubey@Saurabhs-MacBook-Pro.local>
  • Loading branch information
saurabhd336 and Saurabh Dubey authored Mar 8, 2024
1 parent 5beab96 commit c7cc821
Show file tree
Hide file tree
Showing 10 changed files with 284 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,23 +209,23 @@ public static FilterContext getFilter(Function thriftFunction) {
case GREATER_THAN:
return FilterContext.forPredicate(
new RangePredicate(getExpression(operands.get(0)), false, getStringValue(operands.get(1)), false,
RangePredicate.UNBOUNDED));
RangePredicate.UNBOUNDED, new LiteralContext(operands.get(1).getLiteral()).getType()));
case GREATER_THAN_OR_EQUAL:
return FilterContext.forPredicate(
new RangePredicate(getExpression(operands.get(0)), true, getStringValue(operands.get(1)), false,
RangePredicate.UNBOUNDED));
RangePredicate.UNBOUNDED, new LiteralContext(operands.get(1).getLiteral()).getType()));
case LESS_THAN:
return FilterContext.forPredicate(
new RangePredicate(getExpression(operands.get(0)), false, RangePredicate.UNBOUNDED, false,
getStringValue(operands.get(1))));
getStringValue(operands.get(1)), new LiteralContext(operands.get(1).getLiteral()).getType()));
case LESS_THAN_OR_EQUAL:
return FilterContext.forPredicate(
new RangePredicate(getExpression(operands.get(0)), false, RangePredicate.UNBOUNDED, true,
getStringValue(operands.get(1))));
getStringValue(operands.get(1)), new LiteralContext(operands.get(1).getLiteral()).getType()));
case BETWEEN:
return FilterContext.forPredicate(
new RangePredicate(getExpression(operands.get(0)), true, getStringValue(operands.get(1)), true,
getStringValue(operands.get(2))));
getStringValue(operands.get(2)), new LiteralContext(operands.get(1).getLiteral()).getType()));
case RANGE:
return FilterContext.forPredicate(
new RangePredicate(getExpression(operands.get(0)), getStringValue(operands.get(1))));
Expand Down Expand Up @@ -400,22 +400,24 @@ public static FilterContext getFilter(FunctionContext filterFunction) {
}
case GREATER_THAN:
return FilterContext.forPredicate(
new RangePredicate(operands.get(0), false, getStringValue(operands.get(1)), false,
RangePredicate.UNBOUNDED));
new RangePredicate(operands.get(0), false, getStringValue(operands.get(1)), false, RangePredicate.UNBOUNDED,
operands.get(1).getLiteral().getType()));
case GREATER_THAN_OR_EQUAL:
return FilterContext.forPredicate(
new RangePredicate(operands.get(0), true, getStringValue(operands.get(1)), false,
RangePredicate.UNBOUNDED));
new RangePredicate(operands.get(0), true, getStringValue(operands.get(1)), false, RangePredicate.UNBOUNDED,
operands.get(1).getLiteral().getType()));
case LESS_THAN:
return FilterContext.forPredicate(new RangePredicate(operands.get(0), false, RangePredicate.UNBOUNDED, false,
getStringValue(operands.get(1))));
return FilterContext.forPredicate(
new RangePredicate(operands.get(0), false, RangePredicate.UNBOUNDED, false, getStringValue(operands.get(1)),
operands.get(1).getLiteral().getType()));
case LESS_THAN_OR_EQUAL:
return FilterContext.forPredicate(new RangePredicate(operands.get(0), false, RangePredicate.UNBOUNDED, true,
getStringValue(operands.get(1))));
return FilterContext.forPredicate(
new RangePredicate(operands.get(0), false, RangePredicate.UNBOUNDED, true, getStringValue(operands.get(1)),
operands.get(1).getLiteral().getType()));
case BETWEEN:
return FilterContext.forPredicate(
new RangePredicate(operands.get(0), true, getStringValue(operands.get(1)), true,
getStringValue(operands.get(2))));
getStringValue(operands.get(2)), operands.get(1).getLiteral().getType()));
case RANGE:
return FilterContext.forPredicate(new RangePredicate(operands.get(0), getStringValue(operands.get(1))));
case REGEXP_LIKE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Objects;
import org.apache.commons.lang3.StringUtils;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.utils.CommonConstants.Query.Range;


Expand All @@ -43,6 +44,7 @@ public class RangePredicate extends BasePredicate {
private final String _lowerBound;
private final boolean _upperInclusive;
private final String _upperBound;
private final FieldSpec.DataType _rangeDataType;

/**
* The range is formatted as 5 parts:
Expand All @@ -67,15 +69,17 @@ public RangePredicate(ExpressionContext lhs, String range) {
int upperLength = upper.length();
_upperInclusive = upper.charAt(upperLength - 1) == UPPER_INCLUSIVE;
_upperBound = upper.substring(0, upperLength - 1);
_rangeDataType = FieldSpec.DataType.UNKNOWN;
}

public RangePredicate(ExpressionContext lhs, boolean lowerInclusive, String lowerBound, boolean upperInclusive,
String upperBound) {
String upperBound, FieldSpec.DataType rangeDataType) {
super(lhs);
_lowerInclusive = lowerInclusive;
_lowerBound = lowerBound;
_upperInclusive = upperInclusive;
_upperBound = upperBound;
_rangeDataType = rangeDataType;
}

@Override
Expand All @@ -99,6 +103,10 @@ public String getUpperBound() {
return _upperBound;
}

public FieldSpec.DataType getRangeDataType() {
return _rangeDataType;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,6 @@ private RangePredicate createPredicate(int lower, boolean inclLower, int upper,
if (upper == DICT_LEN - 1 && inclUpper) {
upperStr = "*";
}
return new RangePredicate(COLUMN_EXPRESSION, inclLower, lowerStr, inclUpper, upperStr);
return new RangePredicate(COLUMN_EXPRESSION, inclLower, lowerStr, inclUpper, upperStr, DataType.STRING);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.pinot.common.request.context.RequestContextUtils;
import org.apache.pinot.common.request.context.predicate.Predicate;
import org.apache.pinot.common.request.context.predicate.RangePredicate;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.sql.parsers.CalciteSqlParser;
import org.testng.annotations.Test;

Expand Down Expand Up @@ -74,7 +75,8 @@ public void testSerDe() {

// Non-standard RangePredicate (merged ranges)
RangePredicate rangePredicate =
new RangePredicate(ExpressionContext.forIdentifier("foo"), true, "123", false, "456");
new RangePredicate(ExpressionContext.forIdentifier("foo"), true, "123", false, "456",
FieldSpec.DataType.STRING);
String predicateExpression = rangePredicate.toString();
assertEquals(predicateExpression, "(foo >= '123' AND foo < '456')");
Expression thriftExpression = CalciteSqlParser.compileToExpression(predicateExpression);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,14 +295,14 @@ public void testHardcodedQueries() {
List<FilterContext> children = filter.getChildren();
assertEquals(children.size(), 2);
assertEquals(children.get(0), FilterContext.forPredicate(
new RangePredicate(ExpressionContext.forIdentifier("foo"), false, "15", false, "*")));
new RangePredicate(ExpressionContext.forIdentifier("foo"), false, "15", false, "*", FieldSpec.DataType.INT)));
FilterContext orFilter = children.get(1);
assertEquals(orFilter.getType(), FilterContext.Type.OR);
assertEquals(orFilter.getChildren().size(), 2);
assertEquals(orFilter.getChildren().get(0), FilterContext.forPredicate(new RangePredicate(
ExpressionContext.forFunction(new FunctionContext(FunctionContext.Type.TRANSFORM, "div",
Arrays.asList(ExpressionContext.forIdentifier("bar"), ExpressionContext.forIdentifier("foo")))), true,
"10", true, "20")));
"10", true, "20", FieldSpec.DataType.INT)));
assertEquals(orFilter.getChildren().get(1),
FilterContext.forPredicate(new TextMatchPredicate(ExpressionContext.forIdentifier("foobar"), "potato")));
assertEquals(filter.toString(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.regex.Pattern;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FilterContext;
import org.apache.pinot.common.request.context.RequestContextUtils;
Expand All @@ -35,10 +37,12 @@
import org.apache.pinot.common.request.context.predicate.NotEqPredicate;
import org.apache.pinot.common.request.context.predicate.NotInPredicate;
import org.apache.pinot.common.request.context.predicate.Predicate;
import org.apache.pinot.segment.local.segment.creator.impl.inv.json.BaseJsonIndexCreator;
import org.apache.pinot.common.request.context.predicate.RangePredicate;
import org.apache.pinot.common.request.context.predicate.RegexpLikePredicate;
import org.apache.pinot.segment.spi.index.creator.JsonIndexCreator;
import org.apache.pinot.segment.spi.index.mutable.MutableJsonIndex;
import org.apache.pinot.spi.config.table.JsonIndexConfig;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.exception.BadQueryRequestException;
import org.apache.pinot.spi.utils.JsonUtils;
import org.apache.pinot.sql.parsers.CalciteSqlParser;
Expand All @@ -53,7 +57,7 @@
*/
public class MutableJsonIndexImpl implements MutableJsonIndex {
private final JsonIndexConfig _jsonIndexConfig;
private final Map<String, RoaringBitmap> _postingListMap;
private final TreeMap<String, RoaringBitmap> _postingListMap;
private final IntList _docIdMapping;
private final ReentrantReadWriteLock.ReadLock _readLock;
private final ReentrantReadWriteLock.WriteLock _writeLock;
Expand All @@ -63,7 +67,7 @@ public class MutableJsonIndexImpl implements MutableJsonIndex {

public MutableJsonIndexImpl(JsonIndexConfig jsonIndexConfig) {
_jsonIndexConfig = jsonIndexConfig;
_postingListMap = new HashMap<>();
_postingListMap = new TreeMap<>();
_docIdMapping = new IntArrayList();

ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock();
Expand Down Expand Up @@ -230,7 +234,7 @@ private RoaringBitmap getMatchingFlattenedDocIds(Predicate predicate) {
if (!arrayIndex.equals(JsonUtils.WILDCARD)) {
// "[0]"=1 -> ".$index"='0' && "."='1'
// ".foo[1].bar"='abc' -> ".foo.$index"=1 && ".foo..bar"='abc'
String searchKey = leftPart + JsonUtils.ARRAY_INDEX_KEY + BaseJsonIndexCreator.KEY_VALUE_SEPARATOR + arrayIndex;
String searchKey = leftPart + JsonUtils.ARRAY_INDEX_KEY + JsonIndexCreator.KEY_VALUE_SEPARATOR + arrayIndex;
RoaringBitmap docIds = _postingListMap.get(searchKey);
if (docIds != null) {
if (matchingDocIds == null) {
Expand All @@ -250,7 +254,7 @@ private RoaringBitmap getMatchingFlattenedDocIds(Predicate predicate) {
if (predicateType == Predicate.Type.EQ || predicateType == Predicate.Type.NOT_EQ) {
String value = predicateType == Predicate.Type.EQ ? ((EqPredicate) predicate).getValue()
: ((NotEqPredicate) predicate).getValue();
String keyValuePair = key + BaseJsonIndexCreator.KEY_VALUE_SEPARATOR + value;
String keyValuePair = key + JsonIndexCreator.KEY_VALUE_SEPARATOR + value;
RoaringBitmap matchingDocIdsForKeyValuePair = _postingListMap.get(keyValuePair);
if (matchingDocIdsForKeyValuePair != null) {
if (matchingDocIds == null) {
Expand All @@ -267,7 +271,7 @@ private RoaringBitmap getMatchingFlattenedDocIds(Predicate predicate) {
: ((NotInPredicate) predicate).getValues();
RoaringBitmap matchingDocIdsForKeyValuePairs = new RoaringBitmap();
for (String value : values) {
String keyValuePair = key + BaseJsonIndexCreator.KEY_VALUE_SEPARATOR + value;
String keyValuePair = key + JsonIndexCreator.KEY_VALUE_SEPARATOR + value;
RoaringBitmap matchingDocIdsForKeyValuePair = _postingListMap.get(keyValuePair);
if (matchingDocIdsForKeyValuePair != null) {
matchingDocIdsForKeyValuePairs.or(matchingDocIdsForKeyValuePair);
Expand All @@ -291,6 +295,85 @@ private RoaringBitmap getMatchingFlattenedDocIds(Predicate predicate) {
} else {
return new RoaringBitmap();
}
} else if (predicateType == Predicate.Type.REGEXP_LIKE) {
Map<String, RoaringBitmap> subMap = getMatchingKeysMap(key);
if (subMap.isEmpty()) {
return new RoaringBitmap();
}
Pattern pattern = ((RegexpLikePredicate) predicate).getPattern();
RoaringBitmap result = null;

for (Map.Entry<String, RoaringBitmap> entry : subMap.entrySet()) {
if (!pattern.matcher(entry.getKey().substring(key.length() + 1)).matches()) {
continue;
}
if (result == null) {
result = entry.getValue().clone();
} else {
result.or(entry.getValue());
}
}

if (result == null) {
return new RoaringBitmap();
} else {
if (matchingDocIds == null) {
return result;
} else {
matchingDocIds.and(result);
return matchingDocIds;
}
}
} else if (predicateType == Predicate.Type.RANGE) {
Map<String, RoaringBitmap> subMap = getMatchingKeysMap(key);
if (subMap.isEmpty()) {
return new RoaringBitmap();
}
RoaringBitmap result = null;

RangePredicate rangePredicate = (RangePredicate) predicate;
FieldSpec.DataType rangeDataType = rangePredicate.getRangeDataType();
// Simplify to only support numeric and string types
if (rangeDataType.isNumeric()) {
rangeDataType = FieldSpec.DataType.DOUBLE;
} else {
rangeDataType = FieldSpec.DataType.STRING;
}

boolean lowerUnbounded = rangePredicate.getLowerBound().equals(RangePredicate.UNBOUNDED);
boolean upperUnbounded = rangePredicate.getUpperBound().equals(RangePredicate.UNBOUNDED);
boolean lowerInclusive = lowerUnbounded || rangePredicate.isLowerInclusive();
boolean upperInclusive = upperUnbounded || rangePredicate.isUpperInclusive();
Object lowerBound = lowerUnbounded ? null : rangeDataType.convert(rangePredicate.getLowerBound());
Object upperBound = upperUnbounded ? null : rangeDataType.convert(rangePredicate.getUpperBound());

for (Map.Entry<String, RoaringBitmap> entry : subMap.entrySet()) {
Object valueObj = rangeDataType.convert(entry.getKey().substring(key.length() + 1));
boolean lowerCompareResult =
lowerUnbounded || (lowerInclusive ? rangeDataType.compare(valueObj, lowerBound) >= 0
: rangeDataType.compare(valueObj, lowerBound) > 0);
boolean upperCompareResult =
upperUnbounded || (upperInclusive ? rangeDataType.compare(valueObj, upperBound) <= 0
: rangeDataType.compare(valueObj, upperBound) < 0);
if (lowerCompareResult && upperCompareResult) {
if (result == null) {
result = entry.getValue().clone();
} else {
result.or(entry.getValue());
}
}
}

if (result == null) {
return new RoaringBitmap();
} else {
if (matchingDocIds == null) {
return result;
} else {
matchingDocIds.and(result);
return matchingDocIds;
}
}
} else {
throw new IllegalStateException("Unsupported json_match predicate type: " + predicate);
}
Expand All @@ -301,10 +384,8 @@ public Map<String, RoaringBitmap> getMatchingDocsMap(String key) {
Map<String, RoaringBitmap> matchingDocsMap = new HashMap<>();
_readLock.lock();
try {
for (Map.Entry<String, RoaringBitmap> entry : _postingListMap.entrySet()) {
if (!entry.getKey().startsWith(key + BaseJsonIndexCreator.KEY_VALUE_SEPARATOR)) {
continue;
}
Map<String, RoaringBitmap> subMap = getMatchingKeysMap(key);
for (Map.Entry<String, RoaringBitmap> entry : subMap.entrySet()) {
MutableRoaringBitmap flattenedDocIds = entry.getValue().toMutableRoaringBitmap();
PeekableIntIterator it = flattenedDocIds.getIntIterator();
MutableRoaringBitmap postingList = new MutableRoaringBitmap();
Expand Down Expand Up @@ -342,6 +423,11 @@ public String[] getValuesForKeyAndDocs(int[] docIds, Map<String, RoaringBitmap>
return values;
}

private Map<String, RoaringBitmap> getMatchingKeysMap(String key) {
return _postingListMap.subMap(key + JsonIndexCreator.KEY_VALUE_SEPARATOR, false,
key + JsonIndexCreator.KEY_VALUE_SEPARATOR_NEXT_CHAR, false);
}

@Override
public void close() {
}
Expand Down
Loading

0 comments on commit c7cc821

Please sign in to comment.