Skip to content

Commit

Permalink
Merge pull request #1011 from cescoffier/output-guardrail-streaming
Browse files Browse the repository at this point in the history
Implement support for output guardrail on streamed responses
  • Loading branch information
geoand authored Oct 25, 2024
2 parents 7279fe5 + 1747654 commit 6369d00
Show file tree
Hide file tree
Showing 14 changed files with 1,689 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import io.quarkiverse.langchain4j.deployment.items.MethodParameterIgnoredAnnotationsBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
import io.quarkiverse.langchain4j.runtime.QuarkusServiceOutputParser;
Expand Down Expand Up @@ -751,11 +752,16 @@ public void markUsedOutputGuardRailsUnremovable(List<AiServicesMethodBuildItem>
for (String cn : list) {
unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(DotName.createSimple(cn)));
}
DotName dotName = DotName.createSimple(OutputGuardrailAccumulator.class);
if (method.methodInfo.hasAnnotation(dotName)) {
unremovableProducer.produce(
UnremovableBeanBuildItem.beanTypes(method.methodInfo.annotation(dotName).value().asClass().name()));
}
}
}

@BuildStep
public void detectMissingGuardRails(SynthesisFinishedBuildItem synthesisFinished,
public void validateGuardrails(SynthesisFinishedBuildItem synthesisFinished,
List<AiServicesMethodBuildItem> methods,
BuildProducer<ValidationPhaseBuildItem.ValidationErrorBuildItem> errors) {

Expand All @@ -768,6 +774,33 @@ public void detectMissingGuardRails(SynthesisFinishedBuildItem synthesisFinished
new DeploymentException("Missing guardrail bean: " + cn)));
}
}

DotName dotName = DotName.createSimple(OutputGuardrailAccumulator.class);
if (method.methodInfo.hasAnnotation(dotName)) {
// We have an accumulator
// Check that the accumulator exists
var bean = method.methodInfo.annotation(dotName).value().asClass().name();
if (synthesisFinished.beanStream().withBeanType(bean).isEmpty()) {
errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
new DeploymentException("Missing accumulator bean: " + bean.toString())));
}

// Check that the accumulator is used on a method retuning a Multi
DotName returnedType = method.methodInfo.returnType().name();
if (!DotName.createSimple(Multi.class).equals(returnedType)) {
errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
new DeploymentException("OutputGuardrailAccumulator can only be used on method returning a " +
"`Multi<X>`: found `%s` for method `%s.%s`".formatted(returnedType,
method.methodInfo.declaringClass().toString(), method.methodInfo.name()))));
}

// Check that the method have output guardrails
if (method.outputGuardrails.isEmpty()) {
errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
new DeploymentException("OutputGuardrailAccumulator used without OutputGuardrails in method `%s.%s`"
.formatted(method.methodInfo.declaringClass().toString(), method.methodInfo.name()))));
}
}
}
}

Expand Down Expand Up @@ -1165,11 +1198,13 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
List<String> outputGuardrails = AiServicesMethodBuildItem.gatherGuardrails(method, OUTPUT_GUARDRAILS);
List<String> inputGuardrails = AiServicesMethodBuildItem.gatherGuardrails(method, INPUT_GUARDRAILS);

String accumulatorClassName = AiServicesMethodBuildItem.gatherAccumulator(method);

return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
userMessageInfo, memoryIdParamPosition, requiresModeration,
returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)),
metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames, inputGuardrails,
outputGuardrails);
outputGuardrails, accumulatorClassName);
}

private void validateReturnType(MethodInfo method) {
Expand Down Expand Up @@ -1690,5 +1725,18 @@ public static List<String> gatherGuardrails(MethodInfo methodInfo, DotName annot
}
return guardrails;
}

public static String gatherAccumulator(MethodInfo methodInfo) {
DotName annotation = DotName.createSimple(OutputGuardrailAccumulator.class);
AnnotationInstance instance = methodInfo.annotation(annotation);
if (instance == null) {
// Check on class
instance = methodInfo.declaringClass().declaredAnnotation(annotation);
}
if (instance != null) {
return instance.value().asClass().name().toString();
}
return null;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package io.quarkiverse.langchain4j.test.guardrails;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Fail.fail;

import java.util.List;
import java.util.function.Supplier;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.context.control.ActivateRequestContext;
import jakarta.enterprise.inject.spi.DeploymentException;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator;
import io.quarkus.test.QuarkusUnitTest;
import io.smallrye.mutiny.Multi;

public class InvalidOutputGuardrailAccumulatorTest {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(MyAiService.class,
MyMemoryProviderSupplier.class))
.assertException(t -> {
assertThat(t).isInstanceOf(DeploymentException.class);
assertThat(t).hasMessageContaining(
"io.quarkiverse.langchain4j.test.guardrails.InvalidOutputGuardrailAccumulatorTest$MyAiService.hi");
});

@Test
@ActivateRequestContext
void testThatInvalidAccumulatorAreReported() {
fail("Should not be called");
}

@RegisterAiService(streamingChatLanguageModelSupplier = MyStreamingChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
public interface MyAiService {

@UserMessage("Say Hi!")
@OutputGuardrails(MyGuardRail.class)
@OutputGuardrailAccumulator(MyAccumulator.class)
String hi(@MemoryId String mem);

}

@ApplicationScoped
public static class MyAccumulator implements OutputTokenAccumulator {

@Override
public Multi<String> accumulate(Multi<String> tokens) {
return tokens;
}
}

@ApplicationScoped
public static class MyGuardRail implements OutputGuardrail {

@Override
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
throw new RuntimeException("Should not be invoked");
}

}

public static class MyMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {
@Override
public ChatMemoryProvider get() {
return new ChatMemoryProvider() {
@Override
public ChatMemory get(Object memoryId) {
return new MessageWindowChatMemory.Builder().maxMessages(5).build();
}
};
}
}

public static class MyStreamingChatModelSupplier implements Supplier<StreamingChatLanguageModel> {

@Override
public StreamingChatLanguageModel get() {
return new StreamingChatLanguageModel() {
@Override
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
handler.onNext("Stream");
handler.onNext("ing");
handler.onNext(" ");
handler.onNext("world");
handler.onNext("!");
handler.onComplete(Response.from(AiMessage.from("")));
}
};
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package io.quarkiverse.langchain4j.test.guardrails;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Fail.fail;

import java.util.List;
import java.util.function.Supplier;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.context.control.ActivateRequestContext;
import jakarta.enterprise.inject.spi.DeploymentException;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator;
import io.quarkus.test.QuarkusUnitTest;
import io.smallrye.mutiny.Multi;

public class OutputGuardrailAccumulatorNotFoundTest {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(MyAiService.class,
MyMemoryProviderSupplier.class))
.assertException(t -> {
assertThat(t).isInstanceOf(DeploymentException.class);
assertThat(t).hasMessageContaining(
"io.quarkiverse.langchain4j.test.guardrails.OutputGuardrailAccumulatorNotFoundTest$MissingAccumulator");
});

@Test
@ActivateRequestContext
void testThatNotFoundAccumulatorAreReported() {
fail("Should not be called");
}

@RegisterAiService(streamingChatLanguageModelSupplier = MyStreamingChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
public interface MyAiService {

@UserMessage("Say Hi!")
@OutputGuardrails(MyGuardRail.class)
@OutputGuardrailAccumulator(MissingAccumulator.class)
Multi<String> hi(@MemoryId String mem);

}

// Not a bean
public static class MissingAccumulator implements OutputTokenAccumulator {

@Override
public Multi<String> accumulate(Multi<String> tokens) {
return tokens;
}
}

@ApplicationScoped
public static class MyGuardRail implements OutputGuardrail {

@Override
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
throw new RuntimeException("Should not be invoked");
}

}

public static class MyMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {
@Override
public ChatMemoryProvider get() {
return new ChatMemoryProvider() {
@Override
public ChatMemory get(Object memoryId) {
return new MessageWindowChatMemory.Builder().maxMessages(5).build();
}
};
}
}

public static class MyStreamingChatModelSupplier implements Supplier<StreamingChatLanguageModel> {

@Override
public StreamingChatLanguageModel get() {
return new StreamingChatLanguageModel() {
@Override
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
handler.onNext("Stream");
handler.onNext("ing");
handler.onNext(" ");
handler.onNext("world");
handler.onNext("!");
handler.onComplete(Response.from(AiMessage.from("")));
}
};
}
}

}
Loading

0 comments on commit 6369d00

Please sign in to comment.