Skip to content

Commit

Permalink
(improvement)(Chat) QueryFilterMapper obtain viewId from agent (#778)
Browse files Browse the repository at this point in the history
Co-authored-by: jolunoluo
  • Loading branch information
lxwcodemonkey committed Mar 1, 2024
1 parent 532a005 commit 2052352
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ public List<NL2SQLTool> getParserTools(AgentToolType agentToolType) {
.collect(Collectors.toList());
}

public Set<Long> getViewIds() {
return getViewIds(null);
}

public Set<Long> getViewIds(AgentToolType agentToolType) {
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
if (CollectionUtils.isEmpty(commonAgentTools)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
package com.tencent.supersonic.chat.core.mapper;

import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

@Slf4j
Expand All @@ -24,29 +26,35 @@ public class QueryFilterMapper implements SchemaMapper {

@Override
public void map(QueryContext queryContext) {
Long viewId = queryContext.getViewId();
if (viewId == null || viewId <= 0) {
Agent agent = queryContext.getAgent();
if (agent == null || CollectionUtils.isEmpty(agent.getViewIds())) {
return;
}
if (Agent.containsAllModel(agent.getViewIds())) {
return;
}
Set<Long> viewIds = agent.getViewIds();
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
clearOtherSchemaElementMatch(viewId, schemaMapInfo);
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(viewId);
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();
schemaMapInfo.setMatchedElements(viewId, schemaElementMatches);
clearOtherSchemaElementMatch(viewIds, schemaMapInfo);
for (Long viewId : viewIds) {
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(viewId);
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();
schemaMapInfo.setMatchedElements(viewId, schemaElementMatches);
}
addValueSchemaElementMatch(viewId, queryContext, schemaElementMatches);
}
addValueSchemaElementMatch(queryContext, schemaElementMatches);
}

private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
private void clearOtherSchemaElementMatch(Set<Long> viewIds, SchemaMapInfo schemaMapInfo) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getViewElementMatches().entrySet()) {
if (!entry.getKey().equals(modelId)) {
if (!viewIds.contains(entry.getKey())) {
entry.getValue().clear();
}
}
}

private List<SchemaElementMatch> addValueSchemaElementMatch(QueryContext queryContext,
private List<SchemaElementMatch> addValueSchemaElementMatch(Long viewId, QueryContext queryContext,
List<SchemaElementMatch> candidateElementMatches) {
QueryFilters queryFilters = queryContext.getQueryFilters();
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
Expand All @@ -61,7 +69,7 @@ private List<SchemaElementMatch> addValueSchemaElementMatch(QueryContext queryCo
.name(String.valueOf(filter.getValue()))
.type(SchemaElementType.VALUE)
.bizName(filter.getBizName())
.view(queryContext.getViewId())
.view(viewId)
.build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ public class RuleSqlParser implements SemanticParser {

private static List<SemanticParser> auxiliaryParsers = Arrays.asList(
new ContextInheritParser(),
new AgentCheckParser(),
new TimeRangeParser(),
new AggregateTypeParser()
new AggregateTypeParser(),
new AgentCheckParser()
);

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ logging:

inMemoryEmbeddingStore:
persistent:
path: /tmp
path: d://

0 comments on commit 2052352

Please sign in to comment.