Skip to content

Commit

Permalink
[FEATURE] ToolProvider | Select tools dynamically on incoming message (
Browse files Browse the repository at this point in the history
…#989)

* ✨ (ToolProvider): Make ToolProvider Quarkus ready

* ♻️ (DevUi): Feedback from jmartisk (see below)
- Support StreamingChat
- Inject Supplier<ToolProvider>
- Ignore tools when toolProvider exists

* ✏️ (DevUi): Typo in setToolsViaProviderIfAvailable

* 📝 (ToolProvider): Update agent-and-tools.adoc

* Update docs/modules/ROOT/pages/agent-and-tools.adoc

Co-authored-by: Jan Martiska <jmartisk@redhat.com>

---------

Co-authored-by: Jan Martiska <jmartisk@redhat.com>
  • Loading branch information
MiggiV2 and jmartisk authored Oct 22, 2024
1 parent 4851856 commit db01040
Show file tree
Hide file tree
Showing 14 changed files with 455 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@
import dev.langchain4j.service.Moderate;
import dev.langchain4j.service.output.ServiceOutputParser;
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.ToolBox;
import io.quarkiverse.langchain4j.deployment.config.LangChain4jBuildConfig;
import io.quarkiverse.langchain4j.deployment.devui.ToolProviderInfo;
import io.quarkiverse.langchain4j.deployment.items.MethodParameterAllowedAnnotationsBuildItem;
import io.quarkiverse.langchain4j.deployment.items.MethodParameterIgnoredAnnotationsBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
Expand Down Expand Up @@ -208,13 +210,15 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
BuildProducer<RequestModerationModelBeanBuildItem> requestModerationModelBeanProducer,
BuildProducer<RequestImageModelBeanBuildItem> requestImageModelBeanProducer,
BuildProducer<DeclarativeAiServiceBuildItem> declarativeAiServiceProducer,
BuildProducer<ToolProviderMetaBuildItem> toolProviderProducer,
BuildProducer<ReflectiveClassBuildItem> reflectiveClassProducer,
BuildProducer<GeneratedClassBuildItem> generatedClassProducer) {
IndexView index = indexBuildItem.getIndex();

Set<String> chatModelNames = new HashSet<>();
Set<String> moderationModelNames = new HashSet<>();
Set<String> imageModelNames = new HashSet<>();
List<ToolProviderInfo> toolProviderInfos = new ArrayList<>();
ClassOutput generatedClassOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true);
for (AnnotationInstance instance : index.getAnnotations(LangChain4jDotNames.REGISTER_AI_SERVICES)) {
if (instance.target().kind() != AnnotationTarget.Kind.CLASS) {
Expand Down Expand Up @@ -323,6 +327,15 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
validateSupplierAndRegisterForReflection(moderationModelSupplierClassName, index, reflectiveClassProducer);
}

DotName toolProviderClassName = LangChain4jDotNames.BEAN_IF_EXISTS_TOOL_PROVIDER_SUPPLIER;
AnnotationValue toolProviderValue = instance.value("toolProviderSupplier");
if (toolProviderValue != null) {
toolProviderClassName = toolProviderValue.asClass().name();
validateSupplierAndRegisterForReflection(toolProviderClassName, index, reflectiveClassProducer);
toolProviderInfos.add(new ToolProviderInfo(toolProviderClassName.toString(),
declarativeAiServiceClassInfo.simpleName()));
}

DotName imageModelSupplierClassName = LangChain4jDotNames.BEAN_IF_EXISTS_IMAGE_MODEL_SUPPLIER;
AnnotationValue imageModelSupplierValue = instance.value("imageModelSupplier");
if (imageModelSupplierValue != null) {
Expand Down Expand Up @@ -381,8 +394,10 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
cdiScope,
chatModelName,
moderationModelName,
imageModelName));
imageModelName,
toolProviderClassName));
}
toolProviderProducer.produce(new ToolProviderMetaBuildItem(toolProviderInfos));

for (String chatModelName : chatModelNames) {
requestChatModelBeanProducer.produce(new RequestChatModelBeanBuildItem(chatModelName));
Expand Down Expand Up @@ -462,6 +477,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
boolean needsModerationModelBean = false;
boolean needsImageModelBean = false;
Set<DotName> allToolNames = new HashSet<>();
Set<DotName> allToolProviders = new HashSet<>();

for (DeclarativeAiServiceBuildItem bi : declarativeAiServiceItems) {
ClassInfo declarativeAiServiceClassInfo = bi.getServiceClassInfo();
Expand All @@ -477,6 +493,10 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,

List<String> toolClassNames = bi.getToolDotNames().stream().map(DotName::toString).collect(Collectors.toList());

String toolProviderSupplierClassName = (bi.getToolProviderClassDotName() != null
? bi.getToolProviderClassDotName().toString()
: null);

String chatMemoryProviderSupplierClassName = bi.getChatMemoryProviderSupplierClassDotName() != null
? bi.getChatMemoryProviderSupplierClassDotName().toString()
: null;
Expand Down Expand Up @@ -556,7 +576,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
serviceClassName,
chatLanguageModelSupplierClassName,
streamingChatLanguageModelSupplierClassName,
toolClassNames, chatMemoryProviderSupplierClassName, retrieverClassName,
toolClassNames,
toolProviderSupplierClassName,
chatMemoryProviderSupplierClassName, retrieverClassName,
retrievalAugmentorSupplierClassName,
auditServiceClassSupplierName,
moderationModelSupplierClassName,
Expand Down Expand Up @@ -668,6 +690,13 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
needsImageModelBean = true;
}

if (!RegisterAiService.BeanIfExistsToolProviderSupplier.class.getName()
.equals(toolProviderSupplierClassName) && toolProviderSupplierClassName != null) {
DotName toolProvider = DotName.createSimple(toolProviderSupplierClassName);
configurator.addInjectionPoint(ClassType.create(toolProvider));
allToolProviders.add(toolProvider);
}

configurator
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
new Type[] { ClassType.create(OutputGuardrail.class) }, null))
Expand Down Expand Up @@ -700,6 +729,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
if (needsImageModelBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.IMAGE_MODEL));
}
if (!allToolProviders.isEmpty()) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(allToolProviders));
}
if (!allToolNames.isEmpty()) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(allToolNames));
}
Expand Down Expand Up @@ -795,7 +827,7 @@ public void handleAiServices(
for (ClassInfo classInfo : index.getKnownUsers(LangChain4jDotNames.AI_SERVICES)) {
String className = classInfo.name().toString();
if (className.startsWith("io.quarkiverse.langchain4j") || className.startsWith("dev.langchain4j")) { // TODO: this can be made smarter if
// needed
// needed
continue;
}
try (InputStream is = Thread.currentThread().getContextClassLoader().getResourceAsStream(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
private final DotName chatLanguageModelSupplierClassDotName;
private final DotName streamingChatLanguageModelSupplierClassDotName;
private final List<DotName> toolDotNames;
private final DotName toolProviderClassDotName;

private final DotName chatMemoryProviderSupplierClassDotName;
private final DotName retrieverClassDotName;
Expand Down Expand Up @@ -46,7 +47,8 @@ public DeclarativeAiServiceBuildItem(
DotName cdiScope,
String chatModelName,
String moderationModelName,
String imageModelName) {
String imageModelName,
DotName toolProviderClassDotName) {
this.serviceClassInfo = serviceClassInfo;
this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName;
this.streamingChatLanguageModelSupplierClassDotName = streamingChatLanguageModelSupplierClassDotName;
Expand All @@ -63,6 +65,7 @@ public DeclarativeAiServiceBuildItem(
this.chatModelName = chatModelName;
this.moderationModelName = moderationModelName;
this.imageModelName = imageModelName;
this.toolProviderClassDotName = toolProviderClassDotName;
}

public ClassInfo getServiceClassInfo() {
Expand Down Expand Up @@ -128,4 +131,8 @@ public String getModerationModelName() {
public String getImageModelName() {
return imageModelName;
}

public DotName getToolProviderClassDotName() {
return toolProviderClassDotName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ public class LangChain4jDotNames {
static final DotName BEAN_IF_EXISTS_IMAGE_MODEL_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsImageModelSupplier.class);

static final DotName BEAN_IF_EXISTS_TOOL_PROVIDER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsToolProviderSupplier.class);

static final DotName QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER = DotName.createSimple(
QuarkusAiServiceContextQualifier.class);

Expand All @@ -108,4 +111,5 @@ public class LangChain4jDotNames {
static final DotName WEB_SEARCH_ENGINE = DotName.createSimple(WebSearchEngine.class);
static final DotName IMAGE = DotName.createSimple(Image.class);
static final DotName RESULT = DotName.createSimple(Result.class);
static final DotName TOOL_PROVIDER = DotName.createSimple(ToolProcessor.class);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package io.quarkiverse.langchain4j.deployment;

import java.util.List;

import io.quarkiverse.langchain4j.deployment.devui.ToolProviderInfo;
import io.quarkus.builder.item.SimpleBuildItem;

/**
* Holds metadata about toolProviders discovered at build time
*/
public final class ToolProviderMetaBuildItem extends SimpleBuildItem {
List<ToolProviderInfo> metadata;

public ToolProviderMetaBuildItem(List<ToolProviderInfo> metaData) {
this.metadata = metaData;
}

public List<ToolProviderInfo> getMetadata() {
return metadata;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import io.quarkiverse.langchain4j.deployment.DeclarativeAiServiceBuildItem;
import io.quarkiverse.langchain4j.deployment.EmbeddingStoreBuildItem;
import io.quarkiverse.langchain4j.deployment.LangChain4jDotNames;
import io.quarkiverse.langchain4j.deployment.ToolProviderMetaBuildItem;
import io.quarkiverse.langchain4j.deployment.ToolsMetadataBuildItem;
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem;
Expand All @@ -28,6 +29,7 @@ public class LangChain4jDevUIProcessor {

@BuildStep(onlyIf = IsDevelopment.class)
CardPageBuildItem cardPage(List<DeclarativeAiServiceBuildItem> aiServices,
ToolProviderMetaBuildItem toolProviderMetaBuildItem,
ToolsMetadataBuildItem toolsMetadataBuildItem,
List<EmbeddingModelProviderCandidateBuildItem> embeddingModelCandidateBuildItems,
List<InProcessEmbeddingBuildItem> inProcessEmbeddingModelBuildItems,
Expand Down Expand Up @@ -60,6 +62,10 @@ CardPageBuildItem cardPage(List<DeclarativeAiServiceBuildItem> aiServices,

additionalDevUiCardBuildItem.getBuildTimeData().forEach((k, v) -> card.addBuildTimeData(k, v));
}

List<ToolProviderInfo> toolProviderInfos = toolProviderMetaBuildItem.getMetadata();
card.addBuildTimeData("toolProviders", toolProviderInfos);

return card;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package io.quarkiverse.langchain4j.deployment.devui;

public class ToolProviderInfo {
private String className;
private String aiServiceName;

public ToolProviderInfo(String className, String aiServiceName) {
this.className = className;
this.aiServiceName = aiServiceName;
}

public String getClassName() {
return className;
}

public void setClassName(String className) {
this.className = className;
}

public String getAiServiceName() {
return aiServiceName;
}

public void setAiServiceName(String aiServiceName) {
this.aiServiceName = aiServiceName;
}
}
54 changes: 37 additions & 17 deletions core/deployment/src/main/resources/dev-ui/qwc-tools.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { LitElement, html, css} from 'lit';
import {css, html, LitElement} from 'lit';
import '@vaadin/grid';
import '@vaadin/grid/vaadin-grid-sort-column.js';

import {tools} from 'build-time-data';
import {toolProviders, tools} from 'build-time-data';


export class QwcTools extends LitElement {
Expand All @@ -12,6 +12,7 @@ export class QwcTools extends LitElement {
height: 100%;
display: flex;
}
vaadin-grid {
margin-left: 15px;
margin-right: 15px;
Expand All @@ -21,38 +22,57 @@ export class QwcTools extends LitElement {

static properties = {
"_tools": {state: true},
"_toolProviders": {state: true},
}

constructor() {
super();
this._tools = tools;
this._toolProviders = toolProviders;
}

render() {
if (this._tools) {
if (this._toolProviders.length > 0) {
return this._renderToolProvider();
} else if (this._tools) {
return this._renderToolTable();
} else {
return html`<span>No tools found</span>`;
}
}

_renderToolProvider() {
return html`
<vaadin-grid .items="${this._toolProviders}" theme="no-border">
<vaadin-grid-sort-column auto-width
path="className"
header="Class name">
</vaadin-grid-sort-column>
<vaadin-grid-column auto-width
path="aiServiceName"
header="AiService">
</vaadin-grid-column>
</vaadin-grid>`;
}

_renderToolTable() {
return html`
<vaadin-grid .items="${this._tools}" theme="no-border">
<vaadin-grid-sort-column auto-width
path="className"
header="Class name">
</vaadin-grid-sort-column>
<vaadin-grid-column auto-width
path="name"
header="Tool name">
</vaadin-grid-column>
<vaadin-grid-column auto-width
path="description"
header="Description">
</vaadin-grid-column>
</vaadin-grid>`;
<vaadin-grid .items="${this._tools}" theme="no-border">
<vaadin-grid-sort-column auto-width
path="className"
header="Class name">
</vaadin-grid-sort-column>
<vaadin-grid-column auto-width
path="name"
header="Tool name">
</vaadin-grid-column>
<vaadin-grid-column auto-width
path="description"
header="Description">
</vaadin-grid-column>
</vaadin-grid>`;
}

}

customElements.define('qwc-tools', QwcTools);
Loading

0 comments on commit db01040

Please sign in to comment.