Skip to content

Commit

Permalink
Add ability to configure a request timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov committed Feb 15, 2024
1 parent d77b933 commit 28bb9c9
Show file tree
Hide file tree
Showing 16 changed files with 185 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.List;
import java.util.Optional;

Expand All @@ -19,8 +20,12 @@ public final class AssistantsClient extends OpenAIAssistantsClient {
private final URI baseUrl;

AssistantsClient(
URI baseUrl, String apiKey, Optional<String> organization, HttpClient httpClient) {
super(apiKey, organization, httpClient);
URI baseUrl,
String apiKey,
Optional<String> organization,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
this.baseUrl = baseUrl;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.net.http.HttpResponse.BodyHandlers;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;

Expand All @@ -21,8 +22,13 @@ public final class AudioClient extends OpenAIClient {

private final URI baseUrl;

AudioClient(URI baseUrl, String apiKey, Optional<String> organization, HttpClient httpClient) {
super(apiKey, organization, httpClient);
AudioClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
this.baseUrl = baseUrl;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
Expand All @@ -19,8 +20,13 @@ public final class ChatClient extends OpenAIClient {

private final URI endpoint;

ChatClient(URI baseUrl, String apiKey, Optional<String> organization, HttpClient httpClient) {
super(apiKey, organization, httpClient);
ChatClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
endpoint = baseUrl.resolve(Endpoint.CHAT.getPath());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.Optional;

/**
Expand All @@ -17,8 +18,12 @@ public final class EmbeddingsClient extends OpenAIClient {
private final URI endpoint;

EmbeddingsClient(
URI baseUrl, String apiKey, Optional<String> organization, HttpClient httpClient) {
super(apiKey, organization, httpClient);
URI baseUrl,
String apiKey,
Optional<String> organization,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
endpoint = baseUrl.resolve(Endpoint.EMBEDDINCS.getPath());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.List;
import java.util.Optional;

Expand All @@ -17,8 +18,13 @@ public final class FilesClient extends OpenAIClient {

private final URI baseUrl;

FilesClient(URI baseUrl, String apiKey, Optional<String> organization, HttpClient httpClient) {
super(apiKey, organization, httpClient);
FilesClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
this.baseUrl = baseUrl;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -19,8 +20,12 @@ public final class FineTuningClient extends OpenAIClient {
private final URI baseUrl;

FineTuningClient(
URI baseUrl, String apiKey, Optional<String> organization, HttpClient httpClient) {
super(apiKey, organization, httpClient);
URI baseUrl,
String apiKey,
Optional<String> organization,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
this.baseUrl = baseUrl;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;

Expand All @@ -16,8 +17,13 @@ public final class ImagesClient extends OpenAIClient {

private final URI baseUrl;

ImagesClient(URI baseUrl, String apiKey, Optional<String> organization, HttpClient httpClient) {
super(apiKey, organization, httpClient);
ImagesClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
this.baseUrl = baseUrl;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.List;
import java.util.Optional;

Expand All @@ -19,8 +20,13 @@ public final class MessagesClient extends OpenAIAssistantsClient {

private final URI baseUrl;

MessagesClient(URI baseUrl, String apiKey, Optional<String> organization, HttpClient httpClient) {
super(apiKey, organization, httpClient);
MessagesClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
this.baseUrl = baseUrl;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.List;
import java.util.Optional;

Expand All @@ -17,8 +18,13 @@ public final class ModelsClient extends OpenAIClient {

private final URI baseUrl;

ModelsClient(URI baseUrl, String apiKey, Optional<String> organization, HttpClient httpClient) {
super(apiKey, organization, httpClient);
ModelsClient(
URI baseUrl,
String apiKey,
Optional<String> organization,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
this.baseUrl = baseUrl;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.Optional;

/**
Expand All @@ -16,8 +17,12 @@ public final class ModerationsClient extends OpenAIClient {
private final URI endpoint;

ModerationsClient(
URI baseUrl, String apiKey, Optional<String> organization, HttpClient httpClient) {
super(apiKey, organization, httpClient);
URI baseUrl,
String apiKey,
Optional<String> organization,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
endpoint = baseUrl.resolve(Endpoint.MODERATIONS.getPath());
}

Expand Down
49 changes: 35 additions & 14 deletions src/main/java/io/github/stefanbratanov/jvm/openai/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.net.URI;
import java.net.http.HttpClient;
import java.time.Duration;
import java.util.Optional;

/**
Expand All @@ -24,19 +25,28 @@ public final class OpenAI {
private final MessagesClient messagesClient;
private final RunsClient runsClient;

private OpenAI(URI baseUrl, String apiKey, Optional<String> organization, HttpClient httpClient) {
audioClient = new AudioClient(baseUrl, apiKey, organization, httpClient);
chatClient = new ChatClient(baseUrl, apiKey, organization, httpClient);
embeddingsClient = new EmbeddingsClient(baseUrl, apiKey, organization, httpClient);
fineTuningClient = new FineTuningClient(baseUrl, apiKey, organization, httpClient);
filesClient = new FilesClient(baseUrl, apiKey, organization, httpClient);
imagesClient = new ImagesClient(baseUrl, apiKey, organization, httpClient);
modelsClient = new ModelsClient(baseUrl, apiKey, organization, httpClient);
moderationsClient = new ModerationsClient(baseUrl, apiKey, organization, httpClient);
assistantsClient = new AssistantsClient(baseUrl, apiKey, organization, httpClient);
threadsClient = new ThreadsClient(baseUrl, apiKey, organization, httpClient);
messagesClient = new MessagesClient(baseUrl, apiKey, organization, httpClient);
runsClient = new RunsClient(baseUrl, apiKey, organization, httpClient);
private OpenAI(
URI baseUrl,
String apiKey,
Optional<String> organization,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
audioClient = new AudioClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
chatClient = new ChatClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
embeddingsClient =
new EmbeddingsClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
fineTuningClient =
new FineTuningClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
filesClient = new FilesClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
imagesClient = new ImagesClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
modelsClient = new ModelsClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
moderationsClient =
new ModerationsClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
assistantsClient =
new AssistantsClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
threadsClient = new ThreadsClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
messagesClient = new MessagesClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
runsClient = new RunsClient(baseUrl, apiKey, organization, httpClient, requestTimeout);
}

/**
Expand Down Expand Up @@ -152,6 +162,7 @@ public static class Builder {

private Optional<String> organization = Optional.empty();
private Optional<HttpClient> httpClient = Optional.empty();
private Optional<Duration> requestTimeout = Optional.empty();

public Builder(String apiKey) {
this.apiKey = apiKey;
Expand Down Expand Up @@ -182,6 +193,15 @@ public Builder httpClient(HttpClient httpClient) {
return this;
}

/**
* @param requestTimeout a timeout in the form of a {@link Duration} which will be set for all
* API requests. If none is set, there will be no timeout.
*/
public Builder requestTimeout(Duration requestTimeout) {
this.requestTimeout = Optional.of(requestTimeout);
return this;
}

public OpenAI build() {
if (!baseUrl.endsWith("/")) {
baseUrl += "/";
Expand All @@ -190,7 +210,8 @@ public OpenAI build() {
URI.create(baseUrl),
apiKey,
organization,
httpClient.orElseGet(HttpClient::newHttpClient));
httpClient.orElseGet(HttpClient::newHttpClient),
requestTimeout);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.time.Duration;
import java.util.Map;
import java.util.Optional;

Expand All @@ -11,8 +12,12 @@
*/
class OpenAIAssistantsClient extends OpenAIClient {

OpenAIAssistantsClient(String apiKey, Optional<String> organization, HttpClient httpClient) {
super(apiKey, organization, httpClient);
OpenAIAssistantsClient(
String apiKey,
Optional<String> organization,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
super(apiKey, organization, httpClient, requestTimeout);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.net.http.HttpResponse;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
Expand All @@ -21,14 +22,20 @@
*/
abstract class OpenAIClient {

private final String[] authenticationHeaders;

protected final HttpClient httpClient;
protected final ObjectMapper objectMapper = ObjectMapperSingleton.getInstance();
private final ObjectMapper objectMapper = ObjectMapperSingleton.getInstance();

OpenAIClient(String apiKey, Optional<String> organization, HttpClient httpClient) {
private final String[] authenticationHeaders;
private final HttpClient httpClient;
private final Optional<Duration> requestTimeout;

OpenAIClient(
String apiKey,
Optional<String> organization,
HttpClient httpClient,
Optional<Duration> requestTimeout) {
this.authenticationHeaders = getAuthenticationHeaders(apiKey, organization);
this.httpClient = httpClient;
this.requestTimeout = requestTimeout;
}

HttpRequest.Builder newHttpRequestBuilder(String... headers) {
Expand All @@ -37,6 +44,7 @@ HttpRequest.Builder newHttpRequestBuilder(String... headers) {
if (headers.length > 0) {
httpRequestBuilder.headers(headers);
}
requestTimeout.ifPresent(httpRequestBuilder::timeout);
return httpRequestBuilder;
}

Expand Down
Loading

0 comments on commit 28bb9c9

Please sign in to comment.