From c9b5a12fc40ea2a8449185bd25ee3db4b6614f3c Mon Sep 17 00:00:00 2001 From: Stefan Bratanov Date: Wed, 4 Sep 2024 21:44:43 +0300 Subject: [PATCH] Add file search details and ranker options --- .../stefanbratanov/jvm/openai/Constants.java | 1 + .../jvm/openai/RunStepsClient.java | 19 +++++++--- .../stefanbratanov/jvm/openai/RunsClient.java | 36 ++++++++++++++----- .../stefanbratanov/jvm/openai/Tool.java | 21 +++++++++-- .../stefanbratanov/jvm/openai/ToolCall.java | 19 +++++++--- .../OpenAIAssistantsApiIntegrationTest.java | 12 ++++--- .../jvm/openai/SerializationTest.java | 9 +++-- .../jvm/openai/TestDataUtil.java | 24 +++++++++++-- 8 files changed, 113 insertions(+), 28 deletions(-) diff --git a/src/main/java/io/github/stefanbratanov/jvm/openai/Constants.java b/src/main/java/io/github/stefanbratanov/jvm/openai/Constants.java index c336820..9083bb9 100644 --- a/src/main/java/io/github/stefanbratanov/jvm/openai/Constants.java +++ b/src/main/java/io/github/stefanbratanov/jvm/openai/Constants.java @@ -45,6 +45,7 @@ private Constants() {} static final String LIMIT_QUERY_PARAMETER = "limit"; static final String AFTER_QUERY_PARAMETER = "after"; static final String BEFORE_QUERY_PARAMETER = "before"; + static final String INCLUDE_QUERY_PARAMETER = "include[]"; static final String AUTO_CHUNKING_STRATEGY = "auto"; static final String STATIC_CHUNKING_STRATEGY = "static"; diff --git a/src/main/java/io/github/stefanbratanov/jvm/openai/RunStepsClient.java b/src/main/java/io/github/stefanbratanov/jvm/openai/RunStepsClient.java index 67c4c5d..ea44dec 100644 --- a/src/main/java/io/github/stefanbratanov/jvm/openai/RunStepsClient.java +++ b/src/main/java/io/github/stefanbratanov/jvm/openai/RunStepsClient.java @@ -6,6 +6,7 @@ import java.net.http.HttpResponse; import java.time.Duration; import java.util.List; +import java.util.Map; import java.util.Optional; /** @@ -32,10 +33,14 @@ public final class RunStepsClient extends OpenAIAssistantsClient { /** * Returns a list of run steps belonging to a run. * + * @param include A list of additional fields to include in the response. * @throws OpenAIException in case of API errors */ public PaginatedThreadRunSteps listRunSteps( - String threadId, String runId, PaginationQueryParameters paginationQueryParameters) { + String threadId, + String runId, + PaginationQueryParameters paginationQueryParameters, + Optional> include) { HttpRequest httpRequest = newHttpRequestBuilder() .uri( @@ -47,7 +52,9 @@ public PaginatedThreadRunSteps listRunSteps( + "/" + runId + STEPS_SEGMENT - + createQueryParameters(paginationQueryParameters))) + + createQueryParameters( + paginationQueryParameters, + Map.of(Constants.INCLUDE_QUERY_PARAMETER, include)))) .GET() .build(); HttpResponse httpResponse = sendHttpRequest(httpRequest); @@ -60,9 +67,11 @@ public record PaginatedThreadRunSteps( /** * Retrieves a run step. * + * @param include A list of additional fields to include in the response. * @throws OpenAIException in case of API errors */ - public ThreadRunStep retrieveRunStep(String threadId, String runId, String stepId) { + public ThreadRunStep retrieveRunStep( + String threadId, String runId, String stepId, Optional> include) { HttpRequest httpRequest = newHttpRequestBuilder() .uri( @@ -75,7 +84,9 @@ public ThreadRunStep retrieveRunStep(String threadId, String runId, String stepI + runId + STEPS_SEGMENT + "/" - + stepId)) + + stepId + + createQueryParameters( + Map.of(Constants.INCLUDE_QUERY_PARAMETER, include)))) .GET() .build(); HttpResponse httpResponse = sendHttpRequest(httpRequest); diff --git a/src/main/java/io/github/stefanbratanov/jvm/openai/RunsClient.java b/src/main/java/io/github/stefanbratanov/jvm/openai/RunsClient.java index 26859b8..a4d2bf0 100644 --- a/src/main/java/io/github/stefanbratanov/jvm/openai/RunsClient.java +++ b/src/main/java/io/github/stefanbratanov/jvm/openai/RunsClient.java @@ -7,6 +7,7 @@ import java.time.Duration; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Spliterator; import java.util.concurrent.CompletableFuture; @@ -37,10 +38,12 @@ public final class RunsClient extends OpenAIAssistantsClient { /** * Create a run. * + * @param include A list of additional fields to include in the response. * @throws OpenAIException in case of API errors */ - public ThreadRun createRun(String threadId, CreateRunRequest request) { - HttpRequest httpRequest = createRunPostRequest(threadId, request); + public ThreadRun createRun( + String threadId, Optional> include, CreateRunRequest request) { + HttpRequest httpRequest = createRunPostRequest(threadId, include, request); HttpResponse httpResponse = sendHttpRequest(httpRequest); return deserializeResponse(httpResponse.body(), ThreadRun.class); } @@ -48,26 +51,30 @@ public ThreadRun createRun(String threadId, CreateRunRequest request) { /** * Create a run and stream the result of executing it. * + * @param include A list of additional fields to include in the response. * @throws OpenAIException in case of API errors */ public Stream createRunAndStream( - String threadId, CreateRunRequest request) { + String threadId, Optional> include, CreateRunRequest request) { validateStreamRequest(request::stream); - HttpRequest httpRequest = createRunPostRequest(threadId, request); + HttpRequest httpRequest = createRunPostRequest(threadId, include, request); return getAssistantStreamEvents(httpRequest); } /** - * Same as {@link #createRunAndStream(String, CreateRunRequest)} but can pass a {@link + * Same as {@link #createRunAndStream(String, Optional, CreateRunRequest)} but can pass a {@link * AssistantStreamEventSubscriber} implementation instead of using a {@link * Stream} * * @throws OpenAIException in case of API errors */ public void createRunAndStream( - String threadId, CreateRunRequest request, AssistantStreamEventSubscriber subscriber) { + String threadId, + Optional> include, + CreateRunRequest request, + AssistantStreamEventSubscriber subscriber) { validateStreamRequest(request::stream); - HttpRequest httpRequest = createRunPostRequest(threadId, request); + HttpRequest httpRequest = createRunPostRequest(threadId, include, request); streamAndHandleAssistantEvents(httpRequest, subscriber); } @@ -234,9 +241,16 @@ public ThreadRun cancelRun(String threadId, String runId) { return deserializeResponse(httpResponse.body(), ThreadRun.class); } - private HttpRequest createRunPostRequest(String threadId, CreateRunRequest request) { + private HttpRequest createRunPostRequest( + String threadId, Optional> include, CreateRunRequest request) { return newHttpRequestBuilder() - .uri(baseUrl.resolve(Endpoint.THREADS.getPath() + "/" + threadId + RUNS_SEGMENT)) + .uri( + baseUrl.resolve( + Endpoint.THREADS.getPath() + + "/" + + threadId + + RUNS_SEGMENT + + createQueryParameters(include))) .POST(createBodyPublisher(request)) .build(); } @@ -248,6 +262,10 @@ private HttpRequest createThreadAndRunPostRequest(CreateThreadAndRunRequest requ .build(); } + private String createQueryParameters(Optional> include) { + return createQueryParameters(Map.of(Constants.INCLUDE_QUERY_PARAMETER, include)); + } + private HttpRequest createSubmitToolOutputsPostRequest( String threadId, String runId, SubmitToolOutputsRequest request) { return newHttpRequestBuilder() diff --git a/src/main/java/io/github/stefanbratanov/jvm/openai/Tool.java b/src/main/java/io/github/stefanbratanov/jvm/openai/Tool.java index d084aca..ca1f931 100644 --- a/src/main/java/io/github/stefanbratanov/jvm/openai/Tool.java +++ b/src/main/java/io/github/stefanbratanov/jvm/openai/Tool.java @@ -3,6 +3,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; +import io.github.stefanbratanov.jvm.openai.Tool.FileSearchTool.FileSearch.RankingOptions; import java.util.Optional; @JsonTypeInfo( @@ -31,7 +32,11 @@ public String type() { record FileSearchTool(Optional fileSearch) implements Tool { - public record FileSearch(Optional maxNumResults) {} + public record FileSearch( + Optional maxNumResults, Optional rankingOptions) { + + public record RankingOptions(String ranker, Double scoreThreshold) {} + } @Override public String type() { @@ -60,7 +65,19 @@ static FileSearchTool fileSearchTool() { */ static FileSearchTool fileSearchTool(int maxNumResults) { return new FileSearchTool( - Optional.of(new FileSearchTool.FileSearch(Optional.of(maxNumResults)))); + Optional.of(new FileSearchTool.FileSearch(Optional.of(maxNumResults), Optional.empty()))); + } + + /** + * @param maxNumResults The maximum number of results the file search tool should output. + * @param rankingOptions The score threshold for the file search. All values must be a floating + * point number between 0 and 1. + */ + static FileSearchTool fileSearchTool(int maxNumResults, RankingOptions rankingOptions) { + return new FileSearchTool( + Optional.of( + new FileSearchTool.FileSearch( + Optional.of(maxNumResults), Optional.of(rankingOptions)))); } static FunctionTool functionTool(Function function) { diff --git a/src/main/java/io/github/stefanbratanov/jvm/openai/ToolCall.java b/src/main/java/io/github/stefanbratanov/jvm/openai/ToolCall.java index bd92bbd..256bcb9 100644 --- a/src/main/java/io/github/stefanbratanov/jvm/openai/ToolCall.java +++ b/src/main/java/io/github/stefanbratanov/jvm/openai/ToolCall.java @@ -5,10 +5,9 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo; import io.github.stefanbratanov.jvm.openai.ToolCall.CodeInterpreterToolCall.CodeInterpreter; import io.github.stefanbratanov.jvm.openai.ToolCall.CodeInterpreterToolCall.CodeInterpreter.Output.ImageOutput.Image; +import io.github.stefanbratanov.jvm.openai.ToolCall.FileSearchToolCall.FileSearch; import io.github.stefanbratanov.jvm.openai.ToolCall.FunctionToolCall.Function; -import java.util.Collections; import java.util.List; -import java.util.Map; @JsonTypeInfo( use = JsonTypeInfo.Id.NAME, @@ -96,11 +95,21 @@ static ImageOutput imageOutput(Image image) { } } - record FileSearchToolCall(String id, Map fileSearch) implements ToolCall { + record FileSearchToolCall(String id, FileSearch fileSearch) implements ToolCall { @Override public String type() { return Constants.FILE_SEARCH_TOOL_CALL_TYPE; } + + public record FileSearch(RankingOptions rankingOptions, List results) { + + public record RankingOptions(String ranker, double scoreThreshold) {} + + public record Result(String fileId, String fileName, double score, List content) { + + public record Content(String type, String text) {} + } + } } record FunctionToolCall(String id, Function function) implements ToolCall { @@ -123,8 +132,8 @@ static CodeInterpreterToolCall codeInterpreterToolCall( return new CodeInterpreterToolCall(id, codeInterpreter); } - static FileSearchToolCall fileSearchToolCall(String id) { - return new FileSearchToolCall(id, Collections.emptyMap()); + static FileSearchToolCall fileSearchToolCall(String id, FileSearch fileSearch) { + return new FileSearchToolCall(id, fileSearch); } static FunctionToolCall functionToolCall(String id, Function function) { diff --git a/src/test/java/io/github/stefanbratanov/jvm/openai/OpenAIAssistantsApiIntegrationTest.java b/src/test/java/io/github/stefanbratanov/jvm/openai/OpenAIAssistantsApiIntegrationTest.java index 8a2c593..aa61ffd 100644 --- a/src/test/java/io/github/stefanbratanov/jvm/openai/OpenAIAssistantsApiIntegrationTest.java +++ b/src/test/java/io/github/stefanbratanov/jvm/openai/OpenAIAssistantsApiIntegrationTest.java @@ -192,7 +192,7 @@ void testRunsAndRunStepsClients() { CreateRunRequest createRunRequest = CreateRunRequest.newBuilder().assistantId(assistant.id()).build(); - ThreadRun run = runsClient.createRun(threadId, createRunRequest); + ThreadRun run = runsClient.createRun(threadId, Optional.empty(), createRunRequest); String runId = run.id(); assertThat(run.threadId()).isEqualTo(threadId); @@ -224,7 +224,7 @@ void testRunsAndRunStepsClients() { // test with java.util.stream.Stream Set emittedEvents = runsClient - .createRunAndStream(threadId, createRunStreamRequest) + .createRunAndStream(threadId, Optional.empty(), createRunStreamRequest) .map( assistantStreamEvent -> { assertThat(assistantStreamEvent.data()).isNotNull(); @@ -367,14 +367,18 @@ public void onComplete() { // retrieve run steps List runSteps = - runStepsClient.listRunSteps(threadId, runId, PaginationQueryParameters.none()).data(); + runStepsClient + .listRunSteps(threadId, runId, PaginationQueryParameters.none(), Optional.empty()) + .data(); assertThat(runSteps) .first() .satisfies( runStep -> assertThat(runStep) - .isEqualTo(runStepsClient.retrieveRunStep(threadId, runId, runStep.id()))); + .isEqualTo( + runStepsClient.retrieveRunStep( + threadId, runId, runStep.id(), Optional.empty()))); // modify run ThreadRun modifiedRun = diff --git a/src/test/java/io/github/stefanbratanov/jvm/openai/SerializationTest.java b/src/test/java/io/github/stefanbratanov/jvm/openai/SerializationTest.java index 5537b91..d35fdf3 100644 --- a/src/test/java/io/github/stefanbratanov/jvm/openai/SerializationTest.java +++ b/src/test/java/io/github/stefanbratanov/jvm/openai/SerializationTest.java @@ -17,6 +17,8 @@ import io.github.stefanbratanov.jvm.openai.ThreadRunStepDelta.StepDetails.MessageCreationStepDetails; import io.github.stefanbratanov.jvm.openai.ThreadRunStepDelta.StepDetails.MessageCreationStepDetails.MessageCreation; import io.github.stefanbratanov.jvm.openai.ToolCall.CodeInterpreterToolCall.CodeInterpreter.Output; +import io.github.stefanbratanov.jvm.openai.ToolCall.FileSearchToolCall.FileSearch; +import io.github.stefanbratanov.jvm.openai.ToolCall.FileSearchToolCall.FileSearch.RankingOptions; import java.util.List; import java.util.Map; import org.json.JSONException; @@ -203,10 +205,13 @@ void doesNotSerializeTypeTwiceForJsonSubTypesAnnotatedClasses() throws JsonProce assertThat(objectMapper.writeValueAsString(fileSearchTool)) .isEqualTo("{\"type\":\"file_search\"}"); - ToolCall.FileSearchToolCall fileSearchToolCall = ToolCall.fileSearchToolCall("foobar"); + ToolCall.FileSearchToolCall fileSearchToolCall = + ToolCall.fileSearchToolCall( + "foobar", new FileSearch(new RankingOptions("default_2024_08_21", 0.0), List.of())); assertThat(objectMapper.writeValueAsString(fileSearchToolCall)) - .isEqualTo("{\"id\":\"foobar\",\"file_search\":{},\"type\":\"file_search\"}"); + .isEqualTo( + "{\"id\":\"foobar\",\"file_search\":{\"ranking_options\":{\"ranker\":\"default_2024_08_21\",\"score_threshold\":0.0},\"results\":[]},\"type\":\"file_search\"}"); DeltaToolCall.FileSearchToolCall deltaFileSearchToolCall = DeltaToolCall.fileSearchToolCall(0, "foobar"); diff --git a/src/test/java/io/github/stefanbratanov/jvm/openai/TestDataUtil.java b/src/test/java/io/github/stefanbratanov/jvm/openai/TestDataUtil.java index 5f86051..fe4aec2 100644 --- a/src/test/java/io/github/stefanbratanov/jvm/openai/TestDataUtil.java +++ b/src/test/java/io/github/stefanbratanov/jvm/openai/TestDataUtil.java @@ -38,9 +38,13 @@ import io.github.stefanbratanov.jvm.openai.ThreadRunStep.StepDetails; import io.github.stefanbratanov.jvm.openai.ThreadRunStep.StepDetails.MessageCreationStepDetails; import io.github.stefanbratanov.jvm.openai.ThreadRunStep.StepDetails.ToolCallsStepDetails; +import io.github.stefanbratanov.jvm.openai.Tool.FileSearchTool.FileSearch.RankingOptions; import io.github.stefanbratanov.jvm.openai.Tool.FunctionTool; import io.github.stefanbratanov.jvm.openai.ToolCall.CodeInterpreterToolCall.CodeInterpreter; import io.github.stefanbratanov.jvm.openai.ToolCall.CodeInterpreterToolCall.CodeInterpreter.Output.ImageOutput; +import io.github.stefanbratanov.jvm.openai.ToolCall.FileSearchToolCall.FileSearch; +import io.github.stefanbratanov.jvm.openai.ToolCall.FileSearchToolCall.FileSearch.Result; +import io.github.stefanbratanov.jvm.openai.ToolCall.FileSearchToolCall.FileSearch.Result.Content; import io.github.stefanbratanov.jvm.openai.ToolCall.FunctionToolCall; import io.github.stefanbratanov.jvm.openai.ToolResources.FileSearch.VectorStore; import java.util.*; @@ -1032,7 +1036,18 @@ private ToolCall randomToolCall() { return oneOf( randomFunctionToolCall(true), randomCodeInterpreterToolCall(), - ToolCall.fileSearchToolCall(randomString(5))); + ToolCall.fileSearchToolCall( + randomString(5), + new FileSearch( + new FileSearch.RankingOptions("default_2024_08_21", randomLong(0, 1)), + listOf( + randomInt(1, 5), + () -> + new Result( + randomString(5), + randomString(6), + randomLong(0, 1), + List.of(new Content("text", randomString(10)))))))); } private Usage randomUsage() { @@ -1227,7 +1242,12 @@ private String randomFinishReason() { private Tool randomTool() { return oneOf( - randomFunctionTool(), Tool.fileSearchTool(randomInt(1, 50)), Tool.codeInterpreterTool()); + randomFunctionTool(), + Tool.fileSearchTool(randomInt(1, 50)), + Tool.fileSearchTool( + randomInt(1, 50), + new RankingOptions(oneOf("auto", "default_2024_08_21"), randomDouble(0, 1))), + Tool.codeInterpreterTool()); } private DeltaToolCall randomCodeInterpreterDeltaToolCall() {