Skip to content

Commit

Permalink
(improvement)(Headless) corrector supports subselect sql (#1006)
Browse files Browse the repository at this point in the history
* [improvement] corrector support subselect sql

---------

Co-authored-by: zuopengge
  • Loading branch information
mainmainer authored May 17, 2024
1 parent 2411cb3 commit 7949efe
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,12 @@ private static void replaceFieldsInPlainOneSelect(Map<String, String> fieldNameM
replaceAsName(fieldNameMap, selectItem);
}

if (plainSelect.getFromItem() instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) plainSelect.getFromItem();
PlainSelect subPlainSelect = parenthesedSelect.getPlainSelect();
replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace, subPlainSelect);
}

//3. replace oder by fields
List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
if (!CollectionUtils.isEmpty(orderByElements)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,28 @@ public static List<PlainSelect> getPlainSelect(String sql) {
List<PlainSelect> plainSelectList = new ArrayList<>();
if (selectStatement instanceof PlainSelect) {
PlainSelect plainSelect = (PlainSelect) selectStatement;
plainSelectList.add(plainSelect);
getSubPlainSelect(plainSelect, plainSelectList);
} else if (selectStatement instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectStatement;
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
setOperationList.getSelects().forEach(subSelectBody -> {
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
plainSelectList.add(subPlainSelect);
getSubPlainSelect(subPlainSelect, plainSelectList);
});
}
}
return plainSelectList;
}

public static void getSubPlainSelect(PlainSelect plainSelect, List<PlainSelect> plainSelectList) {
plainSelectList.add(plainSelect);
if (plainSelect.getFromItem() instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) plainSelect.getFromItem();
PlainSelect subPlainSelect = parenthesedSelect.getPlainSelect();
getSubPlainSelect(subPlainSelect, plainSelectList);
}
}

public static Select getSelect(String sql) {
Statement statement = null;
try {
Expand Down Expand Up @@ -275,6 +284,14 @@ public static List<FieldExpression> getWhereExpressions(String sql) {
if (Objects.nonNull(where)) {
where.accept(new FieldAndValueAcquireVisitor(result));
}
if (plainSelect.getFromItem() instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) plainSelect.getFromItem();
PlainSelect subPlainSelect = parenthesedSelect.getPlainSelect();
Expression subWhere = subPlainSelect.getWhere();
if (Objects.nonNull(subWhere)) {
subWhere.accept(new FieldAndValueAcquireVisitor(result));
}
}
}
return new ArrayList<>(result);
}
Expand Down
2 changes: 1 addition & 1 deletion evaluation/build_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def build_dataSet(domain_id,model_id1,model_id2,model_id3,model_id4):
{"id":model_id3,"includesAll":False,"metrics":metric_list3,"dimensions":dimension_list3},
{"id":model_id4,"includesAll":False,"metrics":metric_list4,"dimensions":dimension_list4}
]},
"queryConfig":{"tagTypeDefaultConfig":{},"metricTypeDefaultConfig":{"timeDefaultConfig":{"unit":1,"period":"DAY","timeMode":"RECENT"}}},"admins":["admin"],"admin":"admin"}
"queryConfig":{"tagTypeDefaultConfig":{},"metricTypeDefaultConfig":{"timeDefaultConfig":{"unit":0,"period":"DAY","timeMode":"RECENT"}}},"admins":["admin"],"admin":"admin"}
url=get_url_pre()+"/api/semantic/dataSet"
authorization=get_authorization()
header = {}
Expand Down
3 changes: 2 additions & 1 deletion evaluation/build_pred_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,15 @@ def get_pred_result():
pred_sql_list=[]
default_sql="select * from tablea "
time_cost=[]
time.sleep(60)
for i in range(0,len(questions)):
start_time = time.time()
pred_sql=get_pred_sql(questions[i],url,agent_id,chat_id,authorization,default_sql)
end_time = time.time()
cost='%.3f'%(end_time-start_time)
time_cost.append(cost)
pred_sql_list.append(pred_sql)
time.sleep(60)
time.sleep(3)
write_sql(pred_sql_path, pred_sql_list)

return [float(cost) for cost in time_cost]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ private static Pair<String, String> getDefaultDate(String defaultDate, TimeDefau
DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(period);
String startDate = DateUtils.getBeforeDate(unit, datePeriodEnum);
String endDate = DateUtils.getBeforeDate(1, datePeriodEnum);
if (unit == 0) {
endDate = startDate;
}
if (TimeMode.LAST.equals(timeMode)) {
endDate = startDate;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,4 +288,4 @@ ModerationModel moderationModel(S2LangChain4jProperties properties) {
.build();
}

}
}

0 comments on commit 7949efe

Please sign in to comment.