Skip to content

Commit

Permalink
Add file search details and ranker options
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov committed Sep 4, 2024
1 parent d2dcac2 commit c9b5a12
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ private Constants() {}
static final String LIMIT_QUERY_PARAMETER = "limit";
static final String AFTER_QUERY_PARAMETER = "after";
static final String BEFORE_QUERY_PARAMETER = "before";
static final String INCLUDE_QUERY_PARAMETER = "include[]";

static final String AUTO_CHUNKING_STRATEGY = "auto";
static final String STATIC_CHUNKING_STRATEGY = "static";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
Expand All @@ -32,10 +33,14 @@ public final class RunStepsClient extends OpenAIAssistantsClient {
/**
* Returns a list of run steps belonging to a run.
*
* @param include A list of additional fields to include in the response.
* @throws OpenAIException in case of API errors
*/
public PaginatedThreadRunSteps listRunSteps(
String threadId, String runId, PaginationQueryParameters paginationQueryParameters) {
String threadId,
String runId,
PaginationQueryParameters paginationQueryParameters,
Optional<List<String>> include) {
HttpRequest httpRequest =
newHttpRequestBuilder()
.uri(
Expand All @@ -47,7 +52,9 @@ public PaginatedThreadRunSteps listRunSteps(
+ "/"
+ runId
+ STEPS_SEGMENT
+ createQueryParameters(paginationQueryParameters)))
+ createQueryParameters(
paginationQueryParameters,
Map.of(Constants.INCLUDE_QUERY_PARAMETER, include))))
.GET()
.build();
HttpResponse<byte[]> httpResponse = sendHttpRequest(httpRequest);
Expand All @@ -60,9 +67,11 @@ public record PaginatedThreadRunSteps(
/**
* Retrieves a run step.
*
* @param include A list of additional fields to include in the response.
* @throws OpenAIException in case of API errors
*/
public ThreadRunStep retrieveRunStep(String threadId, String runId, String stepId) {
public ThreadRunStep retrieveRunStep(
String threadId, String runId, String stepId, Optional<List<String>> include) {
HttpRequest httpRequest =
newHttpRequestBuilder()
.uri(
Expand All @@ -75,7 +84,9 @@ public ThreadRunStep retrieveRunStep(String threadId, String runId, String stepI
+ runId
+ STEPS_SEGMENT
+ "/"
+ stepId))
+ stepId
+ createQueryParameters(
Map.of(Constants.INCLUDE_QUERY_PARAMETER, include))))
.GET()
.build();
HttpResponse<byte[]> httpResponse = sendHttpRequest(httpRequest);
Expand Down
36 changes: 27 additions & 9 deletions src/main/java/io/github/stefanbratanov/jvm/openai/RunsClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.time.Duration;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Spliterator;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -37,37 +38,43 @@ public final class RunsClient extends OpenAIAssistantsClient {
/**
* Create a run.
*
* @param include A list of additional fields to include in the response.
* @throws OpenAIException in case of API errors
*/
public ThreadRun createRun(String threadId, CreateRunRequest request) {
HttpRequest httpRequest = createRunPostRequest(threadId, request);
public ThreadRun createRun(
String threadId, Optional<List<String>> include, CreateRunRequest request) {
HttpRequest httpRequest = createRunPostRequest(threadId, include, request);
HttpResponse<byte[]> httpResponse = sendHttpRequest(httpRequest);
return deserializeResponse(httpResponse.body(), ThreadRun.class);
}

/**
* Create a run and stream the result of executing it.
*
* @param include A list of additional fields to include in the response.
* @throws OpenAIException in case of API errors
*/
public Stream<AssistantStreamEvent> createRunAndStream(
String threadId, CreateRunRequest request) {
String threadId, Optional<List<String>> include, CreateRunRequest request) {
validateStreamRequest(request::stream);
HttpRequest httpRequest = createRunPostRequest(threadId, request);
HttpRequest httpRequest = createRunPostRequest(threadId, include, request);
return getAssistantStreamEvents(httpRequest);
}

/**
* Same as {@link #createRunAndStream(String, CreateRunRequest)} but can pass a {@link
* Same as {@link #createRunAndStream(String, Optional, CreateRunRequest)} but can pass a {@link
* AssistantStreamEventSubscriber} implementation instead of using a {@link
* Stream<AssistantStreamEvent>}
*
* @throws OpenAIException in case of API errors
*/
public void createRunAndStream(
String threadId, CreateRunRequest request, AssistantStreamEventSubscriber subscriber) {
String threadId,
Optional<List<String>> include,
CreateRunRequest request,
AssistantStreamEventSubscriber subscriber) {
validateStreamRequest(request::stream);
HttpRequest httpRequest = createRunPostRequest(threadId, request);
HttpRequest httpRequest = createRunPostRequest(threadId, include, request);
streamAndHandleAssistantEvents(httpRequest, subscriber);
}

Expand Down Expand Up @@ -234,9 +241,16 @@ public ThreadRun cancelRun(String threadId, String runId) {
return deserializeResponse(httpResponse.body(), ThreadRun.class);
}

private HttpRequest createRunPostRequest(String threadId, CreateRunRequest request) {
private HttpRequest createRunPostRequest(
String threadId, Optional<List<String>> include, CreateRunRequest request) {
return newHttpRequestBuilder()
.uri(baseUrl.resolve(Endpoint.THREADS.getPath() + "/" + threadId + RUNS_SEGMENT))
.uri(
baseUrl.resolve(
Endpoint.THREADS.getPath()
+ "/"
+ threadId
+ RUNS_SEGMENT
+ createQueryParameters(include)))
.POST(createBodyPublisher(request))
.build();
}
Expand All @@ -248,6 +262,10 @@ private HttpRequest createThreadAndRunPostRequest(CreateThreadAndRunRequest requ
.build();
}

private String createQueryParameters(Optional<List<String>> include) {
return createQueryParameters(Map.of(Constants.INCLUDE_QUERY_PARAMETER, include));
}

private HttpRequest createSubmitToolOutputsPostRequest(
String threadId, String runId, SubmitToolOutputsRequest request) {
return newHttpRequestBuilder()
Expand Down
21 changes: 19 additions & 2 deletions src/main/java/io/github/stefanbratanov/jvm/openai/Tool.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import io.github.stefanbratanov.jvm.openai.Tool.FileSearchTool.FileSearch.RankingOptions;
import java.util.Optional;

@JsonTypeInfo(
Expand Down Expand Up @@ -31,7 +32,11 @@ public String type() {

record FileSearchTool(Optional<FileSearch> fileSearch) implements Tool {

public record FileSearch(Optional<Integer> maxNumResults) {}
public record FileSearch(
Optional<Integer> maxNumResults, Optional<RankingOptions> rankingOptions) {

public record RankingOptions(String ranker, Double scoreThreshold) {}
}

@Override
public String type() {
Expand Down Expand Up @@ -60,7 +65,19 @@ static FileSearchTool fileSearchTool() {
*/
static FileSearchTool fileSearchTool(int maxNumResults) {
return new FileSearchTool(
Optional.of(new FileSearchTool.FileSearch(Optional.of(maxNumResults))));
Optional.of(new FileSearchTool.FileSearch(Optional.of(maxNumResults), Optional.empty())));
}

/**
* @param maxNumResults The maximum number of results the file search tool should output.
* @param rankingOptions The score threshold for the file search. All values must be a floating
* point number between 0 and 1.
*/
static FileSearchTool fileSearchTool(int maxNumResults, RankingOptions rankingOptions) {
return new FileSearchTool(
Optional.of(
new FileSearchTool.FileSearch(
Optional.of(maxNumResults), Optional.of(rankingOptions))));
}

static FunctionTool functionTool(Function function) {
Expand Down
19 changes: 14 additions & 5 deletions src/main/java/io/github/stefanbratanov/jvm/openai/ToolCall.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import io.github.stefanbratanov.jvm.openai.ToolCall.CodeInterpreterToolCall.CodeInterpreter;
import io.github.stefanbratanov.jvm.openai.ToolCall.CodeInterpreterToolCall.CodeInterpreter.Output.ImageOutput.Image;
import io.github.stefanbratanov.jvm.openai.ToolCall.FileSearchToolCall.FileSearch;
import io.github.stefanbratanov.jvm.openai.ToolCall.FunctionToolCall.Function;
import java.util.Collections;
import java.util.List;
import java.util.Map;

@JsonTypeInfo(
use = JsonTypeInfo.Id.NAME,
Expand Down Expand Up @@ -96,11 +95,21 @@ static ImageOutput imageOutput(Image image) {
}
}

record FileSearchToolCall(String id, Map<String, Object> fileSearch) implements ToolCall {
record FileSearchToolCall(String id, FileSearch fileSearch) implements ToolCall {
@Override
public String type() {
return Constants.FILE_SEARCH_TOOL_CALL_TYPE;
}

public record FileSearch(RankingOptions rankingOptions, List<Result> results) {

public record RankingOptions(String ranker, double scoreThreshold) {}

public record Result(String fileId, String fileName, double score, List<Content> content) {

public record Content(String type, String text) {}
}
}
}

record FunctionToolCall(String id, Function function) implements ToolCall {
Expand All @@ -123,8 +132,8 @@ static CodeInterpreterToolCall codeInterpreterToolCall(
return new CodeInterpreterToolCall(id, codeInterpreter);
}

static FileSearchToolCall fileSearchToolCall(String id) {
return new FileSearchToolCall(id, Collections.emptyMap());
static FileSearchToolCall fileSearchToolCall(String id, FileSearch fileSearch) {
return new FileSearchToolCall(id, fileSearch);
}

static FunctionToolCall functionToolCall(String id, Function function) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ void testRunsAndRunStepsClients() {
CreateRunRequest createRunRequest =
CreateRunRequest.newBuilder().assistantId(assistant.id()).build();

ThreadRun run = runsClient.createRun(threadId, createRunRequest);
ThreadRun run = runsClient.createRun(threadId, Optional.empty(), createRunRequest);
String runId = run.id();

assertThat(run.threadId()).isEqualTo(threadId);
Expand Down Expand Up @@ -224,7 +224,7 @@ void testRunsAndRunStepsClients() {
// test with java.util.stream.Stream
Set<String> emittedEvents =
runsClient
.createRunAndStream(threadId, createRunStreamRequest)
.createRunAndStream(threadId, Optional.empty(), createRunStreamRequest)
.map(
assistantStreamEvent -> {
assertThat(assistantStreamEvent.data()).isNotNull();
Expand Down Expand Up @@ -367,14 +367,18 @@ public void onComplete() {

// retrieve run steps
List<ThreadRunStep> runSteps =
runStepsClient.listRunSteps(threadId, runId, PaginationQueryParameters.none()).data();
runStepsClient
.listRunSteps(threadId, runId, PaginationQueryParameters.none(), Optional.empty())
.data();

assertThat(runSteps)
.first()
.satisfies(
runStep ->
assertThat(runStep)
.isEqualTo(runStepsClient.retrieveRunStep(threadId, runId, runStep.id())));
.isEqualTo(
runStepsClient.retrieveRunStep(
threadId, runId, runStep.id(), Optional.empty())));

// modify run
ThreadRun modifiedRun =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import io.github.stefanbratanov.jvm.openai.ThreadRunStepDelta.StepDetails.MessageCreationStepDetails;
import io.github.stefanbratanov.jvm.openai.ThreadRunStepDelta.StepDetails.MessageCreationStepDetails.MessageCreation;
import io.github.stefanbratanov.jvm.openai.ToolCall.CodeInterpreterToolCall.CodeInterpreter.Output;
import io.github.stefanbratanov.jvm.openai.ToolCall.FileSearchToolCall.FileSearch;
import io.github.stefanbratanov.jvm.openai.ToolCall.FileSearchToolCall.FileSearch.RankingOptions;
import java.util.List;
import java.util.Map;
import org.json.JSONException;
Expand Down Expand Up @@ -203,10 +205,13 @@ void doesNotSerializeTypeTwiceForJsonSubTypesAnnotatedClasses() throws JsonProce
assertThat(objectMapper.writeValueAsString(fileSearchTool))
.isEqualTo("{\"type\":\"file_search\"}");

ToolCall.FileSearchToolCall fileSearchToolCall = ToolCall.fileSearchToolCall("foobar");
ToolCall.FileSearchToolCall fileSearchToolCall =
ToolCall.fileSearchToolCall(
"foobar", new FileSearch(new RankingOptions("default_2024_08_21", 0.0), List.of()));

assertThat(objectMapper.writeValueAsString(fileSearchToolCall))
.isEqualTo("{\"id\":\"foobar\",\"file_search\":{},\"type\":\"file_search\"}");
.isEqualTo(
"{\"id\":\"foobar\",\"file_search\":{\"ranking_options\":{\"ranker\":\"default_2024_08_21\",\"score_threshold\":0.0},\"results\":[]},\"type\":\"file_search\"}");

DeltaToolCall.FileSearchToolCall deltaFileSearchToolCall =
DeltaToolCall.fileSearchToolCall(0, "foobar");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,13 @@
import io.github.stefanbratanov.jvm.openai.ThreadRunStep.StepDetails;
import io.github.stefanbratanov.jvm.openai.ThreadRunStep.StepDetails.MessageCreationStepDetails;
import io.github.stefanbratanov.jvm.openai.ThreadRunStep.StepDetails.ToolCallsStepDetails;
import io.github.stefanbratanov.jvm.openai.Tool.FileSearchTool.FileSearch.RankingOptions;
import io.github.stefanbratanov.jvm.openai.Tool.FunctionTool;
import io.github.stefanbratanov.jvm.openai.ToolCall.CodeInterpreterToolCall.CodeInterpreter;
import io.github.stefanbratanov.jvm.openai.ToolCall.CodeInterpreterToolCall.CodeInterpreter.Output.ImageOutput;
import io.github.stefanbratanov.jvm.openai.ToolCall.FileSearchToolCall.FileSearch;
import io.github.stefanbratanov.jvm.openai.ToolCall.FileSearchToolCall.FileSearch.Result;
import io.github.stefanbratanov.jvm.openai.ToolCall.FileSearchToolCall.FileSearch.Result.Content;
import io.github.stefanbratanov.jvm.openai.ToolCall.FunctionToolCall;
import io.github.stefanbratanov.jvm.openai.ToolResources.FileSearch.VectorStore;
import java.util.*;
Expand Down Expand Up @@ -1032,7 +1036,18 @@ private ToolCall randomToolCall() {
return oneOf(
randomFunctionToolCall(true),
randomCodeInterpreterToolCall(),
ToolCall.fileSearchToolCall(randomString(5)));
ToolCall.fileSearchToolCall(
randomString(5),
new FileSearch(
new FileSearch.RankingOptions("default_2024_08_21", randomLong(0, 1)),
listOf(
randomInt(1, 5),
() ->
new Result(
randomString(5),
randomString(6),
randomLong(0, 1),
List.of(new Content("text", randomString(10))))))));
}

private Usage randomUsage() {
Expand Down Expand Up @@ -1227,7 +1242,12 @@ private String randomFinishReason() {

private Tool randomTool() {
return oneOf(
randomFunctionTool(), Tool.fileSearchTool(randomInt(1, 50)), Tool.codeInterpreterTool());
randomFunctionTool(),
Tool.fileSearchTool(randomInt(1, 50)),
Tool.fileSearchTool(
randomInt(1, 50),
new RankingOptions(oneOf("auto", "default_2024_08_21"), randomDouble(0, 1))),
Tool.codeInterpreterTool());
}

private DeltaToolCall randomCodeInterpreterDeltaToolCall() {
Expand Down

0 comments on commit c9b5a12

Please sign in to comment.