-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement support for output guardrail on streamed responses
The user can implement a custom accumulator to decide when to invoke the guardrail chain.
- Loading branch information
1 parent
31600d7
commit 6ec2930
Showing
14 changed files
with
1,690 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
115 changes: 115 additions & 0 deletions
115
...ava/io/quarkiverse/langchain4j/test/guardrails/InvalidOutputGuardrailAccumulatorTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
package io.quarkiverse.langchain4j.test.guardrails; | ||
|
||
import static org.assertj.core.api.Assertions.assertThat; | ||
import static org.assertj.core.api.Fail.fail; | ||
|
||
import java.util.List; | ||
import java.util.function.Supplier; | ||
|
||
import jakarta.enterprise.context.ApplicationScoped; | ||
import jakarta.enterprise.context.control.ActivateRequestContext; | ||
import jakarta.enterprise.inject.spi.DeploymentException; | ||
|
||
import org.jboss.shrinkwrap.api.ShrinkWrap; | ||
import org.jboss.shrinkwrap.api.spec.JavaArchive; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.RegisterExtension; | ||
|
||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.memory.ChatMemory; | ||
import dev.langchain4j.memory.chat.ChatMemoryProvider; | ||
import dev.langchain4j.memory.chat.MessageWindowChatMemory; | ||
import dev.langchain4j.model.StreamingResponseHandler; | ||
import dev.langchain4j.model.chat.StreamingChatLanguageModel; | ||
import dev.langchain4j.model.output.Response; | ||
import dev.langchain4j.service.MemoryId; | ||
import dev.langchain4j.service.UserMessage; | ||
import io.quarkiverse.langchain4j.RegisterAiService; | ||
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; | ||
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; | ||
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; | ||
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; | ||
import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; | ||
import io.quarkus.test.QuarkusUnitTest; | ||
import io.smallrye.mutiny.Multi; | ||
|
||
public class InvalidOutputGuardrailAccumulatorTest { | ||
|
||
@RegisterExtension | ||
static final QuarkusUnitTest unitTest = new QuarkusUnitTest() | ||
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) | ||
.addClasses(MyAiService.class, | ||
MyMemoryProviderSupplier.class)) | ||
.assertException(t -> { | ||
assertThat(t).isInstanceOf(DeploymentException.class); | ||
assertThat(t).hasMessageContaining( | ||
"io.quarkiverse.langchain4j.test.guardrails.InvalidOutputGuardrailAccumulatorTest$MyAiService.hi"); | ||
}); | ||
|
||
@Test | ||
@ActivateRequestContext | ||
void testThatInvalidAccumulatorAreReported() { | ||
fail("Should not be called"); | ||
} | ||
|
||
@RegisterAiService(streamingChatLanguageModelSupplier = MyStreamingChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) | ||
public interface MyAiService { | ||
|
||
@UserMessage("Say Hi!") | ||
@OutputGuardrails(MyGuardRail.class) | ||
@OutputGuardrailAccumulator(MyAccumulator.class) | ||
String hi(@MemoryId String mem); | ||
|
||
} | ||
|
||
@ApplicationScoped | ||
public static class MyAccumulator implements OutputTokenAccumulator { | ||
|
||
@Override | ||
public Multi<String> accumulate(Multi<String> tokens) { | ||
return tokens; | ||
} | ||
} | ||
|
||
@ApplicationScoped | ||
public static class MyGuardRail implements OutputGuardrail { | ||
|
||
@Override | ||
public OutputGuardrailResult validate(AiMessage responseFromLLM) { | ||
throw new RuntimeException("Should not be invoked"); | ||
} | ||
|
||
} | ||
|
||
public static class MyMemoryProviderSupplier implements Supplier<ChatMemoryProvider> { | ||
@Override | ||
public ChatMemoryProvider get() { | ||
return new ChatMemoryProvider() { | ||
@Override | ||
public ChatMemory get(Object memoryId) { | ||
return new MessageWindowChatMemory.Builder().maxMessages(5).build(); | ||
} | ||
}; | ||
} | ||
} | ||
|
||
public static class MyStreamingChatModelSupplier implements Supplier<StreamingChatLanguageModel> { | ||
|
||
@Override | ||
public StreamingChatLanguageModel get() { | ||
return new StreamingChatLanguageModel() { | ||
@Override | ||
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) { | ||
handler.onNext("Stream"); | ||
handler.onNext("ing"); | ||
handler.onNext(" "); | ||
handler.onNext("world"); | ||
handler.onNext("!"); | ||
handler.onComplete(Response.from(AiMessage.from(""))); | ||
} | ||
}; | ||
} | ||
} | ||
|
||
} |
115 changes: 115 additions & 0 deletions
115
...va/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailAccumulatorNotFoundTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
package io.quarkiverse.langchain4j.test.guardrails; | ||
|
||
import static org.assertj.core.api.Assertions.assertThat; | ||
import static org.assertj.core.api.Fail.fail; | ||
|
||
import java.util.List; | ||
import java.util.function.Supplier; | ||
|
||
import jakarta.enterprise.context.ApplicationScoped; | ||
import jakarta.enterprise.context.control.ActivateRequestContext; | ||
import jakarta.enterprise.inject.spi.DeploymentException; | ||
|
||
import org.jboss.shrinkwrap.api.ShrinkWrap; | ||
import org.jboss.shrinkwrap.api.spec.JavaArchive; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.RegisterExtension; | ||
|
||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.memory.ChatMemory; | ||
import dev.langchain4j.memory.chat.ChatMemoryProvider; | ||
import dev.langchain4j.memory.chat.MessageWindowChatMemory; | ||
import dev.langchain4j.model.StreamingResponseHandler; | ||
import dev.langchain4j.model.chat.StreamingChatLanguageModel; | ||
import dev.langchain4j.model.output.Response; | ||
import dev.langchain4j.service.MemoryId; | ||
import dev.langchain4j.service.UserMessage; | ||
import io.quarkiverse.langchain4j.RegisterAiService; | ||
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; | ||
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; | ||
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; | ||
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; | ||
import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator; | ||
import io.quarkus.test.QuarkusUnitTest; | ||
import io.smallrye.mutiny.Multi; | ||
|
||
public class OutputGuardrailAccumulatorNotFoundTest { | ||
|
||
@RegisterExtension | ||
static final QuarkusUnitTest unitTest = new QuarkusUnitTest() | ||
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) | ||
.addClasses(MyAiService.class, | ||
MyMemoryProviderSupplier.class)) | ||
.assertException(t -> { | ||
assertThat(t).isInstanceOf(DeploymentException.class); | ||
assertThat(t).hasMessageContaining( | ||
"io.quarkiverse.langchain4j.test.guardrails.OutputGuardrailAccumulatorNotFoundTest$MissingAccumulator"); | ||
}); | ||
|
||
@Test | ||
@ActivateRequestContext | ||
void testThatNotFoundAccumulatorAreReported() { | ||
fail("Should not be called"); | ||
} | ||
|
||
@RegisterAiService(streamingChatLanguageModelSupplier = MyStreamingChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) | ||
public interface MyAiService { | ||
|
||
@UserMessage("Say Hi!") | ||
@OutputGuardrails(MyGuardRail.class) | ||
@OutputGuardrailAccumulator(MissingAccumulator.class) | ||
Multi<String> hi(@MemoryId String mem); | ||
|
||
} | ||
|
||
// Not a bean | ||
public static class MissingAccumulator implements OutputTokenAccumulator { | ||
|
||
@Override | ||
public Multi<String> accumulate(Multi<String> tokens) { | ||
return tokens; | ||
} | ||
} | ||
|
||
@ApplicationScoped | ||
public static class MyGuardRail implements OutputGuardrail { | ||
|
||
@Override | ||
public OutputGuardrailResult validate(AiMessage responseFromLLM) { | ||
throw new RuntimeException("Should not be invoked"); | ||
} | ||
|
||
} | ||
|
||
public static class MyMemoryProviderSupplier implements Supplier<ChatMemoryProvider> { | ||
@Override | ||
public ChatMemoryProvider get() { | ||
return new ChatMemoryProvider() { | ||
@Override | ||
public ChatMemory get(Object memoryId) { | ||
return new MessageWindowChatMemory.Builder().maxMessages(5).build(); | ||
} | ||
}; | ||
} | ||
} | ||
|
||
public static class MyStreamingChatModelSupplier implements Supplier<StreamingChatLanguageModel> { | ||
|
||
@Override | ||
public StreamingChatLanguageModel get() { | ||
return new StreamingChatLanguageModel() { | ||
@Override | ||
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) { | ||
handler.onNext("Stream"); | ||
handler.onNext("ing"); | ||
handler.onNext(" "); | ||
handler.onNext("world"); | ||
handler.onNext("!"); | ||
handler.onComplete(Response.from(AiMessage.from(""))); | ||
} | ||
}; | ||
} | ||
} | ||
|
||
} |
Oops, something went wrong.