Skip to content

Commit

Permalink
[improvement][headless]Restructure LLMReq and LLMSchema.
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryjzhang committed Sep 12, 2024
1 parent 4b1dab8 commit c99d240
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ private List<LLMReq.ElementValue> getLinkingValues(SemanticParseInfo semanticPar
if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
return null;
}
return parseResult.getLinkingValues();
return parseResult.getLlmReq().getSchema().getValues();
}

private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,15 @@
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.jetbrains.annotations.NotNull;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

Expand All @@ -39,7 +36,7 @@ public class LLMRequestService {

public boolean isSkip(ChatQueryContext queryCtx) {
if (!queryCtx.getText2SQLType().enableLLM()) {
log.info("not enable llm, skip");
log.info("LLM disabled, skip");
return true;
}

Expand All @@ -57,33 +54,28 @@ public Long getDataSetId(ChatQueryContext queryCtx) {
}

public LLMReq getLlmReq(ChatQueryContext queryCtx, Long dataSetId) {
List<LLMReq.ElementValue> linkingValues = getValues(queryCtx, dataSetId);
Map<Long, String> dataSetIdToName = queryCtx.getSemanticSchema().getDataSetIdToName();
String queryText = queryCtx.getQueryText();

LLMReq llmReq = new LLMReq();

llmReq.setQueryText(queryText);
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
llmReq.setSchema(llmSchema);
llmSchema.setDataSetId(dataSetId);
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId));
llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId));
llmSchema.setMetrics(getMappedMetrics(queryCtx, dataSetId));
llmSchema.setDimensions(getMappedDimensions(queryCtx, dataSetId));
llmSchema.setPartitionTime(getPartitionTime(queryCtx, dataSetId));
llmSchema.setPrimaryKey(getPrimaryKey(queryCtx, dataSetId));
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
llmReq.setSchema(llmSchema);

List<LLMReq.ElementValue> linking = new ArrayList<>();
boolean linkingValueEnabled =
Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));

if (linkingValueEnabled) {
linking.addAll(linkingValues);
llmSchema.setValues(getMappedValues(queryCtx, dataSetId));
}
llmReq.setLinking(linking);

llmReq.setCurrentDate(DateUtils.getBeforeDate(0));
llmReq.setTerms(getMappedTerms(queryCtx, dataSetId));
llmReq.setSqlGenType(
LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
llmReq.setModelConfig(queryCtx.getModelConfig());
Expand All @@ -102,7 +94,7 @@ public LLMResp runText2SQL(LLMReq llmReq) {
return result;
}

protected List<LLMReq.Term> getTerms(ChatQueryContext queryCtx, Long dataSetId) {
protected List<LLMReq.Term> getMappedTerms(ChatQueryContext queryCtx, Long dataSetId) {
List<SchemaElementMatch> matchedElements =
queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
Expand All @@ -126,31 +118,8 @@ protected List<LLMReq.Term> getTerms(ChatQueryContext queryCtx, Long dataSetId)
.collect(Collectors.toList());
}

private Map<String, String> getFieldNameToDataFormatTypeMap(SemanticSchema semanticSchema) {
return semanticSchema.getMetrics().stream()
.filter(metric -> Objects.nonNull(metric.getDataFormatType()))
.flatMap(
metric -> {
Set<Pair<String, String>> fieldFormatPairs = new HashSet<>();
String dataFormatType = metric.getDataFormatType();
fieldFormatPairs.add(Pair.of(metric.getName(), dataFormatType));
List<String> aliasList = metric.getAlias();
if (!CollectionUtils.isEmpty(aliasList)) {
aliasList.forEach(
alias ->
fieldFormatPairs.add(
Pair.of(alias, dataFormatType)));
}
return fieldFormatPairs.stream();
})
.collect(
Collectors.toMap(
Pair::getLeft,
Pair::getRight,
(existing, replacement) -> existing));
}

public List<LLMReq.ElementValue> getValues(@NotNull ChatQueryContext queryCtx, Long dataSetId) {
protected List<LLMReq.ElementValue> getMappedValues(
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
List<SchemaElementMatch> matchedElements =
queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
Expand All @@ -177,7 +146,7 @@ public List<LLMReq.ElementValue> getValues(@NotNull ChatQueryContext queryCtx, L
return new ArrayList<>(valueMatches);
}

protected List<SchemaElement> getMatchedMetrics(
protected List<SchemaElement> getMappedMetrics(
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
List<SchemaElementMatch> matchedElements =
queryCtx.getMapInfo().getMatchedElements(dataSetId);
Expand All @@ -200,7 +169,7 @@ protected List<SchemaElement> getMatchedMetrics(
return schemaElements;
}

protected List<SchemaElement> getMatchedDimensions(
protected List<SchemaElement> getMappedDimensions(
@NotNull ChatQueryContext queryCtx, Long dataSetId) {

List<SchemaElementMatch> matchedElements =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ private void tryParse(ChatQueryContext queryCtx, Long dataSetId) {
.dataSetId(dataSetId)
.llmReq(llmReq)
.llmResp(llmResp)
.linkingValues(llmReq.getLinking())
.build();
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.List;

@Data
@Builder
@AllArgsConstructor
Expand All @@ -23,6 +21,4 @@ public class ParseResult {
private LLMResp llmResp;

private QueryNLReq request;

private List<LLMReq.ElementValue> linkingValues;
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public String buildSchemaStr(LLMReq llmReq) {
});

List<String> values = Lists.newArrayList();
llmReq.getLinking().stream()
llmReq.getSchema().getValues().stream()
.forEach(
value -> {
StringBuilder valueStr = new StringBuilder();
Expand Down Expand Up @@ -176,7 +176,7 @@ public String buildSchemaStr(LLMReq llmReq) {
}

private String buildTermStr(LLMReq llmReq) {
List<LLMReq.Term> terms = llmReq.getSchema().getTerms();
List<LLMReq.Term> terms = llmReq.getTerms();
List<String> termStr = Lists.newArrayList();
terms.stream()
.forEach(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import lombok.Data;
import org.apache.commons.collections4.CollectionUtils;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

@Data
public class LLMReq {
private String queryText;
private LLMSchema schema;
private List<ElementValue> linking;
private List<Term> terms;
private String currentDate;
private String priorExts;
private SqlGenType sqlGenType;
Expand All @@ -32,12 +35,30 @@ public static class ElementValue {
public static class LLMSchema {
private Long dataSetId;
private String dataSetName;
private List<String> fieldNameList;
private List<SchemaElement> metrics;
private List<SchemaElement> dimensions;
private List<ElementValue> values;
private SchemaElement partitionTime;
private SchemaElement primaryKey;
private List<Term> terms;

public List<String> getFieldNameList() {
List<String> fieldNameList = new ArrayList<>();
if (CollectionUtils.isNotEmpty(metrics)) {
fieldNameList.addAll(
metrics.stream()
.map(metric -> metric.getName())
.collect(Collectors.toList()));
}
if (CollectionUtils.isNotEmpty(dimensions)) {
fieldNameList.addAll(
dimensions.stream()
.map(dimension -> dimension.getName())
.collect(Collectors.toList()));
}
fieldNameList.add(partitionTime.getName());
fieldNameList.add(primaryKey.getName());
return fieldNameList;
}
}

@Data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
import com.tencent.supersonic.headless.chat.parser.llm.ParseResult;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import org.junit.Assert;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

@Disabled
class SchemaCorrectorTest {

private String json =
Expand All @@ -37,17 +39,10 @@ class SchemaCorrectorTest {
+ " \"数据日期\"\n"
+ " ]\n"
+ " },\n"
+ " \"linking\": [\n"
+ "\n"
+ " ],\n"
+ " \"currentDate\": \"2024-02-24\",\n"
+ " \"priorExts\": \"播放份额是小数; \",\n"
+ " \"sqlGenType\": \"1_pass_self_consistency\"\n"
+ " },\n"
+ " \"request\": null,\n"
+ " \"linkingValues\": [\n"
+ "\n"
+ " ]\n"
+ " \"request\": null\n"
+ "}";

@Test
Expand Down Expand Up @@ -86,7 +81,6 @@ void doCorrect() throws JsonProcessingException {
elementValue.setFieldName("商务组");
elementValue.setFieldValue("xxx");
linkingValues.add(elementValue);
parseResult.setLinkingValues(linkingValues);
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);

semanticParseInfo.getSqlInfo().setCorrectedS2SQL(sql);
Expand Down

0 comments on commit c99d240

Please sign in to comment.