Skip to content

Commit

Permalink
(improvement)(chat) In SchemaCorrector, removing filters from linking…
Browse files Browse the repository at this point in the history
…Value that do not exist. (#775)
  • Loading branch information
lexluo09 authored Feb 29, 2024
1 parent 6813582 commit eba3a8a
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,19 @@ protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long
return semanticSchema.getMetrics(viewId);
}

protected Set<String> getDimensions(Long viewId, SemanticSchema semanticSchema) {
Set<String> dimensions = semanticSchema.getDimensions(viewId).stream()
.flatMap(
schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
elements.addAll(schemaElement.getAlias());
}
return elements.stream();
}
).collect(Collectors.toSet());
dimensions.add(TimeDimensionEnum.DAY.getChName());
return dimensions;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.ViewService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;

import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;

/**
* Perform SQL corrections on the "Group by" section in S2SQL.
Expand Down Expand Up @@ -82,22 +80,6 @@ private Boolean needAddGroupBy(QueryContext queryContext, SemanticParseInfo sema
return true;
}

private Set<String> getDimensions(Long viewId, SemanticSchema semanticSchema) {
Set<String> dimensions = semanticSchema.getDimensions(viewId).stream()
.flatMap(
schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
elements.addAll(schemaElement.getAlias());
}
return elements.stream();
}
).collect(Collectors.toSet());
dimensions.add(TimeDimensionEnum.DAY.getChName());
return dimensions;
}

private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Long viewId = semanticParseInfo.getViewId();
//add dimension group by
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
package com.tencent.supersonic.chat.core.corrector;

import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;

import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;

/**
* Perform schema corrections on the Schema information in S2SQL.
Expand All @@ -27,6 +35,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {

removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);

correctAggFunction(semanticParseInfo);

replaceAlias(semanticParseInfo);
Expand Down Expand Up @@ -105,4 +115,35 @@ private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo)
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
sqlInfo.setCorrectS2SQL(sql);
}

public void removeFilterIfNotInLinkingValue(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
if (CollectionUtils.isEmpty(whereExpressionList)) {
return;
}
List<ElementValue> linkingValues = getLinkingValues(semanticParseInfo);
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Set<String> dimensions = getDimensions(semanticParseInfo.getViewId(), semanticSchema);

if (CollectionUtils.isEmpty(linkingValues)) {
linkingValues = new ArrayList<>();
}
Set<String> linkingFieldNames = linkingValues.stream().map(linking -> linking.getFieldName())
.collect(Collectors.toSet());

Set<String> removeFieldNames = whereExpressionList.stream()
.filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction()))
.filter(fieldExpression -> !TimeDimensionEnum.containsTimeDimension(fieldExpression.getFieldName()))
.filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator()))
.filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName()))
.filter(fieldExpression -> !DateUtils.isAnyDateString(fieldExpression.getFieldValue().toString()))
.filter(fieldExpression -> !linkingFieldNames.contains(fieldExpression.getFieldName()))
.map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());

String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
sqlInfo.setCorrectS2SQL(sql);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package com.tencent.supersonic.chat.core.corrector;

import static org.junit.jupiter.api.Assertions.assertEquals;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.junit.jupiter.api.Test;

class SchemaCorrectorTest {

private String json = "{\n"
+ " \"viewId\": 1,\n"
+ " \"llmReq\": {\n"
+ " \"queryText\": \"xxx2024年播放量最高的十首歌\",\n"
+ " \"filterCondition\": {\n"
+ " \"tableName\": null\n"
+ " },\n"
+ " \"schema\": {\n"
+ " \"domainName\": \"歌曲\",\n"
+ " \"viewName\": \"歌曲\",\n"
+ " \"fieldNameList\": [\n"
+ " \"商务组\",\n"
+ " \"歌曲名\",\n"
+ " \"播放量\",\n"
+ " \"播放份额\",\n"
+ " \"数据日期\"\n"
+ " ]\n"
+ " },\n"
+ " \"linking\": [\n"
+ "\n"
+ " ],\n"
+ " \"currentDate\": \"2024-02-24\",\n"
+ " \"priorExts\": \"播放份额是小数; \",\n"
+ " \"sqlGenerationMode\": \"2_pass_auto_cot\"\n"
+ " },\n"
+ " \"request\": null,\n"
+ " \"commonAgentTool\": {\n"
+ " \"id\": \"y3LqVSRL\",\n"
+ " \"name\": \"大模型语义解析\",\n"
+ " \"type\": \"NL2SQL_LLM\",\n"
+ " \"viewIds\": [\n"
+ " 1\n"
+ " ]\n"
+ " },\n"
+ " \"linkingValues\": [\n"
+ "\n"
+ " ]\n"
+ "}";

@Test
void doCorrect() throws JsonProcessingException {
Long viewId = 1L;
QueryContext queryContext = buildQueryContext(viewId);
ObjectMapper objectMapper = new ObjectMapper();
ParseResult parseResult = objectMapper.readValue(json, ParseResult.class);


String sql = "select 歌曲名 from 歌曲 where 发行日期 >= '2024-01-01' "
+ "and 商务组 = 'xxx' order by 播放量 desc limit 10";
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setS2SQL(sql);
sqlInfo.setCorrectS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);

SchemaElement schemaElement = new SchemaElement();
schemaElement.setView(viewId);
semanticParseInfo.setView(schemaElement);


semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);

SchemaCorrector schemaCorrector = new SchemaCorrector();
schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);

assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
+ "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());

parseResult = objectMapper.readValue(json, ParseResult.class);

List<ElementValue> linkingValues = new ArrayList<>();
ElementValue elementValue = new ElementValue();
elementValue.setFieldName("商务组");
elementValue.setFieldValue("xxx");
linkingValues.add(elementValue);
parseResult.setLinkingValues(linkingValues);
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);

semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
semanticParseInfo.getSqlInfo().setS2SQL(sql);
schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);
assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());

}

private QueryContext buildQueryContext(Long viewId) {
QueryContext queryContext = new QueryContext();
List<ViewSchema> viewSchemaList = new ArrayList<>();
ViewSchema viewSchema = new ViewSchema();
QueryConfig queryConfig = new QueryConfig();
viewSchema.setQueryConfig(queryConfig);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setView(viewId);
viewSchema.setView(schemaElement);
Set<SchemaElement> dimensions = new HashSet<>();
SchemaElement element1 = new SchemaElement();
element1.setView(1L);
element1.setName("歌曲名");
dimensions.add(element1);

SchemaElement element2 = new SchemaElement();
element2.setView(1L);
element2.setName("商务组");
dimensions.add(element2);

SchemaElement element3 = new SchemaElement();
element3.setView(1L);
element3.setName("发行日期");
dimensions.add(element3);

viewSchema.setDimensions(dimensions);
viewSchemaList.add(viewSchema);

SemanticSchema semanticSchema = new SemanticSchema(viewSchemaList);
queryContext.setSemanticSchema(semanticSchema);
return queryContext;
}
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
package com.tencent.supersonic.common.util;

import com.tencent.supersonic.common.pojo.Constants;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.time.format.DateTimeParseException;
import java.time.temporal.ChronoField;
import java.time.temporal.TemporalAdjuster;
import java.time.temporal.TemporalAdjusters;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Calendar;
import java.util.Date;
import java.util.List;
import java.util.Objects;
import com.tencent.supersonic.common.pojo.Constants;
import lombok.extern.slf4j.Slf4j;

@Slf4j
Expand Down Expand Up @@ -166,4 +168,27 @@ public static List<String> getDateList(String startDateStr, String endDateStr, S
return datesInRange;
}

public static boolean isAnyDateString(String value) {
List<String> formats = Arrays.asList("yyyy-MM-dd", "yyyy-MM", "yyyy/MM/dd");
return isAnyDateString(value, formats);
}

public static boolean isAnyDateString(String value, List<String> formats) {
for (String format : formats) {
if (isDateString(value, format)) {
return true;
}
}
return false;
}

public static boolean isDateString(String value, String format) {
try {
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(format);
LocalDate.parse(value, formatter);
return true;
} catch (DateTimeParseException e) {
return false;
}
}
}

0 comments on commit eba3a8a

Please sign in to comment.