Skip to content

Commit

Permalink
GPT4All Java Bindings and supported models list (#216)
Browse files Browse the repository at this point in the history
* Generic AI client and models with open-ai client impl from https://github.com/aallam/openai-kotlin

* type LLM models based on their capabilities and type the operations

* add token as parameter to `openAI` fn falling back to env variable

* add config as optional parameter

* remove old config

* adapt to latest changes from main and new java module

* have openai be its own module that depends on xef-core. kotlin, scala and java depends on openai module for defaults. xef core does not depend on open ai

* fix bug in scala fn name for serialization

* make AIClient : AutoCloseable

* Rename enum cases

* Rename to TEXT_EMBEDDING_ADA_002

* Fix AIClient close expectation

* Progress with models

* Refactor to have models typed and increase ergonomics

* Loading embeddings and tokenizer from huggingface, dynamic loading of local models. Local models can be use in the AI DSL and interleaved with any model.

* remove non used repositories

* Fix functions model to GPT_3_5_TURBO_FUNCTIONS and example without AI block and manual component construction

* remove unused import

* GPT4All Java Bindings and supported models list + std out streaming support

---------

Co-authored-by: Simon Vergauwen <nomisRev@users.noreply.github.com>
  • Loading branch information
raulraja and nomisRev authored Jul 3, 2023
1 parent b1ebeea commit b47e93e
Show file tree
Hide file tree
Showing 26 changed files with 274 additions and 341 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,15 @@ class CoreAIScope(
functions: List<CFunction>,
serializer: (json: String) -> A,
promptConfiguration: PromptConfiguration,
): A =
prompt(
): A {
return prompt(
prompt = Prompt(prompt),
context = context,
functions = functions,
serializer = serializer,
promptConfiguration = promptConfiguration,
)
}

@AiDsl
suspend fun Chat.promptMessage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class PromptConfiguration(
val numberOfPredictions: Int = 1,
val docsInContext: Int = 20,
val minResponseTokens: Int = 500,
val streamToStandardOut: Boolean = false
) {
companion object {

Expand All @@ -21,11 +22,16 @@ class PromptConfiguration(
private var numberOfPredictions: Int = 1
private var docsInContext: Int = 20
private var minResponseTokens: Int = 500
private var streamToStandardOut: Boolean = false

fun maxDeserializationAttempts(maxDeserializationAttempts: Int) = apply {
this.maxDeserializationAttempts = maxDeserializationAttempts
}

fun streamToStandardOut(streamToStandardOut: Boolean) = apply {
this.streamToStandardOut = streamToStandardOut
}

fun user(user: String) = apply { this.user = user }

fun temperature(temperature: Double) = apply { this.temperature = temperature }
Expand All @@ -47,7 +53,8 @@ class PromptConfiguration(
temperature,
numberOfPredictions,
docsInContext,
minResponseTokens
minResponseTokens,
streamToStandardOut
)
}

Expand Down
25 changes: 15 additions & 10 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -51,28 +51,33 @@ interface Chat : LLM {
return totalLeftTokens
}

val userMessage = Message(Role.USER, promptWithContext, Role.USER.name)
fun buildChatRequest(): ChatCompletionRequest =
ChatCompletionRequest(
fun buildChatRequest(): ChatCompletionRequest {
val messages: List<Message> = listOf(Message(Role.USER, promptWithContext, Role.USER.name))
return ChatCompletionRequest(
model = name,
user = promptConfiguration.user,
messages = listOf(userMessage),
messages = messages,
n = promptConfiguration.numberOfPredictions,
temperature = promptConfiguration.temperature,
maxTokens = checkTotalLeftChatTokens(listOf(userMessage))
maxTokens = checkTotalLeftChatTokens(messages),
streamToStandardOut = promptConfiguration.streamToStandardOut
)
}

fun chatWithFunctionsRequest(): ChatCompletionRequestWithFunctions =
ChatCompletionRequestWithFunctions(
fun chatWithFunctionsRequest(): ChatCompletionRequestWithFunctions {
val firstFnName: String? = functions.firstOrNull()?.name
val messages: List<Message> = listOf(Message(Role.USER, promptWithContext, Role.USER.name))
return ChatCompletionRequestWithFunctions(
model = name,
user = promptConfiguration.user,
messages = listOf(userMessage),
messages = messages,
n = promptConfiguration.numberOfPredictions,
temperature = promptConfiguration.temperature,
maxTokens = checkTotalLeftChatTokens(listOf(userMessage)),
maxTokens = checkTotalLeftChatTokens(messages),
functions = functions,
functionCall = mapOf("name" to (functions.firstOrNull()?.name ?: ""))
functionCall = mapOf("name" to (firstFnName ?: ""))
)
}

return when (this) {
is ChatWithFunctions ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ interface ChatWithFunctions : Chat {
functions: List<CFunction>,
serializer: (json: String) -> A,
promptConfiguration: PromptConfiguration,
): A =
tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) {
): A {
return tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) {
promptMessage(
prompt = Prompt(prompt),
context = context,
functions = functions,
promptConfiguration
)
}
}

@AiDsl
suspend fun <A> prompt(
Expand All @@ -41,10 +42,11 @@ interface ChatWithFunctions : Chat {
functions: List<CFunction>,
serializer: (json: String) -> A,
promptConfiguration: PromptConfiguration,
): A =
tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) {
): A {
return tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) {
promptMessage(prompt = prompt, context = context, functions = functions, promptConfiguration)
}
}

private suspend fun <A> tryDeserialize(
serializer: (json: String) -> A,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ data class ChatCompletionRequest(
val presencePenalty: Double = 0.0,
val frequencyPenalty: Double = 0.0,
val logitBias: Map<String, Int> = emptyMap(),
val user: String?
val user: String?,
val streamToStandardOut: Boolean = false,
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ data class CompletionRequest(
val frequencyPenalty: Double = 0.0,
val bestOf: Int = 1,
val logitBias: Map<String, Int> = emptyMap(),
val streamToStandardOut: Boolean = false
)
Original file line number Diff line number Diff line change
@@ -1,49 +1,44 @@
package com.xebia.functional.xef.auto.gpt4all

import com.xebia.functional.gpt4all.GPT4All
import com.xebia.functional.gpt4all.LLModel
import com.xebia.functional.gpt4all.Gpt4AllModel
import com.xebia.functional.gpt4all.getOrThrow
import com.xebia.functional.gpt4all.huggingFaceUrl
import com.xebia.functional.xef.auto.PromptConfiguration
import com.xebia.functional.xef.auto.ai
import com.xebia.functional.xef.auto.llm.openai.OpenAI
import com.xebia.functional.xef.pdf.pdf
import java.nio.file.Path

suspend fun main() {
val userDir = System.getProperty("user.dir")
val path = "$userDir/models/gpt4all/ggml-gpt4all-j-v1.3-groovy.bin"
val url = huggingFaceUrl("orel12", "ggml-gpt4all-j-v1.3-groovy", "bin")
val modelType = LLModel.Type.GPTJ
val modelPath: Path = Path.of(path)
val GPT4All = GPT4All(url, modelPath, modelType)
val path = "$userDir/models/gpt4all/ggml-replit-code-v1-3b.bin"

println("🤖 GPT4All loaded: $GPT4All")
val supportedModels = Gpt4AllModel.supportedModels()
supportedModels.forEach {
println("🤖 ${it.name} ${it.url?.let { "- $it" }}")
}

val pdfUrl = "https://www.europarl.europa.eu/RegData/etudes/STUD/2023/740063/IPOL_STU(2023)740063_EN.pdf"
val url = "https://huggingface.co/nomic-ai/ggml-replit-code-v1-3b/resolve/main/ggml-replit-code-v1-3b.bin"
val modelPath: Path = Path.of(path)
val GPT4All = GPT4All(url, modelPath)

println("🤖 GPT4All loaded: $GPT4All")
/**
* Uses internally [HuggingFaceLocalEmbeddings] default of "sentence-transformers", "msmarco-distilbert-dot-v5"
* to provide embeddings for docs in contextScope.
*/

ai {
println("🤖 Loading PDF: $pdfUrl")
contextScope(pdf(pdfUrl)) {
println("🤖 Context loaded: $context")
GPT4All.use { gpT4All: GPT4All ->
println("🤖 Generating prompt for context")
val prompt = gpT4All.promptMessage(
"Describe in one sentence what the context is about.",
println("🤖 Context loaded: $context")
GPT4All.use { gpT4All: GPT4All ->
println("🤖 Generating prompt for context")
while (true) {
println("🤖 Enter your prompt: ")
val userInput = readlnOrNull() ?: break
gpT4All.promptMessage(
userInput,
promptConfiguration = PromptConfiguration {
docsInContext(2)
streamToStandardOut(true)
})
println("🤖 Generating images for prompt: \n\n$prompt")
val images =
OpenAI.DEFAULT_IMAGES.images(prompt.joinToString("\n"), promptConfiguration = PromptConfiguration {
docsInContext(1)
})
println("🤖 Generated images: \n\n${images.data.joinToString("\n") { it.url }}")
}
}
}.getOrThrow()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,47 @@
package com.xebia.functional.xef.auto.manual

import com.xebia.functional.gpt4all.GPT4All
import com.xebia.functional.gpt4all.HuggingFaceLocalEmbeddings
import com.xebia.functional.xef.auto.llm.openai.OpenAI
import com.xebia.functional.gpt4all.huggingFaceUrl
import com.xebia.functional.xef.auto.PromptConfiguration
import com.xebia.functional.xef.pdf.pdf
import com.xebia.functional.xef.vectorstores.LocalVectorStore
import java.nio.file.Path

suspend fun main() {
val chat = OpenAI.DEFAULT_CHAT
val huggingFaceEmbeddings = HuggingFaceLocalEmbeddings.DEFAULT
val vectorStore = LocalVectorStore(huggingFaceEmbeddings)
val results = pdf("https://www.europarl.europa.eu/RegData/etudes/STUD/2023/740063/IPOL_STU(2023)740063_EN.pdf")
// Choose your base folder for downloaded models
val userDir = System.getProperty("user.dir")

// Specify the local model path
val modelPath: Path = Path.of("$userDir/models/gpt4all/ggml-gpt4all-j-v1.3-groovy.bin")

// Specify the Hugging Face URL for the model
val url = huggingFaceUrl("orel12", "ggml-gpt4all-j-v1.3-groovy", "bin")

// Create an instance of GPT4All with the local model
val gpt4All = GPT4All(url, modelPath)

// Create an instance of the OPENAI embeddings
val embeddings = HuggingFaceLocalEmbeddings.DEFAULT

// Create a LocalVectorStore and initialize it with OpenAI Embeddings
val vectorStore = LocalVectorStore(embeddings)

// Fetch and add texts from a PDF document to the vector store
val results = pdf("https://arxiv.org/pdf/2305.10601.pdf")
vectorStore.addTexts(results)
val result: List<String> = chat.promptMessage("What is the content about?", vectorStore)

// Prompt the GPT4All model with a question and provide the vector store for context
val result: List<String> = gpt4All.use {
it.promptMessage(
question = "What is the Tree of Thoughts framework about?",
context = vectorStore,
promptConfiguration = PromptConfiguration {
docsInContext(5)
}
)
}

// Print the response
println(result)
}
32 changes: 30 additions & 2 deletions gpt4all-kotlin/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ plugins {
alias(libs.plugins.spotless)
alias(libs.plugins.arrow.gradle.publish)
alias(libs.plugins.semver.gradle)
alias(libs.plugins.kotlinx.serialization)
}

repositories {
Expand All @@ -19,7 +20,25 @@ java {
}

kotlin {
jvm()
jvm {
compilations {
val integrationTest by compilations.creating {
// Create a test task to run the tests produced by this compilation:
tasks.register<Test>("integrationTest") {
description = "Run the integration tests"
group = "verification"
classpath = compileDependencyFiles + runtimeDependencyFiles + output.allOutputs
testClassesDirs = output.classesDirs

testLogging {
events("passed")
}
}
}
val test by compilations.getting
integrationTest.associateWith(test)
}
}

js(IR) {
browser()
Expand All @@ -43,7 +62,7 @@ kotlin {

val jvmMain by getting {
dependencies {
implementation("net.java.dev.jna:jna-platform:5.13.0")
implementation("com.hexadevlabs:gpt4all-java-binding:+")
implementation("ai.djl.huggingface:tokenizers:+")
}
}
Expand All @@ -60,6 +79,15 @@ kotlin {
}
}

tasks.withType<Test>().configureEach {
maxParallelForks = Runtime.getRuntime().availableProcessors()
useJUnitPlatform()
testLogging {
setExceptionFormat("full")
setEvents(listOf("passed", "skipped", "failed", "standardOut", "standardError"))
}
}

tasks.withType<AbstractPublishToMaven> {
dependsOn(tasks.withType<Sign>())
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit b47e93e

Please sign in to comment.