diff --git a/backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/DocumentService.java b/backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/DocumentService.java index 1a86397..8adb322 100644 --- a/backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/DocumentService.java +++ b/backend/src/main/java/ch/xxx/aidoclibchat/usecase/service/DocumentService.java @@ -14,7 +14,9 @@ import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; +import java.util.StringTokenizer; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -72,7 +74,7 @@ public Long storeDocument(Document document) { public AiResult queryDocuments(String query) { var similarDocuments = this.documentVsRepository.retrieve(query); - Message systemMessage = this.getSystemMessage(similarDocuments); + Message systemMessage = this.getSystemMessage(similarDocuments, 2500); UserMessage userMessage = new UserMessage(query); Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); AiResponse response = aiClient.generate(prompt); @@ -82,14 +84,24 @@ public AiResult queryDocuments(String query) { return new AiResult(query, response.getGenerations(), documents); } - private Message getSystemMessage(List similarDocuments) { - String documents = similarDocuments.stream().map(entry -> entry.getContent()).collect(Collectors.joining("\n")); + private Message getSystemMessage(List similarDocuments, int tokenLimit) { + String documents = similarDocuments.stream().map(entry -> entry.getContent()) + .filter(myStr -> myStr != null && !myStr.isBlank()) + .map(myStr -> this.cutStringToTokenLimit(myStr, tokenLimit)).collect(Collectors.joining("\n")); SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemPrompt); Message systemMessage = systemPromptTemplate.createMessage(Map.of("documents", documents)); return systemMessage; } + private String cutStringToTokenLimit(String documentStr, int tokenLimit) { + String cutString = documentStr; + while (tokenLimit < new StringTokenizer(cutString, " -.;,").countTokens()) { + cutString = cutString.length() > 1000 ? cutString.substring(0, cutString.length() - 1000) : ""; + } + return cutString; + } + public List getDocumentList() { return this.documentRepository.findAll(); }