From 41e1a1cc7293dedacd1f021408fd4f1466b12a92 Mon Sep 17 00:00:00 2001 From: lexluo Date: Wed, 28 Feb 2024 14:24:47 +0800 Subject: [PATCH] (improvement)(chat) In SchemaCorrector, removing filters from linkingValue that do not exist. --- .../core/corrector/BaseSemanticCorrector.java | 15 ++ .../chat/core/corrector/GroupByCorrector.java | 22 +-- .../chat/core/corrector/SchemaCorrector.java | 47 +++++- .../core/corrector/SchemaCorrectorTest.java | 143 ++++++++++++++++++ .../supersonic/common/util/DateUtils.java | 27 +++- 5 files changed, 230 insertions(+), 24 deletions(-) create mode 100644 chat/core/src/test/java/com/tencent/supersonic/chat/core/corrector/SchemaCorrectorTest.java diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/BaseSemanticCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/BaseSemanticCorrector.java index 09a8beaa3..bf530e3ad 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/BaseSemanticCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/BaseSemanticCorrector.java @@ -140,4 +140,19 @@ protected List getMetricElements(QueryContext queryContext, Long return semanticSchema.getMetrics(viewId); } + protected Set getDimensions(Long viewId, SemanticSchema semanticSchema) { + Set dimensions = semanticSchema.getDimensions(viewId).stream() + .flatMap( + schemaElement -> { + Set elements = new HashSet<>(); + elements.add(schemaElement.getName()); + if (!CollectionUtils.isEmpty(schemaElement.getAlias())) { + elements.addAll(schemaElement.getAlias()); + } + return elements.stream(); + } + ).collect(Collectors.toSet()); + dimensions.add(TimeDimensionEnum.DAY.getChName()); + return dimensions; + } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/GroupByCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/GroupByCorrector.java index a88c9c945..0b0f7c688 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/GroupByCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/GroupByCorrector.java @@ -14,14 +14,12 @@ import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.service.ModelService; import com.tencent.supersonic.headless.server.service.ViewService; -import lombok.extern.slf4j.Slf4j; -import org.springframework.util.CollectionUtils; - -import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.springframework.util.CollectionUtils; /** * Perform SQL corrections on the "Group by" section in S2SQL. @@ -82,22 +80,6 @@ private Boolean needAddGroupBy(QueryContext queryContext, SemanticParseInfo sema return true; } - private Set getDimensions(Long viewId, SemanticSchema semanticSchema) { - Set dimensions = semanticSchema.getDimensions(viewId).stream() - .flatMap( - schemaElement -> { - Set elements = new HashSet<>(); - elements.add(schemaElement.getName()); - if (!CollectionUtils.isEmpty(schemaElement.getAlias())) { - elements.addAll(schemaElement.getAlias()); - } - return elements.stream(); - } - ).collect(Collectors.toSet()); - dimensions.add(TimeDimensionEnum.DAY.getChName()); - return dimensions; - } - private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { Long viewId = semanticParseInfo.getViewId(); //add dimension group by diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/SchemaCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/SchemaCorrector.java index 3b657f792..ac9b70f18 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/SchemaCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/SchemaCorrector.java @@ -1,22 +1,30 @@ package com.tencent.supersonic.chat.core.corrector; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult; import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue; import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; +import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum; +import com.tencent.supersonic.common.util.jsqlparser.FieldExpression; +import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper; -import lombok.extern.slf4j.Slf4j; -import org.springframework.util.CollectionUtils; - +import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.util.CollectionUtils; /** * Perform schema corrections on the Schema information in S2SQL. @@ -27,6 +35,8 @@ public class SchemaCorrector extends BaseSemanticCorrector { @Override public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo); + correctAggFunction(semanticParseInfo); replaceAlias(semanticParseInfo); @@ -105,4 +115,35 @@ private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false); sqlInfo.setCorrectS2SQL(sql); } + + public void removeFilterIfNotInLinkingValue(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); + String correctS2SQL = sqlInfo.getCorrectS2SQL(); + List whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL); + if (CollectionUtils.isEmpty(whereExpressionList)) { + return; + } + List linkingValues = getLinkingValues(semanticParseInfo); + SemanticSchema semanticSchema = queryContext.getSemanticSchema(); + Set dimensions = getDimensions(semanticParseInfo.getViewId(), semanticSchema); + + if (CollectionUtils.isEmpty(linkingValues)) { + linkingValues = new ArrayList<>(); + } + Set linkingFieldNames = linkingValues.stream().map(linking -> linking.getFieldName()) + .collect(Collectors.toSet()); + + Set removeFieldNames = whereExpressionList.stream() + .filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction())) + .filter(fieldExpression -> !TimeDimensionEnum.containsTimeDimension(fieldExpression.getFieldName())) + .filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator())) + .filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName())) + .filter(fieldExpression -> !DateUtils.isAnyDateString(fieldExpression.getFieldValue().toString())) + .filter(fieldExpression -> !linkingFieldNames.contains(fieldExpression.getFieldName())) + .map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet()); + + String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames); + sqlInfo.setCorrectS2SQL(sql); + } + } diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/core/corrector/SchemaCorrectorTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/core/corrector/SchemaCorrectorTest.java new file mode 100644 index 000000000..1abdfc6df --- /dev/null +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/core/corrector/SchemaCorrectorTest.java @@ -0,0 +1,143 @@ +package com.tencent.supersonic.chat.core.corrector; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.chat.api.pojo.ViewSchema; +import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; +import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult; +import com.tencent.supersonic.chat.core.pojo.QueryContext; +import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue; +import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.headless.api.pojo.QueryConfig; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import org.junit.jupiter.api.Test; + +class SchemaCorrectorTest { + + private String json = "{\n" + + " \"viewId\": 1,\n" + + " \"llmReq\": {\n" + + " \"queryText\": \"xxx2024年播放量最高的十首歌\",\n" + + " \"filterCondition\": {\n" + + " \"tableName\": null\n" + + " },\n" + + " \"schema\": {\n" + + " \"domainName\": \"歌曲\",\n" + + " \"viewName\": \"歌曲\",\n" + + " \"fieldNameList\": [\n" + + " \"商务组\",\n" + + " \"歌曲名\",\n" + + " \"播放量\",\n" + + " \"播放份额\",\n" + + " \"数据日期\"\n" + + " ]\n" + + " },\n" + + " \"linking\": [\n" + + "\n" + + " ],\n" + + " \"currentDate\": \"2024-02-24\",\n" + + " \"priorExts\": \"播放份额是小数; \",\n" + + " \"sqlGenerationMode\": \"2_pass_auto_cot\"\n" + + " },\n" + + " \"request\": null,\n" + + " \"commonAgentTool\": {\n" + + " \"id\": \"y3LqVSRL\",\n" + + " \"name\": \"大模型语义解析\",\n" + + " \"type\": \"NL2SQL_LLM\",\n" + + " \"viewIds\": [\n" + + " 1\n" + + " ]\n" + + " },\n" + + " \"linkingValues\": [\n" + + "\n" + + " ]\n" + + "}"; + + @Test + void doCorrect() throws JsonProcessingException { + Long viewId = 1L; + QueryContext queryContext = buildQueryContext(viewId); + ObjectMapper objectMapper = new ObjectMapper(); + ParseResult parseResult = objectMapper.readValue(json, ParseResult.class); + + + String sql = "select 歌曲名 from 歌曲 where 发行日期 >= '2024-01-01' " + + "and 商务组 = 'xxx' order by 播放量 desc limit 10"; + SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); + SqlInfo sqlInfo = new SqlInfo(); + sqlInfo.setS2SQL(sql); + sqlInfo.setCorrectS2SQL(sql); + semanticParseInfo.setSqlInfo(sqlInfo); + + SchemaElement schemaElement = new SchemaElement(); + schemaElement.setView(viewId); + semanticParseInfo.setView(schemaElement); + + + semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult); + + SchemaCorrector schemaCorrector = new SchemaCorrector(); + schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo); + + assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " + + "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL()); + + parseResult = objectMapper.readValue(json, ParseResult.class); + + List linkingValues = new ArrayList<>(); + ElementValue elementValue = new ElementValue(); + elementValue.setFieldName("商务组"); + elementValue.setFieldValue("xxx"); + linkingValues.add(elementValue); + parseResult.setLinkingValues(linkingValues); + semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult); + + semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql); + semanticParseInfo.getSqlInfo().setS2SQL(sql); + schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo); + assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " + + "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL()); + + } + + private QueryContext buildQueryContext(Long viewId) { + QueryContext queryContext = new QueryContext(); + List viewSchemaList = new ArrayList<>(); + ViewSchema viewSchema = new ViewSchema(); + QueryConfig queryConfig = new QueryConfig(); + viewSchema.setQueryConfig(queryConfig); + SchemaElement schemaElement = new SchemaElement(); + schemaElement.setView(viewId); + viewSchema.setView(schemaElement); + Set dimensions = new HashSet<>(); + SchemaElement element1 = new SchemaElement(); + element1.setView(1L); + element1.setName("歌曲名"); + dimensions.add(element1); + + SchemaElement element2 = new SchemaElement(); + element2.setView(1L); + element2.setName("商务组"); + dimensions.add(element2); + + SchemaElement element3 = new SchemaElement(); + element3.setView(1L); + element3.setName("发行日期"); + dimensions.add(element3); + + viewSchema.setDimensions(dimensions); + viewSchemaList.add(viewSchema); + + SemanticSchema semanticSchema = new SemanticSchema(viewSchemaList); + queryContext.setSemanticSchema(semanticSchema); + return queryContext; + } +} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java b/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java index 105d82418..5d62163a8 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java @@ -1,19 +1,21 @@ package com.tencent.supersonic.common.util; +import com.tencent.supersonic.common.pojo.Constants; import java.text.DateFormat; import java.text.SimpleDateFormat; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeParseException; import java.time.temporal.ChronoField; import java.time.temporal.TemporalAdjuster; import java.time.temporal.TemporalAdjusters; import java.util.ArrayList; +import java.util.Arrays; import java.util.Calendar; import java.util.Date; import java.util.List; import java.util.Objects; -import com.tencent.supersonic.common.pojo.Constants; import lombok.extern.slf4j.Slf4j; @Slf4j @@ -166,4 +168,27 @@ public static List getDateList(String startDateStr, String endDateStr, S return datesInRange; } + public static boolean isAnyDateString(String value) { + List formats = Arrays.asList("yyyy-MM-dd", "yyyy-MM", "yyyy/MM/dd"); + return isAnyDateString(value, formats); + } + + public static boolean isAnyDateString(String value, List formats) { + for (String format : formats) { + if (isDateString(value, format)) { + return true; + } + } + return false; + } + + public static boolean isDateString(String value, String format) { + try { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern(format); + LocalDate.parse(value, formatter); + return true; + } catch (DateTimeParseException e) { + return false; + } + } }