diff --git a/build.gradle b/build.gradle
index 5a364835b36..b71603a1056 100644
--- a/build.gradle
+++ b/build.gradle
@@ -351,7 +351,7 @@ dependencies {
exclude group: 'org.jetbrains.kotlin'
}
-
+ implementation 'org.apache.velocity:velocity-engine-core:2.3'
implementation platform('ai.djl:bom:0.30.0')
implementation 'ai.djl:api'
implementation 'ai.djl.huggingface:tokenizers'
diff --git a/src/main/java/module-info.java b/src/main/java/module-info.java
index 49906ad2625..f0151b8988e 100644
--- a/src/main/java/module-info.java
+++ b/src/main/java/module-info.java
@@ -160,6 +160,7 @@
uses ai.djl.repository.RepositoryFactory;
uses ai.djl.repository.zoo.ZooProvider;
uses dev.langchain4j.spi.prompt.PromptTemplateFactory;
+ requires velocity.engine.core;
// endregion
// region: Lucene
diff --git a/src/main/java/org/jabref/gui/preferences/ai/AiTab.fxml b/src/main/java/org/jabref/gui/preferences/ai/AiTab.fxml
index b81bd96ba06..64e737747f5 100644
--- a/src/main/java/org/jabref/gui/preferences/ai/AiTab.fxml
+++ b/src/main/java/org/jabref/gui/preferences/ai/AiTab.fxml
@@ -16,6 +16,9 @@
+
+
+
-
-
@@ -235,5 +234,37 @@
glyph="REFRESH"/>
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/main/java/org/jabref/gui/preferences/ai/AiTab.java b/src/main/java/org/jabref/gui/preferences/ai/AiTab.java
index 3360e3a5b4a..1646e57420e 100644
--- a/src/main/java/org/jabref/gui/preferences/ai/AiTab.java
+++ b/src/main/java/org/jabref/gui/preferences/ai/AiTab.java
@@ -7,6 +7,7 @@
import javafx.scene.control.Button;
import javafx.scene.control.CheckBox;
import javafx.scene.control.ComboBox;
+import javafx.scene.control.TextArea;
import javafx.scene.control.TextField;
import org.jabref.gui.actions.ActionFactory;
@@ -15,13 +16,13 @@
import org.jabref.gui.preferences.AbstractPreferenceTabView;
import org.jabref.gui.preferences.PreferencesTab;
import org.jabref.gui.util.ViewModelListCellFactory;
+import org.jabref.logic.ai.templates.AiTemplate;
import org.jabref.logic.help.HelpFile;
import org.jabref.logic.l10n.Localization;
import org.jabref.model.ai.AiProvider;
import org.jabref.model.ai.EmbeddingModel;
import com.airhacks.afterburner.views.ViewLoader;
-import com.dlsc.gemsfx.ResizableTextArea;
import com.dlsc.unitfx.IntegerInputField;
import de.saxsys.mvvmfx.utils.validation.visualization.ControlsFxVisualizer;
import org.controlsfx.control.SearchableComboBox;
@@ -43,7 +44,6 @@ public class AiTab extends AbstractPreferenceTabView implements
@FXML private TextField apiBaseUrlTextField;
@FXML private SearchableComboBox embeddingModelComboBox;
- @FXML private ResizableTextArea instructionTextArea;
@FXML private TextField temperatureTextField;
@FXML private IntegerInputField contextWindowSizeTextField;
@FXML private IntegerInputField documentSplitterChunkSizeTextField;
@@ -51,8 +51,14 @@ public class AiTab extends AbstractPreferenceTabView implements
@FXML private IntegerInputField ragMaxResultsCountTextField;
@FXML private TextField ragMinScoreTextField;
+ @FXML private TextArea systemMessageTextArea;
+ @FXML private TextArea userMessageTextArea;
+ @FXML private TextArea summarizationChunkTextArea;
+ @FXML private TextArea summarizationCombineTextArea;
+
@FXML private Button generalSettingsHelp;
@FXML private Button expertSettingsHelp;
+ @FXML private Button templatesHelp;
private final ControlsFxVisualizer visualizer = new ControlsFxVisualizer();
@@ -74,14 +80,14 @@ public void initialize() {
new ViewModelListCellFactory()
.withText(AiProvider::toString)
.install(aiProviderComboBox);
- aiProviderComboBox.setItems(viewModel.aiProvidersProperty());
+ aiProviderComboBox.itemsProperty().bind(viewModel.aiProvidersProperty());
aiProviderComboBox.valueProperty().bindBidirectional(viewModel.selectedAiProviderProperty());
aiProviderComboBox.disableProperty().bind(viewModel.disableBasicSettingsProperty());
new ViewModelListCellFactory()
.withText(text -> text)
.install(chatModelComboBox);
- chatModelComboBox.setItems(viewModel.chatModelsProperty());
+ chatModelComboBox.itemsProperty().bind(viewModel.chatModelsProperty());
chatModelComboBox.valueProperty().bindBidirectional(viewModel.selectedChatModelProperty());
chatModelComboBox.disableProperty().bind(viewModel.disableBasicSettingsProperty());
@@ -123,9 +129,6 @@ public void initialize() {
apiBaseUrlTextField.setDisable(newValue || viewModel.disableExpertSettingsProperty().get())
);
- instructionTextArea.textProperty().bindBidirectional(viewModel.instructionProperty());
- instructionTextArea.disableProperty().bind(viewModel.disableExpertSettingsProperty());
-
// bindBidirectional doesn't work well with number input fields ({@link IntegerInputField}, {@link DoubleInputField}),
// so they are expanded into `addListener` calls.
@@ -180,7 +183,6 @@ public void initialize() {
visualizer.initVisualization(viewModel.getChatModelValidationStatus(), chatModelComboBox);
visualizer.initVisualization(viewModel.getApiBaseUrlValidationStatus(), apiBaseUrlTextField);
visualizer.initVisualization(viewModel.getEmbeddingModelValidationStatus(), embeddingModelComboBox);
- visualizer.initVisualization(viewModel.getSystemMessageValidationStatus(), instructionTextArea);
visualizer.initVisualization(viewModel.getTemperatureTypeValidationStatus(), temperatureTextField);
visualizer.initVisualization(viewModel.getTemperatureRangeValidationStatus(), temperatureTextField);
visualizer.initVisualization(viewModel.getMessageWindowSizeValidationStatus(), contextWindowSizeTextField);
@@ -191,9 +193,15 @@ public void initialize() {
visualizer.initVisualization(viewModel.getRagMinScoreRangeValidationStatus(), ragMinScoreTextField);
});
+ systemMessageTextArea.textProperty().bindBidirectional(viewModel.getTemplateSources().get(AiTemplate.CHATTING_SYSTEM_MESSAGE));
+ userMessageTextArea.textProperty().bindBidirectional(viewModel.getTemplateSources().get(AiTemplate.CHATTING_USER_MESSAGE));
+ summarizationChunkTextArea.textProperty().bindBidirectional(viewModel.getTemplateSources().get(AiTemplate.SUMMARIZATION_CHUNK));
+ summarizationCombineTextArea.textProperty().bindBidirectional(viewModel.getTemplateSources().get(AiTemplate.SUMMARIZATION_COMBINE));
+
ActionFactory actionFactory = new ActionFactory();
actionFactory.configureIconButton(StandardActions.HELP, new HelpAction(HelpFile.AI_GENERAL_SETTINGS, dialogService, preferences.getExternalApplicationsPreferences()), generalSettingsHelp);
actionFactory.configureIconButton(StandardActions.HELP, new HelpAction(HelpFile.AI_EXPERT_SETTINGS, dialogService, preferences.getExternalApplicationsPreferences()), expertSettingsHelp);
+ actionFactory.configureIconButton(StandardActions.HELP, new HelpAction(HelpFile.AI_TEMPLATES, dialogService, preferences.getExternalApplicationsPreferences()), templatesHelp);
}
@Override
@@ -206,6 +214,11 @@ private void onResetExpertSettingsButtonClick() {
viewModel.resetExpertSettings();
}
+ @FXML
+ private void onResetTemplatesButtonClick() {
+ viewModel.resetTemplates();
+ }
+
public ReadOnlyBooleanProperty aiEnabledProperty() {
return enableAi.selectedProperty();
}
diff --git a/src/main/java/org/jabref/gui/preferences/ai/AiTabViewModel.java b/src/main/java/org/jabref/gui/preferences/ai/AiTabViewModel.java
index b8871f235e4..2cb019cf968 100644
--- a/src/main/java/org/jabref/gui/preferences/ai/AiTabViewModel.java
+++ b/src/main/java/org/jabref/gui/preferences/ai/AiTabViewModel.java
@@ -1,7 +1,9 @@
package org.jabref.gui.preferences.ai;
+import java.util.Arrays;
import java.util.List;
import java.util.Locale;
+import java.util.Map;
import java.util.Objects;
import javafx.beans.property.BooleanProperty;
@@ -20,6 +22,7 @@
import org.jabref.gui.preferences.PreferenceTabViewModel;
import org.jabref.logic.ai.AiDefaultPreferences;
import org.jabref.logic.ai.AiPreferences;
+import org.jabref.logic.ai.templates.AiTemplate;
import org.jabref.logic.l10n.Localization;
import org.jabref.logic.preferences.CliPreferences;
import org.jabref.logic.util.LocalizedNumbers;
@@ -79,7 +82,13 @@ public class AiTabViewModel implements PreferenceTabViewModel {
private final StringProperty huggingFaceApiBaseUrl = new SimpleStringProperty();
private final StringProperty gpt4AllApiBaseUrl = new SimpleStringProperty();
- private final StringProperty instruction = new SimpleStringProperty();
+ private final Map templateSources = Map.of(
+ AiTemplate.CHATTING_SYSTEM_MESSAGE, new SimpleStringProperty(),
+ AiTemplate.CHATTING_USER_MESSAGE, new SimpleStringProperty(),
+ AiTemplate.SUMMARIZATION_CHUNK, new SimpleStringProperty(),
+ AiTemplate.SUMMARIZATION_COMBINE, new SimpleStringProperty()
+ );
+
private final StringProperty temperature = new SimpleStringProperty();
private final IntegerProperty contextWindowSize = new SimpleIntegerProperty();
private final IntegerProperty documentSplitterChunkSize = new SimpleIntegerProperty();
@@ -96,7 +105,6 @@ public class AiTabViewModel implements PreferenceTabViewModel {
private final Validator chatModelValidator;
private final Validator apiBaseUrlValidator;
private final Validator embeddingModelValidator;
- private final Validator instructionValidator;
private final Validator temperatureTypeValidator;
private final Validator temperatureRangeValidator;
private final Validator contextWindowSizeValidator;
@@ -242,11 +250,6 @@ public AiTabViewModel(CliPreferences preferences) {
Objects::nonNull,
ValidationMessage.error(Localization.lang("Embedding model has to be provided")));
- this.instructionValidator = new FunctionBasedValidator<>(
- instruction,
- message -> !StringUtil.isBlank(message),
- ValidationMessage.error(Localization.lang("The instruction has to be provided")));
-
this.temperatureTypeValidator = new FunctionBasedValidator<>(
temperature,
temp -> LocalizedNumbers.stringToDouble(temp).isPresent(),
@@ -318,7 +321,10 @@ public void setValues() {
customizeExpertSettings.setValue(aiPreferences.getCustomizeExpertSettings());
selectedEmbeddingModel.setValue(aiPreferences.getEmbeddingModel());
- instruction.setValue(aiPreferences.getInstruction());
+
+ Arrays.stream(AiTemplate.values()).forEach(template ->
+ templateSources.get(template).set(aiPreferences.getTemplate(template)));
+
temperature.setValue(LocalizedNumbers.doubleToString(aiPreferences.getTemperature()));
contextWindowSize.setValue(aiPreferences.getContextWindowSize());
documentSplitterChunkSize.setValue(aiPreferences.getDocumentSplitterChunkSize());
@@ -359,7 +365,9 @@ public void storeSettings() {
aiPreferences.setHuggingFaceApiBaseUrl(huggingFaceApiBaseUrl.get() == null ? "" : huggingFaceApiBaseUrl.get());
aiPreferences.setGpt4AllApiBaseUrl(gpt4AllApiBaseUrl.get() == null ? "" : gpt4AllApiBaseUrl.get());
- aiPreferences.setInstruction(instruction.get());
+ Arrays.stream(AiTemplate.values()).forEach(template ->
+ aiPreferences.setTemplate(template, templateSources.get(template).get()));
+
// We already check the correctness of temperature and RAG minimum score in validators, so we don't need to check it here.
aiPreferences.setTemperature(LocalizedNumbers.stringToDouble(oldLocale, temperature.get()).get());
aiPreferences.setContextWindowSize(contextWindowSize.get());
@@ -373,8 +381,6 @@ public void resetExpertSettings() {
String resetApiBaseUrl = selectedAiProvider.get().getApiUrl();
currentApiBaseUrl.set(resetApiBaseUrl);
- instruction.set(AiDefaultPreferences.SYSTEM_MESSAGE);
-
contextWindowSize.set(AiDefaultPreferences.getContextWindowSize(selectedAiProvider.get(), currentChatModel.get()));
temperature.set(LocalizedNumbers.doubleToString(AiDefaultPreferences.TEMPERATURE));
@@ -384,6 +390,11 @@ public void resetExpertSettings() {
ragMinScore.set(LocalizedNumbers.doubleToString(AiDefaultPreferences.RAG_MIN_SCORE));
}
+ public void resetTemplates() {
+ Arrays.stream(AiTemplate.values()).forEach(template ->
+ templateSources.get(template).set(AiDefaultPreferences.TEMPLATES.get(template)));
+ }
+
@Override
public boolean validateSettings() {
if (enableAi.get()) {
@@ -410,7 +421,6 @@ public boolean validateExpertSettings() {
List validators = List.of(
apiBaseUrlValidator,
embeddingModelValidator,
- instructionValidator,
temperatureTypeValidator,
temperatureRangeValidator,
contextWindowSizeValidator,
@@ -484,8 +494,8 @@ public BooleanProperty disableApiBaseUrlProperty() {
return disableApiBaseUrl;
}
- public StringProperty instructionProperty() {
- return instruction;
+ public Map getTemplateSources() {
+ return templateSources;
}
public StringProperty temperatureProperty() {
@@ -536,10 +546,6 @@ public ValidationStatus getEmbeddingModelValidationStatus() {
return embeddingModelValidator.getValidationStatus();
}
- public ValidationStatus getSystemMessageValidationStatus() {
- return instructionValidator.getValidationStatus();
- }
-
public ValidationStatus getTemperatureTypeValidationStatus() {
return temperatureTypeValidator.getValidationStatus();
}
diff --git a/src/main/java/org/jabref/logic/ai/AiDefaultPreferences.java b/src/main/java/org/jabref/logic/ai/AiDefaultPreferences.java
index 3cdbd9bfda8..704b33c9f29 100644
--- a/src/main/java/org/jabref/logic/ai/AiDefaultPreferences.java
+++ b/src/main/java/org/jabref/logic/ai/AiDefaultPreferences.java
@@ -4,6 +4,7 @@
import java.util.List;
import java.util.Map;
+import org.jabref.logic.ai.templates.AiTemplate;
import org.jabref.model.ai.AiProvider;
import org.jabref.model.ai.EmbeddingModel;
@@ -80,6 +81,46 @@ public String toString() {
public static final int FALLBACK_CONTEXT_WINDOW_SIZE = 8196;
+ public static final Map TEMPLATES = Map.of(
+ AiTemplate.CHATTING_SYSTEM_MESSAGE, """
+ You are an AI assistant that analyses research papers. You answer questions about papers.
+ You will be supplied with the necessary information. The supplied information will contain mentions of papers in form '@citationKey'.
+ Whenever you refer to a paper, use its citation key in the same form with @ symbol. Whenever you find relevant information, always use the citation key.
+
+ Here are the papers you are analyzing:
+ #foreach( $entry in $entries )
+ ${CanonicalBibEntry.getCanonicalRepresentation($entry)}
+ #end""",
+
+ AiTemplate.CHATTING_USER_MESSAGE, """
+ $message
+
+ Here is some relevant information for you:
+ #foreach( $excerpt in $excerpts )
+ ${excerpt.citationKey()}:
+ ${excerpt.text()}
+ #end""",
+
+ AiTemplate.SUMMARIZATION_CHUNK, """
+ Please provide an overview of the following text. It is a part of a scientific paper.
+ The summary should include the main objectives, methodologies used, key findings, and conclusions.
+ Mention any significant experiments, data, or discussions presented in the paper.
+
+ DOCUMENT:
+ $document
+
+ OVERVIEW:""",
+
+ AiTemplate.SUMMARIZATION_COMBINE, """
+ You have written an overview of a scientific paper. You have been collecting notes from various parts
+ of the paper. Now your task is to combine all of the notes in one structured message.
+
+ SUMMARIES:
+ $summaries
+
+ FINAL OVERVIEW:"""
+ );
+
public static List getAvailableModels(AiProvider aiProvider) {
return Arrays.stream(AiDefaultPreferences.PredefinedChatModel.values())
.filter(model -> model.getAiProvider() == aiProvider)
diff --git a/src/main/java/org/jabref/logic/ai/AiPreferences.java b/src/main/java/org/jabref/logic/ai/AiPreferences.java
index 48b10812338..de1025e72d7 100644
--- a/src/main/java/org/jabref/logic/ai/AiPreferences.java
+++ b/src/main/java/org/jabref/logic/ai/AiPreferences.java
@@ -1,6 +1,7 @@
package org.jabref.logic.ai;
import java.util.List;
+import java.util.Map;
import java.util.Objects;
import javafx.beans.property.BooleanProperty;
@@ -15,6 +16,7 @@
import javafx.beans.property.SimpleStringProperty;
import javafx.beans.property.StringProperty;
+import org.jabref.logic.ai.templates.AiTemplate;
import org.jabref.model.ai.AiProvider;
import org.jabref.model.ai.EmbeddingModel;
import org.jabref.model.strings.StringUtil;
@@ -59,6 +61,8 @@ public class AiPreferences {
private final IntegerProperty ragMaxResultsCount;
private final DoubleProperty ragMinScore;
+ private final Map templates;
+
private Runnable apiKeyChangeListener;
public AiPreferences(boolean enableAi,
@@ -83,7 +87,8 @@ public AiPreferences(boolean enableAi,
int documentSplitterChunkSize,
int documentSplitterOverlapSize,
int ragMaxResultsCount,
- double ragMinScore
+ double ragMinScore,
+ Map templates
) {
this.enableAi = new SimpleBooleanProperty(enableAi);
this.autoGenerateEmbeddings = new SimpleBooleanProperty(autoGenerateEmbeddings);
@@ -113,6 +118,13 @@ public AiPreferences(boolean enableAi,
this.documentSplitterOverlapSize = new SimpleIntegerProperty(documentSplitterOverlapSize);
this.ragMaxResultsCount = new SimpleIntegerProperty(ragMaxResultsCount);
this.ragMinScore = new SimpleDoubleProperty(ragMinScore);
+
+ this.templates = Map.of(
+ AiTemplate.CHATTING_SYSTEM_MESSAGE, new SimpleStringProperty(templates.get(AiTemplate.CHATTING_SYSTEM_MESSAGE)),
+ AiTemplate.CHATTING_USER_MESSAGE, new SimpleStringProperty(templates.get(AiTemplate.CHATTING_USER_MESSAGE)),
+ AiTemplate.SUMMARIZATION_CHUNK, new SimpleStringProperty(templates.get(AiTemplate.SUMMARIZATION_CHUNK)),
+ AiTemplate.SUMMARIZATION_COMBINE, new SimpleStringProperty(templates.get(AiTemplate.SUMMARIZATION_COMBINE))
+ );
}
public String getApiKeyForAiProvider(AiProvider aiProvider) {
@@ -546,4 +558,16 @@ public void setApiKeyChangeListener(Runnable apiKeyChangeListener) {
public void apiKeyUpdated() {
apiKeyChangeListener.run();
}
+
+ public void setTemplate(AiTemplate aiTemplate, String template) {
+ templates.get(aiTemplate).set(template);
+ }
+
+ public String getTemplate(AiTemplate aiTemplate) {
+ return templates.get(aiTemplate).get();
+ }
+
+ public StringProperty templateProperty(AiTemplate aiTemplate) {
+ return templates.get(aiTemplate);
+ }
}
diff --git a/src/main/java/org/jabref/logic/ai/AiService.java b/src/main/java/org/jabref/logic/ai/AiService.java
index a068037b423..2e1b76b67b0 100644
--- a/src/main/java/org/jabref/logic/ai/AiService.java
+++ b/src/main/java/org/jabref/logic/ai/AiService.java
@@ -17,6 +17,7 @@
import org.jabref.logic.ai.ingestion.storages.MVStoreFullyIngestedDocumentsTracker;
import org.jabref.logic.ai.summarization.SummariesService;
import org.jabref.logic.ai.summarization.storages.MVStoreSummariesStorage;
+import org.jabref.logic.ai.templates.TemplatesService;
import org.jabref.logic.citationkeypattern.CitationKeyPatternPreferences;
import org.jabref.logic.util.Directories;
import org.jabref.logic.util.NotificationService;
@@ -71,11 +72,12 @@ public AiService(AiPreferences aiPreferences,
this.mvStoreFullyIngestedDocumentsTracker = new MVStoreFullyIngestedDocumentsTracker(Directories.getAiFilesDirectory().resolve(FULLY_INGESTED_FILE_NAME), notificationService);
this.mvStoreSummariesStorage = new MVStoreSummariesStorage(Directories.getAiFilesDirectory().resolve(SUMMARIES_FILE_NAME), notificationService);
+ TemplatesService templatesService = new TemplatesService(aiPreferences);
this.chatHistoryService = new ChatHistoryService(citationKeyPatternPreferences, mvStoreChatHistoryStorage);
this.jabRefChatLanguageModel = new JabRefChatLanguageModel(aiPreferences);
this.jabRefEmbeddingModel = new JabRefEmbeddingModel(aiPreferences, notificationService, taskExecutor);
- this.aiChatService = new AiChatService(aiPreferences, jabRefChatLanguageModel, jabRefEmbeddingModel, mvStoreEmbeddingStore, cachedThreadPool);
+ this.aiChatService = new AiChatService(aiPreferences, jabRefChatLanguageModel, jabRefEmbeddingModel, mvStoreEmbeddingStore, templatesService);
this.ingestionService = new IngestionService(
aiPreferences,
@@ -91,6 +93,7 @@ public AiService(AiPreferences aiPreferences,
aiPreferences,
mvStoreSummariesStorage,
jabRefChatLanguageModel,
+ templatesService,
shutdownSignal,
filePreferences,
taskExecutor
diff --git a/src/main/java/org/jabref/logic/ai/chatting/AiChatLogic.java b/src/main/java/org/jabref/logic/ai/chatting/AiChatLogic.java
index f079b093604..d237afbdd0b 100644
--- a/src/main/java/org/jabref/logic/ai/chatting/AiChatLogic.java
+++ b/src/main/java/org/jabref/logic/ai/chatting/AiChatLogic.java
@@ -1,8 +1,7 @@
package org.jabref.logic.ai.chatting;
import java.util.List;
-import java.util.concurrent.Executor;
-import java.util.stream.Collectors;
+import java.util.Optional;
import javafx.beans.property.StringProperty;
import javafx.collections.ListChangeListener;
@@ -10,15 +9,15 @@
import org.jabref.logic.ai.AiPreferences;
import org.jabref.logic.ai.ingestion.FileEmbeddingsManager;
+import org.jabref.logic.ai.templates.AiTemplate;
+import org.jabref.logic.ai.templates.PaperExcerpt;
+import org.jabref.logic.ai.templates.TemplatesService;
import org.jabref.logic.ai.util.ErrorMessage;
import org.jabref.model.database.BibDatabaseContext;
import org.jabref.model.entry.BibEntry;
-import org.jabref.model.entry.CanonicalBibEntry;
import org.jabref.model.entry.LinkedFile;
import org.jabref.model.util.ListUtil;
-import dev.langchain4j.chain.Chain;
-import dev.langchain4j.chain.ConversationalRetrievalChain;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
@@ -29,14 +28,12 @@
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiTokenizer;
-import dev.langchain4j.rag.DefaultRetrievalAugmentor;
-import dev.langchain4j.rag.RetrievalAugmentor;
-import dev.langchain4j.rag.content.retriever.ContentRetriever;
-import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
+import dev.langchain4j.store.embedding.EmbeddingMatch;
+import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
+import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
-import jakarta.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -47,7 +44,7 @@ public class AiChatLogic {
private final ChatLanguageModel chatLanguageModel;
private final EmbeddingModel embeddingModel;
private final EmbeddingStore embeddingStore;
- private final Executor cachedThreadPool;
+ private final TemplatesService templatesService;
private final ObservableList chatHistory;
private final ObservableList entries;
@@ -55,13 +52,14 @@ public class AiChatLogic {
private final BibDatabaseContext bibDatabaseContext;
private ChatMemory chatMemory;
- private Chain chain;
+
+ private Optional filter = Optional.empty();
public AiChatLogic(AiPreferences aiPreferences,
ChatLanguageModel chatLanguageModel,
EmbeddingModel embeddingModel,
EmbeddingStore embeddingStore,
- Executor cachedThreadPool,
+ TemplatesService templatesService,
StringProperty name,
ObservableList chatHistory,
ObservableList entries,
@@ -71,26 +69,30 @@ public AiChatLogic(AiPreferences aiPreferences,
this.chatLanguageModel = chatLanguageModel;
this.embeddingModel = embeddingModel;
this.embeddingStore = embeddingStore;
- this.cachedThreadPool = cachedThreadPool;
+ this.templatesService = templatesService;
this.chatHistory = chatHistory;
this.entries = entries;
this.name = name;
this.bibDatabaseContext = bibDatabaseContext;
- this.entries.addListener((ListChangeListener) change -> rebuildChain());
+ this.entries.addListener((ListChangeListener) change -> rebuildFilter());
setupListeningToPreferencesChanges();
rebuildFull(chatHistory);
}
private void setupListeningToPreferencesChanges() {
- aiPreferences.instructionProperty().addListener(obs -> setSystemMessage(aiPreferences.getInstruction()));
+ aiPreferences
+ .templateProperty(AiTemplate.CHATTING_SYSTEM_MESSAGE)
+ .addListener(obs ->
+ setSystemMessage(templatesService.makeChattingSystemMessage(entries)));
+
aiPreferences.contextWindowSizeProperty().addListener(obs -> rebuildFull(chatMemory.messages()));
}
private void rebuildFull(List chatMessages) {
rebuildChatMemory(chatMessages);
- rebuildChain();
+ rebuildFilter();
}
private void rebuildChatMemory(List chatMessages) {
@@ -101,75 +103,93 @@ private void rebuildChatMemory(List chatMessages) {
chatMessages.stream().filter(chatMessage -> !(chatMessage instanceof ErrorMessage)).forEach(chatMemory::add);
- setSystemMessage(aiPreferences.getInstruction());
+ setSystemMessage(templatesService.makeChattingSystemMessage(entries));
}
- private void rebuildChain() {
+ private void rebuildFilter() {
List linkedFiles = ListUtil.getLinkedFiles(entries).toList();
- @Nullable Filter filter;
if (linkedFiles.isEmpty()) {
- // You must not pass an empty list to langchain4j {@link IsIn} filter.
- filter = null;
+ filter = Optional.empty();
} else {
- filter = MetadataFilterBuilder
+ filter = Optional.of(MetadataFilterBuilder
.metadataKey(FileEmbeddingsManager.LINK_METADATA_KEY)
.isIn(linkedFiles
.stream()
.map(LinkedFile::getLink)
.toList()
- );
+ ));
}
+ }
+
+ private void setSystemMessage(String systemMessage) {
+ chatMemory.add(new SystemMessage(systemMessage));
+ }
+
+ public AiMessage execute(UserMessage message) {
+ // Message will be automatically added to ChatMemory through ConversationalRetrievalChain.
- ContentRetriever contentRetriever = EmbeddingStoreContentRetriever
+ chatHistory.add(message);
+
+ LOGGER.info("Sending message to AI provider ({}) for answering in {}: {}",
+ aiPreferences.getAiProvider().getApiUrl(),
+ name.get(),
+ message.singleText());
+
+ EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest
.builder()
- .embeddingStore(embeddingStore)
- .filter(filter)
- .embeddingModel(embeddingModel)
.maxResults(aiPreferences.getRagMaxResultsCount())
.minScore(aiPreferences.getRagMinScore())
+ .filter(filter.orElse(null))
+ .queryEmbedding(embeddingModel.embed(message.singleText()).content())
.build();
- RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor
- .builder()
- .contentRetriever(contentRetriever)
- .contentInjector(new JabRefContentInjector(bibDatabaseContext))
- .executor(cachedThreadPool)
- .build();
+ EmbeddingSearchResult embeddingSearchResult = embeddingStore.search(embeddingSearchRequest);
- this.chain = ConversationalRetrievalChain
+ List excerpts = embeddingSearchResult
+ .matches()
+ .stream()
+ .map(EmbeddingMatch::embedded)
+ .map(textSegment -> {
+ String link = textSegment.metadata().getString(FileEmbeddingsManager.LINK_METADATA_KEY);
+
+ if (link == null) {
+ return new PaperExcerpt("", textSegment.text());
+ } else {
+ return new PaperExcerpt(findEntryByLink(link).flatMap(BibEntry::getCitationKey).orElse(""), textSegment.text());
+ }
+ })
+ .toList();
+
+ LOGGER.debug("Found excerpts for the message: {}", excerpts);
+
+ // This is crazy, but langchain4j {@link ChatMemory} does not allow to remove single messages.
+ ChatMemory tempChatMemory = TokenWindowChatMemory
.builder()
- .chatLanguageModel(chatLanguageModel)
- .retrievalAugmentor(retrievalAugmentor)
- .chatMemory(chatMemory)
+ .maxTokens(aiPreferences.getContextWindowSize(), new OpenAiTokenizer())
.build();
- }
- private void setSystemMessage(String systemMessage) {
- chatMemory.add(new SystemMessage(augmentSystemMessage(systemMessage)));
- }
+ chatMemory.messages().forEach(tempChatMemory::add);
- private String augmentSystemMessage(String systemMessage) {
- String entriesInfo = entries.stream().map(CanonicalBibEntry::getCanonicalRepresentation).collect(Collectors.joining("\n"));
+ tempChatMemory.add(new UserMessage(templatesService.makeChattingUserMessage(message.singleText(), excerpts)));
+ chatMemory.add(message);
- return systemMessage + "\n" + entriesInfo;
- }
+ AiMessage aiMessage = chatLanguageModel.generate(tempChatMemory.messages()).content();
- public AiMessage execute(UserMessage message) {
- // Message will be automatically added to ChatMemory through ConversationalRetrievalChain.
-
- LOGGER.info("Sending message to AI provider ({}) for answering in {}: {}",
- aiPreferences.getAiProvider().getApiUrl(),
- name.get(),
- message.singleText());
+ chatMemory.add(aiMessage);
+ chatHistory.add(aiMessage);
- chatHistory.add(message);
- AiMessage result = new AiMessage(chain.execute(message.singleText()));
- chatHistory.add(result);
+ LOGGER.debug("Message was answered by the AI provider for {}: {}", name.get(), aiMessage.text());
- LOGGER.debug("Message was answered by the AI provider for {}: {}", name.get(), result.text());
+ return aiMessage;
+ }
- return result;
+ private Optional findEntryByLink(String link) {
+ return bibDatabaseContext
+ .getEntries()
+ .stream()
+ .filter(entry -> entry.getFiles().stream().anyMatch(file -> file.getLink().equals(link)))
+ .findFirst();
}
public ObservableList getChatHistory() {
diff --git a/src/main/java/org/jabref/logic/ai/chatting/AiChatService.java b/src/main/java/org/jabref/logic/ai/chatting/AiChatService.java
index d12b20ec502..d0d1698c497 100644
--- a/src/main/java/org/jabref/logic/ai/chatting/AiChatService.java
+++ b/src/main/java/org/jabref/logic/ai/chatting/AiChatService.java
@@ -1,11 +1,10 @@
package org.jabref.logic.ai.chatting;
-import java.util.concurrent.Executor;
-
import javafx.beans.property.StringProperty;
import javafx.collections.ObservableList;
import org.jabref.logic.ai.AiPreferences;
+import org.jabref.logic.ai.templates.TemplatesService;
import org.jabref.model.database.BibDatabaseContext;
import org.jabref.model.entry.BibEntry;
@@ -20,19 +19,19 @@ public class AiChatService {
private final ChatLanguageModel chatLanguageModel;
private final EmbeddingModel embeddingModel;
private final EmbeddingStore embeddingStore;
- private final Executor cachedThreadPool;
+ private final TemplatesService templatesService;
public AiChatService(AiPreferences aiPreferences,
ChatLanguageModel chatLanguageModel,
EmbeddingModel embeddingModel,
EmbeddingStore embeddingStore,
- Executor cachedThreadPool
+ TemplatesService templatesService
) {
this.aiPreferences = aiPreferences;
this.chatLanguageModel = chatLanguageModel;
this.embeddingModel = embeddingModel;
this.embeddingStore = embeddingStore;
- this.cachedThreadPool = cachedThreadPool;
+ this.templatesService = templatesService;
}
public AiChatLogic makeChat(
@@ -46,7 +45,7 @@ public AiChatLogic makeChat(
chatLanguageModel,
embeddingModel,
embeddingStore,
- cachedThreadPool,
+ templatesService,
name,
chatHistory,
entries,
diff --git a/src/main/java/org/jabref/logic/ai/summarization/GenerateSummaryForSeveralTask.java b/src/main/java/org/jabref/logic/ai/summarization/GenerateSummaryForSeveralTask.java
index 4940fba8b21..849040d4de0 100644
--- a/src/main/java/org/jabref/logic/ai/summarization/GenerateSummaryForSeveralTask.java
+++ b/src/main/java/org/jabref/logic/ai/summarization/GenerateSummaryForSeveralTask.java
@@ -12,6 +12,7 @@
import org.jabref.logic.ai.AiPreferences;
import org.jabref.logic.ai.processingstatus.ProcessingInfo;
import org.jabref.logic.ai.processingstatus.ProcessingState;
+import org.jabref.logic.ai.templates.TemplatesService;
import org.jabref.logic.l10n.Localization;
import org.jabref.logic.util.BackgroundTask;
import org.jabref.logic.util.ProgressCounter;
@@ -36,6 +37,7 @@ public class GenerateSummaryForSeveralTask extends BackgroundTask {
private final BibDatabaseContext bibDatabaseContext;
private final SummariesStorage summariesStorage;
private final ChatLanguageModel chatLanguageModel;
+ private final TemplatesService templatesService;
private final ReadOnlyBooleanProperty shutdownSignal;
private final AiPreferences aiPreferences;
private final FilePreferences filePreferences;
@@ -51,6 +53,7 @@ public GenerateSummaryForSeveralTask(
BibDatabaseContext bibDatabaseContext,
SummariesStorage summariesStorage,
ChatLanguageModel chatLanguageModel,
+ TemplatesService templatesService,
ReadOnlyBooleanProperty shutdownSignal,
AiPreferences aiPreferences,
FilePreferences filePreferences,
@@ -61,6 +64,7 @@ public GenerateSummaryForSeveralTask(
this.bibDatabaseContext = bibDatabaseContext;
this.summariesStorage = summariesStorage;
this.chatLanguageModel = chatLanguageModel;
+ this.templatesService = templatesService;
this.shutdownSignal = shutdownSignal;
this.aiPreferences = aiPreferences;
this.filePreferences = filePreferences;
@@ -95,6 +99,7 @@ public Void call() throws Exception {
bibDatabaseContext,
summariesStorage,
chatLanguageModel,
+ templatesService,
shutdownSignal,
aiPreferences,
filePreferences
diff --git a/src/main/java/org/jabref/logic/ai/summarization/GenerateSummaryTask.java b/src/main/java/org/jabref/logic/ai/summarization/GenerateSummaryTask.java
index f863c8f01e0..57fa54d405f 100644
--- a/src/main/java/org/jabref/logic/ai/summarization/GenerateSummaryTask.java
+++ b/src/main/java/org/jabref/logic/ai/summarization/GenerateSummaryTask.java
@@ -3,7 +3,6 @@
import java.nio.file.Path;
import java.time.LocalDateTime;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
@@ -14,6 +13,8 @@
import org.jabref.logic.FilePreferences;
import org.jabref.logic.ai.AiPreferences;
import org.jabref.logic.ai.ingestion.FileToDocument;
+import org.jabref.logic.ai.templates.AiTemplate;
+import org.jabref.logic.ai.templates.TemplatesService;
import org.jabref.logic.ai.util.CitationKeyCheck;
import org.jabref.logic.l10n.Localization;
import org.jabref.logic.util.BackgroundTask;
@@ -27,8 +28,6 @@
import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.chat.ChatLanguageModel;
-import dev.langchain4j.model.input.Prompt;
-import dev.langchain4j.model.input.PromptTemplate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -42,32 +41,6 @@
public class GenerateSummaryTask extends BackgroundTask {
private static final Logger LOGGER = LoggerFactory.getLogger(GenerateSummaryTask.class);
- // Be careful when constructing prompt.
- // 1. It should contain variables `bullets` and `chunk`.
- // 2. Variables should be wrapped in `{{` and `}}` and only with them. No whitespace inside.
- private static final PromptTemplate CHUNK_PROMPT_TEMPLATE = PromptTemplate.from(
- """
- Please provide an overview of the following text. It's a part of a scientific paper.
- The summary should include the main objectives, methodologies used, key findings, and conclusions.
- Mention any significant experiments, data, or discussions presented in the paper.
-
- DOCUMENT:
- {{document}}
-
- OVERVIEW:"""
- );
-
- private static final PromptTemplate COMBINE_PROMPT_TEMPLATE = PromptTemplate.from(
- """
- You have written an overview of a scientific paper. You have been collecting notes from various parts
- of the paper. Now your task is to combine all of the notes in one structured message.
-
- SUMMARIES:
- {{summaries}}
-
- FINAL OVERVIEW:"""
- );
-
private static final int MAX_OVERLAP_SIZE_IN_CHARS = 100;
private static final int CHAR_TOKEN_FACTOR = 4; // Means, every token is roughly 4 characters.
@@ -76,6 +49,7 @@ public class GenerateSummaryTask extends BackgroundTask {
private final String citationKey;
private final ChatLanguageModel chatLanguageModel;
private final SummariesStorage summariesStorage;
+ private final TemplatesService templatesService;
private final ReadOnlyBooleanProperty shutdownSignal;
private final AiPreferences aiPreferences;
private final FilePreferences filePreferences;
@@ -86,6 +60,7 @@ public GenerateSummaryTask(BibEntry entry,
BibDatabaseContext bibDatabaseContext,
SummariesStorage summariesStorage,
ChatLanguageModel chatLanguageModel,
+ TemplatesService templatesService,
ReadOnlyBooleanProperty shutdownSignal,
AiPreferences aiPreferences,
FilePreferences filePreferences
@@ -95,6 +70,7 @@ public GenerateSummaryTask(BibEntry entry,
this.citationKey = entry.getCitationKey().orElse("");
this.chatLanguageModel = chatLanguageModel;
this.summariesStorage = summariesStorage;
+ this.templatesService = templatesService;
this.shutdownSignal = shutdownSignal;
this.aiPreferences = aiPreferences;
this.filePreferences = filePreferences;
@@ -222,7 +198,7 @@ private Optional generateSummary(LinkedFile linkedFile) throws Interrupt
public String summarizeOneDocument(String filePath, String document) throws InterruptedException {
addMoreWork(1); // For the combination of summary chunks.
- DocumentSplitter documentSplitter = DocumentSplitters.recursive(aiPreferences.getContextWindowSize() - MAX_OVERLAP_SIZE_IN_CHARS * 2 - estimateTokenCount(CHUNK_PROMPT_TEMPLATE), MAX_OVERLAP_SIZE_IN_CHARS);
+ DocumentSplitter documentSplitter = DocumentSplitters.recursive(aiPreferences.getContextWindowSize() - MAX_OVERLAP_SIZE_IN_CHARS * 2 - estimateTokenCount(aiPreferences.getTemplate(AiTemplate.SUMMARIZATION_CHUNK)), MAX_OVERLAP_SIZE_IN_CHARS);
List chunkSummaries = documentSplitter.split(new Document(document)).stream().map(TextSegment::text).toList();
@@ -243,10 +219,10 @@ public String summarizeOneDocument(String filePath, String document) throws Inte
throw new InterruptedException();
}
- Prompt prompt = CHUNK_PROMPT_TEMPLATE.apply(Collections.singletonMap("document", chunkSummary));
+ String prompt = templatesService.makeSummarizationChunk(chunkSummary);
LOGGER.debug("Sending request to AI provider to summarize a chunk from file \"{}\" of entry {}", filePath, citationKey);
- String chunk = chatLanguageModel.generate(prompt.toString());
+ String chunk = chatLanguageModel.generate(prompt);
LOGGER.debug("Chunk summary for file \"{}\" of entry {} was generated successfully", filePath, citationKey);
list.add(chunk);
@@ -254,7 +230,7 @@ public String summarizeOneDocument(String filePath, String document) throws Inte
}
chunkSummaries = list;
- } while (estimateTokenCount(chunkSummaries) > aiPreferences.getContextWindowSize() - estimateTokenCount(COMBINE_PROMPT_TEMPLATE));
+ } while (estimateTokenCount(chunkSummaries) > aiPreferences.getContextWindowSize() - estimateTokenCount(aiPreferences.getTemplate(AiTemplate.SUMMARIZATION_COMBINE)));
if (chunkSummaries.size() == 1) {
doneOneWork(); // No need to call LLM for combination of summary chunks.
@@ -262,14 +238,14 @@ public String summarizeOneDocument(String filePath, String document) throws Inte
return chunkSummaries.getFirst();
}
- Prompt prompt = COMBINE_PROMPT_TEMPLATE.apply(Collections.singletonMap("summaries", String.join("\n\n", chunkSummaries)));
+ String prompt = templatesService.makeSummarizationCombine(chunkSummaries);
if (shutdownSignal.get()) {
throw new InterruptedException();
}
LOGGER.debug("Sending request to AI provider to combine summary chunk(s) for file \"{}\" of entry {}", filePath, citationKey);
- String result = chatLanguageModel.generate(prompt.toString());
+ String result = chatLanguageModel.generate(prompt);
LOGGER.debug("Summary of the file \"{}\" of entry {} was generated successfully", filePath, citationKey);
doneOneWork();
@@ -284,10 +260,6 @@ private static int estimateTokenCount(List chunkSummaries) {
return chunkSummaries.stream().mapToInt(GenerateSummaryTask::estimateTokenCount).sum();
}
- private static int estimateTokenCount(PromptTemplate promptTemplate) {
- return estimateTokenCount(promptTemplate.template());
- }
-
private static int estimateTokenCount(String string) {
return estimateTokenCount(string.length());
}
diff --git a/src/main/java/org/jabref/logic/ai/summarization/SummariesService.java b/src/main/java/org/jabref/logic/ai/summarization/SummariesService.java
index afac5e823f5..faf8a4dece4 100644
--- a/src/main/java/org/jabref/logic/ai/summarization/SummariesService.java
+++ b/src/main/java/org/jabref/logic/ai/summarization/SummariesService.java
@@ -12,6 +12,7 @@
import org.jabref.logic.ai.AiPreferences;
import org.jabref.logic.ai.processingstatus.ProcessingInfo;
import org.jabref.logic.ai.processingstatus.ProcessingState;
+import org.jabref.logic.ai.templates.TemplatesService;
import org.jabref.logic.ai.util.CitationKeyCheck;
import org.jabref.logic.util.TaskExecutor;
import org.jabref.model.database.BibDatabaseContext;
@@ -44,6 +45,7 @@ public class SummariesService {
private final AiPreferences aiPreferences;
private final SummariesStorage summariesStorage;
private final ChatLanguageModel chatLanguageModel;
+ private final TemplatesService templatesService;
private final BooleanProperty shutdownSignal;
private final FilePreferences filePreferences;
private final TaskExecutor taskExecutor;
@@ -51,6 +53,7 @@ public class SummariesService {
public SummariesService(AiPreferences aiPreferences,
SummariesStorage summariesStorage,
ChatLanguageModel chatLanguageModel,
+ TemplatesService templatesService,
BooleanProperty shutdownSignal,
FilePreferences filePreferences,
TaskExecutor taskExecutor
@@ -58,6 +61,7 @@ public SummariesService(AiPreferences aiPreferences,
this.aiPreferences = aiPreferences;
this.summariesStorage = summariesStorage;
this.chatLanguageModel = chatLanguageModel;
+ this.templatesService = templatesService;
this.shutdownSignal = shutdownSignal;
this.filePreferences = filePreferences;
this.taskExecutor = taskExecutor;
@@ -134,7 +138,7 @@ public List> summarize(StringProperty groupNam
private void startSummarizationTask(BibEntry entry, BibDatabaseContext bibDatabaseContext, ProcessingInfo processingInfo) {
processingInfo.setState(ProcessingState.PROCESSING);
- new GenerateSummaryTask(entry, bibDatabaseContext, summariesStorage, chatLanguageModel, shutdownSignal, aiPreferences, filePreferences)
+ new GenerateSummaryTask(entry, bibDatabaseContext, summariesStorage, chatLanguageModel, templatesService, shutdownSignal, aiPreferences, filePreferences)
.onSuccess(processingInfo::setSuccess)
.onFailure(processingInfo::setException)
.executeWith(taskExecutor);
@@ -143,7 +147,7 @@ private void startSummarizationTask(BibEntry entry, BibDatabaseContext bibDataba
private void startSummarizationTask(StringProperty groupName, List> entries, BibDatabaseContext bibDatabaseContext) {
entries.forEach(processingInfo -> processingInfo.setState(ProcessingState.PROCESSING));
- new GenerateSummaryForSeveralTask(groupName, entries, bibDatabaseContext, summariesStorage, chatLanguageModel, shutdownSignal, aiPreferences, filePreferences, taskExecutor)
+ new GenerateSummaryForSeveralTask(groupName, entries, bibDatabaseContext, summariesStorage, chatLanguageModel, templatesService, shutdownSignal, aiPreferences, filePreferences, taskExecutor)
.executeWith(taskExecutor);
}
diff --git a/src/main/java/org/jabref/logic/ai/templates/AiTemplate.java b/src/main/java/org/jabref/logic/ai/templates/AiTemplate.java
new file mode 100644
index 00000000000..50feff648d3
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/templates/AiTemplate.java
@@ -0,0 +1,30 @@
+package org.jabref.logic.ai.templates;
+
+import org.jabref.logic.l10n.Localization;
+
+public enum AiTemplate {
+ // System message that is applied in the AI chat.
+ CHATTING_SYSTEM_MESSAGE,
+
+ // Template of a last user message in the AI chat with embeddings.
+ CHATTING_USER_MESSAGE,
+
+ // Template that is used to summarize the chunks of text.
+ SUMMARIZATION_CHUNK,
+
+ // Template that is used to combine the summarized chunks of text.
+ SUMMARIZATION_COMBINE;
+
+ public String getLocalizedName() {
+ return switch (this) {
+ case CHATTING_SYSTEM_MESSAGE ->
+ Localization.lang("System message for chatting");
+ case CHATTING_USER_MESSAGE ->
+ Localization.lang("User message for chatting");
+ case SUMMARIZATION_CHUNK ->
+ Localization.lang("Completion text for summarization of a chunk");
+ case SUMMARIZATION_COMBINE ->
+ Localization.lang("Completion text for summarization of several chunks");
+ };
+ }
+}
diff --git a/src/main/java/org/jabref/logic/ai/templates/PaperExcerpt.java b/src/main/java/org/jabref/logic/ai/templates/PaperExcerpt.java
new file mode 100644
index 00000000000..b8fbf810650
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/templates/PaperExcerpt.java
@@ -0,0 +1,3 @@
+package org.jabref.logic.ai.templates;
+
+public record PaperExcerpt(String citationKey, String text) { }
diff --git a/src/main/java/org/jabref/logic/ai/templates/TemplatesService.java b/src/main/java/org/jabref/logic/ai/templates/TemplatesService.java
new file mode 100644
index 00000000000..79e1634491d
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/templates/TemplatesService.java
@@ -0,0 +1,62 @@
+package org.jabref.logic.ai.templates;
+
+import java.io.StringWriter;
+import java.util.List;
+
+import org.jabref.logic.ai.AiPreferences;
+import org.jabref.model.entry.BibEntry;
+import org.jabref.model.entry.CanonicalBibEntry;
+
+import org.apache.velocity.VelocityContext;
+import org.apache.velocity.app.VelocityEngine;
+
+public class TemplatesService {
+ private final AiPreferences aiPreferences;
+ private final VelocityEngine velocityEngine = new VelocityEngine();
+ private final VelocityContext baseContext = new VelocityContext();
+
+ public TemplatesService(AiPreferences aiPreferences) {
+ this.aiPreferences = aiPreferences;
+
+ velocityEngine.init();
+
+ baseContext.put("CanonicalBibEntry", CanonicalBibEntry.class);
+ }
+
+ public String makeChattingSystemMessage(List entries) {
+ VelocityContext context = new VelocityContext(baseContext);
+ context.put("entries", entries);
+
+ return makeTemplate(AiTemplate.CHATTING_SYSTEM_MESSAGE, context);
+ }
+
+ public String makeChattingUserMessage(String message, List excerpts) {
+ VelocityContext context = new VelocityContext(baseContext);
+ context.put("message", message);
+ context.put("excerpts", excerpts);
+
+ return makeTemplate(AiTemplate.CHATTING_USER_MESSAGE, context);
+ }
+
+ public String makeSummarizationChunk(String text) {
+ VelocityContext context = new VelocityContext(baseContext);
+ context.put("text", text);
+
+ return makeTemplate(AiTemplate.SUMMARIZATION_CHUNK, context);
+ }
+
+ public String makeSummarizationCombine(List chunks) {
+ VelocityContext context = new VelocityContext(baseContext);
+ context.put("chunks", chunks);
+
+ return makeTemplate(AiTemplate.SUMMARIZATION_COMBINE, context);
+ }
+
+ private String makeTemplate(AiTemplate template, VelocityContext context) {
+ StringWriter writer = new StringWriter();
+
+ velocityEngine.evaluate(context, writer, template.name(), aiPreferences.getTemplate(template));
+
+ return writer.toString();
+ }
+}
diff --git a/src/main/java/org/jabref/logic/help/HelpFile.java b/src/main/java/org/jabref/logic/help/HelpFile.java
index c4f91c54153..7ead357ec38 100644
--- a/src/main/java/org/jabref/logic/help/HelpFile.java
+++ b/src/main/java/org/jabref/logic/help/HelpFile.java
@@ -48,7 +48,8 @@ public enum HelpFile {
SQL_DATABASE_MIGRATION("collaborative-work/sqldatabase/sqldatabasemigration"),
PUSH_TO_APPLICATION("cite/pushtoapplications"),
AI_GENERAL_SETTINGS("ai/preferences"),
- AI_EXPERT_SETTINGS("ai/preferences#ai-expert-settings");
+ AI_EXPERT_SETTINGS("ai/preferences#ai-expert-settings"),
+ AI_TEMPLATES("ai/preferences#templates");
private final String pageName;
diff --git a/src/main/java/org/jabref/logic/preferences/JabRefCliPreferences.java b/src/main/java/org/jabref/logic/preferences/JabRefCliPreferences.java
index 65a531f1356..095febc0b7f 100644
--- a/src/main/java/org/jabref/logic/preferences/JabRefCliPreferences.java
+++ b/src/main/java/org/jabref/logic/preferences/JabRefCliPreferences.java
@@ -38,6 +38,7 @@
import org.jabref.logic.LibraryPreferences;
import org.jabref.logic.ai.AiDefaultPreferences;
import org.jabref.logic.ai.AiPreferences;
+import org.jabref.logic.ai.templates.AiTemplate;
import org.jabref.logic.bibtex.FieldPreferences;
import org.jabref.logic.citationkeypattern.CitationKeyPattern;
import org.jabref.logic.citationkeypattern.CitationKeyPatternPreferences;
@@ -372,6 +373,11 @@ public class JabRefCliPreferences implements CliPreferences {
private static final String AI_RAG_MAX_RESULTS_COUNT = "aiRagMaxResultsCount";
private static final String AI_RAG_MIN_SCORE = "aiRagMinScore";
+ private static final String AI_CHATTING_SYSTEM_MESSAGE_TEMPLATE = "aiChattingSystemMessageTemplate";
+ private static final String AI_CHATTING_USER_MESSAGE_TEMPLATE = "aiChattingUserMessageTemplate";
+ private static final String AI_SUMMARIZATION_CHUNK_TEMPLATE = "aiSummarizationChunkTemplate";
+ private static final String AI_SUMMARIZATION_COMBINE_TEMPLATE = "aiSummarizationCombineTemplate";
+
private static final Logger LOGGER = LoggerFactory.getLogger(JabRefCliPreferences.class);
private static final Preferences PREFS_NODE = Preferences.userRoot().node("/org/jabref");
@@ -658,6 +664,14 @@ protected JabRefCliPreferences() {
defaults.put(AI_DOCUMENT_SPLITTER_OVERLAP_SIZE, AiDefaultPreferences.DOCUMENT_SPLITTER_OVERLAP);
defaults.put(AI_RAG_MAX_RESULTS_COUNT, AiDefaultPreferences.RAG_MAX_RESULTS_COUNT);
defaults.put(AI_RAG_MIN_SCORE, AiDefaultPreferences.RAG_MIN_SCORE);
+
+ // region:AI templates
+ defaults.put(AI_CHATTING_SYSTEM_MESSAGE_TEMPLATE, AiDefaultPreferences.TEMPLATES.get(AiTemplate.CHATTING_SYSTEM_MESSAGE));
+ defaults.put(AI_CHATTING_USER_MESSAGE_TEMPLATE, AiDefaultPreferences.TEMPLATES.get(AiTemplate.CHATTING_USER_MESSAGE));
+ defaults.put(AI_SUMMARIZATION_CHUNK_TEMPLATE, AiDefaultPreferences.TEMPLATES.get(AiTemplate.SUMMARIZATION_CHUNK));
+ defaults.put(AI_SUMMARIZATION_COMBINE_TEMPLATE, AiDefaultPreferences.TEMPLATES.get(AiTemplate.SUMMARIZATION_COMBINE));
+ // endregion
+
// endregion
}
@@ -1853,7 +1867,13 @@ public AiPreferences getAiPreferences() {
getInt(AI_DOCUMENT_SPLITTER_CHUNK_SIZE),
getInt(AI_DOCUMENT_SPLITTER_OVERLAP_SIZE),
getInt(AI_RAG_MAX_RESULTS_COUNT),
- getDouble(AI_RAG_MIN_SCORE));
+ getDouble(AI_RAG_MIN_SCORE),
+ Map.of(
+ AiTemplate.CHATTING_SYSTEM_MESSAGE, get(AI_CHATTING_SYSTEM_MESSAGE_TEMPLATE),
+ AiTemplate.CHATTING_USER_MESSAGE, get(AI_CHATTING_USER_MESSAGE_TEMPLATE),
+ AiTemplate.SUMMARIZATION_CHUNK, get(AI_SUMMARIZATION_CHUNK_TEMPLATE),
+ AiTemplate.SUMMARIZATION_COMBINE, get(AI_SUMMARIZATION_COMBINE_TEMPLATE)
+ ));
EasyBind.listen(aiPreferences.enableAiProperty(), (obs, oldValue, newValue) -> putBoolean(AI_ENABLED, newValue));
EasyBind.listen(aiPreferences.autoGenerateEmbeddingsProperty(), (obs, oldValue, newValue) -> putBoolean(AI_AUTO_GENERATE_EMBEDDINGS, newValue));
@@ -1884,6 +1904,11 @@ public AiPreferences getAiPreferences() {
EasyBind.listen(aiPreferences.ragMaxResultsCountProperty(), (obs, oldValue, newValue) -> putInt(AI_RAG_MAX_RESULTS_COUNT, newValue));
EasyBind.listen(aiPreferences.ragMinScoreProperty(), (obs, oldValue, newValue) -> putDouble(AI_RAG_MIN_SCORE, newValue.doubleValue()));
+ EasyBind.listen(aiPreferences.templateProperty(AiTemplate.CHATTING_SYSTEM_MESSAGE), (obs, oldValue, newValue) -> put(AI_CHATTING_SYSTEM_MESSAGE_TEMPLATE, newValue));
+ EasyBind.listen(aiPreferences.templateProperty(AiTemplate.CHATTING_USER_MESSAGE), (obs, oldValue, newValue) -> put(AI_CHATTING_USER_MESSAGE_TEMPLATE, newValue));
+ EasyBind.listen(aiPreferences.templateProperty(AiTemplate.SUMMARIZATION_CHUNK), (obs, oldValue, newValue) -> put(AI_SUMMARIZATION_CHUNK_TEMPLATE, newValue));
+ EasyBind.listen(aiPreferences.templateProperty(AiTemplate.SUMMARIZATION_COMBINE), (obs, oldValue, newValue) -> put(AI_SUMMARIZATION_COMBINE_TEMPLATE, newValue));
+
return aiPreferences;
}
diff --git a/src/main/resources/l10n/JabRef_en.properties b/src/main/resources/l10n/JabRef_en.properties
index 20c0d3f015a..fb3ed0d759b 100644
--- a/src/main/resources/l10n/JabRef_en.properties
+++ b/src/main/resources/l10n/JabRef_en.properties
@@ -2563,7 +2563,6 @@ Are\ you\ sure\ you\ want\ to\ clear\ the\ chat\ history\ of\ this\ entry?=Are y
Context\ window\ size=Context window size
Context\ window\ size\ must\ be\ greater\ than\ 0=Context window size must be greater than 0
Instruction\ for\ AI\ (also\ known\ as\ prompt\ or\ system\ message)=Instruction for AI (also known as prompt or system message)
-The\ instruction\ has\ to\ be\ provided=The instruction has to be provided
An\ I/O\ error\ occurred\ while\ opening\ the\ embedding\ model\ by\ URL\ %0=An I/O error occurred while opening the embedding model by URL %0
Got\ error\ while\ processing\ the\ file\:=Got error while processing the file:
The\ model\ by\ URL\ %0\ is\ malformed=The model by URL %0 is malformed
@@ -2648,6 +2647,12 @@ Generate\ summaries\ for\ entries\ in\ the\ group=Generate summaries for entries
Generating\ summaries\ for\ %0=Generating summaries for %0
Ingestion\ started\ for\ group\ "%0".=Ingestion started for group "%0".
Summarization\ started\ for\ group\ "%0".=Summarization started for group "%0".
+Reset\ templates\ to\ default=Reset templates to default
+Templates=Templates
+System\ message\ for\ chatting=System message for chatting
+User\ message\ for\ chatting=User message for chatting
+Completion\ text\ for\ summarization\ of\ a\ chunk=Completion text for summarization of a chunk
+Completion\ text\ for\ summarization\ of\ several\ chunks=Completion text for summarization of several chunks
Link=Link
Source\ URL=Source URL