From d28acf0026bd21f281e91f3d80d2e81f3a78784e Mon Sep 17 00:00:00 2001 From: Stefan Bratanov Date: Tue, 7 May 2024 10:09:50 +0100 Subject: [PATCH] Add support for usage stats when streaming with the Chat Completions API --- .../jvm/openai/ChatCompletionChunk.java | 7 +++++- .../openai/CreateChatCompletionRequest.java | 23 +++++++++++++++++++ .../RecordNamingStrategyPatchModule.java | 6 +++-- .../jvm/openai/OpenAIIntegrationTest.java | 17 ++++++++++++++ .../jvm/openai/TestDataUtil.java | 2 ++ 5 files changed, 52 insertions(+), 3 deletions(-) diff --git a/src/main/java/io/github/stefanbratanov/jvm/openai/ChatCompletionChunk.java b/src/main/java/io/github/stefanbratanov/jvm/openai/ChatCompletionChunk.java index 586429a..3973b82 100644 --- a/src/main/java/io/github/stefanbratanov/jvm/openai/ChatCompletionChunk.java +++ b/src/main/java/io/github/stefanbratanov/jvm/openai/ChatCompletionChunk.java @@ -7,7 +7,12 @@ * provided input. */ public record ChatCompletionChunk( - String id, List choices, long created, String model, String systemFingerprint) { + String id, + List choices, + long created, + String model, + String systemFingerprint, + Usage usage) { public record Choice(Delta delta, int index, Logprobs logprobs, String finishReason) { diff --git a/src/main/java/io/github/stefanbratanov/jvm/openai/CreateChatCompletionRequest.java b/src/main/java/io/github/stefanbratanov/jvm/openai/CreateChatCompletionRequest.java index 2774031..61df912 100644 --- a/src/main/java/io/github/stefanbratanov/jvm/openai/CreateChatCompletionRequest.java +++ b/src/main/java/io/github/stefanbratanov/jvm/openai/CreateChatCompletionRequest.java @@ -19,6 +19,7 @@ public record CreateChatCompletionRequest( Optional seed, Optional> stop, Optional stream, + Optional streamOptions, Optional temperature, Optional topP, Optional> tools, @@ -29,6 +30,18 @@ public static Builder newBuilder() { return new Builder(); } + /** + * @param includeUsage If set, an additional chunk will be streamed before the data: [DONE] + * message. The usage field on this chunk shows the token usage statistics for the entire + * request, and the choices field will always be an empty array. All other chunks will also + * include a usage field, but with a null value. + */ + public record StreamOptions(Boolean includeUsage) { + public static StreamOptions withUsageIncluded() { + return new StreamOptions(true); + } + } + public static class Builder { private static final String DEFAULT_MODEL = OpenAIModel.GPT_3_5_TURBO.getId(); @@ -48,6 +61,7 @@ public static class Builder { private Optional seed = Optional.empty(); private final List stop = new LinkedList<>(); private Optional stream = Optional.empty(); + private Optional streamOptions = Optional.empty(); private Optional temperature = Optional.empty(); private Optional topP = Optional.empty(); private final List tools = new LinkedList<>(); @@ -195,6 +209,14 @@ public Builder stream(boolean stream) { return this; } + /** + * @param streamOptions Options for streaming response. Only set this when you set stream: true. + */ + public Builder streamOptions(StreamOptions streamOptions) { + this.streamOptions = Optional.of(streamOptions); + return this; + } + /** * @param temperature What sampling temperature to use, between 0 and 2. Higher values like 0.8 * will make the output more random, while lower values like 0.2 will make it more focused @@ -287,6 +309,7 @@ public CreateChatCompletionRequest build() { seed, stop.isEmpty() ? Optional.empty() : Optional.of(List.copyOf(stop)), stream, + streamOptions, temperature, topP, tools.isEmpty() ? Optional.empty() : Optional.of(List.copyOf(tools)), diff --git a/src/main/java/io/github/stefanbratanov/jvm/openai/RecordNamingStrategyPatchModule.java b/src/main/java/io/github/stefanbratanov/jvm/openai/RecordNamingStrategyPatchModule.java index ea09c7b..d335246 100644 --- a/src/main/java/io/github/stefanbratanov/jvm/openai/RecordNamingStrategyPatchModule.java +++ b/src/main/java/io/github/stefanbratanov/jvm/openai/RecordNamingStrategyPatchModule.java @@ -22,9 +22,11 @@ public void setupModule(SetupContext context) { } /** - * Remove when the following issue is resolved: Properties naming strategy do - * not work with Record #2992 + * not work with Record #2992 and Rewrite Bean Property + * Introspection logic in Jackson 2.x (ideally for 2.18) #4515 */ private static class ValueInstantiatorsModifier extends ValueInstantiators.Base { diff --git a/src/test/java/io/github/stefanbratanov/jvm/openai/OpenAIIntegrationTest.java b/src/test/java/io/github/stefanbratanov/jvm/openai/OpenAIIntegrationTest.java index a66bfcd..4bb1721 100644 --- a/src/test/java/io/github/stefanbratanov/jvm/openai/OpenAIIntegrationTest.java +++ b/src/test/java/io/github/stefanbratanov/jvm/openai/OpenAIIntegrationTest.java @@ -5,6 +5,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import io.github.stefanbratanov.jvm.openai.ChatMessage.UserMessage.UserMessageWithContentParts.ContentPart.TextContentPart; +import io.github.stefanbratanov.jvm.openai.CreateChatCompletionRequest.StreamOptions; import java.io.UncheckedIOException; import java.net.http.HttpTimeoutException; import java.nio.file.Path; @@ -107,11 +108,21 @@ void testChatClient() { // test sending content part .message(ChatMessage.userMessage(new TextContentPart("Say this is a test"))) .stream(true) + // test usage stats + .streamOptions(StreamOptions.withUsageIncluded()) .build(); String joinedContent = chatClient .streamChatCompletion(streamRequest) + .filter( + chunk -> { + if (chunk.choices().isEmpty()) { + assertThat(chunk.usage()).isNotNull(); + return false; + } + return true; + }) .map(ChatCompletionChunk::choices) .map( choices -> { @@ -123,6 +134,12 @@ void testChatClient() { assertThat(joinedContent).containsPattern("(?i)this is (a|the) test"); + streamRequest = + CreateChatCompletionRequest.newBuilder() + .message(ChatMessage.userMessage("Say this is a test")) + .stream(true) + .build(); + // test streaming with a subscriber CompletableFuture joinedContentFuture = new CompletableFuture<>(); chatClient.streamChatCompletion( diff --git a/src/test/java/io/github/stefanbratanov/jvm/openai/TestDataUtil.java b/src/test/java/io/github/stefanbratanov/jvm/openai/TestDataUtil.java index 98f461d..7b3ad27 100644 --- a/src/test/java/io/github/stefanbratanov/jvm/openai/TestDataUtil.java +++ b/src/test/java/io/github/stefanbratanov/jvm/openai/TestDataUtil.java @@ -1,6 +1,7 @@ package io.github.stefanbratanov.jvm.openai; import io.github.stefanbratanov.jvm.openai.ChatMessage.UserMessage.UserMessageWithContentParts.ContentPart; +import io.github.stefanbratanov.jvm.openai.CreateChatCompletionRequest.StreamOptions; import io.github.stefanbratanov.jvm.openai.FineTuningJobIntegration.Wandb; import io.github.stefanbratanov.jvm.openai.RunStepsClient.PaginatedThreadRunSteps; import io.github.stefanbratanov.jvm.openai.ThreadMessage.Content.ImageFileContent; @@ -46,6 +47,7 @@ public CreateChatCompletionRequest randomCreateChatCompletionRequest() { .seed(randomInt()) .stop(arrayOf(randomInt(0, 4), () -> randomString(5), String[]::new)) .stream(randomBoolean()) + .streamOptions(StreamOptions.withUsageIncluded()) .temperature(randomDouble(0.0, 2.0)) .topP(randomDouble(0.0, 1.0)) .tools(listOf(randomInt(0, 5), this::randomFunctionTool));