-
Notifications
You must be signed in to change notification settings - Fork 409
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[improvement][chat]Modify core workflow of NL2SQLParser, always invok…
…ing rule-based parsers first.#1729
- Loading branch information
1 parent
b01751a
commit 400b9f8
Showing
9 changed files
with
95 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,4 +22,5 @@ public class ChatParseResp { | |
public ChatParseResp(Long queryId) { | ||
this.queryId = queryId; | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 47 additions & 10 deletions
57
.../main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoSortProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,67 @@ | ||
package com.tencent.supersonic.chat.server.processor.parse; | ||
|
||
import com.google.common.collect.Lists; | ||
import com.google.common.collect.Sets; | ||
import com.tencent.supersonic.chat.server.pojo.ParseContext; | ||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; | ||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; | ||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; | ||
import com.tencent.supersonic.headless.chat.parser.llm.DataSetMatchResult; | ||
import lombok.extern.slf4j.Slf4j; | ||
|
||
import java.util.*; | ||
|
||
/** | ||
* ParseInfoSortProcessor sorts candidate parse info based on certain algorithm. \ | ||
**/ | ||
@Slf4j | ||
public class ParseInfoSortProcessor implements ParseResultProcessor { | ||
|
||
@Override | ||
public void process(ParseContext parseContext) { | ||
Set<String> parseInfoText = Sets.newHashSet(); | ||
List<SemanticParseInfo> sortedParseInfo = Lists.newArrayList(); | ||
List<SemanticParseInfo> selectedParses = parseContext.getResponse().getSelectedParses(); | ||
|
||
parseContext.getResponse().getSelectedParses().forEach(p -> { | ||
if (!parseInfoText.contains(p.getTextInfo())) { | ||
sortedParseInfo.add(p); | ||
parseInfoText.add(p.getTextInfo()); | ||
selectedParses.sort((o1, o2) -> { | ||
DataSetMatchResult mr1 = getDataSetMatchResult(o1.getElementMatches()); | ||
DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches()); | ||
|
||
double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity(); | ||
if (difference == 0) { | ||
difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity(); | ||
if (difference == 0) { | ||
difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity(); | ||
} | ||
if (difference == 0) { | ||
difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt(); | ||
} | ||
} | ||
return difference >= 0 ? -1 : 1; | ||
}); | ||
// re-assign parseId | ||
for (int i = 0; i < selectedParses.size(); i++) { | ||
SemanticParseInfo parseInfo = selectedParses.get(i); | ||
parseInfo.setId(i + 1); | ||
} | ||
} | ||
|
||
sortedParseInfo.sort((o1, o2) -> o1.getScore() - o2.getScore() > 0 ? 1 : 0); | ||
parseContext.getResponse().setSelectedParses(sortedParseInfo); | ||
private DataSetMatchResult getDataSetMatchResult(List<SchemaElementMatch> elementMatches) { | ||
double maxMetricSimilarity = 0; | ||
double maxDatasetSimilarity = 0; | ||
double totalSimilarity = 0; | ||
long maxMetricUseCnt = 0L; | ||
for (SchemaElementMatch match : elementMatches) { | ||
if (SchemaElementType.DATASET.equals(match.getElement().getType())) { | ||
maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity()); | ||
} | ||
if (SchemaElementType.METRIC.equals(match.getElement().getType())) { | ||
maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity()); | ||
if (Objects.nonNull(match.getElement().getUseCnt())) { | ||
maxMetricUseCnt = Math.max(maxMetricUseCnt, match.getElement().getUseCnt()); | ||
} | ||
} | ||
totalSimilarity += match.getSimilarity(); | ||
} | ||
return DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity) | ||
.maxDatesetSimilarity(maxDatasetSimilarity).totalSimilarity(totalSimilarity) | ||
.build(); | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 3 additions & 3 deletions
6
common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,13 @@ | ||
package com.tencent.supersonic.common.pojo.enums; | ||
|
||
public enum Text2SQLType { | ||
ONLY_RULE, ONLY_LLM, RULE_AND_LLM; | ||
ONLY_RULE, ONLY_LLM, LLM_OR_RULE; | ||
|
||
public boolean enableRule() { | ||
return this.equals(ONLY_RULE) || this.equals(RULE_AND_LLM); | ||
return this.equals(ONLY_RULE) || this.equals(LLM_OR_RULE); | ||
} | ||
|
||
public boolean enableLLM() { | ||
return this.equals(ONLY_LLM) || this.equals(RULE_AND_LLM); | ||
return this.equals(ONLY_LLM) || this.equals(LLM_OR_RULE); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters