Skip to content

Commit

Permalink
Merge branch 'tencentmusic:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
sevenliu1896 committed Mar 2, 2024
2 parents 0411466 + 93534af commit f6fc749
Show file tree
Hide file tree
Showing 35 changed files with 586 additions and 95 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# SuperSonic (超音数)

**SuperSonic is the next-generation LLM-powered data analytics platform that integrates ChatBI and HeadlessBI**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of entities/metrics/dimensions/tags, along with their meaning, context and relationships) on top of physical data models, and **no data modification or copying** is required. Meanwhile, SuperSonic is designed to be **highly extensible**, allowing custom functionalities to be added and configured with Java SPI.
**SuperSonic is the next-generation LLM-powered data analytics platform that integrates ChatBI and HeadlessBI**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of entities/metrics/dimensions/tags, along with their meaning, context and relationships) with semantic layer, and **no data modification or copying** is required. Meanwhile, SuperSonic is designed to be **highly extensible**, allowing custom functionalities to be added and configured with Java SPI.

<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>

Expand All @@ -13,7 +13,8 @@
The emergence of Large Language Model (LLM) like ChatGPT is reshaping the way information is retrieved. In the field of data analytics, both academia and industry are primarily focused on leveraging LLM to convert natural language into SQL (so called Text2SQL or NL2SQL). While some approaches exhibit promising results, their **reliability** and **efficiency** are insufficient for real-world applications.

From our perspective, the key to filling the real-world gap lies in three aspects:
1. Integrate ChatBI with HeadlessBI encapsulating underlying data context (joins, keys, formulas, etc) to **reduce complexity**.
1. Integrate ChatBI with HeadlessBI encapsulating underlying data context (joins, keys, formulas, etc) to **reduce complexity**.
<img src="./docs/images/supersonic_ideas.png" height="65%" width="65%" align="center"/>
2. Augment the LLM with schema mappers(as a kind of preprocessor) and semantic correctors(as a kind of postprocessor) to **mitigate hallucination**.
3. Utilize rule-based schema parsers when necessary to **improve efficiency**(in terms of latency and cost).

Expand Down
2 changes: 2 additions & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

在我们看来,为了在实际场景发挥价值,有三个关键点:
1. 融合HeadlessBI,通过统一语义层封装底层数据细节(关联、键值、公式等),降低SQL生成的**复杂度**

<img src="./docs/images/supersonic_ideas.png" height="65%" width="65%" align="center"/>
2. 通过一前一后的模式映射器和语义修正器,来缓解LLM常见的**幻觉**现象。
3. 设计启发式的规则,在一些特定场景提升语义解析的**效率**

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
public class QueryReq {
private String queryText;
private Integer chatId;
private Long modelId;
private Long viewId;
private User user;
private QueryFilters queryFilters;
private boolean saveAnswer = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ public List<NL2SQLTool> getParserTools(AgentToolType agentToolType) {
.collect(Collectors.toList());
}

public Set<Long> getViewIds() {
return getViewIds(null);
}

public Set<Long> getViewIds(AgentToolType agentToolType) {
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
if (CollectionUtils.isEmpty(commonAgentTools)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,19 @@ protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long
return semanticSchema.getMetrics(viewId);
}

protected Set<String> getDimensions(Long viewId, SemanticSchema semanticSchema) {
Set<String> dimensions = semanticSchema.getDimensions(viewId).stream()
.flatMap(
schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
elements.addAll(schemaElement.getAlias());
}
return elements.stream();
}
).collect(Collectors.toSet());
dimensions.add(TimeDimensionEnum.DAY.getChName());
return dimensions;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.ViewService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;

import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;

/**
* Perform SQL corrections on the "Group by" section in S2SQL.
Expand Down Expand Up @@ -82,22 +80,6 @@ private Boolean needAddGroupBy(QueryContext queryContext, SemanticParseInfo sema
return true;
}

private Set<String> getDimensions(Long viewId, SemanticSchema semanticSchema) {
Set<String> dimensions = semanticSchema.getDimensions(viewId).stream()
.flatMap(
schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
elements.addAll(schemaElement.getAlias());
}
return elements.stream();
}
).collect(Collectors.toSet());
dimensions.add(TimeDimensionEnum.DAY.getChName());
return dimensions;
}

private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Long viewId = semanticParseInfo.getViewId();
//add dimension group by
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
package com.tencent.supersonic.chat.core.corrector;

import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;

import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;

/**
* Perform schema corrections on the Schema information in S2SQL.
Expand All @@ -27,6 +35,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {

removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);

correctAggFunction(semanticParseInfo);

replaceAlias(semanticParseInfo);
Expand Down Expand Up @@ -105,4 +115,35 @@ private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo)
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
sqlInfo.setCorrectS2SQL(sql);
}

public void removeFilterIfNotInLinkingValue(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
if (CollectionUtils.isEmpty(whereExpressionList)) {
return;
}
List<ElementValue> linkingValues = getLinkingValues(semanticParseInfo);
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Set<String> dimensions = getDimensions(semanticParseInfo.getViewId(), semanticSchema);

if (CollectionUtils.isEmpty(linkingValues)) {
linkingValues = new ArrayList<>();
}
Set<String> linkingFieldNames = linkingValues.stream().map(linking -> linking.getFieldName())
.collect(Collectors.toSet());

Set<String> removeFieldNames = whereExpressionList.stream()
.filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction()))
.filter(fieldExpression -> !TimeDimensionEnum.containsTimeDimension(fieldExpression.getFieldName()))
.filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator()))
.filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName()))
.filter(fieldExpression -> !DateUtils.isAnyDateString(fieldExpression.getFieldValue().toString()))
.filter(fieldExpression -> !linkingFieldNames.contains(fieldExpression.getFieldName()))
.map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());

String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
sqlInfo.setCorrectS2SQL(sql);
}

}
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
package com.tencent.supersonic.chat.core.mapper;

import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

@Slf4j
Expand All @@ -24,29 +26,35 @@ public class QueryFilterMapper implements SchemaMapper {

@Override
public void map(QueryContext queryContext) {
Long viewId = queryContext.getViewId();
if (viewId == null || viewId <= 0) {
Agent agent = queryContext.getAgent();
if (agent == null || CollectionUtils.isEmpty(agent.getViewIds())) {
return;
}
if (Agent.containsAllModel(agent.getViewIds())) {
return;
}
Set<Long> viewIds = agent.getViewIds();
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
clearOtherSchemaElementMatch(viewId, schemaMapInfo);
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(viewId);
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();
schemaMapInfo.setMatchedElements(viewId, schemaElementMatches);
clearOtherSchemaElementMatch(viewIds, schemaMapInfo);
for (Long viewId : viewIds) {
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(viewId);
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();
schemaMapInfo.setMatchedElements(viewId, schemaElementMatches);
}
addValueSchemaElementMatch(viewId, queryContext, schemaElementMatches);
}
addValueSchemaElementMatch(queryContext, schemaElementMatches);
}

private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
private void clearOtherSchemaElementMatch(Set<Long> viewIds, SchemaMapInfo schemaMapInfo) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getViewElementMatches().entrySet()) {
if (!entry.getKey().equals(modelId)) {
if (!viewIds.contains(entry.getKey())) {
entry.getValue().clear();
}
}
}

private List<SchemaElementMatch> addValueSchemaElementMatch(QueryContext queryContext,
private List<SchemaElementMatch> addValueSchemaElementMatch(Long viewId, QueryContext queryContext,
List<SchemaElementMatch> candidateElementMatches) {
QueryFilters queryFilters = queryContext.getQueryFilters();
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
Expand All @@ -61,7 +69,7 @@ private List<SchemaElementMatch> addValueSchemaElementMatch(QueryContext queryCo
.name(String.valueOf(filter.getValue()))
.type(SchemaElementType.VALUE)
.bizName(filter.getBizName())
.view(queryContext.getViewId())
.view(viewId)
.build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,22 @@ public class AgentCheckParser implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> queries = queryContext.getCandidateQueries();
agentCanSupport(queryContext, queries);
log.info("query size before agent filter:{}", queryContext.getCandidateQueries().size());
filterQueries(queryContext, queries);
log.info("query size after agent filter: {}", queryContext.getCandidateQueries().size());
}

private void agentCanSupport(QueryContext queryContext, List<SemanticQuery> queries) {
private void filterQueries(QueryContext queryContext, List<SemanticQuery> queries) {
Agent agent = queryContext.getAgent();
if (agent == null) {
return;
}
List<RuleParserTool> queryTools = getRuleTools(agent);
if (CollectionUtils.isEmpty(queryTools)) {
queries.clear();
queryContext.setCandidateQueries(Lists.newArrayList());
return;
}
log.info("queries resolved:{} {}", agent.getName(),
log.info("agent name :{}, queries resolved: {}", agent.getName(),
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
queries.removeIf(query -> {
for (RuleParserTool tool : queryTools) {
Expand All @@ -46,10 +48,14 @@ private void agentCanSupport(QueryContext queryContext, List<SemanticQuery> quer
}
if (CollectionUtils.isNotEmpty(tool.getQueryTypes())) {
if (QueryManager.isTagQuery(query.getQueryMode())) {
return !tool.getQueryTypes().contains(QueryType.TAG.name());
if (!tool.getQueryTypes().contains(QueryType.TAG.name())) {
return true;
}
}
if (QueryManager.isMetricQuery(query.getQueryMode())) {
return !tool.getQueryTypes().contains(QueryType.METRIC.name());
if (!tool.getQueryTypes().contains(QueryType.METRIC.name())) {
return true;
}
}
}
if (CollectionUtils.isEmpty(tool.getViewIds())) {
Expand All @@ -62,7 +68,8 @@ private void agentCanSupport(QueryContext queryContext, List<SemanticQuery> quer
}
return true;
});
log.info("rule queries witch can be supported by agent :{} {}", agent.getName(),
queryContext.setCandidateQueries(queries);
log.info("agent name :{}, rule queries witch can be supported by agent :{}", agent.getName(),
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ public class RuleSqlParser implements SemanticParser {

private static List<SemanticParser> auxiliaryParsers = Arrays.asList(
new ContextInheritParser(),
new AgentCheckParser(),
new TimeRangeParser(),
new AggregateTypeParser()
new AggregateTypeParser(),
new AgentCheckParser()
);

@Override
Expand Down
Loading

0 comments on commit f6fc749

Please sign in to comment.