diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/InProcessEmbeddingProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/InProcessEmbeddingProcessor.java index 4fa5d9cbc..07e6e1e12 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/InProcessEmbeddingProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/InProcessEmbeddingProcessor.java @@ -1,14 +1,19 @@ package io.quarkiverse.langchain4j.deployment; import java.util.List; +import java.util.Optional; import jakarta.enterprise.context.ApplicationScoped; +import org.jboss.jandex.AnnotationInstance; import org.jboss.jandex.DotName; import dev.langchain4j.model.embedding.EmbeddingModel; +import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.deployment.items.InProcessEmbeddingBuildItem; +import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem; import io.quarkiverse.langchain4j.runtime.InProcessEmbeddingRecorder; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.bootstrap.classloading.QuarkusClassLoader; import io.quarkus.deployment.annotations.BuildProducer; @@ -120,20 +125,30 @@ InProcessEmbeddingBuildItem e5_small_v2() { @Record(ExecutionTime.RUNTIME_INIT) void exposeInProcessEmbeddingBeans(InProcessEmbeddingRecorder recorder, List embeddings, + List selectedEmbedding, BuildProducer beanProducer) { for (InProcessEmbeddingBuildItem embedding : embeddings) { - beanProducer.produce(SyntheticBeanBuildItem + Optional modelName = selectedEmbedding.stream() + .filter(se -> se.getProvider().equals(embedding.getProvider())) + .map(SelectedEmbeddingModelCandidateBuildItem::getModelName) + .findFirst(); + var builder = SyntheticBeanBuildItem .configure(DotName.createSimple(embedding.className())) .types(EmbeddingModel.class) .defaultBean() .setRuntimeInit() .unremovable() .scope(ApplicationScoped.class) - .supplier(recorder.instantiate(embedding.className())) - .done()); + .supplier(recorder.instantiate(embedding.className())); + modelName.ifPresent(m -> addQualifierIfNecessary(builder, m)); + beanProducer.produce(builder.done()); } - } + private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String modelName) { + if (!NamedModelUtil.isDefault(modelName)) { + builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", modelName).build()); + } + } } diff --git a/integration-tests/multiple-providers/pom.xml b/integration-tests/multiple-providers/pom.xml index 7a2e5bbcf..1f83ff94d 100644 --- a/integration-tests/multiple-providers/pom.xml +++ b/integration-tests/multiple-providers/pom.xml @@ -41,6 +41,11 @@ quarkus-langchain4j-ollama ${project.version} + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2 + ${langchain4j-embeddings.version} + io.quarkiverse.langchain4j quarkus-langchain4j-openshift-ai diff --git a/integration-tests/multiple-providers/src/main/resources/application.properties b/integration-tests/multiple-providers/src/main/resources/application.properties index ca647d657..02a7a6806 100644 --- a/integration-tests/multiple-providers/src/main/resources/application.properties +++ b/integration-tests/multiple-providers/src/main/resources/application.properties @@ -40,4 +40,4 @@ quarkus.langchain4j.watsonx.c7.project-id=proj quarkus.langchain4j.e1.embedding-model.provider=openai quarkus.langchain4j.openai.e1.api-key=test5 quarkus.langchain4j.e2.embedding-model.provider=ollama - +quarkus.langchain4j.e3.embedding-model.provider=dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel diff --git a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleEmbeddingModelsTest.java b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleEmbeddingModelsTest.java index a189589b9..6cffbc730 100644 --- a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleEmbeddingModelsTest.java +++ b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleEmbeddingModelsTest.java @@ -6,6 +6,7 @@ import org.junit.jupiter.api.Test; +import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.openai.OpenAiEmbeddingModel; import io.quarkiverse.langchain4j.ModelName; @@ -25,6 +26,10 @@ public class MultipleEmbeddingModelsTest { @ModelName("e2") EmbeddingModel secondNamedModel; + @Inject + @ModelName("e3") + EmbeddingModel fifthNamedModel; + @Inject @ModelName("c1") EmbeddingModel thirdNamedModel; @@ -52,4 +57,9 @@ void thirdNamedModel() { void fourthNamedModel() { assertThat(ClientProxy.unwrap(fourthNamedModel)).isInstanceOf(AzureOpenAiEmbeddingModel.class); } + + @Test + void fifthNamedModel() { + assertThat(ClientProxy.unwrap(fifthNamedModel)).isInstanceOf(AllMiniLmL6V2EmbeddingModel.class); + } }