diff --git a/README.md b/README.md index 38d2809c1..6bded88f0 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ # SuperSonic (超音数) -**SuperSonic is the next-generation LLM-powered data analytics platform that integrates ChatBI and HeadlessBI**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of entities/metrics/dimensions/tags, along with their meaning, context and relationships) on top of physical data models, and **no data modification or copying** is required. Meanwhile, SuperSonic is designed to be **highly extensible**, allowing custom functionalities to be added and configured with Java SPI. +**SuperSonic is the next-generation LLM-powered data analytics platform that integrates ChatBI and HeadlessBI**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of entities/metrics/dimensions/tags, along with their meaning, context and relationships) with semantic layer, and **no data modification or copying** is required. Meanwhile, SuperSonic is designed to be **highly extensible**, allowing custom functionalities to be added and configured with Java SPI. @@ -13,7 +13,8 @@ The emergence of Large Language Model (LLM) like ChatGPT is reshaping the way information is retrieved. In the field of data analytics, both academia and industry are primarily focused on leveraging LLM to convert natural language into SQL (so called Text2SQL or NL2SQL). While some approaches exhibit promising results, their **reliability** and **efficiency** are insufficient for real-world applications. From our perspective, the key to filling the real-world gap lies in three aspects: -1. Integrate ChatBI with HeadlessBI encapsulating underlying data context (joins, keys, formulas, etc) to **reduce complexity**. +1. Integrate ChatBI with HeadlessBI encapsulating underlying data context (joins, keys, formulas, etc) to **reduce complexity**. + 2. Augment the LLM with schema mappers(as a kind of preprocessor) and semantic correctors(as a kind of postprocessor) to **mitigate hallucination**. 3. Utilize rule-based schema parsers when necessary to **improve efficiency**(in terms of latency and cost). diff --git a/README_CN.md b/README_CN.md index 1ae61b51d..b85af8c8c 100644 --- a/README_CN.md +++ b/README_CN.md @@ -10,6 +10,8 @@ 在我们看来,为了在实际场景发挥价值,有三个关键点: 1. 融合HeadlessBI,通过统一语义层封装底层数据细节(关联、键值、公式等),降低SQL生成的**复杂度**。 + + 2. 通过一前一后的模式映射器和语义修正器,来缓解LLM常见的**幻觉**现象。 3. 设计启发式的规则,在一些特定场景提升语义解析的**效率**。 diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/QueryReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/QueryReq.java index 1d457610a..25adb87a5 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/QueryReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/QueryReq.java @@ -7,7 +7,7 @@ public class QueryReq { private String queryText; private Integer chatId; - private Long modelId; + private Long viewId; private User user; private QueryFilters queryFilters; private boolean saveAnswer = true; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/agent/Agent.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/agent/Agent.java index a1c24f55f..7a723c56f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/agent/Agent.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/agent/Agent.java @@ -65,6 +65,10 @@ public List getParserTools(AgentToolType agentToolType) { .collect(Collectors.toList()); } + public Set getViewIds() { + return getViewIds(null); + } + public Set getViewIds(AgentToolType agentToolType) { List commonAgentTools = getParserTools(agentToolType); if (CollectionUtils.isEmpty(commonAgentTools)) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/BaseSemanticCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/BaseSemanticCorrector.java index 09a8beaa3..bf530e3ad 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/BaseSemanticCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/BaseSemanticCorrector.java @@ -140,4 +140,19 @@ protected List getMetricElements(QueryContext queryContext, Long return semanticSchema.getMetrics(viewId); } + protected Set getDimensions(Long viewId, SemanticSchema semanticSchema) { + Set dimensions = semanticSchema.getDimensions(viewId).stream() + .flatMap( + schemaElement -> { + Set elements = new HashSet<>(); + elements.add(schemaElement.getName()); + if (!CollectionUtils.isEmpty(schemaElement.getAlias())) { + elements.addAll(schemaElement.getAlias()); + } + return elements.stream(); + } + ).collect(Collectors.toSet()); + dimensions.add(TimeDimensionEnum.DAY.getChName()); + return dimensions; + } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/GroupByCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/GroupByCorrector.java index a88c9c945..0b0f7c688 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/GroupByCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/GroupByCorrector.java @@ -14,14 +14,12 @@ import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.service.ModelService; import com.tencent.supersonic.headless.server.service.ViewService; -import lombok.extern.slf4j.Slf4j; -import org.springframework.util.CollectionUtils; - -import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.springframework.util.CollectionUtils; /** * Perform SQL corrections on the "Group by" section in S2SQL. @@ -82,22 +80,6 @@ private Boolean needAddGroupBy(QueryContext queryContext, SemanticParseInfo sema return true; } - private Set getDimensions(Long viewId, SemanticSchema semanticSchema) { - Set dimensions = semanticSchema.getDimensions(viewId).stream() - .flatMap( - schemaElement -> { - Set elements = new HashSet<>(); - elements.add(schemaElement.getName()); - if (!CollectionUtils.isEmpty(schemaElement.getAlias())) { - elements.addAll(schemaElement.getAlias()); - } - return elements.stream(); - } - ).collect(Collectors.toSet()); - dimensions.add(TimeDimensionEnum.DAY.getChName()); - return dimensions; - } - private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { Long viewId = semanticParseInfo.getViewId(); //add dimension group by diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/SchemaCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/SchemaCorrector.java index 3b657f792..ac9b70f18 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/SchemaCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/SchemaCorrector.java @@ -1,22 +1,30 @@ package com.tencent.supersonic.chat.core.corrector; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult; import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue; import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; +import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum; +import com.tencent.supersonic.common.util.jsqlparser.FieldExpression; +import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper; -import lombok.extern.slf4j.Slf4j; -import org.springframework.util.CollectionUtils; - +import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.util.CollectionUtils; /** * Perform schema corrections on the Schema information in S2SQL. @@ -27,6 +35,8 @@ public class SchemaCorrector extends BaseSemanticCorrector { @Override public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo); + correctAggFunction(semanticParseInfo); replaceAlias(semanticParseInfo); @@ -105,4 +115,35 @@ private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false); sqlInfo.setCorrectS2SQL(sql); } + + public void removeFilterIfNotInLinkingValue(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); + String correctS2SQL = sqlInfo.getCorrectS2SQL(); + List whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL); + if (CollectionUtils.isEmpty(whereExpressionList)) { + return; + } + List linkingValues = getLinkingValues(semanticParseInfo); + SemanticSchema semanticSchema = queryContext.getSemanticSchema(); + Set dimensions = getDimensions(semanticParseInfo.getViewId(), semanticSchema); + + if (CollectionUtils.isEmpty(linkingValues)) { + linkingValues = new ArrayList<>(); + } + Set linkingFieldNames = linkingValues.stream().map(linking -> linking.getFieldName()) + .collect(Collectors.toSet()); + + Set removeFieldNames = whereExpressionList.stream() + .filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction())) + .filter(fieldExpression -> !TimeDimensionEnum.containsTimeDimension(fieldExpression.getFieldName())) + .filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator())) + .filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName())) + .filter(fieldExpression -> !DateUtils.isAnyDateString(fieldExpression.getFieldValue().toString())) + .filter(fieldExpression -> !linkingFieldNames.contains(fieldExpression.getFieldName())) + .map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet()); + + String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames); + sqlInfo.setCorrectS2SQL(sql); + } + } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/QueryFilterMapper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/QueryFilterMapper.java index 8e054891d..20ba6743c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/QueryFilterMapper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/QueryFilterMapper.java @@ -1,20 +1,22 @@ package com.tencent.supersonic.chat.core.mapper; import com.google.common.collect.Lists; -import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; -import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; -import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder; +import com.tencent.supersonic.chat.core.agent.Agent; import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; +import com.tencent.supersonic.headless.api.pojo.SchemaElementType; +import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder; import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; @Slf4j @@ -24,29 +26,35 @@ public class QueryFilterMapper implements SchemaMapper { @Override public void map(QueryContext queryContext) { - Long viewId = queryContext.getViewId(); - if (viewId == null || viewId <= 0) { + Agent agent = queryContext.getAgent(); + if (agent == null || CollectionUtils.isEmpty(agent.getViewIds())) { + return; + } + if (Agent.containsAllModel(agent.getViewIds())) { return; } + Set viewIds = agent.getViewIds(); SchemaMapInfo schemaMapInfo = queryContext.getMapInfo(); - clearOtherSchemaElementMatch(viewId, schemaMapInfo); - List schemaElementMatches = schemaMapInfo.getMatchedElements(viewId); - if (schemaElementMatches == null) { - schemaElementMatches = Lists.newArrayList(); - schemaMapInfo.setMatchedElements(viewId, schemaElementMatches); + clearOtherSchemaElementMatch(viewIds, schemaMapInfo); + for (Long viewId : viewIds) { + List schemaElementMatches = schemaMapInfo.getMatchedElements(viewId); + if (schemaElementMatches == null) { + schemaElementMatches = Lists.newArrayList(); + schemaMapInfo.setMatchedElements(viewId, schemaElementMatches); + } + addValueSchemaElementMatch(viewId, queryContext, schemaElementMatches); } - addValueSchemaElementMatch(queryContext, schemaElementMatches); } - private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) { + private void clearOtherSchemaElementMatch(Set viewIds, SchemaMapInfo schemaMapInfo) { for (Map.Entry> entry : schemaMapInfo.getViewElementMatches().entrySet()) { - if (!entry.getKey().equals(modelId)) { + if (!viewIds.contains(entry.getKey())) { entry.getValue().clear(); } } } - private List addValueSchemaElementMatch(QueryContext queryContext, + private List addValueSchemaElementMatch(Long viewId, QueryContext queryContext, List candidateElementMatches) { QueryFilters queryFilters = queryContext.getQueryFilters(); if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) { @@ -61,7 +69,7 @@ private List addValueSchemaElementMatch(QueryContext queryCo .name(String.valueOf(filter.getValue())) .type(SchemaElementType.VALUE) .bizName(filter.getBizName()) - .view(queryContext.getViewId()) + .view(viewId) .build(); SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() .element(element) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/AgentCheckParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/AgentCheckParser.java index 9031c6fb5..3fd10dfdd 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/AgentCheckParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/AgentCheckParser.java @@ -23,20 +23,22 @@ public class AgentCheckParser implements SemanticParser { @Override public void parse(QueryContext queryContext, ChatContext chatContext) { List queries = queryContext.getCandidateQueries(); - agentCanSupport(queryContext, queries); + log.info("query size before agent filter:{}", queryContext.getCandidateQueries().size()); + filterQueries(queryContext, queries); + log.info("query size after agent filter: {}", queryContext.getCandidateQueries().size()); } - private void agentCanSupport(QueryContext queryContext, List queries) { + private void filterQueries(QueryContext queryContext, List queries) { Agent agent = queryContext.getAgent(); if (agent == null) { return; } List queryTools = getRuleTools(agent); if (CollectionUtils.isEmpty(queryTools)) { - queries.clear(); + queryContext.setCandidateQueries(Lists.newArrayList()); return; } - log.info("queries resolved:{} {}", agent.getName(), + log.info("agent name :{}, queries resolved: {}", agent.getName(), queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList())); queries.removeIf(query -> { for (RuleParserTool tool : queryTools) { @@ -46,10 +48,14 @@ private void agentCanSupport(QueryContext queryContext, List quer } if (CollectionUtils.isNotEmpty(tool.getQueryTypes())) { if (QueryManager.isTagQuery(query.getQueryMode())) { - return !tool.getQueryTypes().contains(QueryType.TAG.name()); + if (!tool.getQueryTypes().contains(QueryType.TAG.name())) { + return true; + } } if (QueryManager.isMetricQuery(query.getQueryMode())) { - return !tool.getQueryTypes().contains(QueryType.METRIC.name()); + if (!tool.getQueryTypes().contains(QueryType.METRIC.name())) { + return true; + } } } if (CollectionUtils.isEmpty(tool.getViewIds())) { @@ -62,7 +68,8 @@ private void agentCanSupport(QueryContext queryContext, List quer } return true; }); - log.info("rule queries witch can be supported by agent :{} {}", agent.getName(), + queryContext.setCandidateQueries(queries); + log.info("agent name :{}, rule queries witch can be supported by agent :{}", agent.getName(), queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList())); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/RuleSqlParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/RuleSqlParser.java index 916f25c0f..48ce387b3 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/RuleSqlParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/rule/RuleSqlParser.java @@ -20,9 +20,9 @@ public class RuleSqlParser implements SemanticParser { private static List auxiliaryParsers = Arrays.asList( new ContextInheritParser(), - new AgentCheckParser(), new TimeRangeParser(), - new AggregateTypeParser() + new AggregateTypeParser(), + new AgentCheckParser() ); @Override diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/core/corrector/SchemaCorrectorTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/core/corrector/SchemaCorrectorTest.java new file mode 100644 index 000000000..1abdfc6df --- /dev/null +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/core/corrector/SchemaCorrectorTest.java @@ -0,0 +1,143 @@ +package com.tencent.supersonic.chat.core.corrector; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.chat.api.pojo.ViewSchema; +import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; +import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult; +import com.tencent.supersonic.chat.core.pojo.QueryContext; +import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue; +import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.headless.api.pojo.QueryConfig; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import org.junit.jupiter.api.Test; + +class SchemaCorrectorTest { + + private String json = "{\n" + + " \"viewId\": 1,\n" + + " \"llmReq\": {\n" + + " \"queryText\": \"xxx2024年播放量最高的十首歌\",\n" + + " \"filterCondition\": {\n" + + " \"tableName\": null\n" + + " },\n" + + " \"schema\": {\n" + + " \"domainName\": \"歌曲\",\n" + + " \"viewName\": \"歌曲\",\n" + + " \"fieldNameList\": [\n" + + " \"商务组\",\n" + + " \"歌曲名\",\n" + + " \"播放量\",\n" + + " \"播放份额\",\n" + + " \"数据日期\"\n" + + " ]\n" + + " },\n" + + " \"linking\": [\n" + + "\n" + + " ],\n" + + " \"currentDate\": \"2024-02-24\",\n" + + " \"priorExts\": \"播放份额是小数; \",\n" + + " \"sqlGenerationMode\": \"2_pass_auto_cot\"\n" + + " },\n" + + " \"request\": null,\n" + + " \"commonAgentTool\": {\n" + + " \"id\": \"y3LqVSRL\",\n" + + " \"name\": \"大模型语义解析\",\n" + + " \"type\": \"NL2SQL_LLM\",\n" + + " \"viewIds\": [\n" + + " 1\n" + + " ]\n" + + " },\n" + + " \"linkingValues\": [\n" + + "\n" + + " ]\n" + + "}"; + + @Test + void doCorrect() throws JsonProcessingException { + Long viewId = 1L; + QueryContext queryContext = buildQueryContext(viewId); + ObjectMapper objectMapper = new ObjectMapper(); + ParseResult parseResult = objectMapper.readValue(json, ParseResult.class); + + + String sql = "select 歌曲名 from 歌曲 where 发行日期 >= '2024-01-01' " + + "and 商务组 = 'xxx' order by 播放量 desc limit 10"; + SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); + SqlInfo sqlInfo = new SqlInfo(); + sqlInfo.setS2SQL(sql); + sqlInfo.setCorrectS2SQL(sql); + semanticParseInfo.setSqlInfo(sqlInfo); + + SchemaElement schemaElement = new SchemaElement(); + schemaElement.setView(viewId); + semanticParseInfo.setView(schemaElement); + + + semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult); + + SchemaCorrector schemaCorrector = new SchemaCorrector(); + schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo); + + assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " + + "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL()); + + parseResult = objectMapper.readValue(json, ParseResult.class); + + List linkingValues = new ArrayList<>(); + ElementValue elementValue = new ElementValue(); + elementValue.setFieldName("商务组"); + elementValue.setFieldValue("xxx"); + linkingValues.add(elementValue); + parseResult.setLinkingValues(linkingValues); + semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult); + + semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql); + semanticParseInfo.getSqlInfo().setS2SQL(sql); + schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo); + assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " + + "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL()); + + } + + private QueryContext buildQueryContext(Long viewId) { + QueryContext queryContext = new QueryContext(); + List viewSchemaList = new ArrayList<>(); + ViewSchema viewSchema = new ViewSchema(); + QueryConfig queryConfig = new QueryConfig(); + viewSchema.setQueryConfig(queryConfig); + SchemaElement schemaElement = new SchemaElement(); + schemaElement.setView(viewId); + viewSchema.setView(schemaElement); + Set dimensions = new HashSet<>(); + SchemaElement element1 = new SchemaElement(); + element1.setView(1L); + element1.setName("歌曲名"); + dimensions.add(element1); + + SchemaElement element2 = new SchemaElement(); + element2.setView(1L); + element2.setName("商务组"); + dimensions.add(element2); + + SchemaElement element3 = new SchemaElement(); + element3.setView(1L); + element3.setName("发行日期"); + dimensions.add(element3); + + viewSchema.setDimensions(dimensions); + viewSchemaList.add(viewSchema); + + SemanticSchema semanticSchema = new SemanticSchema(viewSchemaList); + queryContext.setSemanticSchema(semanticSchema); + return queryContext; + } +} \ No newline at end of file diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/SearchServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/SearchServiceImpl.java index 2ccf64e4a..5518ab7be 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/SearchServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/SearchServiceImpl.java @@ -97,7 +97,7 @@ public List search(QueryReq queryReq) { List originals = knowledgeService.getTerms(queryText); log.info("hanlp parse result: {}", originals); MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class); - Set detectViewIds = mapperHelper.getViewIds(queryReq.getModelId(), agentService.getAgent(agentId)); + Set detectViewIds = mapperHelper.getViewIds(queryReq.getViewId(), agentService.getAgent(agentId)); QueryContext queryContext = new QueryContext(); BeanUtils.copyProperties(queryReq, queryContext); @@ -123,7 +123,7 @@ public List search(QueryReq queryReq) { Set searchResults = new LinkedHashSet(); ViewInfoStat modelStat = NatureHelper.getViewStat(originals); - List possibleModels = getPossibleModels(queryReq, originals, modelStat, queryReq.getModelId()); + List possibleModels = getPossibleModels(queryReq, originals, modelStat, queryReq.getViewId()); // 5.1 priority dimension metric boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleModels), modelToName, diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TypeEnums.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TypeEnums.java index 40b3d84ca..dacdfb09e 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TypeEnums.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TypeEnums.java @@ -4,6 +4,7 @@ public enum TypeEnums { METRIC, DIMENSION, + TAG, DOMAIN, ENTITY, VIEW, diff --git a/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java b/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java index 105d82418..5d62163a8 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java @@ -1,19 +1,21 @@ package com.tencent.supersonic.common.util; +import com.tencent.supersonic.common.pojo.Constants; import java.text.DateFormat; import java.text.SimpleDateFormat; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeParseException; import java.time.temporal.ChronoField; import java.time.temporal.TemporalAdjuster; import java.time.temporal.TemporalAdjusters; import java.util.ArrayList; +import java.util.Arrays; import java.util.Calendar; import java.util.Date; import java.util.List; import java.util.Objects; -import com.tencent.supersonic.common.pojo.Constants; import lombok.extern.slf4j.Slf4j; @Slf4j @@ -166,4 +168,27 @@ public static List getDateList(String startDateStr, String endDateStr, S return datesInRange; } + public static boolean isAnyDateString(String value) { + List formats = Arrays.asList("yyyy-MM-dd", "yyyy-MM", "yyyy/MM/dd"); + return isAnyDateString(value, formats); + } + + public static boolean isAnyDateString(String value, List formats) { + for (String format : formats) { + if (isDateString(value, format)) { + return true; + } + } + return false; + } + + public static boolean isDateString(String value, String format) { + try { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern(format); + LocalDate.parse(value, formatter); + return true; + } catch (DateTimeParseException e) { + return false; + } + } } diff --git a/docs/images/supersonic_ideas.png b/docs/images/supersonic_ideas.png new file mode 100644 index 000000000..2104e6af0 Binary files /dev/null and b/docs/images/supersonic_ideas.png differ diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/TagResp.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/TagResp.java index 6857330fa..ccc30c78d 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/TagResp.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/TagResp.java @@ -16,6 +16,10 @@ public class TagResp extends SchemaItem { private String type; + private Boolean isCollect; + + private boolean hasAdminRes; + private Map ext = new HashMap<>(); private TagDefineType tagDefineType = TagDefineType.FIELD; diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/file/FileHandlerImpl.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/file/FileHandlerImpl.java index c22573671..a25ce17fa 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/file/FileHandlerImpl.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/file/FileHandlerImpl.java @@ -3,6 +3,8 @@ import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; + +import java.io.File; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; @@ -17,9 +19,9 @@ @Slf4j @Component public class FileHandlerImpl implements FileHandler { + public static final String FILE_SPILT = File.separator; private final LocalFileConfig localFileConfig; - public FileHandlerImpl(LocalFileConfig localFileConfig) { this.localFileConfig = localFileConfig; } @@ -31,8 +33,8 @@ public void backupFile(String fileName) { createDir(dictDirectoryBackup); } - String source = localFileConfig.getDictDirectoryLatest() + "/" + fileName; - String target = dictDirectoryBackup + "/" + fileName; + String source = localFileConfig.getDictDirectoryLatest() + FILE_SPILT + fileName; + String target = dictDirectoryBackup + FILE_SPILT + fileName; Path sourcePath = Paths.get(source); Path targetPath = Paths.get(target); try { @@ -88,7 +90,7 @@ public void writeFile(List lines, String fileName, Boolean append) { if (!existPath(dictDirectoryLatest)) { createDir(dictDirectoryLatest); } - String filePath = dictDirectoryLatest + "/" + fileName; + String filePath = dictDirectoryLatest + FILE_SPILT + fileName; if (existPath(filePath)) { backupFile(fileName); } @@ -117,7 +119,7 @@ public String getDictRootPath() { @Override public Boolean deleteDictFile(String fileName) { backupFile(fileName); - deleteFile(localFileConfig.getDictDirectoryLatest() + "/" + fileName); + deleteFile(localFileConfig.getDictDirectoryLatest() + FILE_SPILT + fileName); return true; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/file/LocalFileConfig.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/file/LocalFileConfig.java index 6e2c8b4ba..7dc232e3d 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/file/LocalFileConfig.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/file/LocalFileConfig.java @@ -1,10 +1,13 @@ package com.tencent.supersonic.headless.core.file; +import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper; import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Configuration; +import java.io.FileNotFoundException; + @Data @Configuration @Slf4j @@ -18,16 +21,21 @@ public class LocalFileConfig { private String dictDirectoryBackup; public String getDictDirectoryLatest() { - return getResourceDir() + dictDirectoryLatest; + return getDictDirectoryPrefixDir() + dictDirectoryLatest; } public String getDictDirectoryBackup() { - return getResourceDir() + dictDirectoryBackup; + return getDictDirectoryPrefixDir() + dictDirectoryBackup; } - private String getResourceDir() { - //return hanlpPropertiesPath = HanlpHelper.getHanlpPropertiesPath(); - return ClassLoader.getSystemClassLoader().getResource("").getPath(); + private String getDictDirectoryPrefixDir() { + try { + return HanlpHelper.getHanlpPropertiesPath(); + } catch (FileNotFoundException e) { + log.warn("getDictDirectoryPrefixDir error: " + e); + e.printStackTrace(); + } + return ""; } } \ No newline at end of file diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SysTimeDimensionBuilder.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SysTimeDimensionBuilder.java index 311f3e124..e86298418 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SysTimeDimensionBuilder.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SysTimeDimensionBuilder.java @@ -16,7 +16,7 @@ public class SysTimeDimensionBuilder { public static void addSysTimeDimension(List dims, DbAdaptor engineAdaptor) { - log.info("addSysTimeDimension before:{}, engineAdaptor:{}", dims, engineAdaptor); + log.debug("addSysTimeDimension before:{}, engineAdaptor:{}", dims, engineAdaptor); Dim timeDim = getTimeDim(dims); if (timeDim == null) { timeDim = Dim.getDefault(); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/TagCustomMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/TagCustomMapper.java index f9344602e..597bcc8bf 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/TagCustomMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/TagCustomMapper.java @@ -8,4 +8,6 @@ @Mapper public interface TagCustomMapper { List query(TagFilter tagFilter); + + Boolean batchUpdateStatus(List tagDOList); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/TagRepository.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/TagRepository.java index 99d27eeff..7445776d4 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/TagRepository.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/TagRepository.java @@ -15,4 +15,6 @@ public interface TagRepository { TagDO getTagById(Long id); List query(TagFilter tagFilter); + + Boolean batchUpdateStatus(List tagDOList); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/TagRepositoryImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/TagRepositoryImpl.java index 5445b7461..3a9fa5011 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/TagRepositoryImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/TagRepositoryImpl.java @@ -41,4 +41,9 @@ public TagDO getTagById(Long id) { public List query(TagFilter tagFilter) { return tagCustomMapper.query(tagFilter); } + + @Override + public Boolean batchUpdateStatus(List tagDOList) { + return tagCustomMapper.batchUpdateStatus(tagDOList); + } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/pojo/TagFilterPage.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/pojo/TagFilterPage.java new file mode 100644 index 000000000..bb1dbb1b4 --- /dev/null +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/pojo/TagFilterPage.java @@ -0,0 +1,11 @@ +package com.tencent.supersonic.headless.server.pojo; + + +import com.tencent.supersonic.headless.api.pojo.request.PageSchemaItemReq; + +import java.util.List; + +public class TagFilterPage extends PageSchemaItemReq { + private String type; + private List statusList; +} \ No newline at end of file diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/CollectController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/CollectController.java index db77ca14f..401aa584b 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/CollectController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/CollectController.java @@ -33,9 +33,10 @@ public boolean createCollectionIndicators(@RequestBody CollectDO collectDO, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); - return collectService.createCollectionIndicators(user, collectDO.getId()); + return collectService.createCollectionIndicators(user, collectDO); } + @Deprecated @DeleteMapping("/deleteCollectionIndicators/{id}") public boolean deleteCollectionIndicators(@PathVariable Long id, HttpServletRequest request, @@ -44,4 +45,12 @@ public boolean deleteCollectionIndicators(@PathVariable Long id, return collectService.deleteCollectionIndicators(user, id); } + @PostMapping("/deleteCollectionIndicators") + public boolean deleteCollectionIndicators(@RequestBody CollectDO collectDO, + HttpServletRequest request, + HttpServletResponse response) { + User user = UserHolder.findUser(request, response); + return collectService.deleteCollectionIndicators(user, collectDO); + } + } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TagController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TagController.java index 4388b1c63..1a32a5d6c 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TagController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TagController.java @@ -1,9 +1,12 @@ package com.tencent.supersonic.headless.server.rest; +import com.github.pagehelper.PageInfo; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; +import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq; import com.tencent.supersonic.headless.api.pojo.request.TagReq; import com.tencent.supersonic.headless.api.pojo.response.TagResp; +import com.tencent.supersonic.headless.server.pojo.TagFilterPage; import com.tencent.supersonic.headless.server.service.TagService; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -40,6 +43,14 @@ public TagResp update(@RequestBody TagReq tagReq, return tagService.update(tagReq, user); } + @PostMapping("/batchUpdateStatus") + public Boolean batchUpdateStatus(@RequestBody MetaBatchReq metaBatchReq, + HttpServletRequest request, + HttpServletResponse response) { + User user = UserHolder.findUser(request, response); + return tagService.batchUpdateStatus(metaBatchReq, user); + } + @DeleteMapping("delete/{id}") public Boolean delete(@PathVariable("id") Long id, HttpServletRequest request, @@ -53,7 +64,16 @@ public Boolean delete(@PathVariable("id") Long id, public TagResp getTag(@PathVariable("id") Long id, HttpServletRequest request, HttpServletResponse response) { - return tagService.getTag(id); + User user = UserHolder.findUser(request, response); + return tagService.getTag(id, user); + } + + @PostMapping("/queryTag") + public PageInfo queryPage(@RequestBody TagFilterPage tagFilterPage, + HttpServletRequest request, + HttpServletResponse response) throws Exception { + User user = UserHolder.findUser(request, response); + return tagService.queryPage(tagFilterPage, user); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/CollectService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/CollectService.java index a8184f0c5..6dffeecda 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/CollectService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/CollectService.java @@ -10,10 +10,12 @@ public interface CollectService { - Boolean createCollectionIndicators(User user, Long id); + Boolean createCollectionIndicators(User user, CollectDO collectDO); Boolean deleteCollectionIndicators(User user, Long id); + Boolean deleteCollectionIndicators(User user, CollectDO collectDO); + List getCollectList(String username); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/TagService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/TagService.java index 5724530e9..dc8b9e196 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/TagService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/TagService.java @@ -1,9 +1,12 @@ package com.tencent.supersonic.headless.server.service; +import com.github.pagehelper.PageInfo; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq; import com.tencent.supersonic.headless.api.pojo.request.TagReq; import com.tencent.supersonic.headless.api.pojo.response.TagResp; import com.tencent.supersonic.headless.server.pojo.TagFilter; +import com.tencent.supersonic.headless.server.pojo.TagFilterPage; import java.util.List; public interface TagService { @@ -14,7 +17,11 @@ public interface TagService { void delete(Long id, User user) throws Exception; - TagResp getTag(Long id); + TagResp getTag(Long id, User user); List query(TagFilter tagFilter); + + PageInfo queryPage(TagFilterPage tagFilterPage, User user); + + Boolean batchUpdateStatus(MetaBatchReq metaBatchReq, User user); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/CollectServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/CollectServiceImpl.java index ff37083b4..942ab407e 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/CollectServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/CollectServiceImpl.java @@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.server.service.CollectService; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; +import org.apache.logging.log4j.util.Strings; import org.springframework.stereotype.Service; import javax.annotation.Resource; @@ -22,12 +23,12 @@ public class CollectServiceImpl implements CollectService { private CollectMapper collectMapper; @Override - public Boolean createCollectionIndicators(User user, Long id) { - CollectDO collectDO = new CollectDO(); - collectDO.setType(type); - collectDO.setUsername(user.getName()); - collectDO.setCollectId(id); - collectMapper.insert(collectDO); + public Boolean createCollectionIndicators(User user, CollectDO collectReq) { + CollectDO collect = new CollectDO(); + collect.setType(Strings.isEmpty(collectReq.getType()) ? type : collectReq.getType()); + collect.setUsername(user.getName()); + collect.setCollectId(collectReq.getId()); + collectMapper.insert(collect); return true; } @@ -41,6 +42,16 @@ public Boolean deleteCollectionIndicators(User user, Long id) { return true; } + @Override + public Boolean deleteCollectionIndicators(User user, CollectDO collectReq) { + QueryWrapper collectDOQueryWrapper = new QueryWrapper<>(); + collectDOQueryWrapper.lambda().eq(CollectDO::getUsername, user.getName()); + collectDOQueryWrapper.lambda().eq(CollectDO::getCollectId, collectReq.getCollectId()); + collectDOQueryWrapper.lambda().eq(CollectDO::getType, collectReq.getType()); + collectMapper.delete(collectDOQueryWrapper); + return true; + } + @Override public List getCollectList(String username) { QueryWrapper queryWrapper = new QueryWrapper<>(); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TagServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TagServiceImpl.java index 7babbb66b..0ca55c6c7 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TagServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TagServiceImpl.java @@ -1,25 +1,40 @@ package com.tencent.supersonic.headless.server.service.impl; import com.alibaba.fastjson.JSONObject; +import com.github.pagehelper.PageHelper; +import com.github.pagehelper.PageInfo; +import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.common.pojo.enums.AuthType; import com.tencent.supersonic.common.pojo.enums.StatusEnum; +import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.headless.api.pojo.TagDefineParams; import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType; +import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq; import com.tencent.supersonic.headless.api.pojo.request.TagReq; + +import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.TagResp; +import com.tencent.supersonic.headless.server.persistence.dataobject.CollectDO; import com.tencent.supersonic.headless.server.persistence.dataobject.TagDO; import com.tencent.supersonic.headless.server.persistence.repository.TagRepository; import com.tencent.supersonic.headless.server.pojo.TagFilter; +import com.tencent.supersonic.headless.server.pojo.TagFilterPage; +import com.tencent.supersonic.headless.server.service.CollectService; +import com.tencent.supersonic.headless.server.service.ModelService; import com.tencent.supersonic.headless.server.service.TagService; import com.tencent.supersonic.headless.server.utils.NameCheckUtils; + import java.util.ArrayList; import java.util.Arrays; import java.util.Date; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.stream.Collectors; + import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; @@ -31,9 +46,14 @@ public class TagServiceImpl implements TagService { private final TagRepository tagRepository; + private final ModelService modelService; + private final CollectService collectService; - public TagServiceImpl(TagRepository tagRepository) { + public TagServiceImpl(TagRepository tagRepository, ModelService modelService, + CollectService collectService) { this.tagRepository = tagRepository; + this.modelService = modelService; + this.collectService = collectService; } @Override @@ -85,8 +105,21 @@ public void delete(Long id, User user) throws Exception { } @Override - public TagResp getTag(Long id) { - return convert(tagRepository.getTagById(id)); + public TagResp getTag(Long id, User user) { + // return convert(tagRepository.getTagById(id)); + TagDO tagDO = tagRepository.getTagById(id); + TagResp tagResp = fillCollectAndAdminInfo(tagDO, user); + return tagResp; + } + + private TagResp fillCollectAndAdminInfo(TagDO tagDO, User user) { + List collectIds = collectService.getCollectList(user.getName()) + .stream().filter(collectDO -> TypeEnums.TAG.name().equalsIgnoreCase(collectDO.getType())) + .map(CollectDO::getCollectId).collect(Collectors.toList()); + + List tagRespList = convertList(new ArrayList<>(Arrays.asList(tagDO)), collectIds); + fillAdminRes(tagRespList, user); + return tagRespList.get(0); } @Override @@ -98,10 +131,98 @@ public List query(TagFilter tagFilter) { return new ArrayList<>(); } + @Override + public PageInfo queryPage(TagFilterPage tagFilterPage, User user) { + TagFilter tagFilter = new TagFilter(); + BeanUtils.copyProperties(tagFilterPage, tagFilter); + List modelRespList = modelService.getAllModelByDomainIds(tagFilterPage.getDomainIds()); + List modelIds = modelRespList.stream().map(ModelResp::getId).collect(Collectors.toList()); + tagFilterPage.getModelIds().addAll(modelIds); + tagFilter.setModelIds(tagFilterPage.getModelIds()); + + List collectList = collectService.getCollectList(user.getName()) + .stream().filter(collectDO -> TypeEnums.TAG.name().equalsIgnoreCase(collectDO.getType())) + .collect(Collectors.toList()); + List collectIds = collectList.stream().map(CollectDO::getCollectId).collect(Collectors.toList()); + if (tagFilterPage.isHasCollect()) { + if (CollectionUtils.isEmpty(collectIds)) { + tagFilter.setIds(Lists.newArrayList(-1L)); + } else { + tagFilter.setIds(collectIds); + } + } + + PageInfo tagDOPageInfo = PageHelper.startPage(tagFilterPage.getCurrent(), + tagFilterPage.getPageSize()) + .doSelectPageInfo(() -> query(tagFilter)); + PageInfo pageInfo = new PageInfo<>(); + BeanUtils.copyProperties(tagDOPageInfo, pageInfo); + List tagRespList = convertList(tagDOPageInfo.getList(), collectIds); + fillAdminRes(tagRespList, user); + pageInfo.setList(tagRespList); + + return pageInfo; + } + + @Override + public Boolean batchUpdateStatus(MetaBatchReq metaBatchReq, User user) { + if (Objects.isNull(metaBatchReq) || CollectionUtils.isEmpty(metaBatchReq.getIds()) + || Objects.isNull(metaBatchReq.getStatus())) { + return false; + } + TagFilter tagFilter = new TagFilter(); + tagFilter.setIds(metaBatchReq.getIds()); + List tagDOList = tagRepository.query(tagFilter); + if (CollectionUtils.isEmpty(tagDOList)) { + return true; + } + tagDOList.stream().forEach(tagDO -> { + tagDO.setStatus(metaBatchReq.getStatus()); + tagDO.setUpdatedAt(new Date()); + tagDO.setUpdatedBy(user.getName()); + }); + + tagRepository.batchUpdateStatus(tagDOList); + // todo sendEventBatch + + return true; + } + + private void fillAdminRes(List tagRespList, User user) { + List modelRespList = modelService.getModelListWithAuth(user, null, AuthType.ADMIN); + if (CollectionUtils.isEmpty(modelRespList)) { + return; + } + Set modelIdSet = modelRespList.stream().map(ModelResp::getId).collect(Collectors.toSet()); + for (TagResp tagResp : tagRespList) { + if (modelIdSet.contains(tagResp.getModelId())) { + tagResp.setHasAdminRes(true); + } else { + tagResp.setHasAdminRes(false); + } + } + } + + private List convertList(List tagDOList, List collectIds) { + List tagRespList = new ArrayList<>(); + if (CollectionUtils.isNotEmpty(tagDOList)) { + tagDOList.stream().forEach(tagDO -> { + TagResp tagResp = convert(tagDO); + if (CollectionUtils.isNotEmpty(collectIds) && collectIds.contains(tagDO.getId())) { + tagResp.setIsCollect(true); + } else { + tagResp.setIsCollect(false); + } + tagRespList.add(tagResp); + }); + } + return tagRespList; + } + private void checkExit(TagReq tagReq) { TagFilter tagFilter = new TagFilter(); tagFilter.setModelIds(Arrays.asList(tagReq.getModelId())); - //tagFilter.setStatusList(Arrays.asList(StatusEnum.ONLINE.getCode(),StatusEnum.OFFLINE.getCode())); + List tagResps = query(tagFilter); if (!CollectionUtils.isEmpty(tagResps)) { Long bizNameSameCount = tagResps.stream().filter(tagResp -> !tagResp.getId().equals(tagReq.getId())) diff --git a/headless/server/src/main/resources/mapper/custom/TagCustomMapper.xml b/headless/server/src/main/resources/mapper/custom/TagCustomMapper.xml index 339aefbb7..b6862ec28 100644 --- a/headless/server/src/main/resources/mapper/custom/TagCustomMapper.xml +++ b/headless/server/src/main/resources/mapper/custom/TagCustomMapper.xml @@ -71,7 +71,7 @@ and ( id like CONCAT('%',#{key , jdbcType=VARCHAR},'%') or name like CONCAT('%',#{key , jdbcType=VARCHAR},'%') or biz_name like CONCAT('%',#{key , jdbcType=VARCHAR},'%') or - description like CONCAT('%',#{key , jdbcType=VARCHAR},'%') + description like CONCAT('%',#{key , jdbcType=VARCHAR},'%')) and id like CONCAT('%',#{id , jdbcType=VARCHAR},'%') @@ -107,4 +107,14 @@ + + + update s2_tag + set status = #{tag.status,jdbcType=INTEGER}, + updated_at = #{tag.updatedAt,jdbcType=TIMESTAMP}, + updated_by = #{tag.updatedBy,jdbcType=VARCHAR} + where id = #{tag.id,jdbcType=BIGINT} + + + diff --git a/launchers/standalone/src/main/resources/config.update/sql-update.sql b/launchers/standalone/src/main/resources/config.update/sql-update.sql index 3e5c8a916..9d85c8e50 100644 --- a/launchers/standalone/src/main/resources/config.update/sql-update.sql +++ b/launchers/standalone/src/main/resources/config.update/sql-update.sql @@ -212,4 +212,33 @@ CREATE TABLE s2_tag( `updated_by` varchar(100) DEFAULT NULL , `ext` LONGVARCHAR DEFAULT NULL , PRIMARY KEY (`id`) -)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; \ No newline at end of file +)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +--20240301 +CREATE TABLE IF NOT EXISTS `s2_dictionary_conf` ( + `id` INT NOT NULL AUTO_INCREMENT, + `description` varchar(255) , + `type` varchar(255) NOT NULL , + `item_id` INT NOT NULL , + `config` text , + `status` varchar(255) NOT NULL , + `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP , + `created_by` varchar(100) NOT NULL , + PRIMARY KEY (`id`) +); +COMMENT ON TABLE s2_dictionary_conf IS 'dictionary conf information table'; + +CREATE TABLE IF NOT EXISTS `s2_dictionary_task` ( + `id` INT NOT NULL AUTO_INCREMENT, + `name` varchar(255) NOT NULL , + `description` varchar(255) , + `type` varchar(255) NOT NULL , + `item_id` INT NOT NULL , + `config` text , + `status` varchar(255) NOT NULL , + `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP , + `created_by` varchar(100) NOT NULL , + `elapsed_ms` bigINT DEFAULT NULL , + PRIMARY KEY (`id`) +); +COMMENT ON TABLE s2_dictionary_task IS 'dictionary task information table'; diff --git a/launchers/standalone/src/main/resources/db/schema-mysql.sql b/launchers/standalone/src/main/resources/db/schema-mysql.sql index a8e38fad5..4dd166098 100644 --- a/launchers/standalone/src/main/resources/db/schema-mysql.sql +++ b/launchers/standalone/src/main/resources/db/schema-mysql.sql @@ -205,29 +205,28 @@ CREATE TABLE IF NOT EXISTS `s2_dictionary_conf` ( `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT, `description` varchar(255) , `type` varchar(255) NOT NULL , - `item_id` INT NOT NULL , -- task Request Parameters md5 - `config` mediumtext , -- remark related information - `status` varchar(255) NOT NULL , -- the final status of the task + `item_id` INT NOT NULL , + `config` mediumtext , + `status` varchar(255) NOT NULL , `created_at` datetime NOT NULL COMMENT '创建时间' , `created_by` varchar(100) NOT NULL , PRIMARY KEY (`id`) -); -COMMENT ON TABLE s2_dictionary_conf IS '字典配置信息表'; +) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='字典配置信息表'; + CREATE TABLE IF NOT EXISTS `s2_dictionary_task` ( `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT, - `name` varchar(255) NOT NULL , -- task name + `name` varchar(255) NOT NULL , `description` varchar(255) , `type` varchar(255) NOT NULL , - `item_id` INT NOT NULL , -- task Request Parameters md5 - `config` mediumtext , -- remark related information - `status` varchar(255) NOT NULL , -- the final status of the task + `item_id` INT NOT NULL , + `config` mediumtext , + `status` varchar(255) NOT NULL , `created_at` datetime DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', `created_by` varchar(100) NOT NULL , - `elapsed_ms` int(10) DEFAULT NULL , -- the task takes time in milliseconds + `elapsed_ms` int(10) DEFAULT NULL , PRIMARY KEY (`id`) -); -COMMENT ON TABLE s2_dictionary_task IS 'dictionary task information table'; +) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='字典运行任务表'; CREATE TABLE `s2_dimension` ( diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/plugin/PluginRecognizeTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/plugin/PluginRecognizeTest.java index 66c29febb..055d40e42 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/plugin/PluginRecognizeTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/plugin/PluginRecognizeTest.java @@ -57,7 +57,7 @@ public void webPageRecognizeWithQueryFilter() throws Exception { QueryFilter queryFilter = new QueryFilter(); queryFilter.setElementID(2L); queryFilter.setValue("alice"); - queryRequest.setModelId(1L); + // queryRequest.setModelId(1L); queryFilters.getFilters().add(queryFilter); queryRequest.setQueryFilters(queryFilters); diff --git a/launchers/standalone/src/test/resources/application-local.yaml b/launchers/standalone/src/test/resources/application-local.yaml index 1fb4f1692..d30eec7c3 100644 --- a/launchers/standalone/src/test/resources/application-local.yaml +++ b/launchers/standalone/src/test/resources/application-local.yaml @@ -77,4 +77,4 @@ logging: inMemoryEmbeddingStore: persistent: - path: /tmp + path: d:// diff --git a/launchers/standalone/src/test/resources/db/schema-h2.sql b/launchers/standalone/src/test/resources/db/schema-h2.sql index 973d55d21..e12dd1d64 100644 --- a/launchers/standalone/src/test/resources/db/schema-h2.sql +++ b/launchers/standalone/src/test/resources/db/schema-h2.sql @@ -529,4 +529,24 @@ CREATE TABLE IF NOT EXISTS `s2_view` ( query_config VARCHAR(3000), `admin` varchar(3000) DEFAULT NULL, `admin_org` varchar(3000) DEFAULT NULL -); \ No newline at end of file +); + +CREATE TABLE IF NOT EXISTS `s2_tag` ( + `id` INT NOT NULL AUTO_INCREMENT, + `model_id` INT NOT NULL , + `name` varchar(255) NOT NULL , + `biz_name` varchar(255) NOT NULL , + `description` varchar(500) DEFAULT NULL , + `status` INT NOT NULL , + `sensitive_level` INT NOT NULL , + `type` varchar(50) NOT NULL , -- ATOMIC, DERIVED + `define_type` varchar(50) NOT NULL, -- FIELD, DIMENSION + `type_params` LONGVARCHAR DEFAULT NULL , + `created_at` TIMESTAMP NOT NULL , + `created_by` varchar(100) NOT NULL , + `updated_at` TIMESTAMP DEFAULT NULL , + `updated_by` varchar(100) DEFAULT NULL , + `ext` LONGVARCHAR DEFAULT NULL , + PRIMARY KEY (`id`) + ); +COMMENT ON TABLE s2_tag IS 'tag information'; \ No newline at end of file