Skip to content

Commit

Permalink
Add response_format and timestamp_granularities to transcription/tran…
Browse files Browse the repository at this point in the history
…slation requests
  • Loading branch information
StefanBratanov committed Feb 7, 2024
1 parent 38faef4 commit 8d9b476
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 22 deletions.
26 changes: 20 additions & 6 deletions src/main/java/io/github/stefanbratanov/jvm/openai/AudioClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public String createTranscript(TranscriptionRequest request) {
HttpRequest httpRequest = createTranscriptPostRequest(request);

HttpResponse<byte[]> httpResponse = sendHttpRequest(httpRequest);
return deserializeResponseAsTree(httpResponse.body()).get("text").asText();
return new String(httpResponse.body());
}

/**
Expand All @@ -74,8 +74,7 @@ public CompletableFuture<String> createTranscriptAsync(TranscriptionRequest requ
HttpRequest httpRequest = createTranscriptPostRequest(request);

return sendHttpRequestAsync(httpRequest)
.thenApply(
httpResponse -> deserializeResponseAsTree(httpResponse.body()).get("text").asText());
.thenApply(httpResponse -> new String(httpResponse.body()));
}

/**
Expand All @@ -87,7 +86,7 @@ public String createTranslation(TranslationRequest request) {
HttpRequest httpRequest = createTranslationPostRequest(request);

HttpResponse<byte[]> httpResponse = sendHttpRequest(httpRequest);
return deserializeResponseAsTree(httpResponse.body()).get("text").asText();
return new String(httpResponse.body());
}

/**
Expand All @@ -98,8 +97,7 @@ public CompletableFuture<String> createTranslationAsync(TranslationRequest reque
HttpRequest httpRequest = createTranslationPostRequest(request);

return sendHttpRequestAsync(httpRequest)
.thenApply(
httpResponse -> deserializeResponseAsTree(httpResponse.body()).get("text").asText());
.thenApply(httpResponse -> new String(httpResponse.body()));
}

private void createParentDirectories(Path path) {
Expand Down Expand Up @@ -129,10 +127,21 @@ private HttpRequest createTranscriptPostRequest(TranscriptionRequest request) {
.language()
.ifPresent(language -> multipartBodyPublisherBuilder.textPart("language", language));
request.prompt().ifPresent(prompt -> multipartBodyPublisherBuilder.textPart("prompt", prompt));
request
.responseFormat()
.ifPresent(
responseFormat ->
multipartBodyPublisherBuilder.textPart("response_format", responseFormat));
request
.temperature()
.ifPresent(
temperature -> multipartBodyPublisherBuilder.textPart("temperature", temperature));
request
.timestampGranularities()
.ifPresent(
timestampGranularities ->
multipartBodyPublisherBuilder.textPart(
"timestamp_granularities", timestampGranularities));

MultipartBodyPublisher multipartBodyPublisher = multipartBodyPublisherBuilder.build();

Expand All @@ -149,6 +158,11 @@ private HttpRequest createTranslationPostRequest(TranslationRequest request) {
.filePart("file", request.file())
.textPart("model", request.model());
request.prompt().ifPresent(prompt -> multipartBodyPublisherBuilder.textPart("prompt", prompt));
request
.responseFormat()
.ifPresent(
responseFormat ->
multipartBodyPublisherBuilder.textPart("response_format", responseFormat));
request
.temperature()
.ifPresent(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,6 @@ <T> List<T> deserializeDataInResponseAsList(byte[] response, Class<T> elementTyp
}
}

JsonNode deserializeResponseAsTree(byte[] response) {
try {
return objectMapper.readTree(response);
} catch (IOException ex) {
throw new UncheckedIOException(ex);
}
}

private String[] getAuthenticationHeaders(String apiKey, Optional<String> organization) {
List<String> authHeaders = new ArrayList<>();
authHeaders.add("Authorization");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
package io.github.stefanbratanov.jvm.openai;

import java.nio.file.Path;
import java.util.List;
import java.util.Optional;

public record TranscriptionRequest(
Path file,
String model,
Optional<String> language,
Optional<String> prompt,
Optional<Double> temperature) {
Optional<String> responseFormat,
Optional<Double> temperature,
Optional<List<String>> timestampGranularities) {

public static Builder newBuilder() {
return new Builder();
Expand All @@ -22,7 +25,9 @@ public static class Builder {
private String model = DEFAULT_MODEL;
private Optional<String> language = Optional.empty();
private Optional<String> prompt = Optional.empty();
private Optional<String> responseFormat = Optional.empty();
private Optional<Double> temperature = Optional.empty();
private Optional<List<String>> timestampGranularities = Optional.empty();

/**
* @param file The audio file object (not file name) to transcribe, in one of these formats:
Expand Down Expand Up @@ -62,6 +67,14 @@ public Builder prompt(String prompt) {
return this;
}

/**
* @param responseFormat The format of the transcript output
*/
public Builder responseFormat(String responseFormat) {
this.responseFormat = Optional.of(responseFormat);
return this;
}

/**
* @param temperature The sampling temperature, between 0 and 1. Higher values like 0.8 will
* make the output more random, while lower values like 0.2 will make it more focused and
Expand All @@ -78,11 +91,22 @@ public Builder temperature(double temperature) {
return this;
}

/**
* @param timestampGranularities The timestamp granularities to populate for this transcription.
* Any of these options: word, or segment. Note: There is no additional latency for segment
* timestamps, but generating word timestamps incurs additional latency.
*/
public Builder timestampGranularities(List<String> timestampGranularities) {
this.timestampGranularities = Optional.of(timestampGranularities);
return this;
}

public TranscriptionRequest build() {
if (file == null) {
throw new IllegalStateException("file must be set");
}
return new TranscriptionRequest(file, model, language, prompt, temperature);
return new TranscriptionRequest(
file, model, language, prompt, responseFormat, temperature, timestampGranularities);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import java.util.Optional;

public record TranslationRequest(
Path file, String model, Optional<String> prompt, Optional<Double> temperature) {
Path file,
String model,
Optional<String> prompt,
Optional<String> responseFormat,
Optional<Double> temperature) {

public static Builder newBuilder() {
return new Builder();
Expand All @@ -17,6 +21,7 @@ public static class Builder {
private Path file;
private String model = DEFAULT_MODEL;
private Optional<String> prompt = Optional.empty();
private Optional<String> responseFormat = Optional.empty();
private Optional<Double> temperature = Optional.empty();

/**
Expand Down Expand Up @@ -47,6 +52,14 @@ public Builder prompt(String prompt) {
return this;
}

/**
* @param responseFormat The format of the translation output
*/
public Builder responseFormat(String responseFormat) {
this.responseFormat = Optional.of(responseFormat);
return this;
}

/**
* @param temperature The sampling temperature, between 0 and 1. Higher values like 0.8 will
* make the output more random, while lower values like 0.2 will make it more focused and
Expand All @@ -67,7 +80,7 @@ public TranslationRequest build() {
if (file == null) {
throw new IllegalStateException("file must be set");
}
return new TranslationRequest(file, model, prompt, temperature);
return new TranslationRequest(file, model, prompt, responseFormat, temperature);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,20 +144,30 @@ void testAudioClient(@TempDir Path tempDir) {
assertThat(speech).exists().isNotEmptyFile();

TranscriptionRequest transcriptionRequest =
TranscriptionRequest.newBuilder().file(speech).model("whisper-1").build();
TranscriptionRequest.newBuilder()
.file(speech)
.model("whisper-1")
.responseFormat("text")
.build();

String transcript = audioClient.createTranscript(transcriptionRequest);

assertThat(transcript).isEqualToIgnoringCase("The quick brown fox jumped over the lazy dog.");
assertThat(transcript)
.isEqualToIgnoringNewLines("The quick brown fox jumped over the lazy dog.");

Path greeting = getTestResource("/italian-greeting.mp3");

TranslationRequest translationRequest =
TranslationRequest.newBuilder().file(greeting).model("whisper-1").build();
TranslationRequest.newBuilder()
.file(greeting)
.model("whisper-1")
.responseFormat("json")
.build();

String translation = audioClient.createTranslation(translationRequest);

assertThat(translation).isEqualTo("My name is Diego. What's your name?");
assertThat(translation)
.isEqualToIgnoringWhitespace("{\"text\":\"My name is Diego. What's your name?\"}");
}

@Test // using mock server because image models are costly
Expand Down

0 comments on commit 8d9b476

Please sign in to comment.