diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt index 8afe0c330..03ec0ce25 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt @@ -1,11 +1,7 @@ package com.xebia.functional.xef.auto -import com.xebia.functional.xef.AIError -import com.xebia.functional.xef.llm.AIClient import com.xebia.functional.xef.vectorstores.VectorStore -@DslMarker annotation class AiDsl - /** * An [AI] value represents an action relying on artificial intelligence that can be run to produce * an `A`. This value is _lazy_ and can be combined with other `AI` values using @@ -18,10 +14,3 @@ typealias AI = suspend CoreAIScope.() -> A /** A DSL block that makes it more convenient to construct [AI] values. */ inline fun ai(noinline block: suspend CoreAIScope.() -> A): AI = block - -suspend fun AIScope(runtime: AIRuntime, block: AI, orElse: suspend (AIError) -> A): A = - try { - runtime.runtime(block) - } catch (e: AIError) { - orElse(e) - } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AIRuntime.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AIRuntime.kt deleted file mode 100644 index 312c8531e..000000000 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AIRuntime.kt +++ /dev/null @@ -1,12 +0,0 @@ -package com.xebia.functional.xef.auto - -import com.xebia.functional.xef.embeddings.Embeddings -import com.xebia.functional.xef.llm.AIClient - -data class AIRuntime( - val client: AIClient, - val embeddings: Embeddings, - val runtime: suspend (block: AI) -> A -) { - companion object -} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AiDsl.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AiDsl.kt new file mode 100644 index 000000000..aa19f97b9 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AiDsl.kt @@ -0,0 +1,3 @@ +package com.xebia.functional.xef.auto + +@DslMarker annotation class AiDsl diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt index e22ba22d0..4ca848e44 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt @@ -1,23 +1,11 @@ package com.xebia.functional.xef.auto -import arrow.core.nonFatalOrThrow -import arrow.core.raise.catch -import com.xebia.functional.tokenizer.Encoding -import com.xebia.functional.tokenizer.ModelType -import com.xebia.functional.tokenizer.truncateText -import com.xebia.functional.xef.AIError import com.xebia.functional.xef.embeddings.Embeddings -import com.xebia.functional.xef.llm.AIClient -import com.xebia.functional.xef.llm.LLM -import com.xebia.functional.xef.llm.LLMModel -import com.xebia.functional.xef.llm.models.chat.ChatCompletionRequest -import com.xebia.functional.xef.llm.models.chat.ChatCompletionRequestWithFunctions -import com.xebia.functional.xef.llm.models.chat.Message -import com.xebia.functional.xef.llm.models.chat.Role +import com.xebia.functional.xef.llm.Chat +import com.xebia.functional.xef.llm.ChatWithFunctions +import com.xebia.functional.xef.llm.Images import com.xebia.functional.xef.llm.models.functions.CFunction -import com.xebia.functional.xef.llm.models.images.ImagesGenerationRequest import com.xebia.functional.xef.llm.models.images.ImagesGenerationResponse -import com.xebia.functional.xef.llm.models.text.CompletionRequest import com.xebia.functional.xef.prompt.Prompt import com.xebia.functional.xef.vectorstores.CombinedVectorStore import com.xebia.functional.xef.vectorstores.LocalVectorStore @@ -32,18 +20,8 @@ import kotlin.jvm.JvmName * programs. */ class CoreAIScope( - val defaultModel: LLM.Chat, - val defaultSerializationModel: LLM.ChatWithFunctions, - val aiClient: AIClient, - val context: VectorStore, val embeddings: Embeddings, - val maxDeserializationAttempts: Int = 3, - val user: String = "user", - val echo: Boolean = false, - val temperature: Double = 0.4, - val numberOfPredictions: Int = 1, - val docsInContext: Int = 20, - val minResponseTokens: Int = 500 + val context: VectorStore = LocalVectorStore(embeddings), ) { val logger: KLogger = KotlinLogging.logger {} @@ -98,11 +76,8 @@ class CoreAIScope( @AiDsl suspend fun contextScope(store: VectorStore, block: AI): A = CoreAIScope( - defaultModel, - defaultSerializationModel, - this@CoreAIScope.aiClient, - CombinedVectorStore(store, this@CoreAIScope.context), this@CoreAIScope.embeddings, + CombinedVectorStore(store, this@CoreAIScope.context), ) .block() @@ -119,252 +94,26 @@ class CoreAIScope( @AiDsl @JvmName("promptWithSerializer") - suspend fun prompt( - prompt: Prompt, + suspend fun ChatWithFunctions.prompt( + prompt: String, functions: List, serializer: (json: String) -> A, - maxDeserializationAttempts: Int = this.maxDeserializationAttempts, - model: LLM.ChatWithFunctions = defaultSerializationModel, - user: String = this.user, - echo: Boolean = this.echo, - numberOfPredictions: Int = this.numberOfPredictions, - temperature: Double = this.temperature, - bringFromContext: Int = this.docsInContext, - minResponseTokens: Int = this.minResponseTokens, - ): A { - return tryDeserialize(serializer, maxDeserializationAttempts) { - promptMessage( - prompt = prompt, - model = model, - functions = functions, - user = user, - echo = echo, - numberOfPredictions = numberOfPredictions, - temperature = temperature, - bringFromContext = bringFromContext, - minResponseTokens = minResponseTokens - ) - } - } - - suspend fun tryDeserialize( - serializer: (json: String) -> A, - maxDeserializationAttempts: Int, - agent: AI> - ): A { - (0 until maxDeserializationAttempts).forEach { currentAttempts -> - val result = agent().firstOrNull() ?: throw AIError.NoResponse() - catch({ - return@tryDeserialize serializer(result) - }) { e: Throwable -> - logger.error(e) { "Error deserializing response: $result\n${e.message}" } - if (currentAttempts == maxDeserializationAttempts) - throw AIError.JsonParsing(result, maxDeserializationAttempts, e.nonFatalOrThrow()) - // TODO else log attempt ? - } - } - throw AIError.NoResponse() - } - - @AiDsl - suspend fun promptMessage( - question: String, - model: LLM.Chat = defaultModel, - functions: List = emptyList(), - user: String = this.user, - echo: Boolean = this.echo, - n: Int = this.numberOfPredictions, - temperature: Double = this.temperature, - bringFromContext: Int = this.docsInContext, - minResponseTokens: Int = this.minResponseTokens - ): List = - promptMessage( - Prompt(question), - model, - functions, - user, - echo, - n, - temperature, - bringFromContext, - minResponseTokens + promptConfiguration: PromptConfiguration, + ): A = + prompt( + prompt = Prompt(prompt), + context = context, + functions = functions, + serializer = serializer, + promptConfiguration = promptConfiguration, ) @AiDsl - suspend fun promptMessage( - prompt: Prompt, - model: LLM.Chat = defaultModel, + suspend fun Chat.promptMessage( + question: String, functions: List = emptyList(), - user: String = this.user, - echo: Boolean = this.echo, - numberOfPredictions: Int = this.numberOfPredictions, - temperature: Double = this.temperature, - bringFromContext: Int = this.docsInContext, - minResponseTokens: Int - ): List { - - val promptWithContext: String = - createPromptWithContextAwareOfTokens( - ctxInfo = context.similaritySearch(prompt.message, bringFromContext), - modelType = model.modelType, - prompt = prompt.message, - minResponseTokens = minResponseTokens - ) - - fun checkTotalLeftTokens(role: String): Int = - with(model.modelType) { - val roleTokens: Int = encoding.countTokens(role) - val padding = 20 // reserve 20 tokens for additional symbols around the context - val promptTokens: Int = encoding.countTokens(promptWithContext) - val takenTokens: Int = roleTokens + promptTokens + padding - val totalLeftTokens: Int = maxContextLength - takenTokens - if (totalLeftTokens < 0) { - throw AIError.PromptExceedsMaxTokenLength( - promptWithContext, - takenTokens, - maxContextLength - ) - } - logger.debug { - "Tokens -- used: $takenTokens, model max: $maxContextLength, left: $totalLeftTokens" - } - totalLeftTokens - } - - suspend fun buildCompletionRequest(): CompletionRequest = - CompletionRequest( - model = model.name, - user = user, - prompt = promptWithContext, - echo = echo, - n = numberOfPredictions, - temperature = temperature, - maxTokens = checkTotalLeftTokens("") - ) - - fun checkTotalLeftChatTokens(messages: List): Int { - val maxContextLength: Int = model.modelType.maxContextLength - val messagesTokens: Int = tokensFromMessages(messages, model) - val totalLeftTokens: Int = maxContextLength - messagesTokens - if (totalLeftTokens < 0) { - throw AIError.MessagesExceedMaxTokenLength(messages, messagesTokens, maxContextLength) - } - logger.debug { - "Tokens -- used: $messagesTokens, model max: $maxContextLength, left: $totalLeftTokens" - } - return totalLeftTokens - } - - suspend fun buildChatRequest(): ChatCompletionRequest { - val messages: List = listOf(Message(Role.SYSTEM.name, promptWithContext)) - return ChatCompletionRequest( - model = model.name, - user = user, - messages = messages, - n = numberOfPredictions, - temperature = temperature, - maxTokens = checkTotalLeftChatTokens(messages) - ) - } - - suspend fun chatWithFunctionsRequest(): ChatCompletionRequestWithFunctions { - val role: String = Role.USER.name - val firstFnName: String? = functions.firstOrNull()?.name - val messages: List = listOf(Message(role, promptWithContext)) - return ChatCompletionRequestWithFunctions( - model = model.name, - user = user, - messages = messages, - n = numberOfPredictions, - temperature = temperature, - maxTokens = checkTotalLeftChatTokens(messages), - functions = functions, - functionCall = mapOf("name" to (firstFnName ?: "")) - ) - } - - return when (model) { - is LLM.Completion -> - aiClient.createCompletion(buildCompletionRequest()).choices.map { it.text } - is LLM.ChatWithFunctions -> - aiClient.createChatCompletionWithFunctions(chatWithFunctionsRequest()).choices.mapNotNull { - it.message?.functionCall?.arguments - } - else -> - aiClient.createChatCompletion(buildChatRequest()).choices.mapNotNull { it.message?.content } - } - } - - private fun createPromptWithContextAwareOfTokens( - ctxInfo: List, - modelType: ModelType, - prompt: String, - minResponseTokens: Int, - ): String { - val maxContextLength: Int = modelType.maxContextLength - val promptTokens: Int = modelType.encoding.countTokens(prompt) - val remainingTokens: Int = maxContextLength - promptTokens - minResponseTokens - - return if (ctxInfo.isNotEmpty() && remainingTokens > minResponseTokens) { - val ctx: String = ctxInfo.joinToString("\n") - - if (promptTokens >= maxContextLength) { - throw AIError.PromptExceedsMaxTokenLength(prompt, promptTokens, maxContextLength) - } - // truncate the context if it's too long based on the max tokens calculated considering the - // existing prompt tokens - // alternatively we could summarize the context, but that's not implemented yet - val ctxTruncated: String = modelType.encoding.truncateText(ctx, remainingTokens) - - """|```Context - |${ctxTruncated} - |``` - |The context is related to the question try to answer the `goal` as best as you can - |or provide information about the found content - |```goal - |${prompt} - |``` - |ANSWER: - |""" - .trimMargin() - } else prompt - } - - private fun tokensFromMessages(messages: List, model: LLM): Int { - fun Encoding.countTokensFromMessages(tokensPerMessage: Int, tokensPerName: Int): Int = - messages.sumOf { message -> - countTokens(message.role) + - (message.content?.let { countTokens(it) } ?: 0) + - tokensPerMessage + - (message.name?.let { tokensPerName } ?: 0) - } + 3 - - fun fallBackTo(fallbackModel: LLM, paddingTokens: Int): Int { - logger.debug { - "Warning: ${model.name} may change over time. " + - "Returning messages num tokens assuming ${fallbackModel.name} + $paddingTokens padding tokens." - } - return tokensFromMessages(messages, fallbackModel) + paddingTokens - } - - return when (model) { - LLMModel.GPT_3_5_TURBO_FUNCTIONS -> - // paddingToken = 200: reserved for functions - fallBackTo(fallbackModel = LLMModel.GPT_3_5_TURBO_0301, paddingTokens = 200) - LLMModel.GPT_3_5_TURBO -> - // otherwise if the model changes, it might later fail - fallBackTo(fallbackModel = LLMModel.GPT_3_5_TURBO_0301, paddingTokens = 5) - LLMModel.GPT_4, - LLMModel.GPT_4_32K -> - // otherwise if the model changes, it might later fail - fallBackTo(fallbackModel = LLMModel.GPT_4_0314, paddingTokens = 5) - LLMModel.GPT_3_5_TURBO_0301 -> - model.modelType.encoding.countTokensFromMessages(tokensPerMessage = 4, tokensPerName = 0) - LLMModel.GPT_4_0314 -> - model.modelType.encoding.countTokensFromMessages(tokensPerMessage = 3, tokensPerName = 2) - else -> fallBackTo(fallbackModel = LLMModel.GPT_3_5_TURBO_0301, paddingTokens = 20) - } - } + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS + ): List = promptMessage(Prompt(question), context, functions, promptConfiguration) /** * Run a [prompt] describes the images you want to generate within the context of [CoreAIScope]. @@ -374,13 +123,12 @@ class CoreAIScope( * @param numberImages number of images to generate. * @param size the size of the images to generate. */ - suspend fun images( + suspend fun Images.images( prompt: String, - user: String = "testing", numberImages: Int = 1, size: String = "1024x1024", - bringFromContext: Int = 10 - ): ImagesGenerationResponse = images(Prompt(prompt), user, numberImages, size, bringFromContext) + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS + ): ImagesGenerationResponse = this.images(Prompt(prompt), numberImages, size, promptConfiguration) /** * Run a [prompt] describes the images you want to generate within the context of [CoreAIScope]. @@ -390,33 +138,10 @@ class CoreAIScope( * @param numberImages number of images to generate. * @param size the size of the images to generate. */ - suspend fun images( + suspend fun Images.images( prompt: Prompt, - user: String = "testing", numberImages: Int = 1, size: String = "1024x1024", - bringFromContext: Int = 10 - ): ImagesGenerationResponse { - val ctxInfo = context.similaritySearch(prompt.message, bringFromContext) - val promptWithContext = - if (ctxInfo.isNotEmpty()) { - """|Instructions: Use the [Information] below delimited by 3 backticks to accomplish - |the [Objective] at the end of the prompt. - |Try to match the data returned in the [Objective] with this [Information] as best as you can. - |[Information]: - |``` - |${ctxInfo.joinToString("\n")} - |``` - |$prompt""" - .trimMargin() - } else prompt.message - val request = - ImagesGenerationRequest( - prompt = promptWithContext, - numberImages = numberImages, - size = size, - user = user - ) - return aiClient.createImages(request) - } + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS + ): ImagesGenerationResponse = images(prompt, context, numberImages, size, promptConfiguration) } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/PromptConfiguration.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/PromptConfiguration.kt new file mode 100644 index 000000000..4bc979405 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/PromptConfiguration.kt @@ -0,0 +1,59 @@ +package com.xebia.functional.xef.auto + +import com.xebia.functional.xef.llm.models.chat.Role +import kotlin.jvm.JvmField +import kotlin.jvm.JvmName + +class PromptConfiguration( + val maxDeserializationAttempts: Int = 3, + val user: String = Role.USER.name, + val temperature: Double = 0.4, + val numberOfPredictions: Int = 1, + val docsInContext: Int = 20, + val minResponseTokens: Int = 500, +) { + companion object { + + class Builder { + private var maxDeserializationAttempts: Int = 3 + private var user: String = Role.USER.name + private var temperature: Double = 0.4 + private var numberOfPredictions: Int = 1 + private var docsInContext: Int = 20 + private var minResponseTokens: Int = 500 + + fun maxDeserializationAttempts(maxDeserializationAttempts: Int) = apply { + this.maxDeserializationAttempts = maxDeserializationAttempts + } + + fun user(user: String) = apply { this.user = user } + + fun temperature(temperature: Double) = apply { this.temperature = temperature } + + fun numberOfPredictions(numberOfPredictions: Int) = apply { + this.numberOfPredictions = numberOfPredictions + } + + fun docsInContext(docsInContext: Int) = apply { this.docsInContext = docsInContext } + + fun minResponseTokens(minResponseTokens: Int) = apply { + this.minResponseTokens = minResponseTokens + } + + fun build() = + PromptConfiguration( + maxDeserializationAttempts, + user, + temperature, + numberOfPredictions, + docsInContext, + minResponseTokens + ) + } + + @JvmName("build") + operator fun invoke(block: Builder.() -> Unit) = Builder().apply(block).build() + + @JvmField val DEFAULTS = PromptConfiguration() + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/AIClientError.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/AIClientError.kt deleted file mode 100644 index 580d47c4c..000000000 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/AIClientError.kt +++ /dev/null @@ -1,5 +0,0 @@ -package com.xebia.functional.xef.llm - -import kotlinx.serialization.json.JsonElement - -data class AIClientError(val json: JsonElement) : Exception("AI client error") diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt new file mode 100644 index 000000000..2df27608e --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt @@ -0,0 +1,125 @@ +package com.xebia.functional.xef.llm + +import com.xebia.functional.tokenizer.ModelType +import com.xebia.functional.tokenizer.truncateText +import com.xebia.functional.xef.AIError +import com.xebia.functional.xef.auto.AiDsl +import com.xebia.functional.xef.auto.PromptConfiguration +import com.xebia.functional.xef.llm.models.chat.* +import com.xebia.functional.xef.llm.models.functions.CFunction +import com.xebia.functional.xef.prompt.Prompt +import com.xebia.functional.xef.vectorstores.VectorStore + +interface Chat : LLM { + val modelType: ModelType + + suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse + + fun tokensFromMessages(messages: List): Int + + @AiDsl + suspend fun promptMessage( + question: String, + context: VectorStore, + functions: List = emptyList(), + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS + ): List = promptMessage(Prompt(question), context, functions, promptConfiguration) + + @AiDsl + suspend fun promptMessage( + prompt: Prompt, + context: VectorStore, + functions: List = emptyList(), + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS + ): List { + + val promptWithContext: String = + createPromptWithContextAwareOfTokens( + ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext), + modelType = modelType, + prompt = prompt.message, + minResponseTokens = promptConfiguration.minResponseTokens + ) + + fun checkTotalLeftChatTokens(messages: List): Int { + val maxContextLength: Int = modelType.maxContextLength + val messagesTokens: Int = tokensFromMessages(messages) + val totalLeftTokens: Int = maxContextLength - messagesTokens + if (totalLeftTokens < 0) { + throw AIError.MessagesExceedMaxTokenLength(messages, messagesTokens, maxContextLength) + } + return totalLeftTokens + } + + val userMessage = Message(Role.USER, promptWithContext, Role.USER.name) + fun buildChatRequest(): ChatCompletionRequest = + ChatCompletionRequest( + model = name, + user = promptConfiguration.user, + messages = listOf(userMessage), + n = promptConfiguration.numberOfPredictions, + temperature = promptConfiguration.temperature, + maxTokens = checkTotalLeftChatTokens(listOf(userMessage)) + ) + + fun chatWithFunctionsRequest(): ChatCompletionRequestWithFunctions = + ChatCompletionRequestWithFunctions( + model = name, + user = promptConfiguration.user, + messages = listOf(userMessage), + n = promptConfiguration.numberOfPredictions, + temperature = promptConfiguration.temperature, + maxTokens = checkTotalLeftChatTokens(listOf(userMessage)), + functions = functions, + functionCall = mapOf("name" to (functions.firstOrNull()?.name ?: "")) + ) + + return when (this) { + is ChatWithFunctions -> + // we only support functions for now with GPT_3_5_TURBO_FUNCTIONS + if (modelType == ModelType.GPT_3_5_TURBO_FUNCTIONS) { + createChatCompletionWithFunctions(chatWithFunctionsRequest()).choices.mapNotNull { + it.message?.functionCall?.arguments + } + } else { + createChatCompletion(buildChatRequest()).choices.mapNotNull { it.message?.content } + } + else -> createChatCompletion(buildChatRequest()).choices.mapNotNull { it.message?.content } + } + } + + private fun createPromptWithContextAwareOfTokens( + ctxInfo: List, + modelType: ModelType, + prompt: String, + minResponseTokens: Int, + ): String { + val maxContextLength: Int = modelType.maxContextLength + val promptTokens: Int = modelType.encoding.countTokens(prompt) + val remainingTokens: Int = maxContextLength - promptTokens - minResponseTokens + + return if (ctxInfo.isNotEmpty() && remainingTokens > minResponseTokens) { + val ctx: String = ctxInfo.joinToString("\n") + + if (promptTokens >= maxContextLength) { + throw AIError.PromptExceedsMaxTokenLength(prompt, promptTokens, maxContextLength) + } + // truncate the context if it's too long based on the max tokens calculated considering the + // existing prompt tokens + // alternatively we could summarize the context, but that's not implemented yet + val ctxTruncated: String = modelType.encoding.truncateText(ctx, remainingTokens) + + """|```Context + |${ctxTruncated} + |``` + |The context is related to the question try to answer the `goal` as best as you can + |or provide information about the found content + |```goal + |${prompt} + |``` + |ANSWER: + |""" + .trimMargin() + } else prompt + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt new file mode 100644 index 000000000..2de5b177b --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt @@ -0,0 +1,66 @@ +package com.xebia.functional.xef.llm + +import arrow.core.nonFatalOrThrow +import arrow.core.raise.catch +import com.xebia.functional.xef.AIError +import com.xebia.functional.xef.auto.AiDsl +import com.xebia.functional.xef.auto.PromptConfiguration +import com.xebia.functional.xef.llm.models.chat.ChatCompletionRequestWithFunctions +import com.xebia.functional.xef.llm.models.chat.ChatCompletionResponseWithFunctions +import com.xebia.functional.xef.llm.models.functions.CFunction +import com.xebia.functional.xef.prompt.Prompt +import com.xebia.functional.xef.vectorstores.VectorStore + +interface ChatWithFunctions : Chat { + + suspend fun createChatCompletionWithFunctions( + request: ChatCompletionRequestWithFunctions + ): ChatCompletionResponseWithFunctions + + @AiDsl + suspend fun prompt( + prompt: String, + context: VectorStore, + functions: List, + serializer: (json: String) -> A, + promptConfiguration: PromptConfiguration, + ): A = + tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) { + promptMessage( + prompt = Prompt(prompt), + context = context, + functions = functions, + promptConfiguration + ) + } + + @AiDsl + suspend fun prompt( + prompt: Prompt, + context: VectorStore, + functions: List, + serializer: (json: String) -> A, + promptConfiguration: PromptConfiguration, + ): A = + tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) { + promptMessage(prompt = prompt, context = context, functions = functions, promptConfiguration) + } + + private suspend fun tryDeserialize( + serializer: (json: String) -> A, + maxDeserializationAttempts: Int, + agent: suspend () -> List + ): A { + (0 until maxDeserializationAttempts).forEach { currentAttempts -> + val result = agent().firstOrNull() ?: throw AIError.NoResponse() + catch({ + return@tryDeserialize serializer(result) + }) { e: Throwable -> + if (currentAttempts == maxDeserializationAttempts) + throw AIError.JsonParsing(result, maxDeserializationAttempts, e.nonFatalOrThrow()) + // TODO else log attempt ? + } + } + throw AIError.NoResponse() + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Completion.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Completion.kt new file mode 100644 index 000000000..96f20771b --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Completion.kt @@ -0,0 +1,11 @@ +package com.xebia.functional.xef.llm + +import com.xebia.functional.tokenizer.ModelType +import com.xebia.functional.xef.llm.models.text.CompletionRequest +import com.xebia.functional.xef.llm.models.text.CompletionResult + +interface Completion : LLM { + val modelType: ModelType + + suspend fun createCompletion(request: CompletionRequest): CompletionResult +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Embeddings.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Embeddings.kt new file mode 100644 index 000000000..bf8805658 --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Embeddings.kt @@ -0,0 +1,8 @@ +package com.xebia.functional.xef.llm + +import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest +import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult + +interface Embeddings : LLM { + suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Images.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Images.kt new file mode 100644 index 000000000..d9c273b4f --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Images.kt @@ -0,0 +1,45 @@ +package com.xebia.functional.xef.llm + +import com.xebia.functional.xef.auto.PromptConfiguration +import com.xebia.functional.xef.llm.models.images.ImagesGenerationRequest +import com.xebia.functional.xef.llm.models.images.ImagesGenerationResponse +import com.xebia.functional.xef.prompt.Prompt +import com.xebia.functional.xef.vectorstores.VectorStore + +interface Images : LLM { + suspend fun createImages(request: ImagesGenerationRequest): ImagesGenerationResponse + + suspend fun images( + prompt: String, + context: VectorStore, + numberImages: Int = 1, + size: String = "1024x1024", + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS + ): ImagesGenerationResponse = + images(Prompt(prompt), context, numberImages, size, promptConfiguration) + + /** + * Run a [prompt] describes the images you want to generate within the context of [CoreAIScope]. + * Returns a [ImagesGenerationResponse] containing time and urls with images generated. + * + * @param prompt a [Prompt] describing the images you want to generate. + * @param numberImages number of images to generate. + * @param size the size of the images to generate. + */ + suspend fun images( + prompt: Prompt, + context: VectorStore, + numberImages: Int = 1, + size: String = "1024x1024", + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS + ): ImagesGenerationResponse { + val request = + ImagesGenerationRequest( + prompt = prompt.message, + numberImages = numberImages, + size = size, + user = promptConfiguration.user + ) + return createImages(request) + } +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLM.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLM.kt new file mode 100644 index 000000000..95937867b --- /dev/null +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLM.kt @@ -0,0 +1,7 @@ +package com.xebia.functional.xef.llm + +sealed interface LLM : AutoCloseable { + val name: String + + override fun close() {} +} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLMModel.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLMModel.kt deleted file mode 100644 index 8a19464c6..000000000 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLMModel.kt +++ /dev/null @@ -1,65 +0,0 @@ -package com.xebia.functional.xef.llm - -import com.xebia.functional.tokenizer.ModelType -import kotlin.jvm.JvmStatic - -sealed interface LLM { - val name: String - val modelType: ModelType - - interface Chat : LLM - - interface Completion : LLM - - interface ChatWithFunctions : Chat - - interface Embedding : LLM - - interface Images : LLM { - suspend fun createImage() - } -} - -sealed class LLMModel(override val name: String, override val modelType: ModelType) : LLM { - - data class Chat(override val name: String, override val modelType: ModelType) : - LLMModel(name, modelType), LLM.Chat - - data class Completion(override val name: String, override val modelType: ModelType) : - LLMModel(name, modelType), LLM.Completion - - data class ChatWithFunctions(override val name: String, override val modelType: ModelType) : - LLMModel(name, modelType), LLM.ChatWithFunctions - - data class Embedding(override val name: String, override val modelType: ModelType) : - LLMModel(name, modelType), LLM.Embedding - - companion object { - @JvmStatic val GPT_4 = Chat("gpt-4", ModelType.GPT_4) - - @JvmStatic val GPT_4_0314 = Chat("gpt-4-0314", ModelType.GPT_4) - - @JvmStatic val GPT_4_32K = Chat("gpt-4-32k", ModelType.GPT_4_32K) - - @JvmStatic val GPT_3_5_TURBO = Chat("gpt-3.5-turbo", ModelType.GPT_3_5_TURBO) - - @JvmStatic val GPT_3_5_TURBO_16K = Chat("gpt-3.5-turbo-16k", ModelType.GPT_3_5_TURBO_16_K) - - @JvmStatic - val GPT_3_5_TURBO_FUNCTIONS = - ChatWithFunctions("gpt-3.5-turbo-0613", ModelType.GPT_3_5_TURBO_FUNCTIONS) - - @JvmStatic val GPT_3_5_TURBO_0301 = Chat("gpt-3.5-turbo-0301", ModelType.GPT_3_5_TURBO) - - @JvmStatic val TEXT_DAVINCI_003 = Completion("text-davinci-003", ModelType.TEXT_DAVINCI_003) - - @JvmStatic val TEXT_DAVINCI_002 = Completion("text-davinci-002", ModelType.TEXT_DAVINCI_002) - - @JvmStatic - val TEXT_CURIE_001 = Completion("text-curie-001", ModelType.TEXT_SIMILARITY_CURIE_001) - - @JvmStatic val TEXT_BABBAGE_001 = Completion("text-babbage-001", ModelType.TEXT_BABBAGE_001) - - @JvmStatic val TEXT_ADA_001 = Completion("text-ada-001", ModelType.TEXT_ADA_001) - } -} diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Message.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Message.kt index 30fc0e8f5..9f6ef6ed6 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Message.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Message.kt @@ -1,3 +1,3 @@ package com.xebia.functional.xef.llm.models.chat -data class Message(val role: String, val content: String?, val name: String? = Role.ASSISTANT.name) +data class Message(val role: Role, val content: String, val name: String) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Role.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Role.kt index 7eb3fa41b..61d0d8da8 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Role.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/chat/Role.kt @@ -3,6 +3,5 @@ package com.xebia.functional.xef.llm.models.chat enum class Role { SYSTEM, USER, - ASSISTANT, - FUNCTION + ASSISTANT } diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/text/CompletionRequest.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/text/CompletionRequest.kt index 28bc34285..e13fb208d 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/text/CompletionRequest.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/models/text/CompletionRequest.kt @@ -3,7 +3,7 @@ package com.xebia.functional.xef.llm.models.text data class CompletionRequest( val model: String, val user: String, - val prompt: String? = null, + val prompt: String, val suffix: String? = null, val maxTokens: Int? = null, val temperature: Double? = null, diff --git a/core/src/jvmMain/kotlin/com/xebia/functional/xef/loaders/ScrapeURLTextLoader.kt b/core/src/jvmMain/kotlin/com/xebia/functional/xef/loaders/ScrapeURLTextLoader.kt index b64ec2846..e24c3f596 100644 --- a/core/src/jvmMain/kotlin/com/xebia/functional/xef/loaders/ScrapeURLTextLoader.kt +++ b/core/src/jvmMain/kotlin/com/xebia/functional/xef/loaders/ScrapeURLTextLoader.kt @@ -4,26 +4,31 @@ import it.skrape.core.htmlDocument import it.skrape.fetcher.BrowserFetcher import it.skrape.fetcher.response import it.skrape.fetcher.skrape +import java.lang.IllegalStateException /** Creates a TextLoader based on a Path */ suspend fun ScrapeURLTextLoader(url: String): BaseLoader = object : BaseLoader { override suspend fun load(): List = buildList { - skrape(BrowserFetcher) { - request { this.url = url } - response { - htmlDocument { - val cleanedText = cleanUpText(wholeText) - add( - """| + try { + skrape(BrowserFetcher) { + request { this.url = url } + response { + htmlDocument { + val cleanedText = cleanUpText(wholeText) + add( + """| |Title: $titleText |Info: $cleanedText """ - .trimIndent() - ) + .trimIndent() + ) + } } } + } catch (e: IllegalStateException) { + // ignore } } diff --git a/examples/java/src/main/java/com/xebia/functional/xef/java/auto/Animals.java b/examples/java/src/main/java/com/xebia/functional/xef/java/auto/Animals.java index 6b50cb4cf..26821c2ae 100644 --- a/examples/java/src/main/java/com/xebia/functional/xef/java/auto/Animals.java +++ b/examples/java/src/main/java/com/xebia/functional/xef/java/auto/Animals.java @@ -27,20 +27,20 @@ public CompletableFuture story(Animal animal, Invention invention) { return scope.prompt(storyPrompt, Story.class); } - private static class Animal { + public static class Animal { public String name; public String habitat; public String diet; } - private static class Invention { + public static class Invention { public String name; public String inventor; public int year; public String purpose; } - private static class Story { + public static class Story { public Animal animal; public Invention invention; public String text; @@ -60,4 +60,4 @@ public static void main(String[] args) throws ExecutionException, InterruptedExc )).get(); } } -} \ No newline at end of file +} diff --git a/examples/kotlin/build.gradle.kts b/examples/kotlin/build.gradle.kts index 5e5fd632f..fcb49ce1d 100644 --- a/examples/kotlin/build.gradle.kts +++ b/examples/kotlin/build.gradle.kts @@ -37,3 +37,5 @@ tasks.getByName("processResources") { from("${projects.xefGpt4all.dependencyProject.buildDir}/processedResources/jvm/main") into("$buildDir/resources/main") } + + diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/ASCIIArt.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/ASCIIArt.kt index 80ddc5994..aecf5116a 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/ASCIIArt.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/ASCIIArt.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable @@ -7,7 +9,7 @@ data class ASCIIArt(val art: String) suspend fun main() { val art: AI = ai { - prompt("ASCII art of a cat dancing") + prompt( "ASCII art of a cat dancing") } println(art.getOrElse { ASCIIArt("¯\\_(ツ)_/¯" + "\n" + it.message) }) } diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Animal.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Animal.kt index 0acab2387..e160b532e 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Animal.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Animal.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Book.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Book.kt index 01f136e48..1ef28f3b6 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Book.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Book.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/BreakingNews.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/BreakingNews.kt index 97ca4174e..e95d6485d 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/BreakingNews.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/BreakingNews.kt @@ -1,6 +1,8 @@ package com.xebia.functional.xef.auto import com.xebia.functional.xef.agents.search +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable import java.text.SimpleDateFormat import java.util.Date diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/ChessAI.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/ChessAI.kt index aa98fa445..08d02eaa2 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/ChessAI.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/ChessAI.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Colors.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Colors.kt index a5e7d1a1e..76012c018 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Colors.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Colors.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/CustomRuntime.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/CustomRuntime.kt deleted file mode 100644 index ec918c472..000000000 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/CustomRuntime.kt +++ /dev/null @@ -1,36 +0,0 @@ -package com.xebia.functional.xef.auto - -import com.xebia.functional.xef.auto.llm.openai.MockAIScope -import com.xebia.functional.xef.auto.llm.openai.simpleMockAIClient -import com.xebia.functional.xef.embeddings.Embedding -import com.xebia.functional.xef.embeddings.Embeddings -import com.xebia.functional.xef.llm.models.embeddings.RequestConfig - -suspend fun main() { - val program = ai { - val love: List = promptMessage("tell me you like me with just emojis") - println(love) - } - program.getOrElse(customRuntime()) { println(it) } -} - -private fun fakeEmbeddings(): Embeddings = object : Embeddings { - override suspend fun embedDocuments( - texts: List, - chunkSize: Int?, - requestConfig: RequestConfig - ): List = emptyList() - - override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List = - emptyList() -} - -private fun customRuntime(): AIRuntime { - val client = simpleMockAIClient { it } - return AIRuntime(client, fakeEmbeddings()) { block -> - MockAIScope( - client, - block - ) { throw it } - } -} diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/DivergentTasks.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/DivergentTasks.kt index 06ceb404e..22dfe1935 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/DivergentTasks.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/DivergentTasks.kt @@ -1,6 +1,8 @@ package com.xebia.functional.xef.auto import com.xebia.functional.xef.agents.search +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Employee.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Employee.kt index 57bdb34c4..df768dc3f 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Employee.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Employee.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Fact.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Fact.kt index ce7b9fd61..4e7360e0e 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Fact.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Fact.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Love.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Love.kt index 37167b62a..2bd525e21 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Love.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Love.kt @@ -1,5 +1,8 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.promptMessage + suspend fun main() { ai { val love: List = promptMessage("tell me you like me with just emojis") diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Markets.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Markets.kt index 7d5c556d3..d4cfc94e0 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Markets.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Markets.kt @@ -1,6 +1,8 @@ package com.xebia.functional.xef.auto import com.xebia.functional.xef.agents.search +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable import java.text.SimpleDateFormat import java.util.Date diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/MealPlan.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/MealPlan.kt index 6b14a25f5..9a671a7c5 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/MealPlan.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/MealPlan.kt @@ -1,6 +1,8 @@ package com.xebia.functional.xef.auto import com.xebia.functional.xef.agents.search +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/MeaningOfLife.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/MeaningOfLife.kt index e3fc3ce8c..7b8861362 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/MeaningOfLife.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/MeaningOfLife.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Movie.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Movie.kt index 0bc15fde1..d4c12e0e9 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Movie.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Movie.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/PDFDocument.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/PDFDocument.kt index 3ca5ab0e5..7e8867a02 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/PDFDocument.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/PDFDocument.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrThrow +import com.xebia.functional.xef.auto.llm.openai.prompt import com.xebia.functional.xef.pdf.pdf import kotlinx.serialization.Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Person.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Person.kt index 9cb5aeaec..c55600482 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Person.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Person.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Planet.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Planet.kt index 0666121aa..74a0636fc 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Planet.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Planet.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Poem.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Poem.kt index 94dad1846..e12eefda2 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Poem.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Poem.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Population.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Population.kt index 1572f1c9e..cd186fcc7 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Population.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Population.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.image import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Recipe.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Recipe.kt index 3de145f1b..2208d83b7 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Recipe.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Recipe.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/TopAttraction.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/TopAttraction.kt index 78591cf72..a5111569d 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/TopAttraction.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/TopAttraction.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/TouristAttraction.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/TouristAttraction.kt index dfad77734..32c1636bd 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/TouristAttraction.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/TouristAttraction.kt @@ -1,5 +1,7 @@ package com.xebia.functional.xef.auto +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Weather.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Weather.kt index 77db41578..bc5bb60f0 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Weather.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/Weather.kt @@ -1,6 +1,8 @@ package com.xebia.functional.xef.auto import com.xebia.functional.xef.agents.search +import com.xebia.functional.xef.auto.llm.openai.getOrElse +import com.xebia.functional.xef.auto.llm.openai.promptMessage import io.github.oshai.kotlinlogging.KotlinLogging suspend fun main() { diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/gpt4all/Chat.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/gpt4all/Chat.kt index 28ca1be72..8eb84204e 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/gpt4all/Chat.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/gpt4all/Chat.kt @@ -1,50 +1,52 @@ package com.xebia.functional.xef.auto.gpt4all -import arrow.core.raise.either -import arrow.core.raise.ensure -import arrow.core.raise.recover -import com.xebia.functional.gpt4all.* +import com.xebia.functional.gpt4all.GPT4All +import com.xebia.functional.gpt4all.LLModel +import com.xebia.functional.gpt4all.getOrThrow +import com.xebia.functional.gpt4all.huggingFaceUrl +import com.xebia.functional.xef.auto.PromptConfiguration +import com.xebia.functional.xef.auto.ai +import com.xebia.functional.xef.auto.llm.openai.OpenAI +import com.xebia.functional.xef.pdf.pdf import java.nio.file.Path -import java.util.* - -data class ChatError(val content: String) suspend fun main() { - recover({ - val resources = "models/gpt4all" - val path = "$resources/ggml-gpt4all-j-v1.3-groovy.bin" - val modelType = LLModel.Type.GPTJ - - val modelPath: Path = Path.of(path) - ensure(modelPath.toFile().exists()) { - ChatError("Model at ${modelPath.toAbsolutePath()} cannot be found.") - } - - Scanner(System.`in`).use { scanner -> - println("Loading model...") - - GPT4All(modelPath, modelType).use { gpt4All -> - println("Model loaded!") - print("Prompt: ") - - buildList { - while (scanner.hasNextLine()) { - val prompt: String = scanner.nextLine() - if (prompt.equals("exit", ignoreCase = true)) { break } - - println("...") - val promptMessage = Message(Message.Role.USER, prompt) - add(promptMessage) - - val request = ChatCompletionRequest(this, GenerationConfig()) - val response: ChatCompletionResponse = gpt4All.createChatCompletion(request) - println("Response: ${response.choices[0].content}") - - add(response.choices[0]) - print("Prompt: ") - } - } - } - } - }) { println(it) } + val userDir = System.getProperty("user.dir") + val path = "$userDir/models/gpt4all/ggml-gpt4all-j-v1.3-groovy.bin" + val url = huggingFaceUrl("orel12", "ggml-gpt4all-j-v1.3-groovy", "bin") + val modelType = LLModel.Type.GPTJ + val modelPath: Path = Path.of(path) + val GPT4All = GPT4All(url, modelPath, modelType) + + println("🤖 GPT4All loaded: $GPT4All") + + val pdfUrl = "https://www.europarl.europa.eu/RegData/etudes/STUD/2023/740063/IPOL_STU(2023)740063_EN.pdf" + + /** + * Uses internally [HuggingFaceLocalEmbeddings] default of "sentence-transformers", "msmarco-distilbert-dot-v5" + * to provide embeddings for docs in contextScope. + */ + + ai { + println("🤖 Loading PDF: $pdfUrl") + contextScope(pdf(pdfUrl)) { + println("🤖 Context loaded: $context") + GPT4All.use { gpT4All: GPT4All -> + println("🤖 Generating prompt for context") + val prompt = gpT4All.promptMessage( + "Describe in one sentence what the context is about.", + promptConfiguration = PromptConfiguration { + docsInContext(2) + }) + println("🤖 Generating images for prompt: \n\n$prompt") + val images = + OpenAI.DEFAULT_IMAGES.images(prompt.joinToString("\n"), promptConfiguration = PromptConfiguration { + docsInContext(1) + }) + println("🤖 Generated images: \n\n${images.data.joinToString("\n") { it.url }}") + } + } + }.getOrThrow() } + + diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/manual/NoAI.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/manual/NoAI.kt new file mode 100644 index 000000000..0dfcfcc3b --- /dev/null +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/manual/NoAI.kt @@ -0,0 +1,16 @@ +package com.xebia.functional.xef.auto.manual + +import com.xebia.functional.gpt4all.HuggingFaceLocalEmbeddings +import com.xebia.functional.xef.auto.llm.openai.OpenAI +import com.xebia.functional.xef.pdf.pdf +import com.xebia.functional.xef.vectorstores.LocalVectorStore + +suspend fun main() { + val chat = OpenAI.DEFAULT_CHAT + val huggingFaceEmbeddings = HuggingFaceLocalEmbeddings.DEFAULT + val vectorStore = LocalVectorStore(huggingFaceEmbeddings) + val results = pdf("https://www.europarl.europa.eu/RegData/etudes/STUD/2023/740063/IPOL_STU(2023)740063_EN.pdf") + vectorStore.addTexts(results) + val result: List = chat.promptMessage("What is the content about?", vectorStore) + println(result) +} diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/sql/DatabaseExample.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/sql/DatabaseExample.kt index 223aa46d1..8ee08476c 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/sql/DatabaseExample.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/sql/DatabaseExample.kt @@ -1,12 +1,15 @@ package com.xebia.functional.xef.auto.sql import arrow.core.raise.catch -import com.xebia.functional.tokenizer.ModelType +import com.xebia.functional.xef.auto.PromptConfiguration import com.xebia.functional.xef.auto.ai -import com.xebia.functional.xef.auto.getOrThrow +import com.xebia.functional.xef.auto.llm.openai.OpenAI +import com.xebia.functional.xef.auto.llm.openai.getOrThrow import com.xebia.functional.xef.sql.SQL import com.xebia.functional.xef.sql.jdbc.JdbcConfig +val model = OpenAI.DEFAULT_CHAT + val config = JdbcConfig( vendor = System.getenv("XEF_SQL_DB_VENDOR") ?: "mysql", host = System.getenv("XEF_SQL_DB_HOST") ?: "localhost", @@ -14,7 +17,7 @@ val config = JdbcConfig( password = System.getenv("XEF_SQL_DB_PASSWORD") ?: "password", port = System.getenv("XEF_SQL_DB_PORT")?.toInt() ?: 3306, database = System.getenv("XEF_SQL_DB_DATABASE") ?: "database", - llmModelType = ModelType.GPT_3_5_TURBO + model = model ) suspend fun main() = ai { @@ -34,16 +37,21 @@ suspend fun main() = ai { if (input == "exit") break catch({ extendContext(*promptQuery(input).toTypedArray()) - val result = promptMessage("""| - |You are a database assistant that helps users to query and summarize results from the database. - |Instructions: - |1. Summarize the information provided in the `Context` and follow to step 2. - |2. If the information relates to the `input` then answer the question otherwise return just the summary. - |```input - |$input - |``` - |3. Try to answer and provide information with as much detail as you can - """.trimMargin(), bringFromContext = 200) + val result = model.promptMessage( + """| + |You are a database assistant that helps users to query and summarize results from the database. + |Instructions: + |1. Summarize the information provided in the `Context` and follow to step 2. + |2. If the information relates to the `input` then answer the question otherwise return just the summary. + |```input + |$input + |``` + |3. Try to answer and provide information with as much detail as you can + """.trimMargin(), + promptConfiguration = PromptConfiguration.invoke { + docsInContext(50) + } + ) result.forEach { println("llmdb> ${it}") } diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/ControlSignal.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/ControlSignal.kt index 7c460a5f4..80c7239f0 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/ControlSignal.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/ControlSignal.kt @@ -1,7 +1,7 @@ package com.xebia.functional.xef.auto.tot import com.xebia.functional.xef.auto.CoreAIScope -import com.xebia.functional.xef.auto.prompt +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Critique.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Critique.kt index 4f1d5c4bf..0a28a58e1 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Critique.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Critique.kt @@ -1,7 +1,7 @@ package com.xebia.functional.xef.auto.tot import com.xebia.functional.xef.auto.CoreAIScope -import com.xebia.functional.xef.auto.prompt +import com.xebia.functional.xef.auto.llm.openai.prompt import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Main.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Main.kt index f8a2cf499..ce11ec7bc 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Main.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Main.kt @@ -1,7 +1,7 @@ package com.xebia.functional.xef.auto.tot import com.xebia.functional.xef.auto.ai -import com.xebia.functional.xef.auto.getOrThrow +import com.xebia.functional.xef.auto.llm.openai.getOrThrow import kotlinx.serialization.Serializable @Serializable diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Search.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Search.kt deleted file mode 100644 index 0c96b3be6..000000000 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Search.kt +++ /dev/null @@ -1,19 +0,0 @@ -package com.xebia.functional.xef.auto.tot - -import com.xebia.functional.xef.auto.CoreAIScope - -suspend fun CoreAIScope.generateSearchPrompts(problem: Problem): List = - promptMessage( - """|You are an expert SEO consultant. - |You generate search prompts for a problem. - |You are given the following problem: - |${problem.description} - |Instructions: - |1. Generate 1 search prompt to get the best results for this problem. - |2. Ensure the search prompt are relevant to the problem. - |3. Ensure the search prompt expands into the keywords needed to solve the problem. - | - """.trimMargin(), - n = 5 - ).distinct().map { it.replace("\"", "").trim() } - diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Solution.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Solution.kt index f6cfec90e..da44df73e 100644 --- a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Solution.kt +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/auto/tot/Solution.kt @@ -1,7 +1,8 @@ package com.xebia.functional.xef.auto.tot import com.xebia.functional.xef.auto.CoreAIScope -import com.xebia.functional.xef.auto.prompt +import com.xebia.functional.xef.auto.llm.openai.OpenAI +import com.xebia.functional.xef.auto.llm.openai.prompt import com.xebia.functional.xef.prompt.Prompt import kotlinx.serialization.KSerializer import kotlinx.serialization.Serializable @@ -45,7 +46,7 @@ internal suspend fun CoreAIScope.solution( |10. If the solution is valid set the `isValid` field to `true` and the `value` field to the value of the solution. | |""".trimMargin() - return prompt(Prompt(enhancedPrompt), serializer).also { + return prompt(OpenAI.DEFAULT_SERIALIZATION, Prompt(enhancedPrompt), serializer).also { println("🤖 Generated solution: ${truncateText(it.answer)}") } } diff --git a/examples/kotlin/src/main/resources/logback.xml b/examples/kotlin/src/main/resources/logback.xml index 8b62c8b02..dfb690eb2 100644 --- a/examples/kotlin/src/main/resources/logback.xml +++ b/examples/kotlin/src/main/resources/logback.xml @@ -16,6 +16,10 @@ + + + + diff --git a/gpt4all-kotlin/build.gradle.kts b/gpt4all-kotlin/build.gradle.kts index 49b86a305..12d76721e 100644 --- a/gpt4all-kotlin/build.gradle.kts +++ b/gpt4all-kotlin/build.gradle.kts @@ -20,17 +20,17 @@ java { kotlin { jvm() + js(IR) { browser() - nodejs() } - linuxX64() - macosX64() - macosArm64() - mingwX64() sourceSets { - val commonMain by getting {} + val commonMain by getting { + dependencies { + implementation(projects.xefCore) + } + } commonTest { dependencies { @@ -44,32 +44,19 @@ kotlin { val jvmMain by getting { dependencies { implementation("net.java.dev.jna:jna-platform:5.13.0") + implementation("ai.djl.huggingface:tokenizers:+") } } + val jsMain by getting { + } + val jvmTest by getting { dependencies { implementation(libs.kotest.junit5) } } - js { - nodejs {} - browser {} - } - - val linuxX64Main by getting - val macosX64Main by getting - val macosArm64Main by getting - val mingwX64Main by getting - - create("nativeMain") { - dependsOn(commonMain) - linuxX64Main.dependsOn(this) - macosX64Main.dependsOn(this) - macosArm64Main.dependsOn(this) - mingwX64Main.dependsOn(this) - } } } diff --git a/gpt4all-kotlin/src/commonMain/kotlin/com/xebia/functional/gpt4all/models.kt b/gpt4all-kotlin/src/commonMain/kotlin/com/xebia/functional/gpt4all/models.kt index 0ab06d727..33be8d3f8 100644 --- a/gpt4all-kotlin/src/commonMain/kotlin/com/xebia/functional/gpt4all/models.kt +++ b/gpt4all-kotlin/src/commonMain/kotlin/com/xebia/functional/gpt4all/models.kt @@ -14,52 +14,3 @@ data class GenerationConfig( val repeatLastN: Int = 64, val contextErase: Double = 0.5 ) - -data class Completion(val context: String) - -data class Message(val role: Role, val content: String) { - enum class Role { - SYSTEM, - USER, - ASSISTANT - } -} - -data class Embedding( - val embedding: List -) - -data class EmbeddingRequest( - val input: List -) - -data class EmbeddingResponse( - val model: String, - val data: List -) - -data class CompletionRequest( - val prompt: String, - val generationConfig: GenerationConfig -) - -data class ChatCompletionRequest( - val messages: List, - val generationConfig: GenerationConfig -) - -data class CompletionResponse( - val model: String, - val promptTokens: Int, - val completionTokens: Int, - val totalTokens: Int, - val choices: List -) - -data class ChatCompletionResponse( - val model: String, - val promptTokens: Int, - val completionTokens: Int, - val totalTokens: Int, - val choices: List -) diff --git a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt index d818f7da6..17a1273ed 100644 --- a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt +++ b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt @@ -1,84 +1,114 @@ package com.xebia.functional.gpt4all +import ai.djl.training.util.DownloadUtils +import ai.djl.training.util.ProgressBar import com.sun.jna.platform.unix.LibCAPI import com.xebia.functional.gpt4all.libraries.LLModelContext +import com.xebia.functional.tokenizer.EncodingType +import com.xebia.functional.tokenizer.ModelType +import com.xebia.functional.xef.llm.Chat +import com.xebia.functional.xef.llm.Completion +import com.xebia.functional.xef.llm.models.chat.* +import com.xebia.functional.xef.llm.models.text.CompletionChoice +import com.xebia.functional.xef.llm.models.text.CompletionRequest +import com.xebia.functional.xef.llm.models.text.CompletionResult +import com.xebia.functional.xef.llm.models.usage.Usage +import java.net.URL +import java.nio.file.Files import java.nio.file.Path +import java.nio.file.StandardCopyOption +import java.util.* +import kotlin.io.path.name -interface GPT4All : AutoCloseable { - suspend fun createCompletion(request: CompletionRequest): CompletionResponse - suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse - suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResponse - - companion object { - operator fun invoke( - path: Path, - modelType: LLModel.Type - ): GPT4All = object : GPT4All { - val gpt4allModel: GPT4AllModel = GPT4AllModel(path, modelType) - - override suspend fun createCompletion(request: CompletionRequest): CompletionResponse = - with(request) { - val response: String = generateCompletion(prompt, generationConfig) - return CompletionResponse( - gpt4allModel.llModel.name, - prompt.length, - response.length, - totalTokens = prompt.length + response.length, - listOf(Completion(response)) - ) - } - - override suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse = - with(request) { - val prompt: String = messages.buildPrompt() - val response: String = generateCompletion(prompt, generationConfig) - return ChatCompletionResponse( - gpt4allModel.llModel.name, - prompt.length, - response.length, - totalTokens = prompt.length + response.length, - listOf(Message(com.xebia.functional.gpt4all.Message.Role.ASSISTANT, response)) - ) - } - - override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResponse { - TODO("Not yet implemented") - } - - override fun close(): Unit = gpt4allModel.close() - - private fun List.buildPrompt(): String { - val messages: String = joinToString("") { message -> - when (message.role) { - Message.Role.SYSTEM -> message.content - Message.Role.USER -> "\n### Human: ${message.content}" - Message.Role.ASSISTANT -> "\n### Response: ${message.content}" - } - } - return "$messages\n### Response:" - } - - private fun generateCompletion( - prompt: String, - generationConfig: GenerationConfig - ): String { - val context = LLModelContext( - logits_size = LibCAPI.size_t(generationConfig.logitsSize.toLong()), - tokens_size = LibCAPI.size_t(generationConfig.tokensSize.toLong()), - n_past = generationConfig.nPast, - n_ctx = generationConfig.nCtx, - n_predict = generationConfig.nPredict, - top_k = generationConfig.topK, - top_p = generationConfig.topP.toFloat(), - temp = generationConfig.temp.toFloat(), - n_batch = generationConfig.nBatch, - repeat_penalty = generationConfig.repeatPenalty.toFloat(), - repeat_last_n = generationConfig.repeatLastN, - context_erase = generationConfig.contextErase.toFloat() - ) - - return gpt4allModel.prompt(prompt, context) - } +interface GPT4All : AutoCloseable, Chat, Completion { + + val gpt4allModel: GPT4AllModel + + override fun close() { + } + + companion object { + + operator fun invoke( + url: String, + path: Path, + modelType: LLModel.Type, + generationConfig: GenerationConfig = GenerationConfig(), + ): GPT4All = object : GPT4All { + + init { + if (!Files.exists(path)) { + DownloadUtils.download(url, path.toFile().absolutePath , ProgressBar()) + } + } + + override val gpt4allModel = GPT4AllModel.invoke(path, modelType) + + override suspend fun createCompletion(request: CompletionRequest): CompletionResult = + with(request) { + val response: String = gpt4allModel.prompt(prompt, llmModelContext(generationConfig)) + return CompletionResult( + UUID.randomUUID().toString(), + path.name, + System.currentTimeMillis(), + path.name, + listOf(CompletionChoice(response, 0, null, null)), + Usage.ZERO + ) + } + + override suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse = + with(request) { + val response: String = + gpt4allModel.prompt(messages.buildPrompt(), llmModelContext(generationConfig)) + return ChatCompletionResponse( + UUID.randomUUID().toString(), + path.name, + System.currentTimeMillis().toInt(), + path.name, + Usage.ZERO, + listOf(Choice(Message(Role.ASSISTANT, response, Role.ASSISTANT.name), null, 0)), + ) + } + + override fun tokensFromMessages(messages: List): Int = 0 + + override val name: String = path.name + + override fun close(): Unit = gpt4allModel.close() + + override val modelType: ModelType = ModelType.LocalModel(name, EncodingType.CL100K_BASE, 4096) + + private fun List.buildPrompt(): String { + val messages: String = joinToString("") { message -> + when (message.role) { + Role.SYSTEM -> message.content + Role.USER -> "\n### Human: ${message.content}" + Role.ASSISTANT -> "\n### Response: ${message.content}" + } + } + return "$messages\n### Response:" + } + + private fun llmModelContext(generationConfig: GenerationConfig): LLModelContext = + with(generationConfig) { + LLModelContext( + logits_size = LibCAPI.size_t(logitsSize.toLong()), + tokens_size = LibCAPI.size_t(tokensSize.toLong()), + n_past = nPast, + n_ctx = nCtx, + n_predict = nPredict, + top_k = topK, + top_p = topP.toFloat(), + temp = temp.toFloat(), + n_batch = nBatch, + repeat_penalty = repeatPenalty.toFloat(), + repeat_last_n = repeatLastN, + context_erase = contextErase.toFloat() + ) } } + + } } + diff --git a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4AllModel.kt b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4AllModel.kt index 04e2c201a..73c9253e0 100644 --- a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4AllModel.kt +++ b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4AllModel.kt @@ -7,7 +7,6 @@ interface GPT4AllModel : AutoCloseable { val llModel: LLModel fun prompt(prompt: String, context: LLModelContext): String - fun embeddings(prompt: String): List companion object { operator fun invoke( @@ -34,8 +33,6 @@ interface GPT4AllModel : AutoCloseable { responseBuffer.trim().toString() } - override fun embeddings(prompt: String): List = TODO("Not yet implemented") - override fun close(): Unit = llModel.close() } } diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/AIScope.kt b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/Gpt4AllRuntime.kt similarity index 73% rename from kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/AIScope.kt rename to gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/Gpt4AllRuntime.kt index 150714ad8..733610c4e 100644 --- a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/AIScope.kt +++ b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/Gpt4AllRuntime.kt @@ -1,12 +1,12 @@ -package com.xebia.functional.xef.auto +package com.xebia.functional.gpt4all import arrow.core.Either import arrow.core.left import arrow.core.right import com.xebia.functional.xef.AIError -import com.xebia.functional.xef.auto.llm.openai.OpenAIRuntime - -typealias AIScope = CoreAIScope +import com.xebia.functional.xef.auto.AI +import com.xebia.functional.xef.auto.CoreAIScope +import com.xebia.functional.xef.auto.ai /** * Run the [AI] value to produce an [A], this method initialises all the dependencies required to @@ -14,10 +14,8 @@ typealias AIScope = CoreAIScope * * This operator is **terminal** meaning it runs and completes the _chain_ of `AI` actions. */ -suspend inline fun AI.getOrElse( - runtime: AIRuntime = OpenAIRuntime.defaults(), - crossinline orElse: suspend (AIError) -> A -): A = AIScope(runtime, this) { orElse(it) } +suspend inline fun AI.getOrElse(crossinline orElse: suspend (AIError) -> A): A = + AIScope(this) { orElse(it) } /** * Run the [AI] value to produce [A]. this method initialises all the dependencies required to run @@ -41,3 +39,12 @@ suspend inline fun AI.getOrThrow(): A = getOrElse { throw it } */ suspend inline fun AI.toEither(): Either = ai { invoke().right() }.getOrElse { it.left() } + +suspend fun AIScope(block: AI, orElse: suspend (AIError) -> A): A = + try { + val scope = CoreAIScope(HuggingFaceLocalEmbeddings.DEFAULT) + block(scope) + } catch (e: AIError) { + orElse(e) + } + diff --git a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceLocalEmbeddings.kt b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceLocalEmbeddings.kt new file mode 100644 index 000000000..f2b31980f --- /dev/null +++ b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceLocalEmbeddings.kt @@ -0,0 +1,39 @@ +package com.xebia.functional.gpt4all + +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer +import com.xebia.functional.xef.embeddings.Embedding as XefEmbedding +import com.xebia.functional.xef.embeddings.Embeddings +import com.xebia.functional.xef.llm.models.embeddings.Embedding +import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest +import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult +import com.xebia.functional.xef.llm.models.embeddings.RequestConfig +import com.xebia.functional.xef.llm.models.usage.Usage + +class HuggingFaceLocalEmbeddings(name: String, artifact: String) : com.xebia.functional.xef.llm.Embeddings, Embeddings { + + private val tokenizer = HuggingFaceTokenizer.newInstance("$name/$artifact") + + override val name: String = HuggingFaceLocalEmbeddings::class.java.canonicalName + + override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult { + val embeddings = tokenizer.batchEncode(request.input).mapIndexed { ix, em -> + Embedding("embedding", em.ids.map { it.toFloat() }, ix) + } + return EmbeddingResult(embeddings, Usage.ZERO) + } + + override suspend fun embedDocuments( + texts: List, + chunkSize: Int?, + requestConfig: RequestConfig + ): List = + tokenizer.batchEncode(texts).map { em -> XefEmbedding(em.ids.map { it.toFloat() }) } + + override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List = + embedDocuments(listOf(text), null, requestConfig) + + companion object { + @JvmField + val DEFAULT = HuggingFaceLocalEmbeddings("sentence-transformers", "msmarco-distilbert-dot-v5") + } +} diff --git a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceUtils.kt b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceUtils.kt new file mode 100644 index 000000000..ac2f1f4fb --- /dev/null +++ b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceUtils.kt @@ -0,0 +1,4 @@ +package com.xebia.functional.gpt4all + +fun huggingFaceUrl(name: String, artifact:String, extension: String): String = + "https://huggingface.co/$name/$artifact/resolve/main/$artifact.$extension" diff --git a/integrations/pdf/src/main/kotlin/com/xebia/functional/xef/pdf/PDFLoader.kt b/integrations/pdf/src/main/kotlin/com/xebia/functional/xef/pdf/PDFLoader.kt index fe6ac9764..384731c3d 100644 --- a/integrations/pdf/src/main/kotlin/com/xebia/functional/xef/pdf/PDFLoader.kt +++ b/integrations/pdf/src/main/kotlin/com/xebia/functional/xef/pdf/PDFLoader.kt @@ -1,3 +1,5 @@ +@file:JvmName("Loader") +@file:JvmMultifileClass package com.xebia.functional.xef.pdf import com.xebia.functional.tokenizer.ModelType @@ -26,7 +28,6 @@ suspend fun pdf( pdf(file, splitter) } - suspend fun pdf( file: File, splitter: TextSplitter = TokenTextSplitter(modelType = ModelType.GPT_3_5_TURBO, chunkSize = 100, chunkOverlap = 50) diff --git a/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/SQL.kt b/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/SQL.kt index 48cc7426b..dc2e1de54 100644 --- a/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/SQL.kt +++ b/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/SQL.kt @@ -68,7 +68,7 @@ private class JDBCSQLImpl( override suspend fun CoreAIScope.selectTablesForPrompt( tableNames: String, prompt: String - ): List = promptMessage( + ): List = config.model.promptMessage( """|You are an AI assistant which selects the best tables from which the `goal` can be accomplished. |Select from this list of SQL `tables` the tables that you may need to solve the following `goal` |```tables @@ -93,7 +93,7 @@ private class JDBCSQLImpl( val results = resultSet.toDocuments(prompt) logger.debug { "Found: ${results.size} records" } val splitter = TokenTextSplitter( - modelType = config.llmModelType, chunkSize = config.llmModelType.maxContextLength / 2, chunkOverlap = 10 + modelType = config.model.modelType, chunkSize = config.model.modelType.maxContextLength / 2, chunkOverlap = 10 ) val splitDocuments = splitter.splitDocuments(results) logger.debug { "Split into: ${splitDocuments.size} documents" } @@ -101,7 +101,7 @@ private class JDBCSQLImpl( } } - override suspend fun CoreAIScope.sql(ddl: String, input: String): List = promptMessage( + override suspend fun CoreAIScope.sql(ddl: String, input: String): List = config.model.promptMessage( """| |You are an AI assistant which produces SQL SELECT queries in SQL format. |You only reply in valid SQL SELECT queries. @@ -126,7 +126,7 @@ private class JDBCSQLImpl( """.trimMargin() ) - override suspend fun CoreAIScope.getInterestingPromptsForDatabase(): List = promptMessage( + override suspend fun CoreAIScope.getInterestingPromptsForDatabase(): List = config.model.promptMessage( """|You are an AI assistant which replies with a list of the best prompts based on the content of this database: |Instructions: |1. Select from this `ddl` 3 top prompts that the user could ask about this database diff --git a/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/jdbc/JdbcConfig.kt b/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/jdbc/JdbcConfig.kt index f63f3f905..8a0da9b33 100644 --- a/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/jdbc/JdbcConfig.kt +++ b/integrations/sql/src/main/kotlin/com/xebia/functional/xef/sql/jdbc/JdbcConfig.kt @@ -1,6 +1,6 @@ package com.xebia.functional.xef.sql.jdbc -import com.xebia.functional.tokenizer.ModelType +import com.xebia.functional.xef.llm.Chat class JdbcConfig( val vendor: String, @@ -9,7 +9,7 @@ class JdbcConfig( val password: String, val port: Int, val database: String, - val llmModelType: ModelType, + val model: Chat ) { fun toJDBCUrl(): String = "jdbc:$vendor://$host:$port/$database" } diff --git a/java/src/main/java/com/xebia/functional/xef/java/auto/AIScope.java b/java/src/main/java/com/xebia/functional/xef/java/auto/AIScope.java index ac76ddea8..7a4cbd1ea 100644 --- a/java/src/main/java/com/xebia/functional/xef/java/auto/AIScope.java +++ b/java/src/main/java/com/xebia/functional/xef/java/auto/AIScope.java @@ -4,27 +4,25 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.module.jsonSchema.JsonSchema; import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator; -import com.xebia.functional.xef.auto.AIRuntime; import com.xebia.functional.xef.auto.CoreAIScope; -import com.xebia.functional.xef.auto.llm.openai.OpenAIRuntime; -import com.xebia.functional.xef.llm.AIClient; -import com.xebia.functional.xef.llm.LLM; -import com.xebia.functional.xef.llm.LLMModel; +import com.xebia.functional.xef.auto.PromptConfiguration; +import com.xebia.functional.xef.auto.llm.openai.OpenAI; +import com.xebia.functional.xef.auto.llm.openai.OpenAIEmbeddings; +import com.xebia.functional.xef.embeddings.Embeddings; +import com.xebia.functional.xef.llm.Chat; +import com.xebia.functional.xef.llm.ChatWithFunctions; +import com.xebia.functional.xef.llm.Images; import com.xebia.functional.xef.llm.models.functions.CFunction; import com.xebia.functional.xef.llm.models.images.ImageGenerationUrl; import com.xebia.functional.xef.llm.models.images.ImagesGenerationResponse; -import com.xebia.functional.xef.pdf.PDFLoaderKt; +import com.xebia.functional.xef.pdf.Loader; import com.xebia.functional.xef.textsplitters.TextSplitter; import com.xebia.functional.xef.vectorstores.LocalVectorStore; import com.xebia.functional.xef.vectorstores.VectorStore; import kotlin.collections.CollectionsKt; import kotlin.coroutines.Continuation; import kotlin.jvm.functions.Function1; -import kotlinx.coroutines.CoroutineScope; -import kotlinx.coroutines.CoroutineScopeKt; -import kotlinx.coroutines.ExecutorsKt; -import kotlinx.coroutines.JobKt; -import kotlinx.coroutines.CoroutineStart; +import kotlinx.coroutines.*; import kotlinx.coroutines.future.FutureKt; import org.jetbrains.annotations.NotNull; @@ -41,26 +39,24 @@ public class AIScope implements AutoCloseable { private final CoreAIScope scope; private final ObjectMapper om; private final JsonSchemaGenerator schemaGen; - private final AIClient client; private final ExecutorService executorService; private final CoroutineScope coroutineScope; - public AIScope(ObjectMapper om, AIRuntime runtime, ExecutorService executorService) { + public AIScope(ObjectMapper om, Embeddings embeddings, ExecutorService executorService) { this.om = om; this.executorService = executorService; this.coroutineScope = () -> ExecutorsKt.from(executorService).plus(JobKt.Job(null)); this.schemaGen = new JsonSchemaGenerator(om); - this.client = runtime.getClient(); - VectorStore vectorStore = new LocalVectorStore(runtime.getEmbeddings()); - this.scope = new CoreAIScope(LLMModel.getGPT_3_5_TURBO(), LLMModel.getGPT_3_5_TURBO_FUNCTIONS(), client, vectorStore, runtime.getEmbeddings(), 3, "user", false, 0.4, 1, 20, 500); + VectorStore vectorStore = new LocalVectorStore(embeddings); + this.scope = new CoreAIScope(embeddings, vectorStore); } - public AIScope(AIRuntime runtime, ExecutorService executorService) { - this(new ObjectMapper(), runtime, executorService); + public AIScope(Embeddings embeddings, ExecutorService executorService) { + this(new ObjectMapper(), embeddings, executorService); } public AIScope() { - this(new ObjectMapper(), OpenAIRuntime.defaults(), Executors.newCachedThreadPool(new AIScopeThreadFactory())); + this(new ObjectMapper(),new OpenAIEmbeddings(OpenAI.DEFAULT_EMBEDDING), Executors.newCachedThreadPool(new AIScopeThreadFactory())); } private AIScope(CoreAIScope nested, AIScope outer) { @@ -68,15 +64,15 @@ private AIScope(CoreAIScope nested, AIScope outer) { this.executorService = outer.executorService; this.coroutineScope = outer.coroutineScope; this.schemaGen = outer.schemaGen; - this.client = outer.client; this.scope = nested; } public CompletableFuture prompt(String prompt, Class cls) { - return prompt(prompt, cls, scope.getMaxDeserializationAttempts(), scope.getDefaultSerializationModel(), scope.getUser(), scope.getEcho(), scope.getNumberOfPredictions(), scope.getTemperature(), scope.getDocsInContext(), scope.getMinResponseTokens()); + return prompt(prompt, cls, OpenAI.DEFAULT_SERIALIZATION, PromptConfiguration.DEFAULTS); } - public CompletableFuture prompt(String prompt, Class cls, Integer maxAttempts, LLM.ChatWithFunctions llmModel, String user, Boolean echo, Integer n, Double temperature, Integer bringFromContext, Integer minResponseTokens) { + + public CompletableFuture prompt(String prompt, Class cls, ChatWithFunctions llmModel, PromptConfiguration promptConfiguration) { Function1 decoder = json -> { try { return om.readValue(json, cls); @@ -100,30 +96,11 @@ public CompletableFuture prompt(String prompt, Class cls, Integer maxA new CFunction(cls.getSimpleName(), "Generated function for " + cls.getSimpleName(), schema) ); - return future(continuation -> scope.promptWithSerializer(prompt, functions, decoder, maxAttempts, llmModel, user, echo, n, temperature, bringFromContext, minResponseTokens, continuation)); - } - - public CompletableFuture> promptMessage(String prompt, LLM.Chat llmModel, List functions, String user, Boolean echo, Integer n, Double temperature, Integer bringFromContext, Integer minResponseTokens) { - return future(continuation -> scope.promptMessage(prompt, llmModel, functions, user, echo, n, temperature, bringFromContext, minResponseTokens, continuation)); + return future(continuation -> scope.promptWithSerializer(llmModel, prompt, functions, decoder, promptConfiguration, continuation)); } - public CompletableFuture extendContext(String[] docs) { - return future(continuation -> scope.extendContext(docs, continuation)) - .thenApply(unit -> null); - } - - public CompletableFuture contextScope(Function1> f) { - return future(continuation -> scope.contextScope((coreAIScope, continuation1) -> { - AIScope nestedScope = new AIScope(coreAIScope, AIScope.this); - return FutureKt.await(f.invoke(nestedScope), continuation); - }, continuation)); - } - - public CompletableFuture contextScope(VectorStore store, Function1> f) { - return future(continuation -> scope.contextScope(store, (coreAIScope, continuation1) -> { - AIScope nestedScope = new AIScope(coreAIScope, AIScope.this); - return FutureKt.await(f.invoke(nestedScope), continuation); - }, continuation)); + public CompletableFuture> promptMessage(Chat llmModel, String prompt, List functions, PromptConfiguration promptConfiguration) { + return future(continuation -> scope.promptMessage(llmModel, prompt, functions, promptConfiguration, continuation)); } public CompletableFuture contextScope(List docs, Function1> f) { @@ -134,15 +111,15 @@ public CompletableFuture contextScope(List docs, Function1> pdf(String url, TextSplitter splitter) { - return future(continuation -> PDFLoaderKt.pdf(url, splitter, continuation)); + return future(continuation -> Loader.pdf(url, splitter, continuation)); } public CompletableFuture> pdf(File file, TextSplitter splitter) { - return future(continuation -> PDFLoaderKt.pdf(file, splitter, continuation)); + return future(continuation -> Loader.pdf(file, splitter, continuation)); } - public CompletableFuture> images(String prompt, String user, String size, Integer bringFromContext, Integer n) { - return this.future(continuation -> scope.images(prompt, user, n, size, bringFromContext, continuation)) + public CompletableFuture> images(Images model, String prompt, Integer numberOfImages, String size, PromptConfiguration promptConfiguration) { + return this.future(continuation -> scope.images(model, prompt, numberOfImages, size, promptConfiguration, continuation)) .thenApply(response -> CollectionsKt.map(response.getData(), ImageGenerationUrl::getUrl)); } @@ -157,7 +134,6 @@ private CompletableFuture future(Function1 AIScope.prompt( +suspend inline fun CoreAIScope.prompt( + model: ChatWithFunctions, question: String, json: Json = Json { ignoreUnknownKeys = true isLenient = true }, - maxDeserializationAttempts: Int = this.maxDeserializationAttempts, - model: LLM.ChatWithFunctions = this.defaultSerializationModel, - user: String = this.user, - echo: Boolean = this.echo, - n: Int = this.numberOfPredictions, - temperature: Double = this.temperature, - bringFromContext: Int = this.docsInContext -): A = - prompt( - Prompt(question), - json, - maxDeserializationAttempts, - model, - user, - echo, - n, - temperature, - bringFromContext - ) + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS, +): A = prompt(model, Prompt(question), json, promptConfiguration) /** * Run a [prompt] describes the task you want to solve within the context of [AIScope]. Returns a @@ -55,65 +41,34 @@ suspend inline fun AIScope.prompt( * @throws IllegalArgumentException if any of [A]'s type arguments contains star projection. */ @AiDsl -suspend inline fun AIScope.prompt( +suspend inline fun CoreAIScope.prompt( + model: ChatWithFunctions, prompt: Prompt, json: Json = Json { ignoreUnknownKeys = true isLenient = true }, - maxDeserializationAttempts: Int = this.maxDeserializationAttempts, - model: LLM.ChatWithFunctions = this.defaultSerializationModel, - user: String = this.user, - echo: Boolean = this.echo, - n: Int = this.numberOfPredictions, - temperature: Double = this.temperature, - bringFromContext: Int = this.docsInContext -): A = - prompt( - prompt, - serializer(), - json, - maxDeserializationAttempts, - model, - user, - echo, - n, - temperature, - bringFromContext - ) + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS, +): A = prompt(model, prompt, serializer(), json, promptConfiguration) @AiDsl -suspend fun AIScope.prompt( +suspend fun CoreAIScope.prompt( + model: ChatWithFunctions, prompt: Prompt, serializer: KSerializer, json: Json = Json { ignoreUnknownKeys = true isLenient = true }, - maxDeserializationAttempts: Int = this.maxDeserializationAttempts, - model: LLM.ChatWithFunctions = this.defaultSerializationModel, - user: String = this.user, - echo: Boolean = this.echo, - n: Int = this.numberOfPredictions, - temperature: Double = this.temperature, - bringFromContext: Int = this.docsInContext, - minResponseTokens: Int = this.minResponseTokens, -): A { - val functions = generateCFunction(serializer.descriptor) - return prompt( + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS, +): A = + model.prompt( prompt, - functions, + context, + generateCFunction(serializer.descriptor), { json.decodeFromString(serializer, it) }, - maxDeserializationAttempts, - model, - user, - echo, - n, - temperature, - bringFromContext, - minResponseTokens + promptConfiguration ) -} @OptIn(ExperimentalSerializationApi::class) private fun generateCFunction(descriptor: SerialDescriptor): List { diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/ImageGenerationAgent.kt similarity index 66% rename from kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt rename to openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/ImageGenerationAgent.kt index ff058b84b..e6dd15d4c 100644 --- a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/ImageGenerationAgent.kt +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/ImageGenerationAgent.kt @@ -1,6 +1,10 @@ -package com.xebia.functional.xef.auto +package com.xebia.functional.xef.auto.llm.openai import com.xebia.functional.xef.AIError +import com.xebia.functional.xef.auto.CoreAIScope +import com.xebia.functional.xef.auto.PromptConfiguration +import com.xebia.functional.xef.llm.ChatWithFunctions +import com.xebia.functional.xef.llm.Images import com.xebia.functional.xef.llm.models.images.ImagesGenerationResponse import com.xebia.functional.xef.prompt.Prompt @@ -12,14 +16,16 @@ import com.xebia.functional.xef.prompt.Prompt * @param size the size of the images to generate. */ suspend inline fun CoreAIScope.image( + imageModel: Images, + serializationModel: ChatWithFunctions, prompt: String, - user: String = "testing", size: String = "1024x1024", - bringFromContext: Int = 10 + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS, ): A { - val imageResponse = images(prompt, user, 1, size, bringFromContext) + val imageResponse = imageModel.images(prompt, context, 1, size, promptConfiguration) val url = imageResponse.data.firstOrNull() ?: throw AIError.NoResponse() return prompt( + serializationModel, """|Instructions: Format this [URL] and [PROMPT] information in the desired JSON response format |specified at the end of the message. |[URL]: diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/serialization/JsonSchema.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/JsonSchema.kt similarity index 98% rename from kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/serialization/JsonSchema.kt rename to openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/JsonSchema.kt index c7ddf8e5b..16741b7fe 100644 --- a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/serialization/JsonSchema.kt +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/JsonSchema.kt @@ -1,6 +1,8 @@ +@file:JvmName("Json") +@file:JvmMultifileClass @file:OptIn(ExperimentalSerializationApi::class) -package com.xebia.functional.xef.auto.serialization +package com.xebia.functional.xef.auto.llm.openai /* Ported over from https://github.com/Ricky12Awesome/json-schema-serialization @@ -11,6 +13,8 @@ which states the following: // TODO: We should consider a fork and maintain it ourselves. */ +import kotlin.jvm.JvmMultifileClass +import kotlin.jvm.JvmName import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.SerialInfo import kotlinx.serialization.descriptors.PolymorphicKind diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/MockAIClient.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/MockAIClient.kt index f204e20a2..89557430f 100644 --- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/MockAIClient.kt +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/MockAIClient.kt @@ -3,11 +3,11 @@ package com.xebia.functional.xef.auto.llm.openai import arrow.core.Either import arrow.core.left import arrow.core.right +import com.xebia.functional.tokenizer.ModelType import com.xebia.functional.xef.AIError import com.xebia.functional.xef.auto.AI import com.xebia.functional.xef.auto.CoreAIScope -import com.xebia.functional.xef.llm.AIClient -import com.xebia.functional.xef.llm.LLMModel +import com.xebia.functional.xef.llm.* import com.xebia.functional.xef.llm.models.chat.* import com.xebia.functional.xef.llm.models.embeddings.Embedding import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest @@ -37,7 +37,11 @@ class MockOpenAIClient( private val images: (ImagesGenerationRequest) -> ImagesGenerationResponse = { throw NotImplementedError("images not implemented") }, -) : AIClient { +) : ChatWithFunctions, Images, Completion, Embeddings { + + override val name: String = "mock" + override val modelType: ModelType = ModelType.GPT_3_5_TURBO + override suspend fun createCompletion(request: CompletionRequest): CompletionResult = completion(request) @@ -55,6 +59,8 @@ class MockOpenAIClient( override suspend fun createImages(request: ImagesGenerationRequest): ImagesGenerationResponse = images(request) + override fun tokensFromMessages(messages: List): Int = 0 + override fun close() {} } @@ -78,7 +84,7 @@ fun simpleMockAIClient(execute: (String) -> String): MockOpenAIClient = val responses = req.messages.mapIndexed { ix, msg -> val response = execute(msg.content ?: "") - Choice(Message(msg.role, response), "end", ix) + Choice(Message(msg.role, response, msg.role.name), "end", ix) } val requestTokens = req.messages.sumOf { it.content?.split(' ')?.size ?: 0 } val responseTokens = responses.sumOf { it.message?.content?.split(' ')?.size ?: 0 } @@ -96,14 +102,7 @@ suspend fun MockAIScope( try { val embeddings = OpenAIEmbeddings(mockClient) val vectorStore = LocalVectorStore(embeddings) - val scope = - CoreAIScope( - LLMModel.GPT_3_5_TURBO, - LLMModel.GPT_3_5_TURBO_FUNCTIONS, - mockClient, - vectorStore, - embeddings - ) + val scope = CoreAIScope(embeddings, vectorStore) block(scope) } catch (e: AIError) { orElse(e) diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAI.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAI.kt new file mode 100644 index 000000000..aacd9cb51 --- /dev/null +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAI.kt @@ -0,0 +1,57 @@ +package com.xebia.functional.xef.auto.llm.openai + +import arrow.core.nonEmptyListOf +import com.xebia.functional.tokenizer.ModelType +import com.xebia.functional.xef.AIError +import com.xebia.functional.xef.env.getenv +import kotlin.jvm.JvmField + +class OpenAI(internal val token: String) { + val GPT_4 = OpenAIModel(this, "gpt-4", ModelType.GPT_4) + + val GPT_4_0314 = OpenAIModel(this, "gpt-4-0314", ModelType.GPT_4) + + val GPT_4_32K = OpenAIModel(this, "gpt-4-32k", ModelType.GPT_4_32K) + + val GPT_3_5_TURBO = OpenAIModel(this, "gpt-3.5-turbo", ModelType.GPT_3_5_TURBO) + + val GPT_3_5_TURBO_16K = OpenAIModel(this, "gpt-3.5-turbo-16k", ModelType.GPT_3_5_TURBO_16_K) + + val GPT_3_5_TURBO_FUNCTIONS = + OpenAIModel(this, "gpt-3.5-turbo-0613", ModelType.GPT_3_5_TURBO_FUNCTIONS) + + val GPT_3_5_TURBO_0301 = OpenAIModel(this, "gpt-3.5-turbo-0301", ModelType.GPT_3_5_TURBO) + + val TEXT_DAVINCI_003 = OpenAIModel(this, "text-davinci-003", ModelType.TEXT_DAVINCI_003) + + val TEXT_DAVINCI_002 = OpenAIModel(this, "text-davinci-002", ModelType.TEXT_DAVINCI_002) + + val TEXT_CURIE_001 = OpenAIModel(this, "text-curie-001", ModelType.TEXT_SIMILARITY_CURIE_001) + + val TEXT_BABBAGE_001 = OpenAIModel(this, "text-babbage-001", ModelType.TEXT_BABBAGE_001) + + val TEXT_ADA_001 = OpenAIModel(this, "text-ada-001", ModelType.TEXT_ADA_001) + + val TEXT_EMBEDDING_ADA_002 = + OpenAIModel(this, "text-embedding-ada-002", ModelType.TEXT_EMBEDDING_ADA_002) + + val DALLE_2 = OpenAIModel(this, "dalle-2", ModelType.GPT_3_5_TURBO) + + companion object { + + fun openAITokenFromEnv(): String { + return getenv("OPENAI_TOKEN") + ?: throw AIError.Env.OpenAI(nonEmptyListOf("missing OPENAI_TOKEN env var")) + } + + @JvmField val DEFAULT = OpenAI(openAITokenFromEnv()) + + @JvmField val DEFAULT_CHAT = DEFAULT.GPT_3_5_TURBO_16K + + @JvmField val DEFAULT_SERIALIZATION = DEFAULT.GPT_3_5_TURBO_FUNCTIONS + + @JvmField val DEFAULT_EMBEDDING = DEFAULT.TEXT_EMBEDDING_ADA_002 + + @JvmField val DEFAULT_IMAGES = DEFAULT.DALLE_2 + } +} diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIClient.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIClient.kt index 5810789a1..7b9022089 100644 --- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIClient.kt +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIClient.kt @@ -8,7 +8,6 @@ import com.aallam.openai.api.completion.CompletionRequest as OpenAICompletionReq import com.aallam.openai.api.completion.TextCompletion import com.aallam.openai.api.completion.completionRequest import com.aallam.openai.api.core.Usage as OpenAIUsage -import com.aallam.openai.api.embedding.EmbeddingRequest as OpenAIEmbeddingRequest import com.aallam.openai.api.embedding.EmbeddingResponse import com.aallam.openai.api.embedding.embeddingRequest import com.aallam.openai.api.image.ImageCreation @@ -16,8 +15,10 @@ import com.aallam.openai.api.image.ImageSize import com.aallam.openai.api.image.ImageURL import com.aallam.openai.api.image.imageCreation import com.aallam.openai.api.model.ModelId -import com.aallam.openai.client.OpenAI -import com.xebia.functional.xef.llm.AIClient +import com.aallam.openai.client.OpenAI as OpenAIClient +import com.xebia.functional.tokenizer.Encoding +import com.xebia.functional.tokenizer.ModelType +import com.xebia.functional.xef.llm.* import com.xebia.functional.xef.llm.models.chat.* import com.xebia.functional.xef.llm.models.embeddings.Embedding import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest @@ -32,10 +33,16 @@ import com.xebia.functional.xef.llm.models.text.CompletionResult import com.xebia.functional.xef.llm.models.usage.Usage import kotlinx.serialization.json.Json -class OpenAIClient(val openAI: OpenAI) : AIClient, AutoCloseable { +class OpenAIModel( + private val openAI: OpenAI, + override val name: String, + override val modelType: ModelType +) : Chat, ChatWithFunctions, Images, Completion, Embeddings, AutoCloseable { + + private val client = OpenAIClient(openAI.token) override suspend fun createCompletion(request: CompletionRequest): CompletionResult { - val response = openAI.completion(toCompletionRequest(request)) + val response = client.completion(toCompletionRequest(request)) return completionResult(response) } @@ -43,28 +50,61 @@ class OpenAIClient(val openAI: OpenAI) : AIClient, AutoCloseable { override suspend fun createChatCompletion( request: ChatCompletionRequest ): ChatCompletionResponse { - val response = openAI.chatCompletion(toChatCompletionRequest(request)) - return chatCompletionResult(response) + val response = client.chatCompletion(toChatCompletionRequest(request)) + return ChatCompletionResponse( + id = response.id, + `object` = response.model.id, + created = response.created, + model = response.model.id, + choices = response.choices.map { chatCompletionChoice(it) }, + usage = usage(response.usage) + ) } @OptIn(BetaOpenAI::class) override suspend fun createChatCompletionWithFunctions( request: ChatCompletionRequestWithFunctions ): ChatCompletionResponseWithFunctions { - val response = openAI.chatCompletion(toChatCompletionRequestWithFunctions(request)) - return chatCompletionResultWithFunctions(response) + val response = client.chatCompletion(toChatCompletionRequestWithFunctions(request)) + + fun chatCompletionChoiceWithFunctions(choice: ChatChoice): ChoiceWithFunctions = + ChoiceWithFunctions( + message = + choice.message?.let { + MessageWithFunctionCall( + role = it.role.role, + content = it.content, + name = it.name, + functionCall = it.functionCall?.let { FnCall(it.name, it.arguments) } + ) + }, + finishReason = choice.finishReason, + index = choice.index, + ) + + return ChatCompletionResponseWithFunctions( + id = response.id, + `object` = response.model.id, + created = response.created, + model = response.model.id, + choices = response.choices.map { chatCompletionChoiceWithFunctions(it) }, + usage = usage(response.usage) + ) } override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult { - val response = openAI.embeddings(toEmbeddingRequest(request)) - return embeddingResult(response) + val openAIRequest = embeddingRequest { + model = ModelId(request.model) + input = request.input + user = request.user + } + + return embeddingResult(client.embeddings(openAIRequest)) } @OptIn(BetaOpenAI::class) - override suspend fun createImages(request: ImagesGenerationRequest): ImagesGenerationResponse { - val response = openAI.imageURL(toImageCreationRequest(request)) - return imageResult(response) - } + override suspend fun createImages(request: ImagesGenerationRequest): ImagesGenerationResponse = + imageResult(client.imageURL(toImageCreationRequest(request))) private fun toCompletionRequest(request: CompletionRequest): OpenAICompletionRequest = completionRequest { @@ -110,61 +150,39 @@ class OpenAIClient(val openAI: OpenAI) : AIClient, AutoCloseable { totalTokens = usage?.totalTokens, ) - @OptIn(BetaOpenAI::class) - private fun chatCompletionResult(response: ChatCompletion): ChatCompletionResponse = - ChatCompletionResponse( - id = response.id, - `object` = response.model.id, - created = response.created, - model = response.model.id, - choices = response.choices.map { chatCompletionChoice(it) }, - usage = usage(response.usage) - ) - - @OptIn(BetaOpenAI::class) - private fun chatCompletionResultWithFunctions( - response: ChatCompletion - ): ChatCompletionResponseWithFunctions = - ChatCompletionResponseWithFunctions( - id = response.id, - `object` = response.model.id, - created = response.created, - model = response.model.id, - choices = response.choices.map { chatCompletionChoiceWithFunctions(it) }, - usage = usage(response.usage) - ) - - @OptIn(BetaOpenAI::class) - private fun chatCompletionChoiceWithFunctions(choice: ChatChoice): ChoiceWithFunctions = - ChoiceWithFunctions( - message = - choice.message?.let { - MessageWithFunctionCall( - role = it.role.role, - content = it.content, - name = it.name, - functionCall = it.functionCall?.let { FnCall(it.name, it.arguments) } - ) - }, - finishReason = choice.finishReason, - index = choice.index, - ) - @OptIn(BetaOpenAI::class) private fun chatCompletionChoice(choice: ChatChoice): Choice = Choice( message = choice.message?.let { Message( - role = it.role.role, - content = it.content, - name = it.name, + role = toRole(it), + content = it.content ?: "", + name = it.name ?: "", ) }, finishReason = choice.finishReason, index = choice.index, ) + @OptIn(BetaOpenAI::class) + private fun toRole(it: ChatMessage) = + when (it.role) { + ChatRole.User -> Role.USER + ChatRole.Assistant -> Role.ASSISTANT + ChatRole.System -> Role.SYSTEM + ChatRole.Function -> Role.SYSTEM + else -> Role.ASSISTANT + } + + @OptIn(BetaOpenAI::class) + private fun fromRole(it: Role) = + when (it) { + Role.USER -> ChatRole.User + Role.ASSISTANT -> ChatRole.Assistant + Role.SYSTEM -> ChatRole.System + } + @OptIn(BetaOpenAI::class) private fun toChatCompletionRequest(request: ChatCompletionRequest): OpenAIChatCompletionRequest = chatCompletionRequest { @@ -172,7 +190,7 @@ class OpenAIClient(val openAI: OpenAI) : AIClient, AutoCloseable { messages = request.messages.map { ChatMessage( - role = ChatRole(it.role), + role = fromRole(it.role), content = it.content, name = it.name, ) @@ -195,7 +213,7 @@ class OpenAIClient(val openAI: OpenAI) : AIClient, AutoCloseable { model = ModelId(request.model) messages = request.messages.map { - ChatMessage(role = ChatRole(it.role), content = it.content, name = it.name) + ChatMessage(role = fromRole(it.role), content = it.content, name = it.name) } functions = @@ -232,13 +250,6 @@ class OpenAIClient(val openAI: OpenAI) : AIClient, AutoCloseable { usage = usage(response.usage) ) - private fun toEmbeddingRequest(request: EmbeddingRequest): OpenAIEmbeddingRequest = - embeddingRequest { - model = ModelId(request.model) - input = request.input - user = request.user - } - @OptIn(BetaOpenAI::class) private fun imageResult(response: List): ImagesGenerationResponse = ImagesGenerationResponse(data = response.map { ImageGenerationUrl(it.url) }) @@ -252,7 +263,39 @@ class OpenAIClient(val openAI: OpenAI) : AIClient, AutoCloseable { user = request.user } + override fun tokensFromMessages(messages: List): Int { + fun Encoding.countTokensFromMessages(tokensPerMessage: Int, tokensPerName: Int): Int = + messages.sumOf { message -> + countTokens(message.role.name) + + countTokens(message.content) + + tokensPerMessage + + tokensPerName + } + 3 + + fun fallBackTo(fallbackModel: Chat, paddingTokens: Int): Int { + return fallbackModel.tokensFromMessages(messages) + paddingTokens + } + + return when (this) { + openAI.GPT_3_5_TURBO_FUNCTIONS -> + // paddingToken = 200: reserved for functions + fallBackTo(fallbackModel = openAI.GPT_3_5_TURBO_0301, paddingTokens = 200) + openAI.GPT_3_5_TURBO -> + // otherwise if the model changes, it might later fail + fallBackTo(fallbackModel = openAI.GPT_3_5_TURBO_0301, paddingTokens = 5) + openAI.GPT_4, + openAI.GPT_4_32K -> + // otherwise if the model changes, it might later fail + fallBackTo(fallbackModel = openAI.GPT_4_0314, paddingTokens = 5) + openAI.GPT_3_5_TURBO_0301 -> + modelType.encoding.countTokensFromMessages(tokensPerMessage = 4, tokensPerName = 0) + openAI.GPT_4_0314 -> + modelType.encoding.countTokensFromMessages(tokensPerMessage = 3, tokensPerName = 2) + else -> fallBackTo(fallbackModel = openAI.GPT_3_5_TURBO_0301, paddingTokens = 20) + } + } + override fun close() { - openAI.close() + client.close() } } diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIEmbeddings.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIEmbeddings.kt index ca24ab607..c7017b1a5 100644 --- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIEmbeddings.kt +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIEmbeddings.kt @@ -3,13 +3,13 @@ package com.xebia.functional.xef.auto.llm.openai import arrow.fx.coroutines.parMap import com.xebia.functional.xef.embeddings.Embedding import com.xebia.functional.xef.embeddings.Embeddings -import com.xebia.functional.xef.llm.AIClient import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest import com.xebia.functional.xef.llm.models.embeddings.RequestConfig import kotlin.time.ExperimentalTime @ExperimentalTime -class OpenAIEmbeddings(private val oaiClient: AIClient) : Embeddings { +class OpenAIEmbeddings(private val oaiClient: com.xebia.functional.xef.llm.Embeddings) : + Embeddings { override suspend fun embedDocuments( texts: List, diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIRuntime.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIRuntime.kt index a02c9e069..daae1fe92 100644 --- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIRuntime.kt +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIRuntime.kt @@ -1,47 +1,54 @@ +@file:JvmName("OpenAIRuntime") + package com.xebia.functional.xef.auto.llm.openai -import com.aallam.openai.api.logging.LogLevel -import com.aallam.openai.api.logging.Logger -import com.aallam.openai.client.LoggingConfig -import com.aallam.openai.client.OpenAI -import com.aallam.openai.client.OpenAIConfig -import com.xebia.functional.xef.auto.AIRuntime +import arrow.core.Either +import arrow.core.left +import arrow.core.right +import com.xebia.functional.xef.AIError +import com.xebia.functional.xef.auto.AI import com.xebia.functional.xef.auto.CoreAIScope -import com.xebia.functional.xef.env.getenv -import com.xebia.functional.xef.llm.LLMModel -import com.xebia.functional.xef.vectorstores.LocalVectorStore -import kotlin.jvm.JvmStatic +import com.xebia.functional.xef.auto.ai +import kotlin.jvm.JvmName import kotlin.time.ExperimentalTime -object OpenAIRuntime { - @JvmStatic fun defaults(): AIRuntime = openAI(null) +/** + * Run the [AI] value to produce an [A], this method initialises all the dependencies required to + * run the [AI] value and once it finishes it closes all the resources. + * + * This operator is **terminal** meaning it runs and completes the _chain_ of `AI` actions. + */ +suspend inline fun AI.getOrElse(crossinline orElse: suspend (AIError) -> A): A = + AIScope(this) { orElse(it) } + +/** + * Run the [AI] value to produce [A]. this method initialises all the dependencies required to run + * the [AI] value and once it finishes it closes all the resources. + * + * This operator is **terminal** meaning it runs and completes the _chain_ of `AI` actions. + * + * @throws AIError in case something went wrong. + * @see getOrElse for an operator that allow directly handling the [AIError] case instead of + * throwing. + */ +suspend inline fun AI.getOrThrow(): A = getOrElse { throw it } + +/** + * Run the [AI] value to produce _either_ an [AIError], or [A]. this method initialises all the + * dependencies required to run the [AI] value and once it finishes it closes all the resources. + * + * This operator is **terminal** meaning it runs and completes the _chain_ of `AI` actions. + * + * @see getOrElse for an operator that allow directly handling the [AIError] case. + */ +suspend inline fun AI.toEither(): Either = + ai { invoke().right() }.getOrElse { it.left() } - @OptIn(ExperimentalTime::class) - @JvmStatic - fun openAI(config: OpenAIConfig? = null): AIRuntime { - val openAIConfig = - config - ?: OpenAIConfig( - logging = LoggingConfig(logLevel = LogLevel.None, logger = Logger.Empty), - token = - requireNotNull(getenv("OPENAI_TOKEN")) { "OpenAI Token missing from environment." }, - ) - val openAI = OpenAI(openAIConfig) - val client = OpenAIClient(openAI) - val embeddings = OpenAIEmbeddings(client) - return AIRuntime(client, embeddings) { block -> - client.use { openAiClient -> - val vectorStore = LocalVectorStore(embeddings) - val scope = - CoreAIScope( - defaultModel = LLMModel.GPT_3_5_TURBO_16K, - defaultSerializationModel = LLMModel.GPT_3_5_TURBO_FUNCTIONS, - aiClient = openAiClient, - context = vectorStore, - embeddings = embeddings - ) - block(scope) - } - } +@OptIn(ExperimentalTime::class) +suspend fun AIScope(block: AI, orElse: suspend (AIError) -> A): A = + try { + val scope = CoreAIScope(OpenAIEmbeddings(OpenAI.DEFAULT_EMBEDDING)) + block(scope) + } catch (e: AIError) { + orElse(e) } -} diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIScopeExtensions.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIScopeExtensions.kt new file mode 100644 index 000000000..262a456a0 --- /dev/null +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAIScopeExtensions.kt @@ -0,0 +1,52 @@ +package com.xebia.functional.xef.auto.llm.openai + +import com.xebia.functional.xef.auto.AiDsl +import com.xebia.functional.xef.auto.CoreAIScope +import com.xebia.functional.xef.auto.PromptConfiguration +import com.xebia.functional.xef.llm.Chat +import com.xebia.functional.xef.llm.ChatWithFunctions +import com.xebia.functional.xef.llm.models.functions.CFunction +import com.xebia.functional.xef.prompt.Prompt +import kotlinx.serialization.serializer + +@AiDsl +suspend fun CoreAIScope.promptMessage( + prompt: String, + model: Chat = OpenAI.DEFAULT_CHAT, + functions: List = emptyList(), + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS, +): List = model.promptMessage(prompt, context, functions, promptConfiguration) + +@AiDsl +suspend fun CoreAIScope.promptMessage( + prompt: Prompt, + model: Chat = OpenAI.DEFAULT_CHAT, + functions: List = emptyList(), + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS, +): List = model.promptMessage(prompt, context, functions, promptConfiguration) + +@AiDsl +suspend inline fun CoreAIScope.prompt( + prompt: String, + model: ChatWithFunctions = OpenAI.DEFAULT_SERIALIZATION, + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS, +): A = + prompt( + model = model, + prompt = Prompt(prompt), + serializer = serializer(), + promptConfiguration = promptConfiguration + ) + +@AiDsl +suspend inline fun CoreAIScope.image( + prompt: String, + model: ChatWithFunctions = OpenAI.DEFAULT_SERIALIZATION, + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS, +): A = + prompt( + model = model, + prompt = Prompt(prompt), + serializer = serializer(), + promptConfiguration = promptConfiguration + ) diff --git a/scala/build.gradle.kts b/scala/build.gradle.kts index 45186b4ac..678dddf8e 100644 --- a/scala/build.gradle.kts +++ b/scala/build.gradle.kts @@ -10,7 +10,8 @@ plugins { } dependencies { - implementation(projects.xefKotlin) + implementation(projects.xefCore) + implementation(projects.xefOpenai) implementation(projects.kotlinLoom) // TODO split to separate Scala library diff --git a/scala/src/main/scala/com/xebia/functional/xef/scala/auto/AIScope.scala b/scala/src/main/scala/com/xebia/functional/xef/scala/auto/AIScope.scala index 81aff208e..a4557c169 100644 --- a/scala/src/main/scala/com/xebia/functional/xef/scala/auto/AIScope.scala +++ b/scala/src/main/scala/com/xebia/functional/xef/scala/auto/AIScope.scala @@ -1,7 +1,7 @@ package com.xebia.functional.xef.scala.auto -import com.xebia.functional.xef.auto.CoreAIScope as KtAIScope +import com.xebia.functional.xef.auto.CoreAIScope -final case class AIScope(kt: KtAIScope) +final case class AIScope(kt: CoreAIScope) private object AIScope: - def fromCore(coreAIScope: KtAIScope): AIScope = new AIScope(coreAIScope) + def fromCore(coreAIScope: CoreAIScope): AIScope = new AIScope(coreAIScope) diff --git a/scala/src/main/scala/com/xebia/functional/xef/scala/auto/package.scala b/scala/src/main/scala/com/xebia/functional/xef/scala/auto/package.scala index 3a357fb13..8ff91a7a0 100644 --- a/scala/src/main/scala/com/xebia/functional/xef/scala/auto/package.scala +++ b/scala/src/main/scala/com/xebia/functional/xef/scala/auto/package.scala @@ -2,21 +2,22 @@ package com.xebia.functional.xef.scala.auto import com.xebia.functional.loom.LoomAdapter import com.xebia.functional.xef.AIError -import com.xebia.functional.xef.llm.LLM -import com.xebia.functional.xef.llm.LLMModel +import com.xebia.functional.xef.llm.Chat +import com.xebia.functional.xef.llm.ChatWithFunctions +import com.xebia.functional.xef.llm.Images import com.xebia.functional.xef.llm.models.functions.CFunction import io.circe.Decoder import io.circe.parser.parse -import com.xebia.functional.xef.auto.AIKt -import com.xebia.functional.xef.auto.AIRuntime -import com.xebia.functional.xef.auto.llm.openai.OpenAIRuntime -import com.xebia.functional.xef.auto.serialization.JsonSchemaKt -import com.xebia.functional.xef.pdf.PDFLoaderKt +import com.xebia.functional.xef.auto.llm.openai.Json +import com.xebia.functional.xef.pdf.Loader import com.xebia.functional.tokenizer.ModelType import com.xebia.functional.xef.llm._ +import com.xebia.functional.xef.auto.PromptConfiguration import com.xebia.functional.xef.auto.llm.openai._ import com.xebia.functional.xef.scala.textsplitters.TextSplitter import com.xebia.functional.xef.llm.models.images.* +import com.xebia.functional.xef.auto.llm.openai.OpenAI +import com.xebia.functional.xef.auto.llm.openai.OpenAIRuntime import java.io.File import scala.jdk.CollectionConverters.* @@ -26,8 +27,7 @@ type AI[A] = AIScope ?=> A def ai[A](block: AI[A]): A = LoomAdapter.apply { cont => - AIKt.AIScope[A]( - OpenAIRuntime.defaults[A](), + OpenAIRuntime.AIScope[A]( { (coreAIScope, _) => given AIScope = AIScope.fromCore(coreAIScope) @@ -45,28 +45,16 @@ extension [A](block: AI[A]) { def prompt[A: Decoder: SerialDescriptor]( prompt: String, - maxAttempts: Int = 5, - llmModel: LLM.ChatWithFunctions = LLMModel.getGPT_3_5_TURBO_FUNCTIONS, - user: String = "testing", - echo: Boolean = false, - n: Int = 1, - temperature: Double = 0.0, - bringFromContext: Int = 10, - minResponseTokens: Int = 400 + llmModel: ChatWithFunctions = OpenAI.DEFAULT_SERIALIZATION, + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS )(using scope: AIScope): A = LoomAdapter.apply((cont) => scope.kt.promptWithSerializer[A]( + llmModel, prompt, generateCFunctions.asJava, (json: String) => parse(json).flatMap(Decoder[A].decodeJson(_)).fold(throw _, identity), - maxAttempts, - llmModel, - user, - echo, - n, - temperature, - bringFromContext, - minResponseTokens, + promptConfiguration, cont ) ) @@ -77,53 +65,48 @@ private def generateCFunctions[A: SerialDescriptor]: List[CFunction] = val fnName = if (serialName.contains(".")) serialName.substring(serialName.lastIndexOf("."), serialName.length) else serialName - List(CFunction(fnName, "Generated function for $fnName", JsonSchemaKt.encodeJsonSchema(descriptor))) + List(CFunction(fnName, "Generated function for $fnName", Json.encodeJsonSchema(descriptor))) def contextScope[A: Decoder: SerialDescriptor](docs: List[String])(block: AI[A])(using scope: AIScope): A = LoomAdapter.apply(scope.kt.contextScopeWithDocs[A](docs.asJava, (_, _) => block, _)) def promptMessage( prompt: String, - llmModel: LLM.Chat = LLMModel.getGPT_3_5_TURBO, + llmModel: Chat = OpenAI.DEFAULT_CHAT, functions: List[CFunction] = List.empty, - user: String = "testing", - echo: Boolean = false, - n: Int = 1, - temperature: Double = 0.0, - bringFromContext: Int = 10, - minResponseTokens: Int = 500 + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS )(using scope: AIScope): List[String] = LoomAdapter .apply[java.util.List[String]]( - scope.kt.promptMessage(prompt, llmModel, functions.asJava, user, echo, n, temperature, bringFromContext, minResponseTokens, _) + scope.kt.promptMessage(llmModel, prompt, functions.asJava, promptConfiguration, _) ).asScala.toList def pdf( resource: String | File, - splitter: TextSplitter = TextSplitter.tokenTextSplitter(ModelType.GPT_3_5_TURBO, 100, 50) + splitter: TextSplitter = TextSplitter.tokenTextSplitter(ModelType.getDEFAULT_SPLITTER_MODEL, 100, 50) ): List[String] = LoomAdapter .apply[java.util.List[String]](count => resource match - case url: String => PDFLoaderKt.pdf(url, splitter.core, count) - case file: File => PDFLoaderKt.pdf(file, splitter.core, count) + case url: String => Loader.pdf(url, splitter.core, count) + case file: File => Loader.pdf(file, splitter.core, count) ).asScala.toList def images( prompt: String, - user: String = "testing", + model: Images = OpenAI.DEFAULT_IMAGES, + n: Int = 1, size: String = "1024x1024", - bringFromContext: Int = 10, - n: Int = 1 + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS )(using scope: AIScope): List[String] = LoomAdapter .apply[ImagesGenerationResponse](cont => scope.kt.images( + model, prompt, - user, n, size, - bringFromContext, + promptConfiguration, cont ) ).getData.asScala.map(_.getUrl).toList diff --git a/tokenizer/src/commonMain/kotlin/com/xebia/functional/tokenizer/ModelType.kt b/tokenizer/src/commonMain/kotlin/com/xebia/functional/tokenizer/ModelType.kt index b806fd614..b892b9b65 100644 --- a/tokenizer/src/commonMain/kotlin/com/xebia/functional/tokenizer/ModelType.kt +++ b/tokenizer/src/commonMain/kotlin/com/xebia/functional/tokenizer/ModelType.kt @@ -3,15 +3,16 @@ package com.xebia.functional.tokenizer import com.xebia.functional.tokenizer.EncodingType.CL100K_BASE import com.xebia.functional.tokenizer.EncodingType.P50K_BASE import com.xebia.functional.tokenizer.EncodingType.R50K_BASE +import kotlin.jvm.JvmStatic -enum class ModelType( +sealed class ModelType( /** * Returns the name of the model type as used by the OpenAI API. * * @return the name of the model type */ - name: String, - val encodingType: EncodingType, + open val name: String, + open val encodingType: EncodingType, /** * Returns the maximum context length that is supported by this model type. Note that * the maximum context length consists of the amount of prompt tokens and the amount of @@ -19,58 +20,61 @@ enum class ModelType( * * @return the maximum context length for this model type */ - val maxContextLength: Int + open val maxContextLength: Int ) { + + companion object { + @JvmStatic + val DEFAULT_SPLITTER_MODEL = GPT_3_5_TURBO + } + + data class LocalModel(override val name: String, override val encodingType: EncodingType, override val maxContextLength: Int) : ModelType(name, encodingType, maxContextLength) // chat - GPT_4("gpt-4", CL100K_BASE, 8192), - GPT_4_32K("gpt-4-32k", CL100K_BASE, 32768), - GPT_3_5_TURBO("gpt-3.5-turbo", CL100K_BASE, 4097), - GPT_3_5_TURBO_16_K("gpt-3.5-turbo-16k", CL100K_BASE, 4097 * 4), - GPT_3_5_TURBO_FUNCTIONS("gpt-3.5-turbo-0613", CL100K_BASE, 4097), + object GPT_4 : ModelType("gpt-4", CL100K_BASE, 8192) + object GPT_4_32K : ModelType("gpt-4-32k", CL100K_BASE, 32768) + object GPT_3_5_TURBO : ModelType("gpt-3.5-turbo", CL100K_BASE, 4097) + object GPT_3_5_TURBO_16_K : ModelType("gpt-3.5-turbo-16k", CL100K_BASE, 4097 * 4) + object GPT_3_5_TURBO_FUNCTIONS : ModelType("gpt-3.5-turbo-0613", CL100K_BASE, 4097) // text - TEXT_DAVINCI_003("text-davinci-003", P50K_BASE, 4097), - TEXT_DAVINCI_002("text-davinci-002", P50K_BASE, 4097), - TEXT_DAVINCI_001("text-davinci-001", R50K_BASE, 2049), - TEXT_CURIE_001("text-curie-001", R50K_BASE, 2049), - TEXT_BABBAGE_001("text-babbage-001", R50K_BASE, 2049), - TEXT_ADA_001("text-ada-001", R50K_BASE, 2049), - DAVINCI("davinci", R50K_BASE, 2049), - CURIE("curie", R50K_BASE, 2049), - BABBAGE("babbage", R50K_BASE, 2049), - ADA("ada", R50K_BASE, 2049), + object TEXT_DAVINCI_003 : ModelType("text-davinci-003", P50K_BASE, 4097) + object TEXT_DAVINCI_002 : ModelType("text-davinci-002", P50K_BASE, 4097) + object TEXT_DAVINCI_001 : ModelType("text-davinci-001", R50K_BASE, 2049) + object TEXT_CURIE_001 : ModelType("text-curie-001", R50K_BASE, 2049) + object TEXT_BABBAGE_001 : ModelType("text-babbage-001", R50K_BASE, 2049) + object TEXT_ADA_001 : ModelType("text-ada-001", R50K_BASE, 2049) + object DAVINCI : ModelType("davinci", R50K_BASE, 2049) + object CURIE : ModelType("curie", R50K_BASE, 2049) + object BABBAGE : ModelType("babbage", R50K_BASE, 2049) + object ADA : ModelType("ada", R50K_BASE, 2049) // code - CODE_DAVINCI_002("code-davinci-002", P50K_BASE, 8001), - CODE_DAVINCI_001("code-davinci-001", P50K_BASE, 8001), - CODE_CUSHMAN_002("code-cushman-002", P50K_BASE, 2048), - CODE_CUSHMAN_001("code-cushman-001", P50K_BASE, 2048), - DAVINCI_CODEX("davinci-codex", P50K_BASE, 4096), - CUSHMAN_CODEX("cushman-codex", P50K_BASE, 2048), + object CODE_DAVINCI_002 : ModelType("code-davinci-002", P50K_BASE, 8001) + object CODE_DAVINCI_001 : ModelType("code-davinci-001", P50K_BASE, 8001) + object CODE_CUSHMAN_002 : ModelType("code-cushman-002", P50K_BASE, 2048) + object CODE_CUSHMAN_001 : ModelType("code-cushman-001", P50K_BASE, 2048) + object DAVINCI_CODEX : ModelType("davinci-codex", P50K_BASE, 4096) + object CUSHMAN_CODEX : ModelType("cushman-codex", P50K_BASE, 2048) // edit - TEXT_DAVINCI_EDIT_001("text-davinci-edit-001", EncodingType.P50K_EDIT, 3000), - CODE_DAVINCI_EDIT_001("code-davinci-edit-001", EncodingType.P50K_EDIT, 3000), + object TEXT_DAVINCI_EDIT_001 : ModelType("text-davinci-edit-001", EncodingType.P50K_EDIT, 3000) + object CODE_DAVINCI_EDIT_001 : ModelType("code-davinci-edit-001", EncodingType.P50K_EDIT, 3000) // embeddings - TEXT_EMBEDDING_ADA_002("text-embedding-ada-002", CL100K_BASE, 8191), + object TEXT_EMBEDDING_ADA_002 : ModelType("text-embedding-ada-002", CL100K_BASE, 8191) // old embeddings - TEXT_SIMILARITY_DAVINCI_001("text-similarity-davinci-001", R50K_BASE, 2046), - TEXT_SIMILARITY_CURIE_001("text-similarity-curie-001", R50K_BASE, 2046), - TEXT_SIMILARITY_BABBAGE_001("text-similarity-babbage-001", R50K_BASE, 2046), - TEXT_SIMILARITY_ADA_001("text-similarity-ada-001", R50K_BASE, 2046), - TEXT_SEARCH_DAVINCI_DOC_001("text-search-davinci-doc-001", R50K_BASE, 2046), - TEXT_SEARCH_CURIE_DOC_001("text-search-curie-doc-001", R50K_BASE, 2046), - TEXT_SEARCH_BABBAGE_DOC_001("text-search-babbage-doc-001", R50K_BASE, 2046), - TEXT_SEARCH_ADA_DOC_001("text-search-ada-doc-001", R50K_BASE, 2046), - CODE_SEARCH_BABBAGE_CODE_001("code-search-babbage-code-001", R50K_BASE, 2046), - CODE_SEARCH_ADA_CODE_001("code-search-ada-code-001", R50K_BASE, 2046); + object TEXT_SIMILARITY_DAVINCI_001 : ModelType("text-similarity-davinci-001", R50K_BASE, 2046) + object TEXT_SIMILARITY_CURIE_001 : ModelType("text-similarity-curie-001", R50K_BASE, 2046) + object TEXT_SIMILARITY_BABBAGE_001 : ModelType("text-similarity-babbage-001", R50K_BASE, 2046) + object TEXT_SIMILARITY_ADA_001 : ModelType("text-similarity-ada-001", R50K_BASE, 2046) + object TEXT_SEARCH_DAVINCI_DOC_001 : ModelType("text-search-davinci-doc-001", R50K_BASE, 2046) + object TEXT_SEARCH_CURIE_DOC_001 : ModelType("text-search-curie-doc-001", R50K_BASE, 2046) + object TEXT_SEARCH_BABBAGE_DOC_001 : ModelType("text-search-babbage-doc-001", R50K_BASE, 2046) + object TEXT_SEARCH_ADA_DOC_001 : ModelType("text-search-ada-doc-001", R50K_BASE, 2046) + object CODE_SEARCH_BABBAGE_CODE_001 : ModelType("code-search-babbage-code-001", R50K_BASE, 2046) + object CODE_SEARCH_ADA_CODE_001 : ModelType("code-search-ada-code-001", R50K_BASE, 2046) inline val encoding: Encoding inline get() = encodingType.encoding - companion object { - fun fromName(name: String): ModelType? = - values().firstOrNull { it.name == name } - } }