Skip to content

Commit

Permalink
(improvement)(chat) 优化提示工程、重试机制
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyudong committed Sep 12, 2024
1 parent 2fa3bfe commit 40de8cf
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ private void tryParse(ChatQueryContext queryCtx, Long dataSetId) {
} catch (Exception e) {
log.error("currentRetryRound:{}, runText2SQL failed", currentRetry, e);
}
Double temperature = llmReq.getModelConfig().getTemperature();
if(temperature == 0){
//报错时增加随机性,减少无效重试
llmReq.getModelConfig().setTemperature(0.5);
}
currentRetry++;
}
if (MapUtils.isEmpty(sqlRespMap)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.service.AiServices;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
Expand All @@ -24,6 +25,18 @@
@Slf4j
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {

@Data
static class SemanticSql {
@Description("thought or remarks to tell users about the sql, make it short.")
private String thought;
@Description("sql to generate")
private String sql;
}

interface SemanticSqlExtractor {
SemanticSql generateSemanticSql(String text);
}

private static final String INSTRUCTION =
""
+ "\n#Role: You are a data analyst experienced in SQL languages."
Expand All @@ -36,9 +49,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
+ "3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
+ "4.DO NOT calculate date range using functions."
+ "5.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
+ "6.ONLY respond with the converted SQL statement."
+ "\n#Exemplars:\n{{exemplar}}"
+ "Question:{{question}},Schema:{{schema}},SideInfo:{{information}},SQL:";
+ "\n#Question:"
+ "Question:{{question}},Schema:{{schema}},SideInfo:{{information}}";

@Override
public LLMResp generate(LLMReq llmReq) {
Expand All @@ -55,6 +68,8 @@ public LLMResp generate(LLMReq llmReq) {
Prompt prompt = generatePrompt(llmReq, llmResp);
prompt2Exemplar.put(prompt, exemplars);
}
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig());
SemanticSqlExtractor extractor = AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);

// 3.perform multiple self-consistency inferences parallelly
Map<String, Prompt> output2Prompt = new ConcurrentHashMap<>();
Expand All @@ -66,18 +81,10 @@ public LLMResp generate(LLMReq llmReq) {
keyPipelineLog.info(
"OnePassSCSqlGenStrategy reqPrompt:\n{}",
prompt.toUserMessage());
ChatLanguageModel chatLanguageModel =
getChatLanguageModel(llmReq.getModelConfig());
Response<AiMessage> response =
chatLanguageModel.generate(prompt.toUserMessage());
String sqlOutput =
StringUtils.normalizeSpace(response.content().text());
// replace ```
String sqlOutputFormat =
sqlOutput.replaceAll("(?s)```sql\\s*(.*?)\\s*```", "$1").trim();
output2Prompt.put(sqlOutputFormat, prompt);
SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText());
output2Prompt.put(s2Sql.getSql(), prompt);
keyPipelineLog.info(
"OnePassSCSqlGenStrategy modelResp:\n{}", sqlOutputFormat);
"OnePassSCSqlGenStrategy modelResp:\n{}", s2Sql.getSql());
});

// 4.format response.
Expand Down

0 comments on commit 40de8cf

Please sign in to comment.