Skip to content

Commit

Permalink
[improvement][chat]Modify core workflow of NL2SQLParser, always invok…
Browse files Browse the repository at this point in the history
…ing rule-based parsers first.#1729
  • Loading branch information
jerryjzhang committed Oct 29, 2024
1 parent b01751a commit 400b9f8
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ public class ChatParseResp {
public ChatParseResp(Long queryId) {
this.queryId = queryId;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.*;
import java.util.stream.Collectors;

import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
Expand Down Expand Up @@ -78,29 +73,46 @@ public void parse(ParseContext parseContext) {
return;
}

QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
ChatContext chatCtx =
chatContextService.getOrCreateContext(parseContext.getRequest().getChatId());
if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) {
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
}

if (parseContext.needRuleParse()) {
// first go with rule-based parsers unless the user has already selected one parse.
if (Objects.isNull(parseContext.getRequest().getSelectedParse())) {
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
ChatParseResp parseResp = parseContext.getResponse();
for (MapModeEnum mode : MapModeEnum.values()) {
queryNLReq.setMapModeEnum(mode);
doParse(queryNLReq, parseResp);

// inject semantic parse saved by in the chat context
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
ChatContext chatCtx =
chatContextService.getOrCreateContext(parseContext.getRequest().getChatId());
if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) {
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
}

// for every requested dataSet, recursively invoke rule-based parser
// with different mapModes, unless any valid semantic parse is derived.
Set<Long> requestedDatasets = queryNLReq.getDataSetIds();
for (Long datasetId : requestedDatasets) {
queryNLReq.setDataSetIds(Collections.singleton(datasetId));
ChatParseResp parseResp = parseContext.getResponse();
for (MapModeEnum mode : MapModeEnum.values()) {
queryNLReq.setMapModeEnum(mode);
doParse(queryNLReq, parseResp);
if (!parseResp.getSelectedParses().isEmpty()) {
break;
}
}
}
}

// next go with llm-based parsers unless LLM is disabled or use feedback is needed.
if (parseContext.needLLMParse() && !parseContext.needFeedback()) {
SemanticParseInfo selectedParse = parseContext.getRequest().getSelectedParse();
queryNLReq.setSelectedParseInfo(Objects.nonNull(selectedParse) ? selectedParse
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
queryNLReq.setText2SQLType(Text2SQLType.LLM_OR_RULE);

// either the user or the system selects one parse from the candidate parses.
SemanticParseInfo userSelectParse = parseContext.getRequest().getSelectedParse();
queryNLReq.setSelectedParseInfo(Objects.nonNull(userSelectParse) ? userSelectParse
: parseContext.getResponse().getSelectedParses().get(0));
queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
parseContext.getResponse().getSelectedParses().clear();

parseContext.setResponse(new ChatParseResp(parseContext.getResponse().getQueryId()));
rewriteMultiTurn(parseContext, queryNLReq);
addDynamicExemplars(parseContext, queryNLReq);
doParse(queryNLReq, parseContext.getResponse());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ public boolean needFeedback() {
&& response.getSelectedParses().size() > 1);
}

public boolean needRuleParse() {
return Objects.isNull(request.getSelectedParse());
}

public boolean needLLMParse() {
return enableLLM() && (Objects.nonNull(request.getSelectedParse())
|| !response.getSelectedParses().isEmpty());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,30 +1,67 @@
package com.tencent.supersonic.chat.server.processor.parse;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.parser.llm.DataSetMatchResult;
import lombok.extern.slf4j.Slf4j;

import java.util.*;

/**
* ParseInfoSortProcessor sorts candidate parse info based on certain algorithm. \
**/
@Slf4j
public class ParseInfoSortProcessor implements ParseResultProcessor {

@Override
public void process(ParseContext parseContext) {
Set<String> parseInfoText = Sets.newHashSet();
List<SemanticParseInfo> sortedParseInfo = Lists.newArrayList();
List<SemanticParseInfo> selectedParses = parseContext.getResponse().getSelectedParses();

parseContext.getResponse().getSelectedParses().forEach(p -> {
if (!parseInfoText.contains(p.getTextInfo())) {
sortedParseInfo.add(p);
parseInfoText.add(p.getTextInfo());
selectedParses.sort((o1, o2) -> {
DataSetMatchResult mr1 = getDataSetMatchResult(o1.getElementMatches());
DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches());

double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity();
if (difference == 0) {
difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity();
if (difference == 0) {
difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity();
}
if (difference == 0) {
difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt();
}
}
return difference >= 0 ? -1 : 1;
});
// re-assign parseId
for (int i = 0; i < selectedParses.size(); i++) {
SemanticParseInfo parseInfo = selectedParses.get(i);
parseInfo.setId(i + 1);
}
}

sortedParseInfo.sort((o1, o2) -> o1.getScore() - o2.getScore() > 0 ? 1 : 0);
parseContext.getResponse().setSelectedParses(sortedParseInfo);
private DataSetMatchResult getDataSetMatchResult(List<SchemaElementMatch> elementMatches) {
double maxMetricSimilarity = 0;
double maxDatasetSimilarity = 0;
double totalSimilarity = 0;
long maxMetricUseCnt = 0L;
for (SchemaElementMatch match : elementMatches) {
if (SchemaElementType.DATASET.equals(match.getElement().getType())) {
maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity());
}
if (SchemaElementType.METRIC.equals(match.getElement().getType())) {
maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity());
if (Objects.nonNull(match.getElement().getUseCnt())) {
maxMetricUseCnt = Math.max(maxMetricUseCnt, match.getElement().getUseCnt());
}
}
totalSimilarity += match.getSimilarity();
}
return DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity)
.maxDatesetSimilarity(maxDatasetSimilarity).totalSimilarity(totalSimilarity)
.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public static QueryNLReq buildQueryNLReq(ParseContext parseContext) {
QueryNLReq queryNLReq = new QueryNLReq();
BeanMapper.mapper(parseContext.getRequest(), queryNLReq);
queryNLReq.setText2SQLType(
parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE);
parseContext.enableLLM() ? Text2SQLType.LLM_OR_RULE : Text2SQLType.ONLY_RULE);
queryNLReq.setDataSetIds(getDataSetIds(parseContext));
queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
queryNLReq.setSelectedParseInfo(parseContext.getRequest().getSelectedParse());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package com.tencent.supersonic.common.pojo.enums;

public enum Text2SQLType {
ONLY_RULE, ONLY_LLM, RULE_AND_LLM;
ONLY_RULE, ONLY_LLM, LLM_OR_RULE;

public boolean enableRule() {
return this.equals(ONLY_RULE) || this.equals(RULE_AND_LLM);
return this.equals(ONLY_RULE) || this.equals(LLM_OR_RULE);
}

public boolean enableLLM() {
return this.equals(ONLY_LLM) || this.equals(RULE_AND_LLM);
return this.equals(ONLY_LLM) || this.equals(LLM_OR_RULE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public class QueryNLReq extends SemanticQueryReq {
private User user;
private QueryFilters queryFilters;
private boolean saveAnswer = true;
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
private Text2SQLType text2SQLType = Text2SQLType.LLM_OR_RULE;
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
private QueryDataType queryDataType = QueryDataType.ALL;
private Map<String, ChatApp> chatAppConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,12 @@ public static QueryStructReq buildStructReq(SemanticParseInfo parseInfo) {
queryStructReq.setGroups(parseInfo.getDimensions().stream().map(SchemaElement::getBizName)
.collect(Collectors.toList()));
queryStructReq.setLimit(parseInfo.getLimit());
// only one metric is queried at once
Set<SchemaElement> metrics = parseInfo.getMetrics();
if (!CollectionUtils.isEmpty(metrics)) {
SchemaElement metricElement = parseInfo.getMetrics().iterator().next();
Set<Order> order =
getOrder(parseInfo.getOrders(), parseInfo.getAggType(), metricElement);
queryStructReq
.setAggregators(getAggregatorByMetric(parseInfo.getAggType(), metricElement));
queryStructReq.setOrders(new ArrayList<>(order));

for (SchemaElement metricElement : parseInfo.getMetrics()) {
queryStructReq.getAggregators()
.addAll(getAggregatorByMetric(parseInfo.getAggType(), metricElement));
queryStructReq.setOrders(new ArrayList<>(
getOrder(parseInfo.getOrders(), parseInfo.getAggType(), metricElement)));
}

deletionDuplicated(queryStructReq);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public static User getUserTom() {
public static ChatParseReq getChatParseReq(Integer id, String query, boolean enableLLM) {
ChatParseReq chatParseReq = new ChatParseReq();
chatParseReq.setQueryText(query);
chatParseReq.setAgentId(metricAgentId);
chatParseReq.setChatId(id);
chatParseReq.setUser(user_test);
chatParseReq.setDisableLLM(!enableLLM);
Expand Down

0 comments on commit 400b9f8

Please sign in to comment.