Skip to content

Commit

Permalink
Add stream field for assistants streaming (will implement later)
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov committed Apr 3, 2024
1 parent 625fbd3 commit 9f77cdb
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 25 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ chatClient.streamChatCompletion(request, new StreamChatCompletionSubscriber() {
public void onChunk(ChatCompletionChunk chunk) {
System.out.println(chunk);
}

@Override
public void onException(Throwable ex) {
// ...
}

@Override
public void onComplete() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
*/
public final class ChatClient extends OpenAIClient {

private static final String STREAM_TERMINATION = "data: [DONE]";

private final URI endpoint;

ChatClient(
Expand Down Expand Up @@ -103,10 +101,7 @@ private void validateStreamRequest(CreateChatCompletionRequest request) {
}

private Stream<ChatCompletionChunk> getStreamedResponses(HttpRequest httpRequest) {
return sendHttpRequest(httpRequest, HttpResponse.BodyHandlers.ofLines())
.body()
.filter(sseEvent -> !sseEvent.isBlank())
.takeWhile(sseEvent -> !sseEvent.equals(STREAM_TERMINATION))
return streamServerSentEvents(httpRequest)
.map(
sseEvent -> {
String chatChunkResponse = sseEvent.substring(sseEvent.indexOf("{"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ public record CreateRunRequest(
Optional<String> additionalInstructions,
Optional<List<Tool>> tools,
Optional<Map<String, String>> metadata,
Optional<Double> temperature) {
Optional<Double> temperature,
Optional<Boolean> stream) {

public static Builder newBuilder() {
return new Builder();
Expand All @@ -27,6 +28,7 @@ public static class Builder {
private Optional<List<Tool>> tools = Optional.empty();
private Optional<Map<String, String>> metadata = Optional.empty();
private Optional<Double> temperature = Optional.empty();
private Optional<Boolean> stream = Optional.empty();

/**
* @param assistantId The ID of the assistant to use to execute this run.
Expand Down Expand Up @@ -94,9 +96,25 @@ public Builder temperature(Double temperature) {
return this;
}

/**
* @param stream If true, returns a stream of events that happen during the Run as server-sent
* events
*/
public Builder stream(Boolean stream) {
this.stream = Optional.of(stream);
return this;
}

public CreateRunRequest build() {
return new CreateRunRequest(
assistantId, model, instructions, additionalInstructions, tools, metadata, temperature);
assistantId,
model,
instructions,
additionalInstructions,
tools,
metadata,
temperature,
stream);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ public record CreateThreadAndRunRequest(
Optional<String> model,
Optional<String> instructions,
Optional<List<Tool>> tools,
Optional<Map<String, String>> metadata) {
Optional<Map<String, String>> metadata,
Optional<Double> temperature,
Optional<Boolean> stream) {

public static Builder newBuilder() {
return new Builder();
Expand All @@ -25,6 +27,8 @@ public static class Builder {
private Optional<String> instructions = Optional.empty();
private Optional<List<Tool>> tools = Optional.empty();
private Optional<Map<String, String>> metadata = Optional.empty();
private Optional<Double> temperature = Optional.empty();
private Optional<Boolean> stream = Optional.empty();

/**
* @param assistantId The ID of the assistant to use to execute this run.
Expand Down Expand Up @@ -80,9 +84,28 @@ public Builder metadata(Map<String, String> metadata) {
return this;
}

/**
* @param temperature What sampling temperature to use, between 0 and 2. Higher values like 0.8
* will make the output more random, while lower values like 0.2 will make it more focused
* and deterministic.
*/
public Builder temperature(Double temperature) {
this.temperature = Optional.of(temperature);
return this;
}

/**
* @param stream If true, returns a stream of events that happen during the Run as server-sent
* events
*/
public Builder stream(Boolean stream) {
this.stream = Optional.of(stream);
return this;
}

public CreateThreadAndRunRequest build() {
return new CreateThreadAndRunRequest(
assistantId, thread, model, instructions, tools, metadata);
assistantId, thread, model, instructions, tools, metadata, temperature, stream);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
*/
abstract class OpenAIClient {

private static final String STREAM_TERMINATION = "data: [DONE]";

private final ObjectMapper objectMapper = ObjectMapperSingleton.getInstance();

private final String[] authenticationHeaders;
Expand Down Expand Up @@ -96,6 +98,13 @@ <T> CompletableFuture<HttpResponse<T>> sendHttpRequestAsync(
});
}

Stream<String> streamServerSentEvents(HttpRequest httpRequest) {
return sendHttpRequest(httpRequest, HttpResponse.BodyHandlers.ofLines())
.body()
.filter(sseEvent -> !sseEvent.isBlank())
.takeWhile(sseEvent -> !sseEvent.contains(STREAM_TERMINATION));
}

void validateHttpResponse(HttpResponse<?> httpResponse) {
int statusCode = httpResponse.statusCode();
if (statusCode < 200 || statusCode > 299) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,52 @@
package io.github.stefanbratanov.jvm.openai;

import java.util.LinkedList;
import java.util.List;
import java.util.Optional;

public record SubmitToolOutputsRequest(List<ToolOutput> toolOutputs) {
public record SubmitToolOutputsRequest(List<ToolOutput> toolOutputs, Optional<Boolean> stream) {

public static Builder newBuilder() {
return new Builder();
}

public static class Builder {

private final List<ToolOutput> toolOutputs = new LinkedList<>();

private Optional<Boolean> stream = Optional.empty();

/**
* @param toolOutput Tool output to append to the list of tools for which the outputs are being
* submitted.
*/
public Builder toolOutput(ToolOutput toolOutput) {
toolOutputs.add(toolOutput);
return this;
}

/**
* @param toolOutputs Tool outputs to append to the list of tools for which the outputs are
* being submitted.
*/
public Builder toolOutputs(List<ToolOutput> toolOutputs) {
this.toolOutputs.addAll(toolOutputs);
return this;
}

/**
* @param stream If true, returns a stream of events that happen during the Run as server-sent
* events
*/
public Builder stream(boolean stream) {
this.stream = Optional.of(stream);
return this;
}

public SubmitToolOutputsRequest build() {
return new SubmitToolOutputsRequest(List.copyOf(toolOutputs), stream);
}
}

public record ToolOutput(Optional<String> toolCallId, Optional<String> output) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,14 @@ void testRunsClient() {
assertThat(modifiedRun.metadata()).isEqualTo(METADATA);

SubmitToolOutputsRequest submitToolOutputsRequest =
new SubmitToolOutputsRequest(
List.of(
SubmitToolOutputsRequest.ToolOutput.newBuilder()
.toolCallId("call_abc123")
.output("28C")
.build()));
SubmitToolOutputsRequest.newBuilder()
.toolOutputs(
List.of(
SubmitToolOutputsRequest.ToolOutput.newBuilder()
.toolCallId("call_abc123")
.output("28C")
.build()))
.build();

OpenAIException submitToolOutputException =
assertThrows(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ public CreateRunRequest randomCreateRunRequest() {
.tools(listOf(randomInt(1, 20), this::randomTool))
.metadata(randomMetadata())
.temperature(randomDouble(0, 2))
.stream(randomBoolean())
.build();
}

Expand All @@ -321,6 +322,8 @@ public CreateThreadAndRunRequest randomCreateThreadAndRunRequest() {
.instructions(randomString(10, 100))
.tools(listOf(randomInt(1, 20), this::randomTool))
.metadata(randomMetadata())
.temperature(randomDouble(0, 2))
.stream(randomBoolean())
.build();
}

Expand Down Expand Up @@ -379,14 +382,17 @@ public ThreadRunStep randomThreadRunStep() {
}

public SubmitToolOutputsRequest randomSubmitToolOutputsRequest() {
return new SubmitToolOutputsRequest(
listOf(
randomInt(1, 5),
() ->
SubmitToolOutputsRequest.ToolOutput.newBuilder()
.toolCallId(randomString(6))
.output(randomString(5, 20))
.build()));
return SubmitToolOutputsRequest.newBuilder()
.toolOutputs(
listOf(
randomInt(1, 5),
() ->
SubmitToolOutputsRequest.ToolOutput.newBuilder()
.toolCallId(randomString(6))
.output(randomString(5, 20))
.build()))
.stream(randomBoolean())
.build();
}

private StepDetails randomStepDetails() {
Expand Down

0 comments on commit 9f77cdb

Please sign in to comment.