Skip to content

Commit

Permalink
Merge pull request #83 from quarkiverse/audit
Browse files Browse the repository at this point in the history
Sketch out API for auditing
  • Loading branch information
geoand authored Dec 6, 2023
2 parents 374abca + 5471aa9 commit 574e0e1
Show file tree
Hide file tree
Showing 17 changed files with 633 additions and 37 deletions.
13 changes: 0 additions & 13 deletions .github/workflows/build-pull-request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ jobs:
build:
name: Build on ${{ matrix.os }} - ${{ matrix.java }}
strategy:
# PineconeEmbeddingStoreTest uses a single shared index, we can't run multiple CI runs on it at once
# If we have PINECONE_API_KEY available, then the test will run, so set max-parallel to 1
max-parallel: ${{ github.secret_source == 'Actions' && 1 || 16 }}
fail-fast: false
matrix:
os: [ubuntu-latest]
Expand All @@ -46,16 +43,6 @@ jobs:

- name: Build with Maven
run: mvn -B clean install -Dno-format
env: # note that secrets are not available when triggered by PR from a fork, so some tests will be skipped
PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }}
PINECONE_ENVIRONMENT: ${{ secrets.PINECONE_ENVIRONMENT }}
PINECONE_INDEX_NAME: ${{ secrets.PINECONE_INDEX_NAME }}
PINECONE_PROJECT_ID: ${{ secrets.PINECONE_PROJECT_ID }}

- name: Build with Maven (Native)
run: mvn -B install -Dnative -Dquarkus.native.container-build -Dnative.surefire.skip
env: # note that secrets are not available when triggered by PR from a fork, so some tests will be skipped
PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }}
PINECONE_ENVIRONMENT: ${{ secrets.PINECONE_ENVIRONMENT }}
PINECONE_INDEX_NAME: ${{ secrets.PINECONE_INDEX_NAME }}
PINECONE_PROJECT_ID: ${{ secrets.PINECONE_PROJECT_ID }}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import org.objectweb.asm.tree.analysis.AnalyzerException;

import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.V;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
Expand All @@ -53,6 +52,7 @@
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport;
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.MetricsWrapper;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
import io.quarkiverse.langchain4j.runtime.aiservice.SpanWrapper;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
Expand Down Expand Up @@ -101,6 +101,7 @@ public class AiServicesProcessor {
private static final MethodDescriptor SUPPORT_IMPLEMENT = MethodDescriptor.ofMethod(
AiServiceMethodImplementationSupport.class,
"implement", Object.class, AiServiceMethodImplementationSupport.Input.class);
public static final DotName CDI_INSTANCE = DotName.createSimple(Instance.class);

@BuildStep
public void nativeSupport(CombinedIndexBuildItem indexBuildItem,
Expand Down Expand Up @@ -203,13 +204,21 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
}

DotName auditServiceClassSupplierName = Langchain4jDotNames.BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER;
AnnotationValue auditServiceClassSupplierValue = instance.value("auditServiceSupplier");
if (auditServiceClassSupplierValue != null) {
auditServiceClassSupplierName = auditServiceClassSupplierValue.asClass().name();
validateSupplierAndRegisterForReflection(auditServiceClassSupplierName, index, reflectiveClassProducer);
}

declarativeAiServiceProducer.produce(
new DeclarativeAiServiceBuildItem(
declarativeAiServiceClassInfo,
chatLanguageModelSupplierClassDotName,
toolDotNames,
chatMemoryProviderSupplierClassDotName,
retrieverSupplierClassDotName));
retrieverSupplierClassDotName,
auditServiceClassSupplierName));
}

if (needChatModelBean) {
Expand Down Expand Up @@ -244,6 +253,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
boolean needsChatModelBean = false;
boolean needsChatMemoryProviderBean = false;
boolean needsRetrieverBean = false;
boolean needsAuditServiceBean = false;
Set<DotName> allToolNames = new HashSet<>();

for (DeclarativeAiServiceBuildItem bi : declarativeAiServiceItems) {
Expand All @@ -264,12 +274,17 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
? bi.getRetrieverSupplierClassDotName().toString()
: null;

String auditServiceClassSupplierName = bi.getAuditServiceClassSupplierDotName() != null
? bi.getAuditServiceClassSupplierDotName().toString()
: null;

SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
.configure(declarativeAiServiceClassInfo.name())
.createWith(recorder.createDeclarativeAiService(
new DeclarativeAiServiceCreateInfo(serviceClassName, chatLanguageModelSupplierClassName,
toolClassNames, chatMemoryProviderSupplierClassName,
retrieverSupplierClassName)))
retrieverSupplierClassName,
auditServiceClassSupplierName)))
.setRuntimeInit()
.scope(ApplicationScoped.class);
if ((chatLanguageModelSupplierClassName == null) && selectedChatModelProvider.isPresent()) { // TODO: is second condition needed?
Expand All @@ -290,7 +305,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
needsChatMemoryProviderBean = true;
} else if (Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER.toString()
.equals(chatMemoryProviderSupplierClassName)) {
configurator.addInjectionPoint(ParameterizedType.create(DotName.createSimple(Instance.class),
configurator.addInjectionPoint(ParameterizedType.create(CDI_INSTANCE,
new Type[] { ClassType.create(Langchain4jDotNames.CHAT_MEMORY_PROVIDER) }, null));
needsChatMemoryProviderBean = true;
}
Expand All @@ -301,13 +316,19 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
needsRetrieverBean = true;
} else if (Langchain4jDotNames.BEAN_IF_EXISTS_RETRIEVER_SUPPLIER.toString()
.equals(retrieverSupplierClassName)) {
configurator.addInjectionPoint(ParameterizedType.create(DotName.createSimple(Instance.class),
configurator.addInjectionPoint(ParameterizedType.create(CDI_INSTANCE,
new Type[] { ParameterizedType.create(Langchain4jDotNames.RETRIEVER,
new Type[] { ClassType.create(Langchain4jDotNames.TEXT_SEGMENT) }, null) },
null));
needsRetrieverBean = true;
}

if (Langchain4jDotNames.BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER.toString().equals(auditServiceClassSupplierName)) {
configurator.addInjectionPoint(ParameterizedType.create(CDI_INSTANCE,
new Type[] { ClassType.create(Langchain4jDotNames.AUDIT_SERVICE) }, null));
needsAuditServiceBean = true;
}

syntheticBeanProducer.produce(configurator.done());
}

Expand All @@ -320,6 +341,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
if (needsRetrieverBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(Langchain4jDotNames.RETRIEVER));
}
if (needsAuditServiceBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(Langchain4jDotNames.AUDIT_SERVICE));
}
if (!allToolNames.isEmpty()) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(allToolNames));
}
Expand Down Expand Up @@ -436,19 +460,19 @@ public void handleAiServices(AiServicesRecorder recorder,
.interfaces(iface.name().toString())
.build()) {

FieldDescriptor contextField = classCreator.getFieldCreator("context", AiServiceContext.class)
FieldDescriptor contextField = classCreator.getFieldCreator("context", QuarkusAiServiceContext.class)
.setModifiers(Modifier.PRIVATE | Modifier.FINAL)
.getFieldDescriptor();

for (MethodInfo methodInfo : methodsToImplement) {
// The implementation essentially gets method the context and delegates to
// The implementation essentially gets the context and delegates to
// MethodImplementationSupport#implement

String methodId = createMethodId(methodInfo);
perMethodMetadata.put(methodId,
gatherMethodMetadata(methodInfo, addMicrometerMetrics, addOpenTelemetrySpan));
MethodCreator constructor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V",
AiServiceContext.class);
QuarkusAiServiceContext.class);
constructor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, constructor.getThis());
constructor.writeInstanceField(contextField, constructor.getThis(), constructor.getMethodParam(0));
constructor.returnValue(null);
Expand All @@ -466,7 +490,7 @@ public void handleAiServices(AiServicesRecorder recorder,
ResultHandle supportHandle = getFromCDI(mc, AiServiceMethodImplementationSupport.class.getName());
ResultHandle inputHandle = mc.newInstance(
MethodDescriptor.ofConstructor(AiServiceMethodImplementationSupport.Input.class,
AiServiceContext.class, AiServiceMethodCreateInfo.class, Object[].class),
QuarkusAiServiceContext.class, AiServiceMethodCreateInfo.class, Object[].class),
contextHandle, methodCreateInfoHandle, paramsHandle);

ResultHandle resultHandle = mc.invokeVirtualMethod(SUPPORT_IMPLEMENT, supportHandle, inputHandle);
Expand Down Expand Up @@ -547,7 +571,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method, boolea
Optional<AiServiceMethodCreateInfo.MetricsInfo> metricsInfo = gatherMetricsInfo(method, addMicrometerMetrics);
Optional<AiServiceMethodCreateInfo.SpanInfo> spanInfo = gatherSpanInfo(method, addOpenTelemetrySpans);

return new AiServiceMethodCreateInfo(systemMessageInfo, userMessageInfo, memoryIdParamPosition, requiresModeration,
return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
userMessageInfo, memoryIdParamPosition, requiresModeration,
returnType, metricsInfo, spanInfo);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {

private final DotName chatMemoryProviderSupplierClassDotName;
private final DotName retrieverSupplierClassDotName;
private final DotName auditServiceClassSupplierDotName;

public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languageModelSupplierClassDotName,
List<DotName> toolDotNames,
DotName chatMemoryProviderSupplierClassDotName,
DotName retrieverSupplierClassDotName) {
DotName retrieverSupplierClassDotName,
DotName auditServiceClassSupplierDotName) {
this.serviceClassInfo = serviceClassInfo;
this.languageModelSupplierClassDotName = languageModelSupplierClassDotName;
this.toolDotNames = toolDotNames;
this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName;
this.retrieverSupplierClassDotName = retrieverSupplierClassDotName;
this.auditServiceClassSupplierDotName = auditServiceClassSupplierDotName;
}

public ClassInfo getServiceClassInfo() {
Expand All @@ -49,4 +52,8 @@ public DotName getChatMemoryProviderSupplierClassDotName() {
public DotName getRetrieverSupplierClassDotName() {
return retrieverSupplierClassDotName;
}

public DotName getAuditServiceClassSupplierDotName() {
return auditServiceClassSupplierDotName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import dev.langchain4j.service.UserName;
import io.quarkiverse.langchain4j.CreatedAware;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.audit.AuditService;

public class Langchain4jDotNames {
public static final DotName CHAT_MODEL = DotName.createSimple(ChatLanguageModel.class);
Expand Down Expand Up @@ -53,10 +54,15 @@ public class Langchain4jDotNames {
static final DotName RETRIEVER = DotName.createSimple(Retriever.class);
static final DotName TEXT_SEGMENT = DotName.createSimple(TextSegment.class);

static final DotName AUDIT_SERVICE = DotName.createSimple(AuditService.class);

static final DotName BEAN_RETRIEVER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanRetrieverSupplier.class);

static final DotName BEAN_IF_EXISTS_RETRIEVER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsRetrieverSupplier.class);

static final DotName BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsAuditServiceSupplier.class);

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.spi.services.AiServicesFactory;
import io.quarkiverse.langchain4j.audit.AuditService;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutor;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutorFactory;
import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo;
Expand All @@ -27,7 +29,12 @@ public class QuarkusAiServicesFactory implements AiServicesFactory {

@Override
public <T> QuarkusAiServices<T> create(AiServiceContext context) {
return new QuarkusAiServices<>(context);
if (context instanceof QuarkusAiServiceContext) {
return new QuarkusAiServices<>(context);
} else {
// the context is always empty (except for the aiServiceClass) anyway and never escapes, so we can just use our own type
return new QuarkusAiServices<>(new QuarkusAiServiceContext(context.aiServiceClass));
}
}

public static class InstanceHolder {
Expand Down Expand Up @@ -70,6 +77,11 @@ public AiServices<T> tools(List<Object> objectsWithTools) {
return this;
}

public AiServices<T> auditService(AuditService auditService) {
((QuarkusAiServiceContext) context).auditService = auditService;
return this;
}

List<ToolMethodCreateInfo> lookup(Object bean, String className) {
Map<String, List<ToolMethodCreateInfo>> metadata = ToolsRecorder.getMetadata();
// Fast path first.
Expand Down Expand Up @@ -116,8 +128,8 @@ public T build() {

try {
return (T) Class.forName(classCreateInfo.getImplClassName(), true, Thread.currentThread()
.getContextClassLoader()).getConstructor(AiServiceContext.class)
.newInstance(context);
.getContextClassLoader()).getConstructor(QuarkusAiServiceContext.class)
.newInstance(((QuarkusAiServiceContext) context));
} catch (Exception e) {
throw new IllegalStateException("Unable to create class '" + classCreateInfo.getImplClassName(), e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.retriever.Retriever;
import dev.langchain4j.service.AiServices;
import io.quarkiverse.langchain4j.audit.AuditService;

/**
* Used to create Langchain4j's {@link AiServices} in a declarative manner that the application can then use simply by
Expand Down Expand Up @@ -74,6 +75,15 @@
*/
Class<? extends Supplier<Retriever<TextSegment>>> retrieverSupplier() default NoRetrieverSupplier.class;

/**
* Configures the way to obtain the {@link AuditService} to use.
* By default, Quarkus will look for a CDI bean that implements {@link AuditService}, but will fall back to not using
* any memory if no such bean exists.
* If an arbitrary {@link AuditService} instance is needed, a custom implementation of
* {@link Supplier<AuditService>} needs to be provided.
*/
Class<? extends Supplier<AuditService>> auditServiceSupplier() default BeanIfExistsAuditServiceSupplier.class;

/**
* Marker that is used to tell Quarkus to use the {@link ChatLanguageModel} that has been configured as a CDI bean by
* any of the extensions providing such capability (such as {@code quarkus-langchain4j-openai} and
Expand Down Expand Up @@ -155,4 +165,16 @@ public Retriever<TextSegment> get() {
throw new UnsupportedOperationException("should never be called");
}
}

/**
* Marker that is used to tell Quarkus to use the {@link AuditService} that the user has configured as a CDI bean.
* If no such bean exists, then no audit service will be used.
*/
final class BeanIfExistsAuditServiceSupplier implements Supplier<AuditService> {

@Override
public AuditService get() {
throw new UnsupportedOperationException("should never be called");
}
}
}
Loading

0 comments on commit 574e0e1

Please sign in to comment.