Skip to content

Commit

Permalink
[improvement][Headless] Embedding supports Chinese by default and fix…
Browse files Browse the repository at this point in the history
…es the issue of abnormal number recognition (#726)
  • Loading branch information
lexluo09 authored Feb 18, 2024
1 parent 39158d6 commit fdb6954
Show file tree
Hide file tree
Showing 19 changed files with 62 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class OptimizationConfig {
@Value("${embedding.mapper.round.number:10}")
private int embeddingMapperRoundNumber;

@Value("${embedding.mapper.distance.threshold:0.58}")
@Value("${embedding.mapper.distance.threshold:0.01}")
private Double embeddingMapperDistanceThreshold;

@Value("${s2SQL.linking.value.switch:true}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
Expand All @@ -19,6 +13,11 @@
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
@Slf4j
Expand All @@ -29,7 +28,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {

@Override
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
Set<Long> detectViewIds) {
Set<Long> detectViewIds) {
String text = queryContext.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
Expand Down Expand Up @@ -57,9 +56,9 @@ public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> d
int offset = mapperHelper.getStepOffset(terms, startIndex);
index = mapperHelper.getStepIndex(regOffsetToLength, index);
if (index <= text.length()) {
String detectSegment = text.substring(startIndex, index);
String detectSegment = text.substring(startIndex, index).trim();
detectSegments.add(detectSegment);
detectByStep(queryContext, results, detectViewIds, startIndex, index, offset);
detectByStep(queryContext, results, detectViewIds, detectSegment, offset);
}
}
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
Expand Down Expand Up @@ -151,7 +150,7 @@ public void logTerms(List<S2Term> terms) {

public abstract String getMapKey(T a);

public abstract void detectByStep(QueryContext queryContext, Set<T> results,
Set<Long> detectViewIds, Integer startIndex, Integer index, int offset);
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectViewIds,
String detectSegment, int offset);

}
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,12 @@ public String getMapKey(DatabaseMapResult a) {
}

public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectViewIds,
Integer startIndex, Integer index, int offset) {
String detectSegment = queryContext.getQueryText().substring(startIndex, index);
String detectSegment, int offset) {
if (StringUtils.isBlank(detectSegment)) {
return;
}
Set<Long> viewIds = mapperHelper.getViewIds(queryContext.getViewId(), queryContext.getAgent());

Double metricDimensionThresholdConfig = getThreshold(queryContext);

Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);

for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
Expand All @@ -73,9 +70,9 @@ public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> exist
continue;
}
Set<SchemaElement> schemaElements = entry.getValue();
if (!CollectionUtils.isEmpty(viewIds)) {
if (!CollectionUtils.isEmpty(detectViewIds)) {
schemaElements = schemaElements.stream()
.filter(schemaElement -> viewIds.contains(schemaElement.getView()))
.filter(schemaElement -> detectViewIds.contains(schemaElement.getView()))
.collect(Collectors.toSet());
}
for (SchemaElement schemaElement : schemaElements) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package com.tencent.supersonic.chat.core.mapper;

import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
Expand Down Expand Up @@ -34,14 +34,12 @@ public void doMap(QueryContext queryContext) {
//2. build SchemaElementMatch by info
for (EmbeddingResult matchResult : matchResults) {
Long elementId = Retrieval.getLongId(matchResult.getId());

SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
SchemaElement.class);
Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId"));
if (Objects.isNull(viewId)) {
continue;
}
schemaElement = getSchemaElement(viewId, schemaElement.getType(), elementId,
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
SchemaElement schemaElement = getSchemaElement(viewId, elementType, elementId,
queryContext.getSemanticSchema());
if (schemaElement == null) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
import java.util.ArrayList;
import java.util.Comparator;
Expand Down Expand Up @@ -47,6 +47,12 @@ public String getMapKey(EmbeddingResult a) {
return a.getName() + Constants.UNDERLINE + a.getId();
}

@Override
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
String detectSegment, int offset) {

}

@Override
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectViewIds,
Set<String> detectSegments) {
Expand Down Expand Up @@ -111,9 +117,4 @@ private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detec
selectResultInOneRound(results, oneRoundResults);
}

@Override
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
Integer startIndex, Integer index, int offset) {
return;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existRes
}

public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
Integer startIndex, Integer index, int offset) {
String text = queryContext.getQueryText();
String detectSegment = text.substring(startIndex, index);

String detectSegment, int offset) {
// step1. pre search
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ public String getMapKey(HanlpMapResult a) {
}

@Override
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> results, Set<Long> detectViewIds,
Integer startIndex,
Integer i, int offset) {
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
String detectSegment, int offset) {

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ public void onApplicationEvent(DataEvent dataEvent) {
DictWord dictWord = new DictWord();
dictWord.setWord(dataItem.getName());
String sign = DictWordType.NATURE_SPILT;
String nature = sign + 1 + sign + dataItem.getId()
+ sign + dataItem.getType().name().toLowerCase();
String nature = sign + 1 + sign + dataItem.getId() + dataItem.getType().name().toLowerCase();
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
dictWord.setNature(nature);
dictWord.setNatureWithFrequency(natureWithFrequency);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
@Builder
public class DataItem {

private Long id;
/***
* This field uses an underscore (_) at the end.
*/
private String id;

private String bizName;

Expand All @@ -18,6 +21,9 @@ public class DataItem {

private TypeEnums type;

/***
* This field uses an underscore (_) at the end.
*/
private String modelId;

private String defaultAgg;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public void init() {
parameters.add(new Parameter("embedding.mapper.number", "5",
"批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"));
parameters.add(new Parameter("embedding.mapper.distance.threshold",
"0.58", "向量召回相似度阈值", "相似度大于该阈值的则舍弃", "number", "Mapper相关配置"));
"0.01", "向量召回相似度阈值", "相似度大于该阈值的则舍弃", "number", "Mapper相关配置"));

//parser config
Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQu
List<Retrieval> retrievals = new ArrayList<>();
for (EmbeddingMatch<EmbeddingQuery> embeddingMatch : relevant) {
Retrieval retrieval = new Retrieval();
retrieval.setDistance(embeddingMatch.score());
retrieval.setDistance(1 - embeddingMatch.score());
retrieval.setId(embeddingMatch.embeddingId());
retrieval.setQuery(embeddingMatch.embedded().getQuery());
Map<String, Object> metadata = embeddingMatch.embedded().getMetadata();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DataEvent;
import com.tencent.supersonic.common.pojo.enums.EventType;
import com.tencent.supersonic.common.util.ComponentFactory;
Expand Down Expand Up @@ -43,8 +42,7 @@ public void onApplicationEvent(DataEvent event) {
.map(dataItem -> {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(
dataItem.getId().toString() + Constants.UNDERLINE
+ dataItem.getType().name().toLowerCase());
dataItem.getId() + dataItem.getType().name().toLowerCase());
embeddingQuery.setQuery(dataItem.getName());
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
embeddingQuery.setMetadata(meta);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public void updateDimension(DimensionReq dimensionReq, User user) {
if (!oldName.equals(dimensionDO.getName())) {
sendEvent(DataItem.builder().modelId(dimensionDO.getModelId() + Constants.UNDERLINE)
.newName(dimensionReq.getName()).name(oldName).type(TypeEnums.DIMENSION)
.id(dimensionDO.getId()).build(), EventType.UPDATE);
.id(dimensionDO.getId() + Constants.UNDERLINE).build(), EventType.UPDATE);
}
}

Expand Down Expand Up @@ -366,8 +366,9 @@ public void sendDimensionEventBatch(List<Long> modelIds, EventType eventType) {

private void sendEventBatch(List<DimensionDO> dimensionDOS, EventType eventType) {
List<DataItem> dataItems = dimensionDOS.stream()
.map(dimensionDO -> DataItem.builder().id(dimensionDO.getId()).name(dimensionDO.getName())
.modelId(dimensionDO.getModelId() + Constants.UNDERLINE).type(TypeEnums.DIMENSION).build())
.map(dimensionDO -> DataItem.builder().id(dimensionDO.getId() + Constants.UNDERLINE)
.name(dimensionDO.getName()).modelId(dimensionDO.getModelId() + Constants.UNDERLINE)
.type(TypeEnums.DIMENSION).build())
.collect(Collectors.toList());
eventPublisher.publishEvent(new DataEvent(this, dataItems, eventType));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ private void sendEvent(DataItem dataItem, EventType eventType) {
private DataItem getDataItem(MetricDO metricDO) {
MetricResp metricResp = MetricConverter.convert2MetricResp(metricDO,
new HashMap<>(), Lists.newArrayList());
return DataItem.builder().id(metricDO.getId()).name(metricDO.getName())
return DataItem.builder().id(metricDO.getId() + Constants.UNDERLINE).name(metricDO.getName())
.bizName(metricDO.getBizName())
.modelId(metricDO.getModelId() + Constants.UNDERLINE)
.type(TypeEnums.METRIC).defaultAgg(metricResp.getDefaultAgg()).build();
Expand Down
2 changes: 1 addition & 1 deletion launchers/common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
</dependency>
</dependencies>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import static dev.langchain4j.internal.Utils.isNullOrBlank;

import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
import dev.langchain4j.model.huggingface.HuggingFaceChatModel;
import dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel;
import dev.langchain4j.model.huggingface.HuggingFaceLanguageModel;
Expand Down Expand Up @@ -248,7 +248,7 @@ EmbeddingModel embeddingModel(S2LangChain4jProperties properties) {
case IN_PROCESS:
InProcess inProcess = properties.getEmbeddingModel().getInProcess();
if (Objects.isNull(inProcess) || isNullOrBlank(inProcess.getModelPath())) {
return new AllMiniLmL6V2EmbeddingModel();
return new BgeSmallZhEmbeddingModel();
}
return new S2OnnxEmbeddingModel(inProcess.getModelPath(), inProcess.getVocabularyPath());

Expand Down
Loading

0 comments on commit fdb6954

Please sign in to comment.