diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java index 55eb051af..2fa86011e 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java @@ -74,6 +74,11 @@ private void tryParse(ChatQueryContext queryCtx, Long dataSetId) { } catch (Exception e) { log.error("currentRetryRound:{}, runText2SQL failed", currentRetry, e); } + Double temperature = llmReq.getModelConfig().getTemperature(); + if(temperature == 0){ + //报错时增加随机性,减少无效重试 + llmReq.getModelConfig().setTemperature(0.5); + } currentRetry++; } if (MapUtils.isEmpty(sqlRespMap)) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index cdc4bf395..220cb9a6c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -5,11 +5,12 @@ import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; -import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; -import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.structured.Description; +import dev.langchain4j.service.AiServices; +import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; @@ -24,6 +25,18 @@ @Slf4j public class OnePassSCSqlGenStrategy extends SqlGenStrategy { + @Data + static class SemanticSql { + @Description("thought or remarks to tell users about the sql, make it short.") + private String thought; + @Description("sql to generate") + private String sql; + } + + interface SemanticSqlExtractor { + SemanticSql generateSemanticSql(String text); + } + private static final String INSTRUCTION = "" + "\n#Role: You are a data analyst experienced in SQL languages." @@ -36,9 +49,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { + "3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`." + "4.DO NOT calculate date range using functions." + "5.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." - + "6.ONLY respond with the converted SQL statement." + "\n#Exemplars:\n{{exemplar}}" - + "Question:{{question}},Schema:{{schema}},SideInfo:{{information}},SQL:"; + + "\n#Question:" + + "Question:{{question}},Schema:{{schema}},SideInfo:{{information}}"; @Override public LLMResp generate(LLMReq llmReq) { @@ -55,6 +68,8 @@ public LLMResp generate(LLMReq llmReq) { Prompt prompt = generatePrompt(llmReq, llmResp); prompt2Exemplar.put(prompt, exemplars); } + ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig()); + SemanticSqlExtractor extractor = AiServices.create(SemanticSqlExtractor.class, chatLanguageModel); // 3.perform multiple self-consistency inferences parallelly Map output2Prompt = new ConcurrentHashMap<>(); @@ -66,18 +81,10 @@ public LLMResp generate(LLMReq llmReq) { keyPipelineLog.info( "OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage()); - ChatLanguageModel chatLanguageModel = - getChatLanguageModel(llmReq.getModelConfig()); - Response response = - chatLanguageModel.generate(prompt.toUserMessage()); - String sqlOutput = - StringUtils.normalizeSpace(response.content().text()); - // replace ``` - String sqlOutputFormat = - sqlOutput.replaceAll("(?s)```sql\\s*(.*?)\\s*```", "$1").trim(); - output2Prompt.put(sqlOutputFormat, prompt); + SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText()); + output2Prompt.put(s2Sql.getSql(), prompt); keyPipelineLog.info( - "OnePassSCSqlGenStrategy modelResp:\n{}", sqlOutputFormat); + "OnePassSCSqlGenStrategy modelResp:\n{}", s2Sql.getSql()); }); // 4.format response.