Skip to content

Commit

Permalink
[improvement][headless-chat]Optimize HeuristicDataSetResolver to pr…
Browse files Browse the repository at this point in the history
…ioritize max similarity of dataset and metric.
  • Loading branch information
jerryjzhang committed Sep 20, 2024
1 parent 5ba401a commit a41a423
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 126 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package com.tencent.supersonic.headless.chat.parser.llm;

import lombok.Builder;
import lombok.Data;

@Data
@Builder
public class DataSetMatchResult {
private Integer count = 0;
private double maxSimilarity;
private double maxMetricSimilarity;
private double maxDatesetSimilarity;
private double totalSimilarity;
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,157 +4,86 @@
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

@Slf4j
public class HeuristicDataSetResolver implements DataSetResolver {

protected static Long selectDataSetBySchemaElementMatchScore(
Map<Long, SemanticQuery> dataSetQueryModes, SchemaMapInfo schemaMap) {
// dataSet count priority
Long dataSetIdByDataSetCount = getDataSetIdByMatchDataSetScore(schemaMap);
if (Objects.nonNull(dataSetIdByDataSetCount)) {
log.info("selectDataSet by dataSet count:{}", dataSetIdByDataSetCount);
return dataSetIdByDataSetCount;
public Long resolve(ChatQueryContext chatQueryContext, Set<Long> agentDataSetIds) {
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
if (CollectionUtils.isNotEmpty(agentDataSetIds)) {
matchedDataSets.retainAll(agentDataSetIds);
}
if (matchedDataSets.size() == 1) {
return matchedDataSets.stream().findFirst().get();
}
return selectDataSetByMatchSimilarity(mapInfo);
}

Map<Long, DataSetMatchResult> dataSetTypeMap = getDataSetTypeMap(schemaMap);
if (dataSetTypeMap.size() == 1) {
Long dataSetSelect = new ArrayList<>(dataSetTypeMap.entrySet()).get(0).getKey();
if (dataSetQueryModes.containsKey(dataSetSelect)) {
log.info("selectDataSet with only one DataSet [{}]", dataSetSelect);
return dataSetSelect;
}
} else {
Entry<Long, DataSetMatchResult> maxDataSet =
dataSetTypeMap.entrySet().stream()
.filter(entry -> dataSetQueryModes.containsKey(entry.getKey()))
.sorted(
(o1, o2) -> {
int difference =
o2.getValue().getCount() - o1.getValue().getCount();
protected Long selectDataSetByMatchSimilarity(SchemaMapInfo schemaMap) {
Map<Long, DataSetMatchResult> dataSetMatchRet = getDataSetMatchResult(schemaMap);
Entry<Long, DataSetMatchResult> selectedDataset =
dataSetMatchRet.entrySet().stream()
.sorted(
(o1, o2) -> {
double difference =
o1.getValue().getMaxDatesetSimilarity()
- o2.getValue().getMaxDatesetSimilarity();
if (difference == 0) {
difference =
o1.getValue().getMaxMetricSimilarity()
- o2.getValue().getMaxMetricSimilarity();
if (difference == 0) {
return (int)
((o2.getValue().getMaxSimilarity()
- o1.getValue()
.getMaxSimilarity())
* 100);
difference =
o1.getValue().getTotalSimilarity()
- o2.getValue().getTotalSimilarity();
}
return difference;
})
.findFirst()
.orElse(null);
if (maxDataSet != null) {
log.info("selectDataSet with multiple DataSets [{}]", maxDataSet.getKey());
return maxDataSet.getKey();
}
}
return difference >= 0 ? -1 : 1;
})
.findFirst()
.orElse(null);
if (selectedDataset != null) {
log.info("selectDataSet with multiple DataSets [{}]", selectedDataset.getKey());
return selectedDataset.getKey();
}
return null;
}

private static Long getDataSetIdByMatchDataSetScore(SchemaMapInfo schemaMap) {
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
schemaMap.getDataSetElementMatches();
// calculate dataSet match score, matched element gets 1.0 point, and inherit element gets
// 0.5 point
Map<Long, Double> dataSetIdToDataSetScore = new HashMap<>();
if (Objects.nonNull(dataSetElementMatches)) {
for (Entry<Long, List<SchemaElementMatch>> dataSetElementMatch :
dataSetElementMatches.entrySet()) {
Long dataSetId = dataSetElementMatch.getKey();
List<Double> dataSetMatchesScore =
dataSetElementMatch.getValue().stream()
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
.filter(
elementMatch ->
SchemaElementType.DATASET.equals(
elementMatch.getElement().getType()))
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0)
.collect(Collectors.toList());

if (!CollectionUtils.isEmpty(dataSetMatchesScore)) {
// get sum of dataSet match score
double score =
dataSetMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
dataSetIdToDataSetScore.put(dataSetId, score);
}
}
Entry<Long, Double> maxDataSetScore =
dataSetIdToDataSetScore.entrySet().stream()
.max(Comparator.comparingDouble(Entry::getValue))
.orElse(null);
log.info(
"maxDataSetCount:{},dataSetIdToDataSetCount:{}",
maxDataSetScore,
dataSetIdToDataSetScore);
if (Objects.nonNull(maxDataSetScore)) {
return maxDataSetScore.getKey();
}
}
return null;
}

public static Map<Long, DataSetMatchResult> getDataSetTypeMap(SchemaMapInfo schemaMap) {
Map<Long, DataSetMatchResult> dataSetCount = new HashMap<>();
protected Map<Long, DataSetMatchResult> getDataSetMatchResult(SchemaMapInfo schemaMap) {
Map<Long, DataSetMatchResult> dateSetMatchRet = new HashMap<>();
for (Entry<Long, List<SchemaElementMatch>> entry :
schemaMap.getDataSetElementMatches().entrySet()) {
List<SchemaElementMatch> schemaElementMatches =
schemaMap.getMatchedElements(entry.getKey());
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
if (!dataSetCount.containsKey(entry.getKey())) {
dataSetCount.put(entry.getKey(), new DataSetMatchResult());
double maxMetricSimilarity = 0;
double maxDatasetSimilarity = 0;
double totalSimilarity = 0;
for (SchemaElementMatch match : entry.getValue()) {
if (SchemaElementType.DATASET.equals(match.getElement().getType())) {
maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity());
}
DataSetMatchResult dataSetMatchResult = dataSetCount.get(entry.getKey());
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
schemaElementMatches.stream()
.forEach(
schemaElementMatch ->
schemaElementTypes.add(
schemaElementMatch.getElement().getType()));
SchemaElementMatch schemaElementMatchMax =
schemaElementMatches.stream()
.sorted(
(o1, o2) ->
((int)
((o2.getSimilarity() - o1.getSimilarity())
* 100)))
.findFirst()
.orElse(null);
if (schemaElementMatchMax != null) {
dataSetMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
if (SchemaElementType.METRIC.equals(match.getElement().getType())) {
maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity());
}
dataSetMatchResult.setCount(schemaElementTypes.size());
totalSimilarity += match.getSimilarity();
}
dateSetMatchRet.put(
entry.getKey(),
DataSetMatchResult.builder()
.maxMetricSimilarity(maxMetricSimilarity)
.maxDatesetSimilarity(maxDatasetSimilarity)
.totalSimilarity(totalSimilarity)
.build());
}
return dataSetCount;
}

public Long resolve(ChatQueryContext chatQueryContext, Set<Long> agentDataSetIds) {
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
if (CollectionUtils.isNotEmpty(agentDataSetIds)) {
matchedDataSets.retainAll(agentDataSetIds);
}
Map<Long, SemanticQuery> dataSetQueryModes = new HashMap<>();
for (Long dataSetIds : matchedDataSets) {
dataSetQueryModes.put(dataSetIds, null);
}
if (dataSetQueryModes.size() == 1) {
return dataSetQueryModes.keySet().stream().findFirst().get();
}
return selectDataSetBySchemaElementMatchScore(dataSetQueryModes, mapInfo);
return dateSetMatchRet;
}
}
Loading

0 comments on commit a41a423

Please sign in to comment.