Skip to content

Commit

Permalink
Merge branch 'main' into fix-docs-scala
Browse files Browse the repository at this point in the history
  • Loading branch information
Yawolf authored Jul 12, 2023
2 parents bb1ed43 + 6f4b917 commit 1cb1223
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 23 deletions.
1 change: 0 additions & 1 deletion examples/java/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ plugins {
}

dependencies {
implementation(projects.xefCore)
implementation(projects.xefJava)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.xebia.functional.xef.java.auto;

import jakarta.validation.constraints.NotNull;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

public class Books {

private final AIScope scope;

public Books(AIScope scope) {
this.scope = scope;
}

public static class Book {
@NotNull public String title;
@NotNull public String author;
@NotNull public int year;
@NotNull public String genre;

@Override
public String toString() {
return "Book{" +
"title='" + title + '\'' +
", author='" + author + '\'' +
", year=" + year +
", genre='" + genre + '\'' +
'}';
}
}

public CompletableFuture<Books.Book> bookSelection(String topic) {
return scope.prompt("Give me a selection of books about " + topic, Books.Book.class);
}

public static void main(String[] args) throws ExecutionException, InterruptedException {
try (AIScope scope = new AIScope()) {
Books books = new Books(scope);
books.bookSelection("artificial intelligence")
.thenAccept(System.out::println)
.get();
}
}
}
7 changes: 7 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ scala = "3.3.0"
openai-client-version = "3.3.1"
gpt4all-java = "1.1.3"
ai-djl = "0.22.1"
jackson = "2.15.2"
jsonschema = "4.31.1"
jakarta = "3.0.2"

[libraries]
arrow-core = { module = "io.arrow-kt:arrow-core", version.ref = "arrow" }
Expand Down Expand Up @@ -82,6 +85,10 @@ scala-lang = { module = "org.scala-lang:scala3-library_3", version.ref = "scala"
openai-client = { module = "com.aallam.openai:openai-client", version.ref = "openai-client-version" }
gpt4all-java-bindings = { module = "com.hexadevlabs:gpt4all-java-binding", version.ref = "gpt4all-java" }
ai-djl-huggingface-tokenizers = { module = "ai.djl.huggingface:tokenizers", version.ref = "ai-djl" }
jackson = { module = "com.fasterxml.jackson.core:jackson-databind", version.ref = "jackson" }
jackson-schema = { module = "com.github.victools:jsonschema-generator", version.ref = "jsonschema" }
jackson-schema-jakarta = { module = "com.github.victools:jsonschema-module-jakarta-validation", version.ref = "jsonschema" }
jakarta-validation = { module = "jakarta.validation:jakarta.validation-api", version.ref = "jakarta" }

[bundles]
arrow = [
Expand Down
14 changes: 8 additions & 6 deletions java/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
@file:Suppress("DSL_SCOPE_VIOLATION")

plugins {
java
`java-library`
alias(libs.plugins.semver.gradle)
alias(libs.plugins.spotless)
}

dependencies {
implementation(projects.xefCore)
implementation(projects.xefOpenai)
implementation(projects.xefPdf)
implementation("com.fasterxml.jackson.core:jackson-databind:2.15.2")
implementation("com.fasterxml.jackson.module:jackson-module-jsonSchema:2.15.2")
api(projects.xefCore)
api(projects.xefOpenai)
api(projects.xefPdf)
api(libs.jackson)
api(libs.jackson.schema)
api(libs.jackson.schema.jakarta)
api(libs.jakarta.validation)
}

tasks.withType<Test>().configureEach {
Expand Down
35 changes: 19 additions & 16 deletions java/src/main/java/com/xebia/functional/xef/java/auto/AIScope.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.module.jsonSchema.JsonSchema;
import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator;
import com.github.victools.jsonschema.generator.*;
import com.github.victools.jsonschema.module.jakarta.validation.JakartaValidationModule;
import com.github.victools.jsonschema.module.jakarta.validation.JakartaValidationOption;
import com.xebia.functional.xef.auto.CoreAIScope;
import com.xebia.functional.xef.auto.PromptConfiguration;
import com.xebia.functional.xef.auto.llm.openai.OpenAI;
Expand Down Expand Up @@ -38,15 +39,22 @@
public class AIScope implements AutoCloseable {
private final CoreAIScope scope;
private final ObjectMapper om;
private final JsonSchemaGenerator schemaGen;
private final SchemaGenerator schemaGenerator;
private final ExecutorService executorService;
private final CoroutineScope coroutineScope;

public AIScope(ObjectMapper om, Embeddings embeddings, ExecutorService executorService) {
this.om = om;
this.executorService = executorService;
this.coroutineScope = () -> ExecutorsKt.from(executorService).plus(JobKt.Job(null));
this.schemaGen = new JsonSchemaGenerator(om);
JakartaValidationModule module = new JakartaValidationModule(
JakartaValidationOption.NOT_NULLABLE_FIELD_IS_REQUIRED,
JakartaValidationOption.INCLUDE_PATTERN_EXPRESSIONS
);
SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_7, OptionPreset.PLAIN_JSON)
.with(module);
SchemaGeneratorConfig config = configBuilder.build();
this.schemaGenerator = new SchemaGenerator(config);
VectorStore vectorStore = new LocalVectorStore(embeddings);
this.scope = new CoreAIScope(embeddings, vectorStore);
}
Expand All @@ -56,22 +64,21 @@ public AIScope(Embeddings embeddings, ExecutorService executorService) {
}

public AIScope() {
this(new ObjectMapper(),new OpenAIEmbeddings(OpenAI.DEFAULT_EMBEDDING), Executors.newCachedThreadPool(new AIScopeThreadFactory()));
this(new ObjectMapper(), new OpenAIEmbeddings(OpenAI.DEFAULT_EMBEDDING), Executors.newCachedThreadPool(new AIScopeThreadFactory()));
}

private AIScope(CoreAIScope nested, AIScope outer) {
this.om = outer.om;
this.executorService = outer.executorService;
this.coroutineScope = outer.coroutineScope;
this.schemaGen = outer.schemaGen;
this.schemaGenerator = outer.schemaGenerator;
this.scope = nested;
}

public <A> CompletableFuture<A> prompt(String prompt, Class<A> cls) {
return prompt(prompt, cls, OpenAI.DEFAULT_SERIALIZATION, PromptConfiguration.DEFAULTS);
}


public <A> CompletableFuture<A> prompt(String prompt, Class<A> cls, ChatWithFunctions llmModel, PromptConfiguration promptConfiguration) {
Function1<? super String, ? extends A> decoder = json -> {
try {
Expand All @@ -82,15 +89,7 @@ public <A> CompletableFuture<A> prompt(String prompt, Class<A> cls, ChatWithFunc
}
};

String schema;
try {
JsonSchema jsonSchema = schemaGen.generateSchema(cls);
jsonSchema.setId(null);
schema = om.writeValueAsString(jsonSchema);
} catch (JsonProcessingException e) {
// TODO AIError ex = new AIError.JsonParsing(json, maxAttempts, e);
throw new RuntimeException(e);
}
String schema = schemaGenerator.generateSchema(cls).toString();

List<CFunction> functions = Collections.singletonList(
new CFunction(cls.getSimpleName(), "Generated function for " + cls.getSimpleName(), schema)
Expand All @@ -99,6 +98,10 @@ public <A> CompletableFuture<A> prompt(String prompt, Class<A> cls, ChatWithFunc
return future(continuation -> scope.promptWithSerializer(llmModel, prompt, functions, decoder, promptConfiguration, continuation));
}

public CompletableFuture<String> promptMessage(String prompt) {
return promptMessage(OpenAI.DEFAULT_CHAT, prompt, PromptConfiguration.DEFAULTS);
}

public CompletableFuture<String> promptMessage(Chat llmModel, String prompt, PromptConfiguration promptConfiguration) {
return future(continuation -> scope.promptMessage(llmModel, prompt, promptConfiguration, continuation));
}
Expand Down

0 comments on commit 1cb1223

Please sign in to comment.