Skip to content

Commit

Permalink
[improvement][headless]Add databaseType into the Schema part of the T…
Browse files Browse the repository at this point in the history
…ext2SQL prompt. #1621
  • Loading branch information
jerryjzhang committed Sep 12, 2024
1 parent 47cc933 commit 37f1239
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
@Data
public class DataSetSchema {

private String databaseType;
private SchemaElement dataSet;
private Set<SchemaElement> metrics = new HashSet<>();
private Set<SchemaElement> dimensions = new HashSet<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
@NoArgsConstructor
public class DataSetSchemaResp extends DataSetResp {

private String databaseType;
private List<MetricSchemaResp> metrics = Lists.newArrayList();
private List<DimSchemaResp> dimensions = Lists.newArrayList();
private List<ModelResp> modelResps = Lists.newArrayList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public LLMReq getLlmReq(ChatQueryContext queryCtx, Long dataSetId) {
llmReq.setQueryText(queryText);
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
llmReq.setSchema(llmSchema);
llmSchema.setDatabaseType(getDatabaseType(queryCtx, dataSetId));
llmSchema.setDataSetId(dataSetId);
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
llmSchema.setMetrics(getMappedMetrics(queryCtx, dataSetId));
Expand Down Expand Up @@ -205,4 +206,14 @@ protected SchemaElement getPrimaryKey(@NotNull ChatQueryContext queryCtx, Long d
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
return dataSetSchema.getPrimaryKey();
}

protected String getDatabaseType(@NotNull ChatQueryContext queryCtx, Long dataSetId) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
return null;
}
Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
return dataSetSchema.getDatabaseType();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,17 @@ public String buildSchemaStr(LLMReq llmReq) {
primaryKeyStr = String.format("%s", llmReq.getSchema().getPrimaryKey().getName());
}

String databaseTypeStr = "";
if (llmReq.getSchema().getDatabaseType() != null) {
databaseTypeStr = llmReq.getSchema().getDatabaseType();
}

String template =
"Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], "
"DatabaseType=[%s], Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], "
+ "Metrics=[%s], Dimensions=[%s], Values=[%s]";
return String.format(
template,
databaseTypeStr,
tableStr,
partitionTimeStr,
primaryKeyStr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public static class ElementValue {

@Data
public static class LLMSchema {
private String databaseType;
private Long dataSetId;
private String dataSetName;
private List<SchemaElement> metrics;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ public List<DataSetSchemaResp> buildDataSetSchema(DataSetFilterReq filter) {
.collect(Collectors.toList()));
dataSetSchemaResp.setTermResps(
termMaps.getOrDefault(dataSetResp.getDomainId(), Lists.newArrayList()));
if (!CollectionUtils.isEmpty(dataSetSchemaResp.getModelResps())) {
DatabaseResp databaseResp =
databaseService.getDatabase(
dataSetSchemaResp.getModelResps().get(0).getDatabaseId());
dataSetSchemaResp.setDatabaseType(databaseResp.getType());
}
dataSetSchemaResps.add(dataSetSchemaResp);
}
fillStaticInfo(dataSetSchemaResps);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public static DataSetSchema build(DataSetSchemaResp resp) {
.type(SchemaElementType.DATASET)
.build();
dataSetSchema.setDataSet(dataSet);
dataSetSchema.setDatabaseType(resp.getDatabaseType());

Set<SchemaElement> metrics = getMetrics(resp);
dataSetSchema.getMetrics().addAll(metrics);
Expand Down
16 changes: 8 additions & 8 deletions launchers/standalone/src/main/resources/s2-exemplar.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,49 @@
{
"question": "比较jackjchen和robinlee今年以来的访问次数",
"sideInfo": "CurrentDate=[2020-12-01],DomainTerms=[<核心用户 COMMENT '核心用户指tom和lucy'>]",
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>,<访问用户数 ALIAS 'UV,访问人数,' COMMENT '访问的用户个数' AGGREGATE 'COUNT'>,<人均访问次数 ALIAS '平均访问次数,' COMMENT '每个用户平均访问的次数'>], Dimensions=[<数据日期>], Values[<用户='jackjchen'>,<用户='robinlee'>]",
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>,<访问用户数 ALIAS 'UV,访问人数,' COMMENT '访问的用户个数' AGGREGATE 'COUNT'>,<人均访问次数 ALIAS '平均访问次数,' COMMENT '每个用户平均访问的次数'>], Dimensions=[<数据日期>], Values[<用户='jackjchen'>,<用户='robinlee'>]",
"sql": "SELECT 用户, 访问次数 FROM 超音数产品 WHERE 用户 IN ('jackjchen', 'robinlee') AND 数据日期 >= '2020-01-01' AND 数据日期 <= '2020-12-01'"
},
{
"question": "超音数近12个月访问人数 按部门",
"sideInfo": "CurrentDate=[2022-11-06]",
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>,<访问用户数 ALIAS 'UV,访问人数,' COMMENT '访问的用户个数' AGGREGATE 'COUNT'>,<人均访问次数 ALIAS '平均访问次数,' COMMENT '每个用户平均访问的次数'>], Dimensions=[<部门>,<数据日期>], Values=[]",
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>,<访问用户数 ALIAS 'UV,访问人数,' COMMENT '访问的用户个数' AGGREGATE 'COUNT'>,<人均访问次数 ALIAS '平均访问次数,' COMMENT '每个用户平均访问的次数'>], Dimensions=[<部门>,<数据日期>], Values=[]",
"sql": "SELECT 部门, 数据日期, 访问人数 FROM 超音数产品 WHERE 数据日期 >= '2021-11-06' AND 数据日期 <= '2022-11-06'"
},
{
"question": "超音数过去90天美术部、技术研发部的访问时长",
"sideInfo": "CurrentDate=[2023-04-21]",
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions=[<数据日期>], Values=[<部门='美术部'>,<部门='技术研发部'>]",
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions=[<数据日期>], Values=[<部门='美术部'>,<部门='技术研发部'>]",
"sql": "SELECT 部门, 访问时长 FROM 超音数产品 WHERE 部门 IN ('美术部', '技术研发部') AND 数据日期 >= '2023-01-20' AND 数据日期 <= '2023-04-21'"
},
{
"question": "超音数访问时长小于1小时,且来自美术部的用户是哪些",
"sideInfo": "CurrentDate=[2023-07-31],DomainTerms=[<核心用户 COMMENT '用户为tom和lucy'>]",
"dbSchema": "Table:[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics:[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions:[<用户>,<数据日期>], Values:[<部门='美术部'>]",
"dbSchema": "DatabaseType=[h2], Table:[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics:[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions:[<用户>,<数据日期>], Values:[<部门='美术部'>]",
"sql": "SELECT 用户 FROM 超音数产品 WHERE 部门 = '美术部' AND 访问时长 < 1"
},
{
"question": "超音数本月pv最高的用户有哪些",
"sideInfo": "CurrentDate=[2023-08-31],DomainTerms=[<核心用户 COMMENT '用户为tom和lucy'>]",
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<用户>,<数据日期>], Values=[]",
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<用户>,<数据日期>], Values=[]",
"sql": "SELECT 用户 FROM 超音数产品 WHERE 数据日期 >= '2023-08-01' AND 数据日期 <= '2023-08-31' ORDER BY 访问次数 DESC LIMIT 1"
},
{
"question": "超音数访问次数大于1k的部门是哪些",
"sideInfo": "CurrentDate=[2023-09-14]",
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<部门>,<数据日期>], Values=[]",
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<部门>,<数据日期>], Values=[]",
"sql": "SELECT 部门 FROM 超音数产品 WHERE 访问次数 > 1000"
},
{
"question": "过去半个月核心用户的访问次数",
"sideInfo": "CurrentDate=[2023-09-15],DomainTerms=[<核心用户 COMMENT '用户为alice'>]",
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<部门>,<数据日期>], Values=[]",
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<部门>,<数据日期>], Values=[]",
"sql": "SELECT 用户,SUM(访问次数) FROM 超音数产品 WHERE 用户='alice' AND 数据日期 >= '2023-09-01' AND 数据日期 <= '2023-09-15'"
},
{
"question": "过去半个月忠实用户有哪一些",
"sideInfo": "CurrentDate=[2023-09-15],DomainTerms=[<忠实用户 COMMENT '一段时间内总访问次数大于100的用户'>]",
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<用户>,<数据日期>], Values=[]",
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<用户>,<数据日期>], Values=[]",
"sql": "SELECT 用户 FROM 超音数产品 WHERE 数据日期 >= '2023-09-01' AND 数据日期 <= '2023-09-15' GROUP BY 用户 HAVING SUM(访问次数) > 100"
}
]
16 changes: 8 additions & 8 deletions launchers/standalone/src/test/resources/s2-exemplar.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,49 @@
{
"question": "比较jackjchen和robinlee今年以来的访问次数",
"sideInfo": "CurrentDate=[2020-12-01],DomainTerms=[<核心用户 COMMENT '核心用户指tom和lucy'>]",
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>,<访问用户数 ALIAS 'UV,访问人数,' COMMENT '访问的用户个数' AGGREGATE 'COUNT'>,<人均访问次数 ALIAS '平均访问次数,' COMMENT '每个用户平均访问的次数'>], Dimensions=[<数据日期>], Values[<用户='jackjchen'>,<用户='robinlee'>]",
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>,<访问用户数 ALIAS 'UV,访问人数,' COMMENT '访问的用户个数' AGGREGATE 'COUNT'>,<人均访问次数 ALIAS '平均访问次数,' COMMENT '每个用户平均访问的次数'>], Dimensions=[<数据日期>], Values[<用户='jackjchen'>,<用户='robinlee'>]",
"sql": "SELECT 用户, 访问次数 FROM 超音数产品 WHERE 用户 IN ('jackjchen', 'robinlee') AND 数据日期 >= '2020-01-01' AND 数据日期 <= '2020-12-01'"
},
{
"question": "超音数近12个月访问人数 按部门",
"sideInfo": "CurrentDate=[2022-11-06]",
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>,<访问用户数 ALIAS 'UV,访问人数,' COMMENT '访问的用户个数' AGGREGATE 'COUNT'>,<人均访问次数 ALIAS '平均访问次数,' COMMENT '每个用户平均访问的次数'>], Dimensions=[<部门>,<数据日期>], Values=[]",
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>,<访问用户数 ALIAS 'UV,访问人数,' COMMENT '访问的用户个数' AGGREGATE 'COUNT'>,<人均访问次数 ALIAS '平均访问次数,' COMMENT '每个用户平均访问的次数'>], Dimensions=[<部门>,<数据日期>], Values=[]",
"sql": "SELECT 部门, 数据日期, 访问人数 FROM 超音数产品 WHERE 数据日期 >= '2021-11-06' AND 数据日期 <= '2022-11-06'"
},
{
"question": "超音数过去90天美术部、技术研发部的访问时长",
"sideInfo": "CurrentDate=[2023-04-21]",
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions=[<数据日期>], Values=[<部门='美术部'>,<部门='技术研发部'>]",
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions=[<数据日期>], Values=[<部门='美术部'>,<部门='技术研发部'>]",
"sql": "SELECT 部门, 访问时长 FROM 超音数产品 WHERE 部门 IN ('美术部', '技术研发部') AND 数据日期 >= '2023-01-20' AND 数据日期 <= '2023-04-21'"
},
{
"question": "超音数访问时长小于1小时,且来自美术部的用户是哪些",
"sideInfo": "CurrentDate=[2023-07-31],DomainTerms=[<核心用户 COMMENT '用户为tom和lucy'>]",
"dbSchema": "Table:[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics:[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions:[<用户>,<数据日期>], Values:[<部门='美术部'>]",
"dbSchema": "DatabaseType=[h2], Table:[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics:[<访问时长 COMMENT '一段时间内用户的访问时长' AGGREGATE 'SUM'>], Dimensions:[<用户>,<数据日期>], Values:[<部门='美术部'>]",
"sql": "SELECT 用户 FROM 超音数产品 WHERE 部门 = '美术部' AND 访问时长 < 1"
},
{
"question": "超音数本月pv最高的用户有哪些",
"sideInfo": "CurrentDate=[2023-08-31],DomainTerms=[<核心用户 COMMENT '用户为tom和lucy'>]",
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<用户>,<数据日期>], Values=[]",
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<用户>,<数据日期>], Values=[]",
"sql": "SELECT 用户 FROM 超音数产品 WHERE 数据日期 >= '2023-08-01' AND 数据日期 <= '2023-08-31' ORDER BY 访问次数 DESC LIMIT 1"
},
{
"question": "超音数访问次数大于1k的部门是哪些",
"sideInfo": "CurrentDate=[2023-09-14]",
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<部门>,<数据日期>], Values=[]",
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<部门>,<数据日期>], Values=[]",
"sql": "SELECT 部门 FROM 超音数产品 WHERE 访问次数 > 1000"
},
{
"question": "过去半个月核心用户的访问次数",
"sideInfo": "CurrentDate=[2023-09-15],DomainTerms=[<核心用户 COMMENT '用户为alice'>]",
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<部门>,<数据日期>], Values=[]",
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<部门>,<数据日期>], Values=[]",
"sql": "SELECT 用户,SUM(访问次数) FROM 超音数产品 WHERE 用户='alice' AND 数据日期 >= '2023-09-01' AND 数据日期 <= '2023-09-15'"
},
{
"question": "过去半个月忠实用户有哪一些",
"sideInfo": "CurrentDate=[2023-09-15],DomainTerms=[<忠实用户 COMMENT '一段时间内总访问次数大于100的用户'>]",
"dbSchema": "Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<用户>,<数据日期>], Values=[]",
"dbSchema": "DatabaseType=[h2], Table=[超音数产品], PartitionTimeField=[数据日期 FORMAT 'yyyy-MM-dd'], Metrics=[<访问次数 ALIAS 'pv' COMMENT '一段时间内用户的访问次数' AGGREGATE 'SUM'>], Dimensions=[<用户>,<数据日期>], Values=[]",
"sql": "SELECT 用户 FROM 超音数产品 WHERE 数据日期 >= '2023-09-01' AND 数据日期 <= '2023-09-15' GROUP BY 用户 HAVING SUM(访问次数) > 100"
}
]

0 comments on commit 37f1239

Please sign in to comment.