-
Notifications
You must be signed in to change notification settings - Fork 407
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
(improvement)(chat)Implement a new version of multi-turn conversation.
- Loading branch information
1 parent
710f120
commit a591c5e
Showing
10 changed files
with
150 additions
and
249 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
141 changes: 141 additions & 0 deletions
141
chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
package com.tencent.supersonic.chat.server.parser; | ||
|
||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; | ||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext; | ||
import com.tencent.supersonic.chat.server.util.QueryReqConverter; | ||
import com.tencent.supersonic.common.util.ContextUtils; | ||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; | ||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; | ||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; | ||
import com.tencent.supersonic.headless.api.pojo.response.MapResp; | ||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; | ||
import com.tencent.supersonic.headless.server.service.ChatQueryService; | ||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.model.chat.ChatLanguageModel; | ||
import dev.langchain4j.model.input.Prompt; | ||
import dev.langchain4j.model.input.PromptTemplate; | ||
import dev.langchain4j.model.output.Response; | ||
import lombok.Builder; | ||
import lombok.Data; | ||
import lombok.extern.slf4j.Slf4j; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
import org.springframework.core.env.Environment; | ||
|
||
import java.util.*; | ||
import java.util.stream.Collectors; | ||
|
||
@Slf4j | ||
public class MultiTurnParser implements ChatParser { | ||
|
||
private static final Logger keyPipelineLog = LoggerFactory.getLogger(MultiTurnParser.class); | ||
|
||
private static final PromptTemplate promptTemplate = PromptTemplate.from( | ||
"You are a data product manager experienced in data requirements." + | ||
"Your will be provided with current and history questions asked by a user," + | ||
"along with their mapped schema elements(metric, dimension and value), " + | ||
"please try understanding the semantics and rewrite a question(keep relevant metrics, dimensions, values and date ranges)." + | ||
"Current Question: {{curtQuestion}} " + | ||
"Current Mapped Schema: {{curtSchema}} " + | ||
"History Question: {{histQuestion}} " + | ||
"History Mapped Schema: {{histSchema}} " + | ||
"Rewritten Question: "); | ||
|
||
@Override | ||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { | ||
Environment environment = ContextUtils.getBean(Environment.class); | ||
Boolean multiTurn = environment.getProperty("multi.turn", Boolean.class); | ||
if (Boolean.FALSE.equals(multiTurn)) { | ||
return; | ||
} | ||
|
||
// derive mapping result of current question and parsing result of last question. | ||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); | ||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); | ||
MapResp currentMapResult = chatQueryService.performMapping(queryReq); | ||
|
||
List<ParseResp> historyParseResults = getHistoryParseResult(chatParseContext.getChatId(), 1); | ||
if (historyParseResults.size() == 0) { | ||
return; | ||
} | ||
ParseResp lastParseResult = historyParseResults.get(0); | ||
Long dataId = lastParseResult.getSelectedParses().get(0).getDataSetId(); | ||
|
||
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId)); | ||
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches()); | ||
String rewrittenQuery = rewriteQuery(RewriteContext.builder() | ||
.curtQuestion(currentMapResult.getQueryText()) | ||
.histQuestion(lastParseResult.getQueryText()) | ||
.curtSchema(curtMapStr) | ||
.histSchema(histMapStr) | ||
.build()); | ||
chatParseContext.setQueryText(rewrittenQuery); | ||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", | ||
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery); | ||
} | ||
|
||
private String rewriteQuery(RewriteContext context) { | ||
|
||
Map<String, Object> variables = new HashMap<>(); | ||
variables.put("curtQuestion", context.getCurtQuestion()); | ||
variables.put("histQuestion", context.getHistQuestion()); | ||
variables.put("curtSchema", context.getCurtSchema()); | ||
variables.put("histSchema", context.getHistSchema()); | ||
|
||
Prompt prompt = promptTemplate.apply(variables); | ||
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage()); | ||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class); | ||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage()); | ||
String result = response.content().text(); | ||
keyPipelineLog.info("model response:{}", result); | ||
//3.format response. | ||
String rewriteQuery = response.content().text(); | ||
|
||
return rewriteQuery; | ||
} | ||
|
||
private String generateSchemaPrompt(List<SchemaElementMatch> elementMatches) { | ||
List<String> metrics = new ArrayList<>(); | ||
List<String> dimensions = new ArrayList<>(); | ||
List<String> values = new ArrayList<>(); | ||
|
||
for(SchemaElementMatch match : elementMatches) { | ||
if (match.getElement().getType().equals(SchemaElementType.METRIC)) { | ||
metrics.add(match.getWord()); | ||
} else if (match.getElement().getType().equals(SchemaElementType.DIMENSION)) { | ||
dimensions.add(match.getWord()); | ||
} else if (match.getElement().getType().equals(SchemaElementType.VALUE)) { | ||
values.add(match.getWord()); | ||
} | ||
} | ||
|
||
StringBuilder prompt = new StringBuilder(); | ||
prompt.append(String.format("'metrics:':[%s]", String.join(",", metrics))); | ||
prompt.append(","); | ||
prompt.append(String.format("'dimensions:':[%s]", String.join(",", dimensions))); | ||
prompt.append(","); | ||
prompt.append(String.format("'values:':[%s]", String.join(",", values))); | ||
|
||
return prompt.toString(); | ||
} | ||
|
||
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) { | ||
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class); | ||
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId) | ||
.stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList()); | ||
|
||
List<ParseResp> contextualList = contextualParseInfoList.subList(0, | ||
Math.min(multiNum, contextualParseInfoList.size())); | ||
Collections.reverse(contextualList); | ||
return contextualList; | ||
} | ||
|
||
@Data | ||
@Builder | ||
public static class RewriteContext { | ||
private String curtQuestion; | ||
private String histQuestion; | ||
private String curtSchema; | ||
private String histSchema; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32 changes: 0 additions & 32 deletions
32
...main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExamplarLoader.java
This file was deleted.
Oops, something went wrong.
14 changes: 0 additions & 14 deletions
14
...re/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExample.java
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.