Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[improvement][chat] The parser interface supports using the dataSetId provided by the frontend as the reference #1852

Merged
merged 1 commit into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public class ChatParseReq {
private String queryText;
private Integer chatId;
private Integer agentId;
private Long dataSetId;
private User user;
private QueryFilters queryFilters;
private boolean saveAnswer = true;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,56 +1,61 @@
package com.tencent.supersonic.chat.server.agent;

import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.server.memory.MemoryReviewTask;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data;
import org.springframework.util.CollectionUtils;

import java.util.*;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

@Data
public class Agent extends RecordInfo {

private static final int ONLINE_STATUS = 1;
private static final int OFFLINE_STATUS = 0;
private static final int ENABLED = 1;
private static final int DISABLED = 0;

private Integer id;
private String name;
private String description;
/** 0 offline, 1 online */
private Integer status = 1;
private Integer status = ONLINE_STATUS;
private List<String> examples;
private Integer enableSearch = 1;
private Integer enableFeedback = 0;
private Integer enableSearch = ENABLED;
private Integer enableFeedback = DISABLED;
private String toolConfig;
private Map<String, ChatApp> chatAppConfig = Collections.EMPTY_MAP;
private Map<String, ChatApp> chatAppConfig = Collections.emptyMap();
private VisualConfig visualConfig;

public List<String> getTools(AgentToolType type) {
Map map = JSONObject.parseObject(toolConfig, Map.class);
Map<String, Object> map = JSONObject.parseObject(toolConfig, Map.class);
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
return Lists.newArrayList();
return Collections.emptyList();
}
List<Map> toolList = (List) map.get("tools");
return toolList.stream().filter(tool -> {
if (Objects.isNull(type)) {
return true;
}
return type.name().equals(tool.get("type"));
}).map(JSONObject::toJSONString).collect(Collectors.toList());
List<Map<String, Object>> toolList = (List<Map<String, Object>>) map.get("tools");
return toolList.stream()
.filter(tool -> type == null || type.name().equals(tool.get("type")))
.map(JSONObject::toJSONString).collect(Collectors.toList());
}

public boolean enableSearch() {
return enableSearch == 1;
return enableSearch == ENABLED;
}

public boolean enableFeedback() {
return enableFeedback == 1;
return enableFeedback == ENABLED;
}

public boolean enableMemoryReview() {
return chatAppConfig.get(MemoryReviewTask.APP_KEY).isEnable();
ChatApp memoryReviewApp = chatAppConfig.get(MemoryReviewTask.APP_KEY);
return memoryReviewApp != null && memoryReviewApp.isEnable();
}

public static boolean containsAllModel(Set<Long> detectViewIds) {
Expand All @@ -60,7 +65,7 @@ public static boolean containsAllModel(Set<Long> detectViewIds) {
public List<DatasetTool> getParserTools(AgentToolType agentToolType) {
List<String> tools = this.getTools(agentToolType);
if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList();
return Collections.emptyList();
}
return tools.stream().map(tool -> JSONObject.parseObject(tool, DatasetTool.class))
.collect(Collectors.toList());
Expand All @@ -75,33 +80,29 @@ public boolean containsDatasetTool() {
}

public boolean containsAnyTool() {
Map map = JSONObject.parseObject(toolConfig, Map.class);
Map<String, Object> map = JSONObject.parseObject(toolConfig, Map.class);
if (CollectionUtils.isEmpty(map)) {
return false;
}
List<Map> toolList = (List) map.get("tools");
if (CollectionUtils.isEmpty(toolList)) {
return false;
}

return true;
List<Map<String, Object>> toolList = (List<Map<String, Object>>) map.get("tools");
return !CollectionUtils.isEmpty(toolList);
}

public Set<Long> getDataSetIds() {
Set<Long> dataSetIds = getDataSetIds(null);
if (containsAllModel(dataSetIds)) {
return Sets.newHashSet();
return Collections.emptySet();
}
return dataSetIds;
}

public Set<Long> getDataSetIds(AgentToolType agentToolType) {
List<DatasetTool> commonAgentTools = getParserTools(agentToolType);
if (CollectionUtils.isEmpty(commonAgentTools)) {
return new HashSet<>();
return Collections.emptySet();
}
return commonAgentTools.stream().map(DatasetTool::getDataSetIds)
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds)).flatMap(Collection::stream)
.collect(Collectors.toSet());
.filter(dataSetIds -> !CollectionUtils.isEmpty(dataSetIds))
.flatMap(Collection::stream).collect(Collectors.toSet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
package com.tencent.supersonic.chat.server.util;

import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import org.springframework.util.CollectionUtils;

import java.util.Collections;
import java.util.Objects;
import java.util.Set;

public class QueryReqConverter {

Expand All @@ -12,10 +18,23 @@ public static QueryNLReq buildQueryNLReq(ParseContext parseContext) {
BeanMapper.mapper(parseContext.getRequest(), queryNLReq);
queryNLReq.setText2SQLType(
parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE);
queryNLReq.setDataSetIds(parseContext.getAgent().getDataSetIds());
queryNLReq.setDataSetIds(getDataSetIds(parseContext));
queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
queryNLReq.setSelectedParseInfo(parseContext.getRequest().getSelectedParse());

return queryNLReq;
}

private static Set<Long> getDataSetIds(ParseContext parseContext) {
ChatParseReq chatParseReq = parseContext.getRequest();
Set<Long> dataSetIds = parseContext.getAgent().getDataSetIds();
Long requestDataSetId = chatParseReq.getDataSetId();

if (Objects.nonNull(requestDataSetId)) {
if (CollectionUtils.isEmpty(dataSetIds)) {
return Collections.singleton(requestDataSetId);
}
dataSetIds.removeIf(dataSetId -> !dataSetId.equals(requestDataSetId));
}
return dataSetIds;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ void testReplaceValue() {
replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false);

Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' " +
"AND 歌手名 = '林俊杰' AND 歌手名 = '陈' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11",
"SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' "
+ "AND 歌手名 = '林俊杰' AND 歌手名 = '陈' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11",
replaceSql);

replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 "
Expand All @@ -93,8 +93,8 @@ void testReplaceValue() {
replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false);

Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND " +
"歌手名 = '林俊杰' AND 歌手名 = '陈' AND 歌曲发布时 = '2023-08-01') AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11",
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND "
+ "歌手名 = '林俊杰' AND 歌手名 = '陈' AND 歌曲发布时 = '2023-08-01') AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11",
replaceSql);

replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 "
Expand All @@ -105,9 +105,9 @@ void testReplaceValue() {
replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false);

Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND 歌手名 = '林俊杰' AND " +
"歌手名 = '陈' AND 歌曲发布时 = '2023-08-01' AND 播放量 < (SELECT min(播放量) FROM 歌曲库 WHERE 语种 = '英文')) " +
"AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11",
"SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' AND 歌手名 = '林俊杰' AND "
+ "歌手名 = '陈' AND 歌曲发布时 = '2023-08-01' AND 播放量 < (SELECT min(播放量) FROM 歌曲库 WHERE 语种 = '英文')) "
+ "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11",
replaceSql);

Map<String, Map<String, String>> filedNameToValueMap3 = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStrategy<T> {
@Override
public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
Set<Long> detectDataSetIds) {
String text = chatQueryContext.getRequest().getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
Expand All @@ -36,7 +36,7 @@ public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2T
}

public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
Set<Long> detectDataSetIds) {
throw new RuntimeException("Not implemented");
}

Expand Down
Loading