Skip to content

Commit

Permalink
Throws and exception for radial search when mapping is for on-disk mo…
Browse files Browse the repository at this point in the history
…de (opensearch-project#2055)

Signed-off-by: Tejas Shah <shatejas@amazon.com>
  • Loading branch information
shatejas authored Sep 7, 2024
1 parent 7bf52ac commit cbc6343
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.mapper;

import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;

import java.util.Optional;

Expand Down Expand Up @@ -48,6 +49,14 @@ default CompressionLevel getCompressionLevel() {
return CompressionLevel.NOT_CONFIGURED;
}

/**
* Returns quantization config
* @return
*/
default QuantizationConfig getQuantizationConfig() {
return QuantizationConfig.EMPTY;
}

/**
*
* @return the dimension of the index; for model based indices, it will be null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ public KNNVectorFieldMapper build(BuilderContext context) {
hasDocValues.get(),
modelDao,
indexCreatedVersion,
originalParameters
originalParameters,
knnMethodConfigContext
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ public static MethodFieldMapper createFieldMapper(
boolean hasDocValues,
OriginalMappingParameters originalMappingParameters
) {

KNNMethodContext knnMethodContext = originalMappingParameters.getResolvedKnnMethodContext();
QuantizationConfig quantizationConfig = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext)
.getQuantizationConfig();

final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(
fullname,
metaValue,
Expand All @@ -75,6 +81,11 @@ public Mode getMode() {
public CompressionLevel getCompressionLevel() {
return knnMethodConfigContext.getCompressionLevel();
}

@Override
public QuantizationConfig getQuantizationConfig() {
return quantizationConfig;
}
}
);
return new MethodFieldMapper(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,17 @@ public static ModelFieldMapper createFieldMapper(
boolean hasDocValues,
ModelDao modelDao,
Version indexCreatedVersion,
OriginalMappingParameters originalMappingParameters
OriginalMappingParameters originalMappingParameters,
KNNMethodConfigContext knnMethodConfigContext
) {

final KNNMethodContext knnMethodContext = originalMappingParameters.getKnnMethodContext();
final QuantizationConfig quantizationConfig = knnMethodContext == null
? QuantizationConfig.EMPTY
: knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext)
.getQuantizationConfig();

final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(fullname, metaValue, vectorDataType, new KNNMappingConfig() {
private Integer dimension = null;
private Mode mode = null;
Expand Down Expand Up @@ -94,6 +102,11 @@ public CompressionLevel getCompressionLevel() {
return compressionLevel;
}

@Override
public QuantizationConfig getQuantizationConfig() {
return quantizationConfig;
}

// ModelMetadata relies on cluster state which may not be available during field mapper creation. Thus,
// we lazily initialize it.
private void initFromModelMetadata() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.model.QueryContext;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.query.parser.RescoreParser;
Expand Down Expand Up @@ -451,6 +452,10 @@ protected Query doToQuery(QueryShardContext context) {
if (vectorDataType == VectorDataType.BINARY) {
throw new UnsupportedOperationException(String.format(Locale.ROOT, "Binary data type does not support radial search"));
}

if (knnMappingConfig.getQuantizationConfig() != QuantizationConfig.EMPTY) {
throw new UnsupportedOperationException("Radial search is not supported for indices which have quantization enabled");
}
}

// Currently, k-NN supports distance and score types radial search
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
Expand All @@ -18,6 +19,7 @@
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNWeight;
import org.opensearch.knn.index.query.ResultUtil;
Expand All @@ -39,6 +41,7 @@
* for k-NN query if required. This is done by overriding rewrite method to execute ANN on each leaf
* {@link KNNQuery} does not give the ability to post process segment results.
*/
@Log4j2
@Getter
@RequiredArgsConstructor
public class NativeEngineKnnVectorQuery extends Query {
Expand All @@ -60,7 +63,11 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
int firstPassK = rescoreContext.getFirstPassK(finalK);
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK);
ResultUtil.reduceToTopK(perLeafResults, firstPassK);

StopWatch stopWatch = new StopWatch().start();
perLeafResults = doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK);
long rescoreTime = stopWatch.stop().totalTime().millis();
log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, leafReaderContexts.size());
}
ResultUtil.reduceToTopK(perLeafResults, finalK);
TopDocs[] topDocs = new TopDocs[perLeafResults.size()];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,12 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy
MockedStatic<KNNVectorFieldMapperUtil> utilMockedStatic = Mockito.mockStatic(KNNVectorFieldMapperUtil.class);
MockedStatic<ModelUtil> modelUtilMockedStatic = Mockito.mockStatic(ModelUtil.class)
) {
KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder()
.vectorDataType(VectorDataType.FLOAT)
.versionCreated(CURRENT)
.dimension(TEST_DIMENSION)
.build();

for (VectorDataType dataType : VectorDataType.values()) {
log.info("Vector Data Type is : {}", dataType);
SpaceType spaceType = VectorDataType.BINARY == dataType ? SpaceType.DEFAULT_BINARY : SpaceType.INNER_PRODUCT;
Expand Down Expand Up @@ -1173,7 +1179,8 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy
false,
modelDao,
CURRENT,
originalMappingParameters
originalMappingParameters,
knnMethodConfigContext
);

modelFieldMapper.parseCreateField(parseContext);
Expand Down Expand Up @@ -1214,7 +1221,8 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy
false,
modelDao,
CURRENT,
originalMappingParameters
originalMappingParameters,
knnMethodConfigContext
);

modelFieldMapper.parseCreateField(parseContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.mapper.Mode;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.index.util.KNNClusterUtil;
import org.opensearch.knn.index.engine.KNNMethodContext;
Expand All @@ -41,6 +44,7 @@
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -440,6 +444,47 @@ public void testDoToQuery_whenRadialSearchOnBinaryIndex_thenException() {
assertTrue(e.getMessage().contains("Binary data type does not support radial search"));
}

public void testDoToQuery_whenRadialSearchOnDiskMode_thenException() {
float[] queryVector = { 1.0f };
KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder()
.fieldName(FIELD_NAME)
.vector(queryVector)
.maxDistance(MAX_DISTANCE)
.build();
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
MethodComponentContext methodComponentContext = new MethodComponentContext(
org.opensearch.knn.common.KNNConstants.METHOD_HNSW,
ImmutableMap.of()
);
KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext);
when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(new KNNMappingConfig() {
@Override
public Optional<KNNMethodContext> getKnnMethodContext() {
return Optional.of(knnMethodContext);
}

@Override
public int getDimension() {
return 1;
}

public Mode getMode() {
return Mode.ON_DISK;
}

public QuantizationConfig getQuantizationConfig() {
return QuantizationConfig.builder().quantizationType(ScalarQuantizationType.ONE_BIT).build();
}
});
Exception e = expectThrows(UnsupportedOperationException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
assertEquals("Radial search is not supported for indices which have quantization enabled", e.getMessage());
}

public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception {
// Given
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
Expand Down

0 comments on commit cbc6343

Please sign in to comment.