diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt
index 8afe0c330..03ec0ce25 100644
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt
@@ -1,11 +1,7 @@
package com.xebia.functional.xef.auto
-import com.xebia.functional.xef.AIError
-import com.xebia.functional.xef.llm.AIClient
import com.xebia.functional.xef.vectorstores.VectorStore
-@DslMarker annotation class AiDsl
-
/**
* An [AI] value represents an action relying on artificial intelligence that can be run to produce
* an `A`. This value is _lazy_ and can be combined with other `AI` values using
@@ -18,10 +14,3 @@ typealias AI = suspend CoreAIScope.() -> A
/** A DSL block that makes it more convenient to construct [AI] values. */
inline fun ai(noinline block: suspend CoreAIScope.() -> A): AI = block
-
-suspend fun AIScope(runtime: AIRuntime, block: AI, orElse: suspend (AIError) -> A): A =
- try {
- runtime.runtime(block)
- } catch (e: AIError) {
- orElse(e)
- }
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AIRuntime.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AIRuntime.kt
deleted file mode 100644
index 312c8531e..000000000
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AIRuntime.kt
+++ /dev/null
@@ -1,12 +0,0 @@
-package com.xebia.functional.xef.auto
-
-import com.xebia.functional.xef.embeddings.Embeddings
-import com.xebia.functional.xef.llm.AIClient
-
-data class AIRuntime(
- val client: AIClient,
- val embeddings: Embeddings,
- val runtime: suspend (block: AI) -> A
-) {
- companion object
-}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AiDsl.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AiDsl.kt
new file mode 100644
index 000000000..aa19f97b9
--- /dev/null
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AiDsl.kt
@@ -0,0 +1,3 @@
+package com.xebia.functional.xef.auto
+
+@DslMarker annotation class AiDsl
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt
index e22ba22d0..4ca848e44 100644
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt
@@ -1,23 +1,11 @@
package com.xebia.functional.xef.auto
-import arrow.core.nonFatalOrThrow
-import arrow.core.raise.catch
-import com.xebia.functional.tokenizer.Encoding
-import com.xebia.functional.tokenizer.ModelType
-import com.xebia.functional.tokenizer.truncateText
-import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.embeddings.Embeddings
-import com.xebia.functional.xef.llm.AIClient
-import com.xebia.functional.xef.llm.LLM
-import com.xebia.functional.xef.llm.LLMModel
-import com.xebia.functional.xef.llm.models.chat.ChatCompletionRequest
-import com.xebia.functional.xef.llm.models.chat.ChatCompletionRequestWithFunctions
-import com.xebia.functional.xef.llm.models.chat.Message
-import com.xebia.functional.xef.llm.models.chat.Role
+import com.xebia.functional.xef.llm.Chat
+import com.xebia.functional.xef.llm.ChatWithFunctions
+import com.xebia.functional.xef.llm.Images
import com.xebia.functional.xef.llm.models.functions.CFunction
-import com.xebia.functional.xef.llm.models.images.ImagesGenerationRequest
import com.xebia.functional.xef.llm.models.images.ImagesGenerationResponse
-import com.xebia.functional.xef.llm.models.text.CompletionRequest
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.vectorstores.CombinedVectorStore
import com.xebia.functional.xef.vectorstores.LocalVectorStore
@@ -32,18 +20,8 @@ import kotlin.jvm.JvmName
* programs.
*/
class CoreAIScope(
- val defaultModel: LLM.Chat,
- val defaultSerializationModel: LLM.ChatWithFunctions,
- val aiClient: AIClient,
- val context: VectorStore,
val embeddings: Embeddings,
- val maxDeserializationAttempts: Int = 3,
- val user: String = "user",
- val echo: Boolean = false,
- val temperature: Double = 0.4,
- val numberOfPredictions: Int = 1,
- val docsInContext: Int = 20,
- val minResponseTokens: Int = 500
+ val context: VectorStore = LocalVectorStore(embeddings),
) {
val logger: KLogger = KotlinLogging.logger {}
@@ -98,11 +76,8 @@ class CoreAIScope(
@AiDsl
suspend fun contextScope(store: VectorStore, block: AI): A =
CoreAIScope(
- defaultModel,
- defaultSerializationModel,
- this@CoreAIScope.aiClient,
- CombinedVectorStore(store, this@CoreAIScope.context),
this@CoreAIScope.embeddings,
+ CombinedVectorStore(store, this@CoreAIScope.context),
)
.block()
@@ -119,252 +94,26 @@ class CoreAIScope(
@AiDsl
@JvmName("promptWithSerializer")
- suspend fun prompt(
- prompt: Prompt,
+ suspend fun ChatWithFunctions.prompt(
+ prompt: String,
functions: List,
serializer: (json: String) -> A,
- maxDeserializationAttempts: Int = this.maxDeserializationAttempts,
- model: LLM.ChatWithFunctions = defaultSerializationModel,
- user: String = this.user,
- echo: Boolean = this.echo,
- numberOfPredictions: Int = this.numberOfPredictions,
- temperature: Double = this.temperature,
- bringFromContext: Int = this.docsInContext,
- minResponseTokens: Int = this.minResponseTokens,
- ): A {
- return tryDeserialize(serializer, maxDeserializationAttempts) {
- promptMessage(
- prompt = prompt,
- model = model,
- functions = functions,
- user = user,
- echo = echo,
- numberOfPredictions = numberOfPredictions,
- temperature = temperature,
- bringFromContext = bringFromContext,
- minResponseTokens = minResponseTokens
- )
- }
- }
-
- suspend fun tryDeserialize(
- serializer: (json: String) -> A,
- maxDeserializationAttempts: Int,
- agent: AI>
- ): A {
- (0 until maxDeserializationAttempts).forEach { currentAttempts ->
- val result = agent().firstOrNull() ?: throw AIError.NoResponse()
- catch({
- return@tryDeserialize serializer(result)
- }) { e: Throwable ->
- logger.error(e) { "Error deserializing response: $result\n${e.message}" }
- if (currentAttempts == maxDeserializationAttempts)
- throw AIError.JsonParsing(result, maxDeserializationAttempts, e.nonFatalOrThrow())
- // TODO else log attempt ?
- }
- }
- throw AIError.NoResponse()
- }
-
- @AiDsl
- suspend fun promptMessage(
- question: String,
- model: LLM.Chat = defaultModel,
- functions: List = emptyList(),
- user: String = this.user,
- echo: Boolean = this.echo,
- n: Int = this.numberOfPredictions,
- temperature: Double = this.temperature,
- bringFromContext: Int = this.docsInContext,
- minResponseTokens: Int = this.minResponseTokens
- ): List =
- promptMessage(
- Prompt(question),
- model,
- functions,
- user,
- echo,
- n,
- temperature,
- bringFromContext,
- minResponseTokens
+ promptConfiguration: PromptConfiguration,
+ ): A =
+ prompt(
+ prompt = Prompt(prompt),
+ context = context,
+ functions = functions,
+ serializer = serializer,
+ promptConfiguration = promptConfiguration,
)
@AiDsl
- suspend fun promptMessage(
- prompt: Prompt,
- model: LLM.Chat = defaultModel,
+ suspend fun Chat.promptMessage(
+ question: String,
functions: List = emptyList(),
- user: String = this.user,
- echo: Boolean = this.echo,
- numberOfPredictions: Int = this.numberOfPredictions,
- temperature: Double = this.temperature,
- bringFromContext: Int = this.docsInContext,
- minResponseTokens: Int
- ): List {
-
- val promptWithContext: String =
- createPromptWithContextAwareOfTokens(
- ctxInfo = context.similaritySearch(prompt.message, bringFromContext),
- modelType = model.modelType,
- prompt = prompt.message,
- minResponseTokens = minResponseTokens
- )
-
- fun checkTotalLeftTokens(role: String): Int =
- with(model.modelType) {
- val roleTokens: Int = encoding.countTokens(role)
- val padding = 20 // reserve 20 tokens for additional symbols around the context
- val promptTokens: Int = encoding.countTokens(promptWithContext)
- val takenTokens: Int = roleTokens + promptTokens + padding
- val totalLeftTokens: Int = maxContextLength - takenTokens
- if (totalLeftTokens < 0) {
- throw AIError.PromptExceedsMaxTokenLength(
- promptWithContext,
- takenTokens,
- maxContextLength
- )
- }
- logger.debug {
- "Tokens -- used: $takenTokens, model max: $maxContextLength, left: $totalLeftTokens"
- }
- totalLeftTokens
- }
-
- suspend fun buildCompletionRequest(): CompletionRequest =
- CompletionRequest(
- model = model.name,
- user = user,
- prompt = promptWithContext,
- echo = echo,
- n = numberOfPredictions,
- temperature = temperature,
- maxTokens = checkTotalLeftTokens("")
- )
-
- fun checkTotalLeftChatTokens(messages: List): Int {
- val maxContextLength: Int = model.modelType.maxContextLength
- val messagesTokens: Int = tokensFromMessages(messages, model)
- val totalLeftTokens: Int = maxContextLength - messagesTokens
- if (totalLeftTokens < 0) {
- throw AIError.MessagesExceedMaxTokenLength(messages, messagesTokens, maxContextLength)
- }
- logger.debug {
- "Tokens -- used: $messagesTokens, model max: $maxContextLength, left: $totalLeftTokens"
- }
- return totalLeftTokens
- }
-
- suspend fun buildChatRequest(): ChatCompletionRequest {
- val messages: List = listOf(Message(Role.SYSTEM.name, promptWithContext))
- return ChatCompletionRequest(
- model = model.name,
- user = user,
- messages = messages,
- n = numberOfPredictions,
- temperature = temperature,
- maxTokens = checkTotalLeftChatTokens(messages)
- )
- }
-
- suspend fun chatWithFunctionsRequest(): ChatCompletionRequestWithFunctions {
- val role: String = Role.USER.name
- val firstFnName: String? = functions.firstOrNull()?.name
- val messages: List = listOf(Message(role, promptWithContext))
- return ChatCompletionRequestWithFunctions(
- model = model.name,
- user = user,
- messages = messages,
- n = numberOfPredictions,
- temperature = temperature,
- maxTokens = checkTotalLeftChatTokens(messages),
- functions = functions,
- functionCall = mapOf("name" to (firstFnName ?: ""))
- )
- }
-
- return when (model) {
- is LLM.Completion ->
- aiClient.createCompletion(buildCompletionRequest()).choices.map { it.text }
- is LLM.ChatWithFunctions ->
- aiClient.createChatCompletionWithFunctions(chatWithFunctionsRequest()).choices.mapNotNull {
- it.message?.functionCall?.arguments
- }
- else ->
- aiClient.createChatCompletion(buildChatRequest()).choices.mapNotNull { it.message?.content }
- }
- }
-
- private fun createPromptWithContextAwareOfTokens(
- ctxInfo: List,
- modelType: ModelType,
- prompt: String,
- minResponseTokens: Int,
- ): String {
- val maxContextLength: Int = modelType.maxContextLength
- val promptTokens: Int = modelType.encoding.countTokens(prompt)
- val remainingTokens: Int = maxContextLength - promptTokens - minResponseTokens
-
- return if (ctxInfo.isNotEmpty() && remainingTokens > minResponseTokens) {
- val ctx: String = ctxInfo.joinToString("\n")
-
- if (promptTokens >= maxContextLength) {
- throw AIError.PromptExceedsMaxTokenLength(prompt, promptTokens, maxContextLength)
- }
- // truncate the context if it's too long based on the max tokens calculated considering the
- // existing prompt tokens
- // alternatively we could summarize the context, but that's not implemented yet
- val ctxTruncated: String = modelType.encoding.truncateText(ctx, remainingTokens)
-
- """|```Context
- |${ctxTruncated}
- |```
- |The context is related to the question try to answer the `goal` as best as you can
- |or provide information about the found content
- |```goal
- |${prompt}
- |```
- |ANSWER:
- |"""
- .trimMargin()
- } else prompt
- }
-
- private fun tokensFromMessages(messages: List, model: LLM): Int {
- fun Encoding.countTokensFromMessages(tokensPerMessage: Int, tokensPerName: Int): Int =
- messages.sumOf { message ->
- countTokens(message.role) +
- (message.content?.let { countTokens(it) } ?: 0) +
- tokensPerMessage +
- (message.name?.let { tokensPerName } ?: 0)
- } + 3
-
- fun fallBackTo(fallbackModel: LLM, paddingTokens: Int): Int {
- logger.debug {
- "Warning: ${model.name} may change over time. " +
- "Returning messages num tokens assuming ${fallbackModel.name} + $paddingTokens padding tokens."
- }
- return tokensFromMessages(messages, fallbackModel) + paddingTokens
- }
-
- return when (model) {
- LLMModel.GPT_3_5_TURBO_FUNCTIONS ->
- // paddingToken = 200: reserved for functions
- fallBackTo(fallbackModel = LLMModel.GPT_3_5_TURBO_0301, paddingTokens = 200)
- LLMModel.GPT_3_5_TURBO ->
- // otherwise if the model changes, it might later fail
- fallBackTo(fallbackModel = LLMModel.GPT_3_5_TURBO_0301, paddingTokens = 5)
- LLMModel.GPT_4,
- LLMModel.GPT_4_32K ->
- // otherwise if the model changes, it might later fail
- fallBackTo(fallbackModel = LLMModel.GPT_4_0314, paddingTokens = 5)
- LLMModel.GPT_3_5_TURBO_0301 ->
- model.modelType.encoding.countTokensFromMessages(tokensPerMessage = 4, tokensPerName = 0)
- LLMModel.GPT_4_0314 ->
- model.modelType.encoding.countTokensFromMessages(tokensPerMessage = 3, tokensPerName = 2)
- else -> fallBackTo(fallbackModel = LLMModel.GPT_3_5_TURBO_0301, paddingTokens = 20)
- }
- }
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
+ ): List = promptMessage(Prompt(question), context, functions, promptConfiguration)
/**
* Run a [prompt] describes the images you want to generate within the context of [CoreAIScope].
@@ -374,13 +123,12 @@ class CoreAIScope(
* @param numberImages number of images to generate.
* @param size the size of the images to generate.
*/
- suspend fun images(
+ suspend fun Images.images(
prompt: String,
- user: String = "testing",
numberImages: Int = 1,
size: String = "1024x1024",
- bringFromContext: Int = 10
- ): ImagesGenerationResponse = images(Prompt(prompt), user, numberImages, size, bringFromContext)
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
+ ): ImagesGenerationResponse = this.images(Prompt(prompt), numberImages, size, promptConfiguration)
/**
* Run a [prompt] describes the images you want to generate within the context of [CoreAIScope].
@@ -390,33 +138,10 @@ class CoreAIScope(
* @param numberImages number of images to generate.
* @param size the size of the images to generate.
*/
- suspend fun images(
+ suspend fun Images.images(
prompt: Prompt,
- user: String = "testing",
numberImages: Int = 1,
size: String = "1024x1024",
- bringFromContext: Int = 10
- ): ImagesGenerationResponse {
- val ctxInfo = context.similaritySearch(prompt.message, bringFromContext)
- val promptWithContext =
- if (ctxInfo.isNotEmpty()) {
- """|Instructions: Use the [Information] below delimited by 3 backticks to accomplish
- |the [Objective] at the end of the prompt.
- |Try to match the data returned in the [Objective] with this [Information] as best as you can.
- |[Information]:
- |```
- |${ctxInfo.joinToString("\n")}
- |```
- |$prompt"""
- .trimMargin()
- } else prompt.message
- val request =
- ImagesGenerationRequest(
- prompt = promptWithContext,
- numberImages = numberImages,
- size = size,
- user = user
- )
- return aiClient.createImages(request)
- }
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
+ ): ImagesGenerationResponse = images(prompt, context, numberImages, size, promptConfiguration)
}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/PromptConfiguration.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/PromptConfiguration.kt
new file mode 100644
index 000000000..4bc979405
--- /dev/null
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/PromptConfiguration.kt
@@ -0,0 +1,59 @@
+package com.xebia.functional.xef.auto
+
+import com.xebia.functional.xef.llm.models.chat.Role
+import kotlin.jvm.JvmField
+import kotlin.jvm.JvmName
+
+class PromptConfiguration(
+ val maxDeserializationAttempts: Int = 3,
+ val user: String = Role.USER.name,
+ val temperature: Double = 0.4,
+ val numberOfPredictions: Int = 1,
+ val docsInContext: Int = 20,
+ val minResponseTokens: Int = 500,
+) {
+ companion object {
+
+ class Builder {
+ private var maxDeserializationAttempts: Int = 3
+ private var user: String = Role.USER.name
+ private var temperature: Double = 0.4
+ private var numberOfPredictions: Int = 1
+ private var docsInContext: Int = 20
+ private var minResponseTokens: Int = 500
+
+ fun maxDeserializationAttempts(maxDeserializationAttempts: Int) = apply {
+ this.maxDeserializationAttempts = maxDeserializationAttempts
+ }
+
+ fun user(user: String) = apply { this.user = user }
+
+ fun temperature(temperature: Double) = apply { this.temperature = temperature }
+
+ fun numberOfPredictions(numberOfPredictions: Int) = apply {
+ this.numberOfPredictions = numberOfPredictions
+ }
+
+ fun docsInContext(docsInContext: Int) = apply { this.docsInContext = docsInContext }
+
+ fun minResponseTokens(minResponseTokens: Int) = apply {
+ this.minResponseTokens = minResponseTokens
+ }
+
+ fun build() =
+ PromptConfiguration(
+ maxDeserializationAttempts,
+ user,
+ temperature,
+ numberOfPredictions,
+ docsInContext,
+ minResponseTokens
+ )
+ }
+
+ @JvmName("build")
+ operator fun invoke(block: Builder.() -> Unit) = Builder().apply(block).build()
+
+ @JvmField val DEFAULTS = PromptConfiguration()
+ }
+}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/AIClientError.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/AIClientError.kt
deleted file mode 100644
index 580d47c4c..000000000
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/AIClientError.kt
+++ /dev/null
@@ -1,5 +0,0 @@
-package com.xebia.functional.xef.llm
-
-import kotlinx.serialization.json.JsonElement
-
-data class AIClientError(val json: JsonElement) : Exception("AI client error")
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt
new file mode 100644
index 000000000..2df27608e
--- /dev/null
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt
@@ -0,0 +1,125 @@
+package com.xebia.functional.xef.llm
+
+import com.xebia.functional.tokenizer.ModelType
+import com.xebia.functional.tokenizer.truncateText
+import com.xebia.functional.xef.AIError
+import com.xebia.functional.xef.auto.AiDsl
+import com.xebia.functional.xef.auto.PromptConfiguration
+import com.xebia.functional.xef.llm.models.chat.*
+import com.xebia.functional.xef.llm.models.functions.CFunction
+import com.xebia.functional.xef.prompt.Prompt
+import com.xebia.functional.xef.vectorstores.VectorStore
+
+interface Chat : LLM {
+ val modelType: ModelType
+
+ suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse
+
+ fun tokensFromMessages(messages: List): Int
+
+ @AiDsl
+ suspend fun promptMessage(
+ question: String,
+ context: VectorStore,
+ functions: List = emptyList(),
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
+ ): List = promptMessage(Prompt(question), context, functions, promptConfiguration)
+
+ @AiDsl
+ suspend fun promptMessage(
+ prompt: Prompt,
+ context: VectorStore,
+ functions: List = emptyList(),
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
+ ): List {
+
+ val promptWithContext: String =
+ createPromptWithContextAwareOfTokens(
+ ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext),
+ modelType = modelType,
+ prompt = prompt.message,
+ minResponseTokens = promptConfiguration.minResponseTokens
+ )
+
+ fun checkTotalLeftChatTokens(messages: List): Int {
+ val maxContextLength: Int = modelType.maxContextLength
+ val messagesTokens: Int = tokensFromMessages(messages)
+ val totalLeftTokens: Int = maxContextLength - messagesTokens
+ if (totalLeftTokens < 0) {
+ throw AIError.MessagesExceedMaxTokenLength(messages, messagesTokens, maxContextLength)
+ }
+ return totalLeftTokens
+ }
+
+ val userMessage = Message(Role.USER, promptWithContext, Role.USER.name)
+ fun buildChatRequest(): ChatCompletionRequest =
+ ChatCompletionRequest(
+ model = name,
+ user = promptConfiguration.user,
+ messages = listOf(userMessage),
+ n = promptConfiguration.numberOfPredictions,
+ temperature = promptConfiguration.temperature,
+ maxTokens = checkTotalLeftChatTokens(listOf(userMessage))
+ )
+
+ fun chatWithFunctionsRequest(): ChatCompletionRequestWithFunctions =
+ ChatCompletionRequestWithFunctions(
+ model = name,
+ user = promptConfiguration.user,
+ messages = listOf(userMessage),
+ n = promptConfiguration.numberOfPredictions,
+ temperature = promptConfiguration.temperature,
+ maxTokens = checkTotalLeftChatTokens(listOf(userMessage)),
+ functions = functions,
+ functionCall = mapOf("name" to (functions.firstOrNull()?.name ?: ""))
+ )
+
+ return when (this) {
+ is ChatWithFunctions ->
+ // we only support functions for now with GPT_3_5_TURBO_FUNCTIONS
+ if (modelType == ModelType.GPT_3_5_TURBO_FUNCTIONS) {
+ createChatCompletionWithFunctions(chatWithFunctionsRequest()).choices.mapNotNull {
+ it.message?.functionCall?.arguments
+ }
+ } else {
+ createChatCompletion(buildChatRequest()).choices.mapNotNull { it.message?.content }
+ }
+ else -> createChatCompletion(buildChatRequest()).choices.mapNotNull { it.message?.content }
+ }
+ }
+
+ private fun createPromptWithContextAwareOfTokens(
+ ctxInfo: List,
+ modelType: ModelType,
+ prompt: String,
+ minResponseTokens: Int,
+ ): String {
+ val maxContextLength: Int = modelType.maxContextLength
+ val promptTokens: Int = modelType.encoding.countTokens(prompt)
+ val remainingTokens: Int = maxContextLength - promptTokens - minResponseTokens
+
+ return if (ctxInfo.isNotEmpty() && remainingTokens > minResponseTokens) {
+ val ctx: String = ctxInfo.joinToString("\n")
+
+ if (promptTokens >= maxContextLength) {
+ throw AIError.PromptExceedsMaxTokenLength(prompt, promptTokens, maxContextLength)
+ }
+ // truncate the context if it's too long based on the max tokens calculated considering the
+ // existing prompt tokens
+ // alternatively we could summarize the context, but that's not implemented yet
+ val ctxTruncated: String = modelType.encoding.truncateText(ctx, remainingTokens)
+
+ """|```Context
+ |${ctxTruncated}
+ |```
+ |The context is related to the question try to answer the `goal` as best as you can
+ |or provide information about the found content
+ |```goal
+ |${prompt}
+ |```
+ |ANSWER:
+ |"""
+ .trimMargin()
+ } else prompt
+ }
+}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt
new file mode 100644
index 000000000..2de5b177b
--- /dev/null
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt
@@ -0,0 +1,66 @@
+package com.xebia.functional.xef.llm
+
+import arrow.core.nonFatalOrThrow
+import arrow.core.raise.catch
+import com.xebia.functional.xef.AIError
+import com.xebia.functional.xef.auto.AiDsl
+import com.xebia.functional.xef.auto.PromptConfiguration
+import com.xebia.functional.xef.llm.models.chat.ChatCompletionRequestWithFunctions
+import com.xebia.functional.xef.llm.models.chat.ChatCompletionResponseWithFunctions
+import com.xebia.functional.xef.llm.models.functions.CFunction
+import com.xebia.functional.xef.prompt.Prompt
+import com.xebia.functional.xef.vectorstores.VectorStore
+
+interface ChatWithFunctions : Chat {
+
+ suspend fun createChatCompletionWithFunctions(
+ request: ChatCompletionRequestWithFunctions
+ ): ChatCompletionResponseWithFunctions
+
+ @AiDsl
+ suspend fun prompt(
+ prompt: String,
+ context: VectorStore,
+ functions: List,
+ serializer: (json: String) -> A,
+ promptConfiguration: PromptConfiguration,
+ ): A =
+ tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) {
+ promptMessage(
+ prompt = Prompt(prompt),
+ context = context,
+ functions = functions,
+ promptConfiguration
+ )
+ }
+
+ @AiDsl
+ suspend fun prompt(
+ prompt: Prompt,
+ context: VectorStore,
+ functions: List,
+ serializer: (json: String) -> A,
+ promptConfiguration: PromptConfiguration,
+ ): A =
+ tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) {
+ promptMessage(prompt = prompt, context = context, functions = functions, promptConfiguration)
+ }
+
+ private suspend fun tryDeserialize(
+ serializer: (json: String) -> A,
+ maxDeserializationAttempts: Int,
+ agent: suspend () -> List
+ ): A {
+ (0 until maxDeserializationAttempts).forEach { currentAttempts ->
+ val result = agent().firstOrNull() ?: throw AIError.NoResponse()
+ catch({
+ return@tryDeserialize serializer(result)
+ }) { e: Throwable ->
+ if (currentAttempts == maxDeserializationAttempts)
+ throw AIError.JsonParsing(result, maxDeserializationAttempts, e.nonFatalOrThrow())
+ // TODO else log attempt ?
+ }
+ }
+ throw AIError.NoResponse()
+ }
+}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Completion.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Completion.kt
new file mode 100644
index 000000000..96f20771b
--- /dev/null
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Completion.kt
@@ -0,0 +1,11 @@
+package com.xebia.functional.xef.llm
+
+import com.xebia.functional.tokenizer.ModelType
+import com.xebia.functional.xef.llm.models.text.CompletionRequest
+import com.xebia.functional.xef.llm.models.text.CompletionResult
+
+interface Completion : LLM {
+ val modelType: ModelType
+
+ suspend fun createCompletion(request: CompletionRequest): CompletionResult
+}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Embeddings.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Embeddings.kt
new file mode 100644
index 000000000..bf8805658
--- /dev/null
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Embeddings.kt
@@ -0,0 +1,8 @@
+package com.xebia.functional.xef.llm
+
+import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
+import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult
+
+interface Embeddings : LLM {
+ suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult
+}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Images.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Images.kt
new file mode 100644
index 000000000..d9c273b4f
--- /dev/null
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Images.kt
@@ -0,0 +1,45 @@
+package com.xebia.functional.xef.llm
+
+import com.xebia.functional.xef.auto.PromptConfiguration
+import com.xebia.functional.xef.llm.models.images.ImagesGenerationRequest
+import com.xebia.functional.xef.llm.models.images.ImagesGenerationResponse
+import com.xebia.functional.xef.prompt.Prompt
+import com.xebia.functional.xef.vectorstores.VectorStore
+
+interface Images : LLM {
+ suspend fun createImages(request: ImagesGenerationRequest): ImagesGenerationResponse
+
+ suspend fun images(
+ prompt: String,
+ context: VectorStore,
+ numberImages: Int = 1,
+ size: String = "1024x1024",
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
+ ): ImagesGenerationResponse =
+ images(Prompt(prompt), context, numberImages, size, promptConfiguration)
+
+ /**
+ * Run a [prompt] describes the images you want to generate within the context of [CoreAIScope].
+ * Returns a [ImagesGenerationResponse] containing time and urls with images generated.
+ *
+ * @param prompt a [Prompt] describing the images you want to generate.
+ * @param numberImages number of images to generate.
+ * @param size the size of the images to generate.
+ */
+ suspend fun images(
+ prompt: Prompt,
+ context: VectorStore,
+ numberImages: Int = 1,
+ size: String = "1024x1024",
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
+ ): ImagesGenerationResponse {
+ val request =
+ ImagesGenerationRequest(
+ prompt = prompt.message,
+ numberImages = numberImages,
+ size = size,
+ user = promptConfiguration.user
+ )
+ return createImages(request)
+ }
+}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLM.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLM.kt
new file mode 100644
index 000000000..95937867b
--- /dev/null
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLM.kt
@@ -0,0 +1,7 @@
+package com.xebia.functional.xef.llm
+
+sealed interface LLM : AutoCloseable {
+ val name: String
+
+ override fun close() {}
+}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLMModel.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLMModel.kt
deleted file mode 100644
index 8a19464c6..000000000
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLMModel.kt
+++ /dev/null
@@ -1,65 +0,0 @@
-package com.xebia.functional.xef.llm
-
-import com.xebia.functional.tokenizer.ModelType
-import kotlin.jvm.JvmStatic
-
-sealed interface LLM {
- val name: String
- val modelType: ModelType
-
- interface Chat : LLM
-
- interface Completion : LLM
-
- interface ChatWithFunctions : Chat
-
- interface Embedding : LLM
-
- interface Images : LLM {
- suspend fun createImage()
- }
-}
-
-sealed class LLMModel(override val name: String, override val modelType: ModelType) : LLM {
-
- data class Chat(override val name: String, override val modelType: ModelType) :
- LLMModel(name, modelType), LLM.Chat
-
- data class Completion(override val name: String, override val modelType: ModelType) :
- LLMModel(name, modelType), LLM.Completion
-
- data class ChatWithFunctions(override val name: String, override val modelType: ModelType) :
- LLMModel(name, modelType), LLM.ChatWithFunctions
-
- data class Embedding(override val name: String, override val modelType: ModelType) :
- LLMModel(name, modelType), LLM.Embedding
-
- companion object {
- @JvmStatic val GPT_4 = Chat("gpt-4", ModelType.GPT_4)
-
- @JvmStatic val GPT_4_0314 = Chat("gpt-4-0314", ModelType.GPT_4)
-
- @JvmStatic val GPT_4_32K = Chat("gpt-4-32k", ModelType.GPT_4_32K)
-
- @JvmStatic val GPT_3_5_TURBO = Chat("gpt-3.5-turbo", ModelType.GPT_3_5_TURBO)
-
- @JvmStatic val GPT_3_5_TURBO_16K = Chat("gpt-3.5-turbo-16k", ModelType.GPT_3_5_TURBO_16_K)
-
- @JvmStatic
- val GPT_3_5_TURBO_FUNCTIONS =
- ChatWithFunctions("gpt-3.5-turbo-0613", ModelType.GPT_3_5_TURBO_FUNCTIONS)
-
- @JvmStatic val GPT_3_5_TURBO_0301 = Chat("gpt-3.5-turbo-0301", ModelType.GPT_3_5_TURBO)
-
- @JvmStatic val TEXT_DAVINCI_003 = Completion("text-davinci-003", ModelType.TEXT_DAVINCI_003)
-
- @JvmStatic val TEXT_DAVINCI_002 = Completion("text-davinci-002", ModelType.TEXT_DAVINCI_002)
-
- @JvmStatic
- val TEXT_CURIE_001 = Completion("text-curie-001", ModelType.TEXT_SIMILARITY_CURIE_001)
-
- @JvmStatic val TEXT_BABBAGE_001 = Completion("text-babbage-001", ModelType.TEXT_BABBAGE_001)
-
- @JvmStatic val TEXT_ADA_001 = Completion("text-ada-001", ModelType.TEXT_ADA_001)
- }
-}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Message.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Message.kt
index 30fc0e8f5..9f6ef6ed6 100644
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Message.kt
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Message.kt
@@ -1,3 +1,3 @@
package com.xebia.functional.xef.llm.models.chat
-data class Message(val role: String, val content: String?, val name: String? = Role.ASSISTANT.name)
+data class Message(val role: Role, val content: String, val name: String)
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Role.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Role.kt
index 7eb3fa41b..61d0d8da8 100644
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Role.kt
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Role.kt
@@ -3,6 +3,5 @@ package com.xebia.functional.xef.llm.models.chat
enum class Role {
SYSTEM,
USER,
- ASSISTANT,
- FUNCTION
+ ASSISTANT
}
diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/text/CompletionRequest.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/text/CompletionRequest.kt
index 28bc34285..e13fb208d 100644
--- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/text/CompletionRequest.kt
+++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/text/CompletionRequest.kt
@@ -3,7 +3,7 @@ package com.xebia.functional.xef.llm.models.text
data class CompletionRequest(
val model: String,
val user: String,
- val prompt: String? = null,
+ val prompt: String,
val suffix: String? = null,
val maxTokens: Int? = null,
val temperature: Double? = null,
diff --git a/core/src/jvmMain/kotlin/com/xebia/functional/xef/loaders/ScrapeURLTextLoader.kt b/core/src/jvmMain/kotlin/com/xebia/functional/xef/loaders/ScrapeURLTextLoader.kt
index b64ec2846..e24c3f596 100644
--- a/core/src/jvmMain/kotlin/com/xebia/functional/xef/loaders/ScrapeURLTextLoader.kt
+++ b/core/src/jvmMain/kotlin/com/xebia/functional/xef/loaders/ScrapeURLTextLoader.kt
@@ -4,26 +4,31 @@ import it.skrape.core.htmlDocument
import it.skrape.fetcher.BrowserFetcher
import it.skrape.fetcher.response
import it.skrape.fetcher.skrape
+import java.lang.IllegalStateException
/** Creates a TextLoader based on a Path */
suspend fun ScrapeURLTextLoader(url: String): BaseLoader =
object : BaseLoader {
override suspend fun load(): List = buildList {
- skrape(BrowserFetcher) {
- request { this.url = url }
- response {
- htmlDocument {
- val cleanedText = cleanUpText(wholeText)
- add(
- """|
+ try {
+ skrape(BrowserFetcher) {
+ request { this.url = url }
+ response {
+ htmlDocument {
+ val cleanedText = cleanUpText(wholeText)
+ add(
+ """|
|Title: $titleText
|Info: $cleanedText
"""
- .trimIndent()
- )
+ .trimIndent()
+ )
+ }
}
}
+ } catch (e: IllegalStateException) {
+ // ignore
}
}
diff --git a/examples/java/src/main/java/com/xebia/functional/xef/java/auto/Animals.java b/examples/java/src/main/java/com/xebia/functional/xef/java/auto/Animals.java
index 6b50cb4cf..26821c2ae 100644
--- a/examples/java/src/main/java/com/xebia/functional/xef/java/auto/Animals.java
+++ b/examples/java/src/main/java/com/xebia/functional/xef/java/auto/Animals.java
@@ -27,20 +27,20 @@ public CompletableFuture story(Animal animal, Invention invention) {
return scope.prompt(storyPrompt, Story.class);
}
- private static class Animal {
+ public static class Animal {
public String name;
public String habitat;
public String diet;
}
- private static class Invention {
+ public static class Invention {
public String name;
public String inventor;
public int year;
public String purpose;
}
- private static class Story {
+ public static class Story {
public Animal animal;
public Invention invention;
public String text;
@@ -60,4 +60,4 @@ public static void main(String[] args) throws ExecutionException, InterruptedExc
)).get();
}
}
-}
\ No newline at end of file
+}
diff --git a/examples/kotlin/build.gradle.kts b/examples/kotlin/build.gradle.kts
index 5e5fd632f..fcb49ce1d 100644
--- a/examples/kotlin/build.gradle.kts
+++ b/examples/kotlin/build.gradle.kts
@@ -37,3 +37,5 @@ tasks.getByName("processResources") {
from("${projects.xefGpt4all.dependencyProject.buildDir}/processedResources/jvm/main")
into("$buildDir/resources/main")
}
+
+
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/ASCIIArt.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/ASCIIArt.kt
index 80ddc5994..aecf5116a 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/ASCIIArt.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/ASCIIArt.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
@@ -7,7 +9,7 @@ data class ASCIIArt(val art: String)
suspend fun main() {
val art: AI = ai {
- prompt("ASCII art of a cat dancing")
+ prompt( "ASCII art of a cat dancing")
}
println(art.getOrElse { ASCIIArt("¯\\_(ツ)_/¯" + "\n" + it.message) })
}
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Animal.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Animal.kt
index 0acab2387..e160b532e 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Animal.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Animal.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Book.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Book.kt
index 01f136e48..1ef28f3b6 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Book.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Book.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/BreakingNews.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/BreakingNews.kt
index 97ca4174e..e95d6485d 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/BreakingNews.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/BreakingNews.kt
@@ -1,6 +1,8 @@
package com.xebia.functional.xef.auto
import com.xebia.functional.xef.agents.search
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
import java.text.SimpleDateFormat
import java.util.Date
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/ChessAI.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/ChessAI.kt
index aa98fa445..08d02eaa2 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/ChessAI.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/ChessAI.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Colors.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Colors.kt
index a5e7d1a1e..76012c018 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Colors.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Colors.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/CustomRuntime.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/CustomRuntime.kt
deleted file mode 100644
index ec918c472..000000000
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/CustomRuntime.kt
+++ /dev/null
@@ -1,36 +0,0 @@
-package com.xebia.functional.xef.auto
-
-import com.xebia.functional.xef.auto.llm.openai.MockAIScope
-import com.xebia.functional.xef.auto.llm.openai.simpleMockAIClient
-import com.xebia.functional.xef.embeddings.Embedding
-import com.xebia.functional.xef.embeddings.Embeddings
-import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
-
-suspend fun main() {
- val program = ai {
- val love: List = promptMessage("tell me you like me with just emojis")
- println(love)
- }
- program.getOrElse(customRuntime()) { println(it) }
-}
-
-private fun fakeEmbeddings(): Embeddings = object : Embeddings {
- override suspend fun embedDocuments(
- texts: List,
- chunkSize: Int?,
- requestConfig: RequestConfig
- ): List = emptyList()
-
- override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List =
- emptyList()
-}
-
-private fun customRuntime(): AIRuntime {
- val client = simpleMockAIClient { it }
- return AIRuntime(client, fakeEmbeddings()) { block ->
- MockAIScope(
- client,
- block
- ) { throw it }
- }
-}
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/DivergentTasks.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/DivergentTasks.kt
index 06ceb404e..22dfe1935 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/DivergentTasks.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/DivergentTasks.kt
@@ -1,6 +1,8 @@
package com.xebia.functional.xef.auto
import com.xebia.functional.xef.agents.search
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Employee.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Employee.kt
index 57bdb34c4..df768dc3f 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Employee.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Employee.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Fact.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Fact.kt
index ce7b9fd61..4e7360e0e 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Fact.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Fact.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Love.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Love.kt
index 37167b62a..2bd525e21 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Love.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Love.kt
@@ -1,5 +1,8 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.promptMessage
+
suspend fun main() {
ai {
val love: List = promptMessage("tell me you like me with just emojis")
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Markets.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Markets.kt
index 7d5c556d3..d4cfc94e0 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Markets.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Markets.kt
@@ -1,6 +1,8 @@
package com.xebia.functional.xef.auto
import com.xebia.functional.xef.agents.search
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
import java.text.SimpleDateFormat
import java.util.Date
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/MealPlan.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/MealPlan.kt
index 6b14a25f5..9a671a7c5 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/MealPlan.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/MealPlan.kt
@@ -1,6 +1,8 @@
package com.xebia.functional.xef.auto
import com.xebia.functional.xef.agents.search
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/MeaningOfLife.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/MeaningOfLife.kt
index e3fc3ce8c..7b8861362 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/MeaningOfLife.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/MeaningOfLife.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Movie.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Movie.kt
index 0bc15fde1..d4c12e0e9 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Movie.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Movie.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/PDFDocument.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/PDFDocument.kt
index 3ca5ab0e5..7e8867a02 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/PDFDocument.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/PDFDocument.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrThrow
+import com.xebia.functional.xef.auto.llm.openai.prompt
import com.xebia.functional.xef.pdf.pdf
import kotlinx.serialization.Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Person.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Person.kt
index 9cb5aeaec..c55600482 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Person.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Person.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Planet.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Planet.kt
index 0666121aa..74a0636fc 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Planet.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Planet.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Poem.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Poem.kt
index 94dad1846..e12eefda2 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Poem.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Poem.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Population.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Population.kt
index 1572f1c9e..cd186fcc7 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Population.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Population.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.image
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Recipe.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Recipe.kt
index 3de145f1b..2208d83b7 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Recipe.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Recipe.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/TopAttraction.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/TopAttraction.kt
index 78591cf72..a5111569d 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/TopAttraction.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/TopAttraction.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/TouristAttraction.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/TouristAttraction.kt
index dfad77734..32c1636bd 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/TouristAttraction.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/TouristAttraction.kt
@@ -1,5 +1,7 @@
package com.xebia.functional.xef.auto
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Weather.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Weather.kt
index 77db41578..bc5bb60f0 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Weather.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Weather.kt
@@ -1,6 +1,8 @@
package com.xebia.functional.xef.auto
import com.xebia.functional.xef.agents.search
+import com.xebia.functional.xef.auto.llm.openai.getOrElse
+import com.xebia.functional.xef.auto.llm.openai.promptMessage
import io.github.oshai.kotlinlogging.KotlinLogging
suspend fun main() {
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/gpt4all/Chat.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/gpt4all/Chat.kt
index 28ca1be72..8eb84204e 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/gpt4all/Chat.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/gpt4all/Chat.kt
@@ -1,50 +1,52 @@
package com.xebia.functional.xef.auto.gpt4all
-import arrow.core.raise.either
-import arrow.core.raise.ensure
-import arrow.core.raise.recover
-import com.xebia.functional.gpt4all.*
+import com.xebia.functional.gpt4all.GPT4All
+import com.xebia.functional.gpt4all.LLModel
+import com.xebia.functional.gpt4all.getOrThrow
+import com.xebia.functional.gpt4all.huggingFaceUrl
+import com.xebia.functional.xef.auto.PromptConfiguration
+import com.xebia.functional.xef.auto.ai
+import com.xebia.functional.xef.auto.llm.openai.OpenAI
+import com.xebia.functional.xef.pdf.pdf
import java.nio.file.Path
-import java.util.*
-
-data class ChatError(val content: String)
suspend fun main() {
- recover({
- val resources = "models/gpt4all"
- val path = "$resources/ggml-gpt4all-j-v1.3-groovy.bin"
- val modelType = LLModel.Type.GPTJ
-
- val modelPath: Path = Path.of(path)
- ensure(modelPath.toFile().exists()) {
- ChatError("Model at ${modelPath.toAbsolutePath()} cannot be found.")
- }
-
- Scanner(System.`in`).use { scanner ->
- println("Loading model...")
-
- GPT4All(modelPath, modelType).use { gpt4All ->
- println("Model loaded!")
- print("Prompt: ")
-
- buildList {
- while (scanner.hasNextLine()) {
- val prompt: String = scanner.nextLine()
- if (prompt.equals("exit", ignoreCase = true)) { break }
-
- println("...")
- val promptMessage = Message(Message.Role.USER, prompt)
- add(promptMessage)
-
- val request = ChatCompletionRequest(this, GenerationConfig())
- val response: ChatCompletionResponse = gpt4All.createChatCompletion(request)
- println("Response: ${response.choices[0].content}")
-
- add(response.choices[0])
- print("Prompt: ")
- }
- }
- }
- }
- }) { println(it) }
+ val userDir = System.getProperty("user.dir")
+ val path = "$userDir/models/gpt4all/ggml-gpt4all-j-v1.3-groovy.bin"
+ val url = huggingFaceUrl("orel12", "ggml-gpt4all-j-v1.3-groovy", "bin")
+ val modelType = LLModel.Type.GPTJ
+ val modelPath: Path = Path.of(path)
+ val GPT4All = GPT4All(url, modelPath, modelType)
+
+ println("🤖 GPT4All loaded: $GPT4All")
+
+ val pdfUrl = "https://www.europarl.europa.eu/RegData/etudes/STUD/2023/740063/IPOL_STU(2023)740063_EN.pdf"
+
+ /**
+ * Uses internally [HuggingFaceLocalEmbeddings] default of "sentence-transformers", "msmarco-distilbert-dot-v5"
+ * to provide embeddings for docs in contextScope.
+ */
+
+ ai {
+ println("🤖 Loading PDF: $pdfUrl")
+ contextScope(pdf(pdfUrl)) {
+ println("🤖 Context loaded: $context")
+ GPT4All.use { gpT4All: GPT4All ->
+ println("🤖 Generating prompt for context")
+ val prompt = gpT4All.promptMessage(
+ "Describe in one sentence what the context is about.",
+ promptConfiguration = PromptConfiguration {
+ docsInContext(2)
+ })
+ println("🤖 Generating images for prompt: \n\n$prompt")
+ val images =
+ OpenAI.DEFAULT_IMAGES.images(prompt.joinToString("\n"), promptConfiguration = PromptConfiguration {
+ docsInContext(1)
+ })
+ println("🤖 Generated images: \n\n${images.data.joinToString("\n") { it.url }}")
+ }
+ }
+ }.getOrThrow()
}
+
+
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/manual/NoAI.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/manual/NoAI.kt
new file mode 100644
index 000000000..0dfcfcc3b
--- /dev/null
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/manual/NoAI.kt
@@ -0,0 +1,16 @@
+package com.xebia.functional.xef.auto.manual
+
+import com.xebia.functional.gpt4all.HuggingFaceLocalEmbeddings
+import com.xebia.functional.xef.auto.llm.openai.OpenAI
+import com.xebia.functional.xef.pdf.pdf
+import com.xebia.functional.xef.vectorstores.LocalVectorStore
+
+suspend fun main() {
+ val chat = OpenAI.DEFAULT_CHAT
+ val huggingFaceEmbeddings = HuggingFaceLocalEmbeddings.DEFAULT
+ val vectorStore = LocalVectorStore(huggingFaceEmbeddings)
+ val results = pdf("https://www.europarl.europa.eu/RegData/etudes/STUD/2023/740063/IPOL_STU(2023)740063_EN.pdf")
+ vectorStore.addTexts(results)
+ val result: List = chat.promptMessage("What is the content about?", vectorStore)
+ println(result)
+}
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/sql/DatabaseExample.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/sql/DatabaseExample.kt
index 223aa46d1..8ee08476c 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/sql/DatabaseExample.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/sql/DatabaseExample.kt
@@ -1,12 +1,15 @@
package com.xebia.functional.xef.auto.sql
import arrow.core.raise.catch
-import com.xebia.functional.tokenizer.ModelType
+import com.xebia.functional.xef.auto.PromptConfiguration
import com.xebia.functional.xef.auto.ai
-import com.xebia.functional.xef.auto.getOrThrow
+import com.xebia.functional.xef.auto.llm.openai.OpenAI
+import com.xebia.functional.xef.auto.llm.openai.getOrThrow
import com.xebia.functional.xef.sql.SQL
import com.xebia.functional.xef.sql.jdbc.JdbcConfig
+val model = OpenAI.DEFAULT_CHAT
+
val config = JdbcConfig(
vendor = System.getenv("XEF_SQL_DB_VENDOR") ?: "mysql",
host = System.getenv("XEF_SQL_DB_HOST") ?: "localhost",
@@ -14,7 +17,7 @@ val config = JdbcConfig(
password = System.getenv("XEF_SQL_DB_PASSWORD") ?: "password",
port = System.getenv("XEF_SQL_DB_PORT")?.toInt() ?: 3306,
database = System.getenv("XEF_SQL_DB_DATABASE") ?: "database",
- llmModelType = ModelType.GPT_3_5_TURBO
+ model = model
)
suspend fun main() = ai {
@@ -34,16 +37,21 @@ suspend fun main() = ai {
if (input == "exit") break
catch({
extendContext(*promptQuery(input).toTypedArray())
- val result = promptMessage("""|
- |You are a database assistant that helps users to query and summarize results from the database.
- |Instructions:
- |1. Summarize the information provided in the `Context` and follow to step 2.
- |2. If the information relates to the `input` then answer the question otherwise return just the summary.
- |```input
- |$input
- |```
- |3. Try to answer and provide information with as much detail as you can
- """.trimMargin(), bringFromContext = 200)
+ val result = model.promptMessage(
+ """|
+ |You are a database assistant that helps users to query and summarize results from the database.
+ |Instructions:
+ |1. Summarize the information provided in the `Context` and follow to step 2.
+ |2. If the information relates to the `input` then answer the question otherwise return just the summary.
+ |```input
+ |$input
+ |```
+ |3. Try to answer and provide information with as much detail as you can
+ """.trimMargin(),
+ promptConfiguration = PromptConfiguration.invoke {
+ docsInContext(50)
+ }
+ )
result.forEach {
println("llmdb> ${it}")
}
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/ControlSignal.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/ControlSignal.kt
index 7c460a5f4..80c7239f0 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/ControlSignal.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/ControlSignal.kt
@@ -1,7 +1,7 @@
package com.xebia.functional.xef.auto.tot
import com.xebia.functional.xef.auto.CoreAIScope
-import com.xebia.functional.xef.auto.prompt
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Critique.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Critique.kt
index 4f1d5c4bf..0a28a58e1 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Critique.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Critique.kt
@@ -1,7 +1,7 @@
package com.xebia.functional.xef.auto.tot
import com.xebia.functional.xef.auto.CoreAIScope
-import com.xebia.functional.xef.auto.prompt
+import com.xebia.functional.xef.auto.llm.openai.prompt
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Main.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Main.kt
index f8a2cf499..ce11ec7bc 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Main.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Main.kt
@@ -1,7 +1,7 @@
package com.xebia.functional.xef.auto.tot
import com.xebia.functional.xef.auto.ai
-import com.xebia.functional.xef.auto.getOrThrow
+import com.xebia.functional.xef.auto.llm.openai.getOrThrow
import kotlinx.serialization.Serializable
@Serializable
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Search.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Search.kt
deleted file mode 100644
index 0c96b3be6..000000000
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Search.kt
+++ /dev/null
@@ -1,19 +0,0 @@
-package com.xebia.functional.xef.auto.tot
-
-import com.xebia.functional.xef.auto.CoreAIScope
-
-suspend fun CoreAIScope.generateSearchPrompts(problem: Problem): List =
- promptMessage(
- """|You are an expert SEO consultant.
- |You generate search prompts for a problem.
- |You are given the following problem:
- |${problem.description}
- |Instructions:
- |1. Generate 1 search prompt to get the best results for this problem.
- |2. Ensure the search prompt are relevant to the problem.
- |3. Ensure the search prompt expands into the keywords needed to solve the problem.
- |
- """.trimMargin(),
- n = 5
- ).distinct().map { it.replace("\"", "").trim() }
-
diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Solution.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Solution.kt
index f6cfec90e..da44df73e 100644
--- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Solution.kt
+++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Solution.kt
@@ -1,7 +1,8 @@
package com.xebia.functional.xef.auto.tot
import com.xebia.functional.xef.auto.CoreAIScope
-import com.xebia.functional.xef.auto.prompt
+import com.xebia.functional.xef.auto.llm.openai.OpenAI
+import com.xebia.functional.xef.auto.llm.openai.prompt
import com.xebia.functional.xef.prompt.Prompt
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
@@ -45,7 +46,7 @@ internal suspend fun CoreAIScope.solution(
|10. If the solution is valid set the `isValid` field to `true` and the `value` field to the value of the solution.
|
|""".trimMargin()
- return prompt(Prompt(enhancedPrompt), serializer).also {
+ return prompt(OpenAI.DEFAULT_SERIALIZATION, Prompt(enhancedPrompt), serializer).also {
println("🤖 Generated solution: ${truncateText(it.answer)}")
}
}
diff --git a/examples/kotlin/src/main/resources/logback.xml b/examples/kotlin/src/main/resources/logback.xml
index 8b62c8b02..dfb690eb2 100644
--- a/examples/kotlin/src/main/resources/logback.xml
+++ b/examples/kotlin/src/main/resources/logback.xml
@@ -16,6 +16,10 @@
+
+
+
+
diff --git a/gpt4all-kotlin/build.gradle.kts b/gpt4all-kotlin/build.gradle.kts
index 49b86a305..12d76721e 100644
--- a/gpt4all-kotlin/build.gradle.kts
+++ b/gpt4all-kotlin/build.gradle.kts
@@ -20,17 +20,17 @@ java {
kotlin {
jvm()
+
js(IR) {
browser()
- nodejs()
}
- linuxX64()
- macosX64()
- macosArm64()
- mingwX64()
sourceSets {
- val commonMain by getting {}
+ val commonMain by getting {
+ dependencies {
+ implementation(projects.xefCore)
+ }
+ }
commonTest {
dependencies {
@@ -44,32 +44,19 @@ kotlin {
val jvmMain by getting {
dependencies {
implementation("net.java.dev.jna:jna-platform:5.13.0")
+ implementation("ai.djl.huggingface:tokenizers:+")
}
}
+ val jsMain by getting {
+ }
+
val jvmTest by getting {
dependencies {
implementation(libs.kotest.junit5)
}
}
- js {
- nodejs {}
- browser {}
- }
-
- val linuxX64Main by getting
- val macosX64Main by getting
- val macosArm64Main by getting
- val mingwX64Main by getting
-
- create("nativeMain") {
- dependsOn(commonMain)
- linuxX64Main.dependsOn(this)
- macosX64Main.dependsOn(this)
- macosArm64Main.dependsOn(this)
- mingwX64Main.dependsOn(this)
- }
}
}
diff --git a/gpt4all-kotlin/src/commonMain/kotlin/com/xebia/functional/gpt4all/models.kt b/gpt4all-kotlin/src/commonMain/kotlin/com/xebia/functional/gpt4all/models.kt
index 0ab06d727..33be8d3f8 100644
--- a/gpt4all-kotlin/src/commonMain/kotlin/com/xebia/functional/gpt4all/models.kt
+++ b/gpt4all-kotlin/src/commonMain/kotlin/com/xebia/functional/gpt4all/models.kt
@@ -14,52 +14,3 @@ data class GenerationConfig(
val repeatLastN: Int = 64,
val contextErase: Double = 0.5
)
-
-data class Completion(val context: String)
-
-data class Message(val role: Role, val content: String) {
- enum class Role {
- SYSTEM,
- USER,
- ASSISTANT
- }
-}
-
-data class Embedding(
- val embedding: List
-)
-
-data class EmbeddingRequest(
- val input: List
-)
-
-data class EmbeddingResponse(
- val model: String,
- val data: List
-)
-
-data class CompletionRequest(
- val prompt: String,
- val generationConfig: GenerationConfig
-)
-
-data class ChatCompletionRequest(
- val messages: List,
- val generationConfig: GenerationConfig
-)
-
-data class CompletionResponse(
- val model: String,
- val promptTokens: Int,
- val completionTokens: Int,
- val totalTokens: Int,
- val choices: List
-)
-
-data class ChatCompletionResponse(
- val model: String,
- val promptTokens: Int,
- val completionTokens: Int,
- val totalTokens: Int,
- val choices: List
-)
diff --git a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt
index d818f7da6..17a1273ed 100644
--- a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt
+++ b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt
@@ -1,84 +1,114 @@
package com.xebia.functional.gpt4all
+import ai.djl.training.util.DownloadUtils
+import ai.djl.training.util.ProgressBar
import com.sun.jna.platform.unix.LibCAPI
import com.xebia.functional.gpt4all.libraries.LLModelContext
+import com.xebia.functional.tokenizer.EncodingType
+import com.xebia.functional.tokenizer.ModelType
+import com.xebia.functional.xef.llm.Chat
+import com.xebia.functional.xef.llm.Completion
+import com.xebia.functional.xef.llm.models.chat.*
+import com.xebia.functional.xef.llm.models.text.CompletionChoice
+import com.xebia.functional.xef.llm.models.text.CompletionRequest
+import com.xebia.functional.xef.llm.models.text.CompletionResult
+import com.xebia.functional.xef.llm.models.usage.Usage
+import java.net.URL
+import java.nio.file.Files
import java.nio.file.Path
+import java.nio.file.StandardCopyOption
+import java.util.*
+import kotlin.io.path.name
-interface GPT4All : AutoCloseable {
- suspend fun createCompletion(request: CompletionRequest): CompletionResponse
- suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse
- suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResponse
-
- companion object {
- operator fun invoke(
- path: Path,
- modelType: LLModel.Type
- ): GPT4All = object : GPT4All {
- val gpt4allModel: GPT4AllModel = GPT4AllModel(path, modelType)
-
- override suspend fun createCompletion(request: CompletionRequest): CompletionResponse =
- with(request) {
- val response: String = generateCompletion(prompt, generationConfig)
- return CompletionResponse(
- gpt4allModel.llModel.name,
- prompt.length,
- response.length,
- totalTokens = prompt.length + response.length,
- listOf(Completion(response))
- )
- }
-
- override suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse =
- with(request) {
- val prompt: String = messages.buildPrompt()
- val response: String = generateCompletion(prompt, generationConfig)
- return ChatCompletionResponse(
- gpt4allModel.llModel.name,
- prompt.length,
- response.length,
- totalTokens = prompt.length + response.length,
- listOf(Message(com.xebia.functional.gpt4all.Message.Role.ASSISTANT, response))
- )
- }
-
- override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResponse {
- TODO("Not yet implemented")
- }
-
- override fun close(): Unit = gpt4allModel.close()
-
- private fun List.buildPrompt(): String {
- val messages: String = joinToString("") { message ->
- when (message.role) {
- Message.Role.SYSTEM -> message.content
- Message.Role.USER -> "\n### Human: ${message.content}"
- Message.Role.ASSISTANT -> "\n### Response: ${message.content}"
- }
- }
- return "$messages\n### Response:"
- }
-
- private fun generateCompletion(
- prompt: String,
- generationConfig: GenerationConfig
- ): String {
- val context = LLModelContext(
- logits_size = LibCAPI.size_t(generationConfig.logitsSize.toLong()),
- tokens_size = LibCAPI.size_t(generationConfig.tokensSize.toLong()),
- n_past = generationConfig.nPast,
- n_ctx = generationConfig.nCtx,
- n_predict = generationConfig.nPredict,
- top_k = generationConfig.topK,
- top_p = generationConfig.topP.toFloat(),
- temp = generationConfig.temp.toFloat(),
- n_batch = generationConfig.nBatch,
- repeat_penalty = generationConfig.repeatPenalty.toFloat(),
- repeat_last_n = generationConfig.repeatLastN,
- context_erase = generationConfig.contextErase.toFloat()
- )
-
- return gpt4allModel.prompt(prompt, context)
- }
+interface GPT4All : AutoCloseable, Chat, Completion {
+
+ val gpt4allModel: GPT4AllModel
+
+ override fun close() {
+ }
+
+ companion object {
+
+ operator fun invoke(
+ url: String,
+ path: Path,
+ modelType: LLModel.Type,
+ generationConfig: GenerationConfig = GenerationConfig(),
+ ): GPT4All = object : GPT4All {
+
+ init {
+ if (!Files.exists(path)) {
+ DownloadUtils.download(url, path.toFile().absolutePath , ProgressBar())
+ }
+ }
+
+ override val gpt4allModel = GPT4AllModel.invoke(path, modelType)
+
+ override suspend fun createCompletion(request: CompletionRequest): CompletionResult =
+ with(request) {
+ val response: String = gpt4allModel.prompt(prompt, llmModelContext(generationConfig))
+ return CompletionResult(
+ UUID.randomUUID().toString(),
+ path.name,
+ System.currentTimeMillis(),
+ path.name,
+ listOf(CompletionChoice(response, 0, null, null)),
+ Usage.ZERO
+ )
+ }
+
+ override suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse =
+ with(request) {
+ val response: String =
+ gpt4allModel.prompt(messages.buildPrompt(), llmModelContext(generationConfig))
+ return ChatCompletionResponse(
+ UUID.randomUUID().toString(),
+ path.name,
+ System.currentTimeMillis().toInt(),
+ path.name,
+ Usage.ZERO,
+ listOf(Choice(Message(Role.ASSISTANT, response, Role.ASSISTANT.name), null, 0)),
+ )
+ }
+
+ override fun tokensFromMessages(messages: List): Int = 0
+
+ override val name: String = path.name
+
+ override fun close(): Unit = gpt4allModel.close()
+
+ override val modelType: ModelType = ModelType.LocalModel(name, EncodingType.CL100K_BASE, 4096)
+
+ private fun List.buildPrompt(): String {
+ val messages: String = joinToString("") { message ->
+ when (message.role) {
+ Role.SYSTEM -> message.content
+ Role.USER -> "\n### Human: ${message.content}"
+ Role.ASSISTANT -> "\n### Response: ${message.content}"
+ }
+ }
+ return "$messages\n### Response:"
+ }
+
+ private fun llmModelContext(generationConfig: GenerationConfig): LLModelContext =
+ with(generationConfig) {
+ LLModelContext(
+ logits_size = LibCAPI.size_t(logitsSize.toLong()),
+ tokens_size = LibCAPI.size_t(tokensSize.toLong()),
+ n_past = nPast,
+ n_ctx = nCtx,
+ n_predict = nPredict,
+ top_k = topK,
+ top_p = topP.toFloat(),
+ temp = temp.toFloat(),
+ n_batch = nBatch,
+ repeat_penalty = repeatPenalty.toFloat(),
+ repeat_last_n = repeatLastN,
+ context_erase = contextErase.toFloat()
+ )
}
}
+
+ }
}
+
diff --git a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4AllModel.kt b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4AllModel.kt
index 04e2c201a..73c9253e0 100644
--- a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4AllModel.kt
+++ b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4AllModel.kt
@@ -7,7 +7,6 @@ interface GPT4AllModel : AutoCloseable {
val llModel: LLModel
fun prompt(prompt: String, context: LLModelContext): String
- fun embeddings(prompt: String): List
companion object {
operator fun invoke(
@@ -34,8 +33,6 @@ interface GPT4AllModel : AutoCloseable {
responseBuffer.trim().toString()
}
- override fun embeddings(prompt: String): List = TODO("Not yet implemented")
-
override fun close(): Unit = llModel.close()
}
}
diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/AIScope.kt b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/Gpt4AllRuntime.kt
similarity index 73%
rename from kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/AIScope.kt
rename to gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/Gpt4AllRuntime.kt
index 150714ad8..733610c4e 100644
--- a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/AIScope.kt
+++ b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/Gpt4AllRuntime.kt
@@ -1,12 +1,12 @@
-package com.xebia.functional.xef.auto
+package com.xebia.functional.gpt4all
import arrow.core.Either
import arrow.core.left
import arrow.core.right
import com.xebia.functional.xef.AIError
-import com.xebia.functional.xef.auto.llm.openai.OpenAIRuntime
-
-typealias AIScope = CoreAIScope
+import com.xebia.functional.xef.auto.AI
+import com.xebia.functional.xef.auto.CoreAIScope
+import com.xebia.functional.xef.auto.ai
/**
* Run the [AI] value to produce an [A], this method initialises all the dependencies required to
@@ -14,10 +14,8 @@ typealias AIScope = CoreAIScope
*
* This operator is **terminal** meaning it runs and completes the _chain_ of `AI` actions.
*/
-suspend inline fun AI.getOrElse(
- runtime: AIRuntime = OpenAIRuntime.defaults(),
- crossinline orElse: suspend (AIError) -> A
-): A = AIScope(runtime, this) { orElse(it) }
+suspend inline fun AI.getOrElse(crossinline orElse: suspend (AIError) -> A): A =
+ AIScope(this) { orElse(it) }
/**
* Run the [AI] value to produce [A]. this method initialises all the dependencies required to run
@@ -41,3 +39,12 @@ suspend inline fun AI.getOrThrow(): A = getOrElse { throw it }
*/
suspend inline fun AI.toEither(): Either =
ai { invoke().right() }.getOrElse { it.left() }
+
+suspend fun AIScope(block: AI, orElse: suspend (AIError) -> A): A =
+ try {
+ val scope = CoreAIScope(HuggingFaceLocalEmbeddings.DEFAULT)
+ block(scope)
+ } catch (e: AIError) {
+ orElse(e)
+ }
+
diff --git a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceLocalEmbeddings.kt b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceLocalEmbeddings.kt
new file mode 100644
index 000000000..f2b31980f
--- /dev/null
+++ b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceLocalEmbeddings.kt
@@ -0,0 +1,39 @@
+package com.xebia.functional.gpt4all
+
+import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer
+import com.xebia.functional.xef.embeddings.Embedding as XefEmbedding
+import com.xebia.functional.xef.embeddings.Embeddings
+import com.xebia.functional.xef.llm.models.embeddings.Embedding
+import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
+import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult
+import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
+import com.xebia.functional.xef.llm.models.usage.Usage
+
+class HuggingFaceLocalEmbeddings(name: String, artifact: String) : com.xebia.functional.xef.llm.Embeddings, Embeddings {
+
+ private val tokenizer = HuggingFaceTokenizer.newInstance("$name/$artifact")
+
+ override val name: String = HuggingFaceLocalEmbeddings::class.java.canonicalName
+
+ override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult {
+ val embeddings = tokenizer.batchEncode(request.input).mapIndexed { ix, em ->
+ Embedding("embedding", em.ids.map { it.toFloat() }, ix)
+ }
+ return EmbeddingResult(embeddings, Usage.ZERO)
+ }
+
+ override suspend fun embedDocuments(
+ texts: List,
+ chunkSize: Int?,
+ requestConfig: RequestConfig
+ ): List =
+ tokenizer.batchEncode(texts).map { em -> XefEmbedding(em.ids.map { it.toFloat() }) }
+
+ override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List =
+ embedDocuments(listOf(text), null, requestConfig)
+
+ companion object {
+ @JvmField
+ val DEFAULT = HuggingFaceLocalEmbeddings("sentence-transformers", "msmarco-distilbert-dot-v5")
+ }
+}
diff --git a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceUtils.kt b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceUtils.kt
new file mode 100644
index 000000000..ac2f1f4fb
--- /dev/null
+++ b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceUtils.kt
@@ -0,0 +1,4 @@
+package com.xebia.functional.gpt4all
+
+fun huggingFaceUrl(name: String, artifact:String, extension: String): String =
+ "https://huggingface.co/$name/$artifact/resolve/main/$artifact.$extension"
diff --git a/integrations/pdf/src/main/kotlin/com/xebia/functional/xef/pdf/PDFLoader.kt b/integrations/pdf/src/main/kotlin/com/xebia/functional/xef/pdf/PDFLoader.kt
index fe6ac9764..384731c3d 100644
--- a/integrations/pdf/src/main/kotlin/com/xebia/functional/xef/pdf/PDFLoader.kt
+++ b/integrations/pdf/src/main/kotlin/com/xebia/functional/xef/pdf/PDFLoader.kt
@@ -1,3 +1,5 @@
+@file:JvmName("Loader")
+@file:JvmMultifileClass
package com.xebia.functional.xef.pdf
import com.xebia.functional.tokenizer.ModelType
@@ -26,7 +28,6 @@ suspend fun pdf(
pdf(file, splitter)
}
-
suspend fun pdf(
file: File,
splitter: TextSplitter = TokenTextSplitter(modelType = ModelType.GPT_3_5_TURBO, chunkSize = 100, chunkOverlap = 50)
diff --git a/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/SQL.kt b/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/SQL.kt
index 48cc7426b..dc2e1de54 100644
--- a/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/SQL.kt
+++ b/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/SQL.kt
@@ -68,7 +68,7 @@ private class JDBCSQLImpl(
override suspend fun CoreAIScope.selectTablesForPrompt(
tableNames: String, prompt: String
- ): List = promptMessage(
+ ): List = config.model.promptMessage(
"""|You are an AI assistant which selects the best tables from which the `goal` can be accomplished.
|Select from this list of SQL `tables` the tables that you may need to solve the following `goal`
|```tables
@@ -93,7 +93,7 @@ private class JDBCSQLImpl(
val results = resultSet.toDocuments(prompt)
logger.debug { "Found: ${results.size} records" }
val splitter = TokenTextSplitter(
- modelType = config.llmModelType, chunkSize = config.llmModelType.maxContextLength / 2, chunkOverlap = 10
+ modelType = config.model.modelType, chunkSize = config.model.modelType.maxContextLength / 2, chunkOverlap = 10
)
val splitDocuments = splitter.splitDocuments(results)
logger.debug { "Split into: ${splitDocuments.size} documents" }
@@ -101,7 +101,7 @@ private class JDBCSQLImpl(
}
}
- override suspend fun CoreAIScope.sql(ddl: String, input: String): List = promptMessage(
+ override suspend fun CoreAIScope.sql(ddl: String, input: String): List = config.model.promptMessage(
"""|
|You are an AI assistant which produces SQL SELECT queries in SQL format.
|You only reply in valid SQL SELECT queries.
@@ -126,7 +126,7 @@ private class JDBCSQLImpl(
""".trimMargin()
)
- override suspend fun CoreAIScope.getInterestingPromptsForDatabase(): List = promptMessage(
+ override suspend fun CoreAIScope.getInterestingPromptsForDatabase(): List = config.model.promptMessage(
"""|You are an AI assistant which replies with a list of the best prompts based on the content of this database:
|Instructions:
|1. Select from this `ddl` 3 top prompts that the user could ask about this database
diff --git a/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/jdbc/JdbcConfig.kt b/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/jdbc/JdbcConfig.kt
index f63f3f905..8a0da9b33 100644
--- a/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/jdbc/JdbcConfig.kt
+++ b/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/jdbc/JdbcConfig.kt
@@ -1,6 +1,6 @@
package com.xebia.functional.xef.sql.jdbc
-import com.xebia.functional.tokenizer.ModelType
+import com.xebia.functional.xef.llm.Chat
class JdbcConfig(
val vendor: String,
@@ -9,7 +9,7 @@ class JdbcConfig(
val password: String,
val port: Int,
val database: String,
- val llmModelType: ModelType,
+ val model: Chat
) {
fun toJDBCUrl(): String = "jdbc:$vendor://$host:$port/$database"
}
diff --git a/java/src/main/java/com/xebia/functional/xef/java/auto/AIScope.java b/java/src/main/java/com/xebia/functional/xef/java/auto/AIScope.java
index ac76ddea8..7a4cbd1ea 100644
--- a/java/src/main/java/com/xebia/functional/xef/java/auto/AIScope.java
+++ b/java/src/main/java/com/xebia/functional/xef/java/auto/AIScope.java
@@ -4,27 +4,25 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.module.jsonSchema.JsonSchema;
import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator;
-import com.xebia.functional.xef.auto.AIRuntime;
import com.xebia.functional.xef.auto.CoreAIScope;
-import com.xebia.functional.xef.auto.llm.openai.OpenAIRuntime;
-import com.xebia.functional.xef.llm.AIClient;
-import com.xebia.functional.xef.llm.LLM;
-import com.xebia.functional.xef.llm.LLMModel;
+import com.xebia.functional.xef.auto.PromptConfiguration;
+import com.xebia.functional.xef.auto.llm.openai.OpenAI;
+import com.xebia.functional.xef.auto.llm.openai.OpenAIEmbeddings;
+import com.xebia.functional.xef.embeddings.Embeddings;
+import com.xebia.functional.xef.llm.Chat;
+import com.xebia.functional.xef.llm.ChatWithFunctions;
+import com.xebia.functional.xef.llm.Images;
import com.xebia.functional.xef.llm.models.functions.CFunction;
import com.xebia.functional.xef.llm.models.images.ImageGenerationUrl;
import com.xebia.functional.xef.llm.models.images.ImagesGenerationResponse;
-import com.xebia.functional.xef.pdf.PDFLoaderKt;
+import com.xebia.functional.xef.pdf.Loader;
import com.xebia.functional.xef.textsplitters.TextSplitter;
import com.xebia.functional.xef.vectorstores.LocalVectorStore;
import com.xebia.functional.xef.vectorstores.VectorStore;
import kotlin.collections.CollectionsKt;
import kotlin.coroutines.Continuation;
import kotlin.jvm.functions.Function1;
-import kotlinx.coroutines.CoroutineScope;
-import kotlinx.coroutines.CoroutineScopeKt;
-import kotlinx.coroutines.ExecutorsKt;
-import kotlinx.coroutines.JobKt;
-import kotlinx.coroutines.CoroutineStart;
+import kotlinx.coroutines.*;
import kotlinx.coroutines.future.FutureKt;
import org.jetbrains.annotations.NotNull;
@@ -41,26 +39,24 @@ public class AIScope implements AutoCloseable {
private final CoreAIScope scope;
private final ObjectMapper om;
private final JsonSchemaGenerator schemaGen;
- private final AIClient client;
private final ExecutorService executorService;
private final CoroutineScope coroutineScope;
- public AIScope(ObjectMapper om, AIRuntime runtime, ExecutorService executorService) {
+ public AIScope(ObjectMapper om, Embeddings embeddings, ExecutorService executorService) {
this.om = om;
this.executorService = executorService;
this.coroutineScope = () -> ExecutorsKt.from(executorService).plus(JobKt.Job(null));
this.schemaGen = new JsonSchemaGenerator(om);
- this.client = runtime.getClient();
- VectorStore vectorStore = new LocalVectorStore(runtime.getEmbeddings());
- this.scope = new CoreAIScope(LLMModel.getGPT_3_5_TURBO(), LLMModel.getGPT_3_5_TURBO_FUNCTIONS(), client, vectorStore, runtime.getEmbeddings(), 3, "user", false, 0.4, 1, 20, 500);
+ VectorStore vectorStore = new LocalVectorStore(embeddings);
+ this.scope = new CoreAIScope(embeddings, vectorStore);
}
- public AIScope(AIRuntime runtime, ExecutorService executorService) {
- this(new ObjectMapper(), runtime, executorService);
+ public AIScope(Embeddings embeddings, ExecutorService executorService) {
+ this(new ObjectMapper(), embeddings, executorService);
}
public AIScope() {
- this(new ObjectMapper(), OpenAIRuntime.defaults(), Executors.newCachedThreadPool(new AIScopeThreadFactory()));
+ this(new ObjectMapper(),new OpenAIEmbeddings(OpenAI.DEFAULT_EMBEDDING), Executors.newCachedThreadPool(new AIScopeThreadFactory()));
}
private AIScope(CoreAIScope nested, AIScope outer) {
@@ -68,15 +64,15 @@ private AIScope(CoreAIScope nested, AIScope outer) {
this.executorService = outer.executorService;
this.coroutineScope = outer.coroutineScope;
this.schemaGen = outer.schemaGen;
- this.client = outer.client;
this.scope = nested;
}
public CompletableFuture prompt(String prompt, Class cls) {
- return prompt(prompt, cls, scope.getMaxDeserializationAttempts(), scope.getDefaultSerializationModel(), scope.getUser(), scope.getEcho(), scope.getNumberOfPredictions(), scope.getTemperature(), scope.getDocsInContext(), scope.getMinResponseTokens());
+ return prompt(prompt, cls, OpenAI.DEFAULT_SERIALIZATION, PromptConfiguration.DEFAULTS);
}
- public CompletableFuture prompt(String prompt, Class cls, Integer maxAttempts, LLM.ChatWithFunctions llmModel, String user, Boolean echo, Integer n, Double temperature, Integer bringFromContext, Integer minResponseTokens) {
+
+ public CompletableFuture prompt(String prompt, Class cls, ChatWithFunctions llmModel, PromptConfiguration promptConfiguration) {
Function1 super String, ? extends A> decoder = json -> {
try {
return om.readValue(json, cls);
@@ -100,30 +96,11 @@ public CompletableFuture prompt(String prompt, Class cls, Integer maxA
new CFunction(cls.getSimpleName(), "Generated function for " + cls.getSimpleName(), schema)
);
- return future(continuation -> scope.promptWithSerializer(prompt, functions, decoder, maxAttempts, llmModel, user, echo, n, temperature, bringFromContext, minResponseTokens, continuation));
- }
-
- public CompletableFuture> promptMessage(String prompt, LLM.Chat llmModel, List functions, String user, Boolean echo, Integer n, Double temperature, Integer bringFromContext, Integer minResponseTokens) {
- return future(continuation -> scope.promptMessage(prompt, llmModel, functions, user, echo, n, temperature, bringFromContext, minResponseTokens, continuation));
+ return future(continuation -> scope.promptWithSerializer(llmModel, prompt, functions, decoder, promptConfiguration, continuation));
}
- public CompletableFuture extendContext(String[] docs) {
- return future(continuation -> scope.extendContext(docs, continuation))
- .thenApply(unit -> null);
- }
-
- public CompletableFuture contextScope(Function1> f) {
- return future(continuation -> scope.contextScope((coreAIScope, continuation1) -> {
- AIScope nestedScope = new AIScope(coreAIScope, AIScope.this);
- return FutureKt.await(f.invoke(nestedScope), continuation);
- }, continuation));
- }
-
- public CompletableFuture contextScope(VectorStore store, Function1> f) {
- return future(continuation -> scope.contextScope(store, (coreAIScope, continuation1) -> {
- AIScope nestedScope = new AIScope(coreAIScope, AIScope.this);
- return FutureKt.await(f.invoke(nestedScope), continuation);
- }, continuation));
+ public CompletableFuture> promptMessage(Chat llmModel, String prompt, List functions, PromptConfiguration promptConfiguration) {
+ return future(continuation -> scope.promptMessage(llmModel, prompt, functions, promptConfiguration, continuation));
}
public CompletableFuture contextScope(List docs, Function1> f) {
@@ -134,15 +111,15 @@ public CompletableFuture contextScope(List docs, Function1> pdf(String url, TextSplitter splitter) {
- return future(continuation -> PDFLoaderKt.pdf(url, splitter, continuation));
+ return future(continuation -> Loader.pdf(url, splitter, continuation));
}
public CompletableFuture> pdf(File file, TextSplitter splitter) {
- return future(continuation -> PDFLoaderKt.pdf(file, splitter, continuation));
+ return future(continuation -> Loader.pdf(file, splitter, continuation));
}
- public CompletableFuture> images(String prompt, String user, String size, Integer bringFromContext, Integer n) {
- return this.future(continuation -> scope.images(prompt, user, n, size, bringFromContext, continuation))
+ public CompletableFuture> images(Images model, String prompt, Integer numberOfImages, String size, PromptConfiguration promptConfiguration) {
+ return this.future(continuation -> scope.images(model, prompt, numberOfImages, size, promptConfiguration, continuation))
.thenApply(response -> CollectionsKt.map(response.getData(), ImageGenerationUrl::getUrl));
}
@@ -157,7 +134,6 @@ private CompletableFuture future(Function1 super Continuation super A
@Override
public void close() {
- client.close();
CoroutineScopeKt.cancel(coroutineScope, null);
executorService.shutdown();
}
diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/DeserializerLLMAgent.kt
similarity index 51%
rename from kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt
rename to openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/DeserializerLLMAgent.kt
index 68c10dd99..f2c8275f8 100644
--- a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt
+++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/DeserializerLLMAgent.kt
@@ -1,7 +1,9 @@
-package com.xebia.functional.xef.auto
+package com.xebia.functional.xef.auto.llm.openai
-import com.xebia.functional.xef.auto.serialization.encodeJsonSchema
-import com.xebia.functional.xef.llm.LLM
+import com.xebia.functional.xef.auto.AiDsl
+import com.xebia.functional.xef.auto.CoreAIScope
+import com.xebia.functional.xef.auto.PromptConfiguration
+import com.xebia.functional.xef.llm.ChatWithFunctions
import com.xebia.functional.xef.llm.models.functions.CFunction
import com.xebia.functional.xef.prompt.Prompt
import kotlinx.serialization.ExperimentalSerializationApi
@@ -20,31 +22,15 @@ import kotlinx.serialization.serializer
* @throws IllegalArgumentException if any of [A]'s type arguments contains star projection.
*/
@AiDsl
-suspend inline fun AIScope.prompt(
+suspend inline fun CoreAIScope.prompt(
+ model: ChatWithFunctions,
question: String,
json: Json = Json {
ignoreUnknownKeys = true
isLenient = true
},
- maxDeserializationAttempts: Int = this.maxDeserializationAttempts,
- model: LLM.ChatWithFunctions = this.defaultSerializationModel,
- user: String = this.user,
- echo: Boolean = this.echo,
- n: Int = this.numberOfPredictions,
- temperature: Double = this.temperature,
- bringFromContext: Int = this.docsInContext
-): A =
- prompt(
- Prompt(question),
- json,
- maxDeserializationAttempts,
- model,
- user,
- echo,
- n,
- temperature,
- bringFromContext
- )
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
+): A = prompt(model, Prompt(question), json, promptConfiguration)
/**
* Run a [prompt] describes the task you want to solve within the context of [AIScope]. Returns a
@@ -55,65 +41,34 @@ suspend inline fun AIScope.prompt(
* @throws IllegalArgumentException if any of [A]'s type arguments contains star projection.
*/
@AiDsl
-suspend inline fun AIScope.prompt(
+suspend inline fun CoreAIScope.prompt(
+ model: ChatWithFunctions,
prompt: Prompt,
json: Json = Json {
ignoreUnknownKeys = true
isLenient = true
},
- maxDeserializationAttempts: Int = this.maxDeserializationAttempts,
- model: LLM.ChatWithFunctions = this.defaultSerializationModel,
- user: String = this.user,
- echo: Boolean = this.echo,
- n: Int = this.numberOfPredictions,
- temperature: Double = this.temperature,
- bringFromContext: Int = this.docsInContext
-): A =
- prompt(
- prompt,
- serializer(),
- json,
- maxDeserializationAttempts,
- model,
- user,
- echo,
- n,
- temperature,
- bringFromContext
- )
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
+): A = prompt(model, prompt, serializer(), json, promptConfiguration)
@AiDsl
-suspend fun AIScope.prompt(
+suspend fun CoreAIScope.prompt(
+ model: ChatWithFunctions,
prompt: Prompt,
serializer: KSerializer,
json: Json = Json {
ignoreUnknownKeys = true
isLenient = true
},
- maxDeserializationAttempts: Int = this.maxDeserializationAttempts,
- model: LLM.ChatWithFunctions = this.defaultSerializationModel,
- user: String = this.user,
- echo: Boolean = this.echo,
- n: Int = this.numberOfPredictions,
- temperature: Double = this.temperature,
- bringFromContext: Int = this.docsInContext,
- minResponseTokens: Int = this.minResponseTokens,
-): A {
- val functions = generateCFunction(serializer.descriptor)
- return prompt(
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
+): A =
+ model.prompt(
prompt,
- functions,
+ context,
+ generateCFunction(serializer.descriptor),
{ json.decodeFromString(serializer, it) },
- maxDeserializationAttempts,
- model,
- user,
- echo,
- n,
- temperature,
- bringFromContext,
- minResponseTokens
+ promptConfiguration
)
-}
@OptIn(ExperimentalSerializationApi::class)
private fun generateCFunction(descriptor: SerialDescriptor): List {
diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/ImageGenerationAgent.kt
similarity index 66%
rename from kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt
rename to openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/ImageGenerationAgent.kt
index ff058b84b..e6dd15d4c 100644
--- a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt
+++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/ImageGenerationAgent.kt
@@ -1,6 +1,10 @@
-package com.xebia.functional.xef.auto
+package com.xebia.functional.xef.auto.llm.openai
import com.xebia.functional.xef.AIError
+import com.xebia.functional.xef.auto.CoreAIScope
+import com.xebia.functional.xef.auto.PromptConfiguration
+import com.xebia.functional.xef.llm.ChatWithFunctions
+import com.xebia.functional.xef.llm.Images
import com.xebia.functional.xef.llm.models.images.ImagesGenerationResponse
import com.xebia.functional.xef.prompt.Prompt
@@ -12,14 +16,16 @@ import com.xebia.functional.xef.prompt.Prompt
* @param size the size of the images to generate.
*/
suspend inline fun CoreAIScope.image(
+ imageModel: Images,
+ serializationModel: ChatWithFunctions,
prompt: String,
- user: String = "testing",
size: String = "1024x1024",
- bringFromContext: Int = 10
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
): A {
- val imageResponse = images(prompt, user, 1, size, bringFromContext)
+ val imageResponse = imageModel.images(prompt, context, 1, size, promptConfiguration)
val url = imageResponse.data.firstOrNull() ?: throw AIError.NoResponse()
return prompt(
+ serializationModel,
"""|Instructions: Format this [URL] and [PROMPT] information in the desired JSON response format
|specified at the end of the message.
|[URL]:
diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/serialization/JsonSchema.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/JsonSchema.kt
similarity index 98%
rename from kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/serialization/JsonSchema.kt
rename to openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/JsonSchema.kt
index c7ddf8e5b..16741b7fe 100644
--- a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/serialization/JsonSchema.kt
+++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/JsonSchema.kt
@@ -1,6 +1,8 @@
+@file:JvmName("Json")
+@file:JvmMultifileClass
@file:OptIn(ExperimentalSerializationApi::class)
-package com.xebia.functional.xef.auto.serialization
+package com.xebia.functional.xef.auto.llm.openai
/*
Ported over from https://github.com/Ricky12Awesome/json-schema-serialization
@@ -11,6 +13,8 @@ which states the following:
// TODO: We should consider a fork and maintain it ourselves.
*/
+import kotlin.jvm.JvmMultifileClass
+import kotlin.jvm.JvmName
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.SerialInfo
import kotlinx.serialization.descriptors.PolymorphicKind
diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/MockAIClient.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/MockAIClient.kt
index f204e20a2..89557430f 100644
--- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/MockAIClient.kt
+++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/MockAIClient.kt
@@ -3,11 +3,11 @@ package com.xebia.functional.xef.auto.llm.openai
import arrow.core.Either
import arrow.core.left
import arrow.core.right
+import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.auto.AI
import com.xebia.functional.xef.auto.CoreAIScope
-import com.xebia.functional.xef.llm.AIClient
-import com.xebia.functional.xef.llm.LLMModel
+import com.xebia.functional.xef.llm.*
import com.xebia.functional.xef.llm.models.chat.*
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
@@ -37,7 +37,11 @@ class MockOpenAIClient(
private val images: (ImagesGenerationRequest) -> ImagesGenerationResponse = {
throw NotImplementedError("images not implemented")
},
-) : AIClient {
+) : ChatWithFunctions, Images, Completion, Embeddings {
+
+ override val name: String = "mock"
+ override val modelType: ModelType = ModelType.GPT_3_5_TURBO
+
override suspend fun createCompletion(request: CompletionRequest): CompletionResult =
completion(request)
@@ -55,6 +59,8 @@ class MockOpenAIClient(
override suspend fun createImages(request: ImagesGenerationRequest): ImagesGenerationResponse =
images(request)
+ override fun tokensFromMessages(messages: List): Int = 0
+
override fun close() {}
}
@@ -78,7 +84,7 @@ fun simpleMockAIClient(execute: (String) -> String): MockOpenAIClient =
val responses =
req.messages.mapIndexed { ix, msg ->
val response = execute(msg.content ?: "")
- Choice(Message(msg.role, response), "end", ix)
+ Choice(Message(msg.role, response, msg.role.name), "end", ix)
}
val requestTokens = req.messages.sumOf { it.content?.split(' ')?.size ?: 0 }
val responseTokens = responses.sumOf { it.message?.content?.split(' ')?.size ?: 0 }
@@ -96,14 +102,7 @@ suspend fun MockAIScope(
try {
val embeddings = OpenAIEmbeddings(mockClient)
val vectorStore = LocalVectorStore(embeddings)
- val scope =
- CoreAIScope(
- LLMModel.GPT_3_5_TURBO,
- LLMModel.GPT_3_5_TURBO_FUNCTIONS,
- mockClient,
- vectorStore,
- embeddings
- )
+ val scope = CoreAIScope(embeddings, vectorStore)
block(scope)
} catch (e: AIError) {
orElse(e)
diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAI.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAI.kt
new file mode 100644
index 000000000..aacd9cb51
--- /dev/null
+++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAI.kt
@@ -0,0 +1,57 @@
+package com.xebia.functional.xef.auto.llm.openai
+
+import arrow.core.nonEmptyListOf
+import com.xebia.functional.tokenizer.ModelType
+import com.xebia.functional.xef.AIError
+import com.xebia.functional.xef.env.getenv
+import kotlin.jvm.JvmField
+
+class OpenAI(internal val token: String) {
+ val GPT_4 = OpenAIModel(this, "gpt-4", ModelType.GPT_4)
+
+ val GPT_4_0314 = OpenAIModel(this, "gpt-4-0314", ModelType.GPT_4)
+
+ val GPT_4_32K = OpenAIModel(this, "gpt-4-32k", ModelType.GPT_4_32K)
+
+ val GPT_3_5_TURBO = OpenAIModel(this, "gpt-3.5-turbo", ModelType.GPT_3_5_TURBO)
+
+ val GPT_3_5_TURBO_16K = OpenAIModel(this, "gpt-3.5-turbo-16k", ModelType.GPT_3_5_TURBO_16_K)
+
+ val GPT_3_5_TURBO_FUNCTIONS =
+ OpenAIModel(this, "gpt-3.5-turbo-0613", ModelType.GPT_3_5_TURBO_FUNCTIONS)
+
+ val GPT_3_5_TURBO_0301 = OpenAIModel(this, "gpt-3.5-turbo-0301", ModelType.GPT_3_5_TURBO)
+
+ val TEXT_DAVINCI_003 = OpenAIModel(this, "text-davinci-003", ModelType.TEXT_DAVINCI_003)
+
+ val TEXT_DAVINCI_002 = OpenAIModel(this, "text-davinci-002", ModelType.TEXT_DAVINCI_002)
+
+ val TEXT_CURIE_001 = OpenAIModel(this, "text-curie-001", ModelType.TEXT_SIMILARITY_CURIE_001)
+
+ val TEXT_BABBAGE_001 = OpenAIModel(this, "text-babbage-001", ModelType.TEXT_BABBAGE_001)
+
+ val TEXT_ADA_001 = OpenAIModel(this, "text-ada-001", ModelType.TEXT_ADA_001)
+
+ val TEXT_EMBEDDING_ADA_002 =
+ OpenAIModel(this, "text-embedding-ada-002", ModelType.TEXT_EMBEDDING_ADA_002)
+
+ val DALLE_2 = OpenAIModel(this, "dalle-2", ModelType.GPT_3_5_TURBO)
+
+ companion object {
+
+ fun openAITokenFromEnv(): String {
+ return getenv("OPENAI_TOKEN")
+ ?: throw AIError.Env.OpenAI(nonEmptyListOf("missing OPENAI_TOKEN env var"))
+ }
+
+ @JvmField val DEFAULT = OpenAI(openAITokenFromEnv())
+
+ @JvmField val DEFAULT_CHAT = DEFAULT.GPT_3_5_TURBO_16K
+
+ @JvmField val DEFAULT_SERIALIZATION = DEFAULT.GPT_3_5_TURBO_FUNCTIONS
+
+ @JvmField val DEFAULT_EMBEDDING = DEFAULT.TEXT_EMBEDDING_ADA_002
+
+ @JvmField val DEFAULT_IMAGES = DEFAULT.DALLE_2
+ }
+}
diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIClient.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIClient.kt
index 5810789a1..7b9022089 100644
--- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIClient.kt
+++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIClient.kt
@@ -8,7 +8,6 @@ import com.aallam.openai.api.completion.CompletionRequest as OpenAICompletionReq
import com.aallam.openai.api.completion.TextCompletion
import com.aallam.openai.api.completion.completionRequest
import com.aallam.openai.api.core.Usage as OpenAIUsage
-import com.aallam.openai.api.embedding.EmbeddingRequest as OpenAIEmbeddingRequest
import com.aallam.openai.api.embedding.EmbeddingResponse
import com.aallam.openai.api.embedding.embeddingRequest
import com.aallam.openai.api.image.ImageCreation
@@ -16,8 +15,10 @@ import com.aallam.openai.api.image.ImageSize
import com.aallam.openai.api.image.ImageURL
import com.aallam.openai.api.image.imageCreation
import com.aallam.openai.api.model.ModelId
-import com.aallam.openai.client.OpenAI
-import com.xebia.functional.xef.llm.AIClient
+import com.aallam.openai.client.OpenAI as OpenAIClient
+import com.xebia.functional.tokenizer.Encoding
+import com.xebia.functional.tokenizer.ModelType
+import com.xebia.functional.xef.llm.*
import com.xebia.functional.xef.llm.models.chat.*
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
@@ -32,10 +33,16 @@ import com.xebia.functional.xef.llm.models.text.CompletionResult
import com.xebia.functional.xef.llm.models.usage.Usage
import kotlinx.serialization.json.Json
-class OpenAIClient(val openAI: OpenAI) : AIClient, AutoCloseable {
+class OpenAIModel(
+ private val openAI: OpenAI,
+ override val name: String,
+ override val modelType: ModelType
+) : Chat, ChatWithFunctions, Images, Completion, Embeddings, AutoCloseable {
+
+ private val client = OpenAIClient(openAI.token)
override suspend fun createCompletion(request: CompletionRequest): CompletionResult {
- val response = openAI.completion(toCompletionRequest(request))
+ val response = client.completion(toCompletionRequest(request))
return completionResult(response)
}
@@ -43,28 +50,61 @@ class OpenAIClient(val openAI: OpenAI) : AIClient, AutoCloseable {
override suspend fun createChatCompletion(
request: ChatCompletionRequest
): ChatCompletionResponse {
- val response = openAI.chatCompletion(toChatCompletionRequest(request))
- return chatCompletionResult(response)
+ val response = client.chatCompletion(toChatCompletionRequest(request))
+ return ChatCompletionResponse(
+ id = response.id,
+ `object` = response.model.id,
+ created = response.created,
+ model = response.model.id,
+ choices = response.choices.map { chatCompletionChoice(it) },
+ usage = usage(response.usage)
+ )
}
@OptIn(BetaOpenAI::class)
override suspend fun createChatCompletionWithFunctions(
request: ChatCompletionRequestWithFunctions
): ChatCompletionResponseWithFunctions {
- val response = openAI.chatCompletion(toChatCompletionRequestWithFunctions(request))
- return chatCompletionResultWithFunctions(response)
+ val response = client.chatCompletion(toChatCompletionRequestWithFunctions(request))
+
+ fun chatCompletionChoiceWithFunctions(choice: ChatChoice): ChoiceWithFunctions =
+ ChoiceWithFunctions(
+ message =
+ choice.message?.let {
+ MessageWithFunctionCall(
+ role = it.role.role,
+ content = it.content,
+ name = it.name,
+ functionCall = it.functionCall?.let { FnCall(it.name, it.arguments) }
+ )
+ },
+ finishReason = choice.finishReason,
+ index = choice.index,
+ )
+
+ return ChatCompletionResponseWithFunctions(
+ id = response.id,
+ `object` = response.model.id,
+ created = response.created,
+ model = response.model.id,
+ choices = response.choices.map { chatCompletionChoiceWithFunctions(it) },
+ usage = usage(response.usage)
+ )
}
override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult {
- val response = openAI.embeddings(toEmbeddingRequest(request))
- return embeddingResult(response)
+ val openAIRequest = embeddingRequest {
+ model = ModelId(request.model)
+ input = request.input
+ user = request.user
+ }
+
+ return embeddingResult(client.embeddings(openAIRequest))
}
@OptIn(BetaOpenAI::class)
- override suspend fun createImages(request: ImagesGenerationRequest): ImagesGenerationResponse {
- val response = openAI.imageURL(toImageCreationRequest(request))
- return imageResult(response)
- }
+ override suspend fun createImages(request: ImagesGenerationRequest): ImagesGenerationResponse =
+ imageResult(client.imageURL(toImageCreationRequest(request)))
private fun toCompletionRequest(request: CompletionRequest): OpenAICompletionRequest =
completionRequest {
@@ -110,61 +150,39 @@ class OpenAIClient(val openAI: OpenAI) : AIClient, AutoCloseable {
totalTokens = usage?.totalTokens,
)
- @OptIn(BetaOpenAI::class)
- private fun chatCompletionResult(response: ChatCompletion): ChatCompletionResponse =
- ChatCompletionResponse(
- id = response.id,
- `object` = response.model.id,
- created = response.created,
- model = response.model.id,
- choices = response.choices.map { chatCompletionChoice(it) },
- usage = usage(response.usage)
- )
-
- @OptIn(BetaOpenAI::class)
- private fun chatCompletionResultWithFunctions(
- response: ChatCompletion
- ): ChatCompletionResponseWithFunctions =
- ChatCompletionResponseWithFunctions(
- id = response.id,
- `object` = response.model.id,
- created = response.created,
- model = response.model.id,
- choices = response.choices.map { chatCompletionChoiceWithFunctions(it) },
- usage = usage(response.usage)
- )
-
- @OptIn(BetaOpenAI::class)
- private fun chatCompletionChoiceWithFunctions(choice: ChatChoice): ChoiceWithFunctions =
- ChoiceWithFunctions(
- message =
- choice.message?.let {
- MessageWithFunctionCall(
- role = it.role.role,
- content = it.content,
- name = it.name,
- functionCall = it.functionCall?.let { FnCall(it.name, it.arguments) }
- )
- },
- finishReason = choice.finishReason,
- index = choice.index,
- )
-
@OptIn(BetaOpenAI::class)
private fun chatCompletionChoice(choice: ChatChoice): Choice =
Choice(
message =
choice.message?.let {
Message(
- role = it.role.role,
- content = it.content,
- name = it.name,
+ role = toRole(it),
+ content = it.content ?: "",
+ name = it.name ?: "",
)
},
finishReason = choice.finishReason,
index = choice.index,
)
+ @OptIn(BetaOpenAI::class)
+ private fun toRole(it: ChatMessage) =
+ when (it.role) {
+ ChatRole.User -> Role.USER
+ ChatRole.Assistant -> Role.ASSISTANT
+ ChatRole.System -> Role.SYSTEM
+ ChatRole.Function -> Role.SYSTEM
+ else -> Role.ASSISTANT
+ }
+
+ @OptIn(BetaOpenAI::class)
+ private fun fromRole(it: Role) =
+ when (it) {
+ Role.USER -> ChatRole.User
+ Role.ASSISTANT -> ChatRole.Assistant
+ Role.SYSTEM -> ChatRole.System
+ }
+
@OptIn(BetaOpenAI::class)
private fun toChatCompletionRequest(request: ChatCompletionRequest): OpenAIChatCompletionRequest =
chatCompletionRequest {
@@ -172,7 +190,7 @@ class OpenAIClient(val openAI: OpenAI) : AIClient, AutoCloseable {
messages =
request.messages.map {
ChatMessage(
- role = ChatRole(it.role),
+ role = fromRole(it.role),
content = it.content,
name = it.name,
)
@@ -195,7 +213,7 @@ class OpenAIClient(val openAI: OpenAI) : AIClient, AutoCloseable {
model = ModelId(request.model)
messages =
request.messages.map {
- ChatMessage(role = ChatRole(it.role), content = it.content, name = it.name)
+ ChatMessage(role = fromRole(it.role), content = it.content, name = it.name)
}
functions =
@@ -232,13 +250,6 @@ class OpenAIClient(val openAI: OpenAI) : AIClient, AutoCloseable {
usage = usage(response.usage)
)
- private fun toEmbeddingRequest(request: EmbeddingRequest): OpenAIEmbeddingRequest =
- embeddingRequest {
- model = ModelId(request.model)
- input = request.input
- user = request.user
- }
-
@OptIn(BetaOpenAI::class)
private fun imageResult(response: List): ImagesGenerationResponse =
ImagesGenerationResponse(data = response.map { ImageGenerationUrl(it.url) })
@@ -252,7 +263,39 @@ class OpenAIClient(val openAI: OpenAI) : AIClient, AutoCloseable {
user = request.user
}
+ override fun tokensFromMessages(messages: List): Int {
+ fun Encoding.countTokensFromMessages(tokensPerMessage: Int, tokensPerName: Int): Int =
+ messages.sumOf { message ->
+ countTokens(message.role.name) +
+ countTokens(message.content) +
+ tokensPerMessage +
+ tokensPerName
+ } + 3
+
+ fun fallBackTo(fallbackModel: Chat, paddingTokens: Int): Int {
+ return fallbackModel.tokensFromMessages(messages) + paddingTokens
+ }
+
+ return when (this) {
+ openAI.GPT_3_5_TURBO_FUNCTIONS ->
+ // paddingToken = 200: reserved for functions
+ fallBackTo(fallbackModel = openAI.GPT_3_5_TURBO_0301, paddingTokens = 200)
+ openAI.GPT_3_5_TURBO ->
+ // otherwise if the model changes, it might later fail
+ fallBackTo(fallbackModel = openAI.GPT_3_5_TURBO_0301, paddingTokens = 5)
+ openAI.GPT_4,
+ openAI.GPT_4_32K ->
+ // otherwise if the model changes, it might later fail
+ fallBackTo(fallbackModel = openAI.GPT_4_0314, paddingTokens = 5)
+ openAI.GPT_3_5_TURBO_0301 ->
+ modelType.encoding.countTokensFromMessages(tokensPerMessage = 4, tokensPerName = 0)
+ openAI.GPT_4_0314 ->
+ modelType.encoding.countTokensFromMessages(tokensPerMessage = 3, tokensPerName = 2)
+ else -> fallBackTo(fallbackModel = openAI.GPT_3_5_TURBO_0301, paddingTokens = 20)
+ }
+ }
+
override fun close() {
- openAI.close()
+ client.close()
}
}
diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIEmbeddings.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIEmbeddings.kt
index ca24ab607..c7017b1a5 100644
--- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIEmbeddings.kt
+++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIEmbeddings.kt
@@ -3,13 +3,13 @@ package com.xebia.functional.xef.auto.llm.openai
import arrow.fx.coroutines.parMap
import com.xebia.functional.xef.embeddings.Embedding
import com.xebia.functional.xef.embeddings.Embeddings
-import com.xebia.functional.xef.llm.AIClient
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import kotlin.time.ExperimentalTime
@ExperimentalTime
-class OpenAIEmbeddings(private val oaiClient: AIClient) : Embeddings {
+class OpenAIEmbeddings(private val oaiClient: com.xebia.functional.xef.llm.Embeddings) :
+ Embeddings {
override suspend fun embedDocuments(
texts: List,
diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIRuntime.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIRuntime.kt
index a02c9e069..daae1fe92 100644
--- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIRuntime.kt
+++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIRuntime.kt
@@ -1,47 +1,54 @@
+@file:JvmName("OpenAIRuntime")
+
package com.xebia.functional.xef.auto.llm.openai
-import com.aallam.openai.api.logging.LogLevel
-import com.aallam.openai.api.logging.Logger
-import com.aallam.openai.client.LoggingConfig
-import com.aallam.openai.client.OpenAI
-import com.aallam.openai.client.OpenAIConfig
-import com.xebia.functional.xef.auto.AIRuntime
+import arrow.core.Either
+import arrow.core.left
+import arrow.core.right
+import com.xebia.functional.xef.AIError
+import com.xebia.functional.xef.auto.AI
import com.xebia.functional.xef.auto.CoreAIScope
-import com.xebia.functional.xef.env.getenv
-import com.xebia.functional.xef.llm.LLMModel
-import com.xebia.functional.xef.vectorstores.LocalVectorStore
-import kotlin.jvm.JvmStatic
+import com.xebia.functional.xef.auto.ai
+import kotlin.jvm.JvmName
import kotlin.time.ExperimentalTime
-object OpenAIRuntime {
- @JvmStatic fun defaults(): AIRuntime = openAI(null)
+/**
+ * Run the [AI] value to produce an [A], this method initialises all the dependencies required to
+ * run the [AI] value and once it finishes it closes all the resources.
+ *
+ * This operator is **terminal** meaning it runs and completes the _chain_ of `AI` actions.
+ */
+suspend inline fun AI.getOrElse(crossinline orElse: suspend (AIError) -> A): A =
+ AIScope(this) { orElse(it) }
+
+/**
+ * Run the [AI] value to produce [A]. this method initialises all the dependencies required to run
+ * the [AI] value and once it finishes it closes all the resources.
+ *
+ * This operator is **terminal** meaning it runs and completes the _chain_ of `AI` actions.
+ *
+ * @throws AIError in case something went wrong.
+ * @see getOrElse for an operator that allow directly handling the [AIError] case instead of
+ * throwing.
+ */
+suspend inline fun AI.getOrThrow(): A = getOrElse { throw it }
+
+/**
+ * Run the [AI] value to produce _either_ an [AIError], or [A]. this method initialises all the
+ * dependencies required to run the [AI] value and once it finishes it closes all the resources.
+ *
+ * This operator is **terminal** meaning it runs and completes the _chain_ of `AI` actions.
+ *
+ * @see getOrElse for an operator that allow directly handling the [AIError] case.
+ */
+suspend inline fun AI.toEither(): Either =
+ ai { invoke().right() }.getOrElse { it.left() }
- @OptIn(ExperimentalTime::class)
- @JvmStatic
- fun openAI(config: OpenAIConfig? = null): AIRuntime {
- val openAIConfig =
- config
- ?: OpenAIConfig(
- logging = LoggingConfig(logLevel = LogLevel.None, logger = Logger.Empty),
- token =
- requireNotNull(getenv("OPENAI_TOKEN")) { "OpenAI Token missing from environment." },
- )
- val openAI = OpenAI(openAIConfig)
- val client = OpenAIClient(openAI)
- val embeddings = OpenAIEmbeddings(client)
- return AIRuntime(client, embeddings) { block ->
- client.use { openAiClient ->
- val vectorStore = LocalVectorStore(embeddings)
- val scope =
- CoreAIScope(
- defaultModel = LLMModel.GPT_3_5_TURBO_16K,
- defaultSerializationModel = LLMModel.GPT_3_5_TURBO_FUNCTIONS,
- aiClient = openAiClient,
- context = vectorStore,
- embeddings = embeddings
- )
- block(scope)
- }
- }
+@OptIn(ExperimentalTime::class)
+suspend fun AIScope(block: AI, orElse: suspend (AIError) -> A): A =
+ try {
+ val scope = CoreAIScope(OpenAIEmbeddings(OpenAI.DEFAULT_EMBEDDING))
+ block(scope)
+ } catch (e: AIError) {
+ orElse(e)
}
-}
diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIScopeExtensions.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIScopeExtensions.kt
new file mode 100644
index 000000000..262a456a0
--- /dev/null
+++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIScopeExtensions.kt
@@ -0,0 +1,52 @@
+package com.xebia.functional.xef.auto.llm.openai
+
+import com.xebia.functional.xef.auto.AiDsl
+import com.xebia.functional.xef.auto.CoreAIScope
+import com.xebia.functional.xef.auto.PromptConfiguration
+import com.xebia.functional.xef.llm.Chat
+import com.xebia.functional.xef.llm.ChatWithFunctions
+import com.xebia.functional.xef.llm.models.functions.CFunction
+import com.xebia.functional.xef.prompt.Prompt
+import kotlinx.serialization.serializer
+
+@AiDsl
+suspend fun CoreAIScope.promptMessage(
+ prompt: String,
+ model: Chat = OpenAI.DEFAULT_CHAT,
+ functions: List = emptyList(),
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
+): List = model.promptMessage(prompt, context, functions, promptConfiguration)
+
+@AiDsl
+suspend fun CoreAIScope.promptMessage(
+ prompt: Prompt,
+ model: Chat = OpenAI.DEFAULT_CHAT,
+ functions: List = emptyList(),
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
+): List = model.promptMessage(prompt, context, functions, promptConfiguration)
+
+@AiDsl
+suspend inline fun CoreAIScope.prompt(
+ prompt: String,
+ model: ChatWithFunctions = OpenAI.DEFAULT_SERIALIZATION,
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
+): A =
+ prompt(
+ model = model,
+ prompt = Prompt(prompt),
+ serializer = serializer(),
+ promptConfiguration = promptConfiguration
+ )
+
+@AiDsl
+suspend inline fun CoreAIScope.image(
+ prompt: String,
+ model: ChatWithFunctions = OpenAI.DEFAULT_SERIALIZATION,
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
+): A =
+ prompt(
+ model = model,
+ prompt = Prompt(prompt),
+ serializer = serializer(),
+ promptConfiguration = promptConfiguration
+ )
diff --git a/scala/build.gradle.kts b/scala/build.gradle.kts
index 45186b4ac..678dddf8e 100644
--- a/scala/build.gradle.kts
+++ b/scala/build.gradle.kts
@@ -10,7 +10,8 @@ plugins {
}
dependencies {
- implementation(projects.xefKotlin)
+ implementation(projects.xefCore)
+ implementation(projects.xefOpenai)
implementation(projects.kotlinLoom)
// TODO split to separate Scala library
diff --git a/scala/src/main/scala/com/xebia/functional/xef/scala/auto/AIScope.scala b/scala/src/main/scala/com/xebia/functional/xef/scala/auto/AIScope.scala
index 81aff208e..a4557c169 100644
--- a/scala/src/main/scala/com/xebia/functional/xef/scala/auto/AIScope.scala
+++ b/scala/src/main/scala/com/xebia/functional/xef/scala/auto/AIScope.scala
@@ -1,7 +1,7 @@
package com.xebia.functional.xef.scala.auto
-import com.xebia.functional.xef.auto.CoreAIScope as KtAIScope
+import com.xebia.functional.xef.auto.CoreAIScope
-final case class AIScope(kt: KtAIScope)
+final case class AIScope(kt: CoreAIScope)
private object AIScope:
- def fromCore(coreAIScope: KtAIScope): AIScope = new AIScope(coreAIScope)
+ def fromCore(coreAIScope: CoreAIScope): AIScope = new AIScope(coreAIScope)
diff --git a/scala/src/main/scala/com/xebia/functional/xef/scala/auto/package.scala b/scala/src/main/scala/com/xebia/functional/xef/scala/auto/package.scala
index 3a357fb13..8ff91a7a0 100644
--- a/scala/src/main/scala/com/xebia/functional/xef/scala/auto/package.scala
+++ b/scala/src/main/scala/com/xebia/functional/xef/scala/auto/package.scala
@@ -2,21 +2,22 @@ package com.xebia.functional.xef.scala.auto
import com.xebia.functional.loom.LoomAdapter
import com.xebia.functional.xef.AIError
-import com.xebia.functional.xef.llm.LLM
-import com.xebia.functional.xef.llm.LLMModel
+import com.xebia.functional.xef.llm.Chat
+import com.xebia.functional.xef.llm.ChatWithFunctions
+import com.xebia.functional.xef.llm.Images
import com.xebia.functional.xef.llm.models.functions.CFunction
import io.circe.Decoder
import io.circe.parser.parse
-import com.xebia.functional.xef.auto.AIKt
-import com.xebia.functional.xef.auto.AIRuntime
-import com.xebia.functional.xef.auto.llm.openai.OpenAIRuntime
-import com.xebia.functional.xef.auto.serialization.JsonSchemaKt
-import com.xebia.functional.xef.pdf.PDFLoaderKt
+import com.xebia.functional.xef.auto.llm.openai.Json
+import com.xebia.functional.xef.pdf.Loader
import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.llm._
+import com.xebia.functional.xef.auto.PromptConfiguration
import com.xebia.functional.xef.auto.llm.openai._
import com.xebia.functional.xef.scala.textsplitters.TextSplitter
import com.xebia.functional.xef.llm.models.images.*
+import com.xebia.functional.xef.auto.llm.openai.OpenAI
+import com.xebia.functional.xef.auto.llm.openai.OpenAIRuntime
import java.io.File
import scala.jdk.CollectionConverters.*
@@ -26,8 +27,7 @@ type AI[A] = AIScope ?=> A
def ai[A](block: AI[A]): A =
LoomAdapter.apply { cont =>
- AIKt.AIScope[A](
- OpenAIRuntime.defaults[A](),
+ OpenAIRuntime.AIScope[A](
{ (coreAIScope, _) =>
given AIScope = AIScope.fromCore(coreAIScope)
@@ -45,28 +45,16 @@ extension [A](block: AI[A]) {
def prompt[A: Decoder: SerialDescriptor](
prompt: String,
- maxAttempts: Int = 5,
- llmModel: LLM.ChatWithFunctions = LLMModel.getGPT_3_5_TURBO_FUNCTIONS,
- user: String = "testing",
- echo: Boolean = false,
- n: Int = 1,
- temperature: Double = 0.0,
- bringFromContext: Int = 10,
- minResponseTokens: Int = 400
+ llmModel: ChatWithFunctions = OpenAI.DEFAULT_SERIALIZATION,
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
)(using scope: AIScope): A =
LoomAdapter.apply((cont) =>
scope.kt.promptWithSerializer[A](
+ llmModel,
prompt,
generateCFunctions.asJava,
(json: String) => parse(json).flatMap(Decoder[A].decodeJson(_)).fold(throw _, identity),
- maxAttempts,
- llmModel,
- user,
- echo,
- n,
- temperature,
- bringFromContext,
- minResponseTokens,
+ promptConfiguration,
cont
)
)
@@ -77,53 +65,48 @@ private def generateCFunctions[A: SerialDescriptor]: List[CFunction] =
val fnName =
if (serialName.contains(".")) serialName.substring(serialName.lastIndexOf("."), serialName.length)
else serialName
- List(CFunction(fnName, "Generated function for $fnName", JsonSchemaKt.encodeJsonSchema(descriptor)))
+ List(CFunction(fnName, "Generated function for $fnName", Json.encodeJsonSchema(descriptor)))
def contextScope[A: Decoder: SerialDescriptor](docs: List[String])(block: AI[A])(using scope: AIScope): A =
LoomAdapter.apply(scope.kt.contextScopeWithDocs[A](docs.asJava, (_, _) => block, _))
def promptMessage(
prompt: String,
- llmModel: LLM.Chat = LLMModel.getGPT_3_5_TURBO,
+ llmModel: Chat = OpenAI.DEFAULT_CHAT,
functions: List[CFunction] = List.empty,
- user: String = "testing",
- echo: Boolean = false,
- n: Int = 1,
- temperature: Double = 0.0,
- bringFromContext: Int = 10,
- minResponseTokens: Int = 500
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
)(using scope: AIScope): List[String] =
LoomAdapter
.apply[java.util.List[String]](
- scope.kt.promptMessage(prompt, llmModel, functions.asJava, user, echo, n, temperature, bringFromContext, minResponseTokens, _)
+ scope.kt.promptMessage(llmModel, prompt, functions.asJava, promptConfiguration, _)
).asScala.toList
def pdf(
resource: String | File,
- splitter: TextSplitter = TextSplitter.tokenTextSplitter(ModelType.GPT_3_5_TURBO, 100, 50)
+ splitter: TextSplitter = TextSplitter.tokenTextSplitter(ModelType.getDEFAULT_SPLITTER_MODEL, 100, 50)
): List[String] =
LoomAdapter
.apply[java.util.List[String]](count =>
resource match
- case url: String => PDFLoaderKt.pdf(url, splitter.core, count)
- case file: File => PDFLoaderKt.pdf(file, splitter.core, count)
+ case url: String => Loader.pdf(url, splitter.core, count)
+ case file: File => Loader.pdf(file, splitter.core, count)
).asScala.toList
def images(
prompt: String,
- user: String = "testing",
+ model: Images = OpenAI.DEFAULT_IMAGES,
+ n: Int = 1,
size: String = "1024x1024",
- bringFromContext: Int = 10,
- n: Int = 1
+ promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
)(using scope: AIScope): List[String] =
LoomAdapter
.apply[ImagesGenerationResponse](cont =>
scope.kt.images(
+ model,
prompt,
- user,
n,
size,
- bringFromContext,
+ promptConfiguration,
cont
)
).getData.asScala.map(_.getUrl).toList
diff --git a/tokenizer/src/commonMain/kotlin/com/xebia/functional/tokenizer/ModelType.kt b/tokenizer/src/commonMain/kotlin/com/xebia/functional/tokenizer/ModelType.kt
index b806fd614..b892b9b65 100644
--- a/tokenizer/src/commonMain/kotlin/com/xebia/functional/tokenizer/ModelType.kt
+++ b/tokenizer/src/commonMain/kotlin/com/xebia/functional/tokenizer/ModelType.kt
@@ -3,15 +3,16 @@ package com.xebia.functional.tokenizer
import com.xebia.functional.tokenizer.EncodingType.CL100K_BASE
import com.xebia.functional.tokenizer.EncodingType.P50K_BASE
import com.xebia.functional.tokenizer.EncodingType.R50K_BASE
+import kotlin.jvm.JvmStatic
-enum class ModelType(
+sealed class ModelType(
/**
* Returns the name of the model type as used by the OpenAI API.
*
* @return the name of the model type
*/
- name: String,
- val encodingType: EncodingType,
+ open val name: String,
+ open val encodingType: EncodingType,
/**
* Returns the maximum context length that is supported by this model type. Note that
* the maximum context length consists of the amount of prompt tokens and the amount of
@@ -19,58 +20,61 @@ enum class ModelType(
*
* @return the maximum context length for this model type
*/
- val maxContextLength: Int
+ open val maxContextLength: Int
) {
+
+ companion object {
+ @JvmStatic
+ val DEFAULT_SPLITTER_MODEL = GPT_3_5_TURBO
+ }
+
+ data class LocalModel(override val name: String, override val encodingType: EncodingType, override val maxContextLength: Int) : ModelType(name, encodingType, maxContextLength)
// chat
- GPT_4("gpt-4", CL100K_BASE, 8192),
- GPT_4_32K("gpt-4-32k", CL100K_BASE, 32768),
- GPT_3_5_TURBO("gpt-3.5-turbo", CL100K_BASE, 4097),
- GPT_3_5_TURBO_16_K("gpt-3.5-turbo-16k", CL100K_BASE, 4097 * 4),
- GPT_3_5_TURBO_FUNCTIONS("gpt-3.5-turbo-0613", CL100K_BASE, 4097),
+ object GPT_4 : ModelType("gpt-4", CL100K_BASE, 8192)
+ object GPT_4_32K : ModelType("gpt-4-32k", CL100K_BASE, 32768)
+ object GPT_3_5_TURBO : ModelType("gpt-3.5-turbo", CL100K_BASE, 4097)
+ object GPT_3_5_TURBO_16_K : ModelType("gpt-3.5-turbo-16k", CL100K_BASE, 4097 * 4)
+ object GPT_3_5_TURBO_FUNCTIONS : ModelType("gpt-3.5-turbo-0613", CL100K_BASE, 4097)
// text
- TEXT_DAVINCI_003("text-davinci-003", P50K_BASE, 4097),
- TEXT_DAVINCI_002("text-davinci-002", P50K_BASE, 4097),
- TEXT_DAVINCI_001("text-davinci-001", R50K_BASE, 2049),
- TEXT_CURIE_001("text-curie-001", R50K_BASE, 2049),
- TEXT_BABBAGE_001("text-babbage-001", R50K_BASE, 2049),
- TEXT_ADA_001("text-ada-001", R50K_BASE, 2049),
- DAVINCI("davinci", R50K_BASE, 2049),
- CURIE("curie", R50K_BASE, 2049),
- BABBAGE("babbage", R50K_BASE, 2049),
- ADA("ada", R50K_BASE, 2049),
+ object TEXT_DAVINCI_003 : ModelType("text-davinci-003", P50K_BASE, 4097)
+ object TEXT_DAVINCI_002 : ModelType("text-davinci-002", P50K_BASE, 4097)
+ object TEXT_DAVINCI_001 : ModelType("text-davinci-001", R50K_BASE, 2049)
+ object TEXT_CURIE_001 : ModelType("text-curie-001", R50K_BASE, 2049)
+ object TEXT_BABBAGE_001 : ModelType("text-babbage-001", R50K_BASE, 2049)
+ object TEXT_ADA_001 : ModelType("text-ada-001", R50K_BASE, 2049)
+ object DAVINCI : ModelType("davinci", R50K_BASE, 2049)
+ object CURIE : ModelType("curie", R50K_BASE, 2049)
+ object BABBAGE : ModelType("babbage", R50K_BASE, 2049)
+ object ADA : ModelType("ada", R50K_BASE, 2049)
// code
- CODE_DAVINCI_002("code-davinci-002", P50K_BASE, 8001),
- CODE_DAVINCI_001("code-davinci-001", P50K_BASE, 8001),
- CODE_CUSHMAN_002("code-cushman-002", P50K_BASE, 2048),
- CODE_CUSHMAN_001("code-cushman-001", P50K_BASE, 2048),
- DAVINCI_CODEX("davinci-codex", P50K_BASE, 4096),
- CUSHMAN_CODEX("cushman-codex", P50K_BASE, 2048),
+ object CODE_DAVINCI_002 : ModelType("code-davinci-002", P50K_BASE, 8001)
+ object CODE_DAVINCI_001 : ModelType("code-davinci-001", P50K_BASE, 8001)
+ object CODE_CUSHMAN_002 : ModelType("code-cushman-002", P50K_BASE, 2048)
+ object CODE_CUSHMAN_001 : ModelType("code-cushman-001", P50K_BASE, 2048)
+ object DAVINCI_CODEX : ModelType("davinci-codex", P50K_BASE, 4096)
+ object CUSHMAN_CODEX : ModelType("cushman-codex", P50K_BASE, 2048)
// edit
- TEXT_DAVINCI_EDIT_001("text-davinci-edit-001", EncodingType.P50K_EDIT, 3000),
- CODE_DAVINCI_EDIT_001("code-davinci-edit-001", EncodingType.P50K_EDIT, 3000),
+ object TEXT_DAVINCI_EDIT_001 : ModelType("text-davinci-edit-001", EncodingType.P50K_EDIT, 3000)
+ object CODE_DAVINCI_EDIT_001 : ModelType("code-davinci-edit-001", EncodingType.P50K_EDIT, 3000)
// embeddings
- TEXT_EMBEDDING_ADA_002("text-embedding-ada-002", CL100K_BASE, 8191),
+ object TEXT_EMBEDDING_ADA_002 : ModelType("text-embedding-ada-002", CL100K_BASE, 8191)
// old embeddings
- TEXT_SIMILARITY_DAVINCI_001("text-similarity-davinci-001", R50K_BASE, 2046),
- TEXT_SIMILARITY_CURIE_001("text-similarity-curie-001", R50K_BASE, 2046),
- TEXT_SIMILARITY_BABBAGE_001("text-similarity-babbage-001", R50K_BASE, 2046),
- TEXT_SIMILARITY_ADA_001("text-similarity-ada-001", R50K_BASE, 2046),
- TEXT_SEARCH_DAVINCI_DOC_001("text-search-davinci-doc-001", R50K_BASE, 2046),
- TEXT_SEARCH_CURIE_DOC_001("text-search-curie-doc-001", R50K_BASE, 2046),
- TEXT_SEARCH_BABBAGE_DOC_001("text-search-babbage-doc-001", R50K_BASE, 2046),
- TEXT_SEARCH_ADA_DOC_001("text-search-ada-doc-001", R50K_BASE, 2046),
- CODE_SEARCH_BABBAGE_CODE_001("code-search-babbage-code-001", R50K_BASE, 2046),
- CODE_SEARCH_ADA_CODE_001("code-search-ada-code-001", R50K_BASE, 2046);
+ object TEXT_SIMILARITY_DAVINCI_001 : ModelType("text-similarity-davinci-001", R50K_BASE, 2046)
+ object TEXT_SIMILARITY_CURIE_001 : ModelType("text-similarity-curie-001", R50K_BASE, 2046)
+ object TEXT_SIMILARITY_BABBAGE_001 : ModelType("text-similarity-babbage-001", R50K_BASE, 2046)
+ object TEXT_SIMILARITY_ADA_001 : ModelType("text-similarity-ada-001", R50K_BASE, 2046)
+ object TEXT_SEARCH_DAVINCI_DOC_001 : ModelType("text-search-davinci-doc-001", R50K_BASE, 2046)
+ object TEXT_SEARCH_CURIE_DOC_001 : ModelType("text-search-curie-doc-001", R50K_BASE, 2046)
+ object TEXT_SEARCH_BABBAGE_DOC_001 : ModelType("text-search-babbage-doc-001", R50K_BASE, 2046)
+ object TEXT_SEARCH_ADA_DOC_001 : ModelType("text-search-ada-doc-001", R50K_BASE, 2046)
+ object CODE_SEARCH_BABBAGE_CODE_001 : ModelType("code-search-babbage-code-001", R50K_BASE, 2046)
+ object CODE_SEARCH_ADA_CODE_001 : ModelType("code-search-ada-code-001", R50K_BASE, 2046)
inline val encoding: Encoding
inline get() = encodingType.encoding
- companion object {
- fun fromName(name: String): ModelType? =
- values().firstOrNull { it.name == name }
- }
}