diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt index 7f59a2569..4ca848e44 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt @@ -99,15 +99,14 @@ class CoreAIScope( functions: List, serializer: (json: String) -> A, promptConfiguration: PromptConfiguration, - ): A { - return prompt( + ): A = + prompt( prompt = Prompt(prompt), context = context, functions = functions, serializer = serializer, promptConfiguration = promptConfiguration, ) - } @AiDsl suspend fun Chat.promptMessage( diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt index 8e628e590..2de5b177b 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/ChatWithFunctions.kt @@ -24,8 +24,8 @@ interface ChatWithFunctions : Chat { functions: List, serializer: (json: String) -> A, promptConfiguration: PromptConfiguration, - ): A { - return tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) { + ): A = + tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) { promptMessage( prompt = Prompt(prompt), context = context, @@ -33,7 +33,6 @@ interface ChatWithFunctions : Chat { promptConfiguration ) } - } @AiDsl suspend fun prompt( @@ -42,11 +41,10 @@ interface ChatWithFunctions : Chat { functions: List, serializer: (json: String) -> A, promptConfiguration: PromptConfiguration, - ): A { - return tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) { + ): A = + tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) { promptMessage(prompt = prompt, context = context, functions = functions, promptConfiguration) } - } private suspend fun tryDeserialize( serializer: (json: String) -> A, diff --git a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt index 05c0c32e3..17a1273ed 100644 --- a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt +++ b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt @@ -46,7 +46,7 @@ interface GPT4All : AutoCloseable, Chat, Completion { override suspend fun createCompletion(request: CompletionRequest): CompletionResult = with(request) { - val response: String = generateCompletion(prompt, generationConfig) + val response: String = gpt4allModel.prompt(prompt, llmModelContext(generationConfig)) return CompletionResult( UUID.randomUUID().toString(), path.name, @@ -59,8 +59,8 @@ interface GPT4All : AutoCloseable, Chat, Completion { override suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse = with(request) { - val prompt: String = messages.buildPrompt() - val response: String = generateCompletion(prompt, generationConfig) + val response: String = + gpt4allModel.prompt(messages.buildPrompt(), llmModelContext(generationConfig)) return ChatCompletionResponse( UUID.randomUUID().toString(), path.name, @@ -71,9 +71,7 @@ interface GPT4All : AutoCloseable, Chat, Completion { ) } - override fun tokensFromMessages(messages: List): Int { - return 0 - } + override fun tokensFromMessages(messages: List): Int = 0 override val name: String = path.name @@ -92,31 +90,25 @@ interface GPT4All : AutoCloseable, Chat, Completion { return "$messages\n### Response:" } - private fun generateCompletion( - prompt: String, - generationConfig: GenerationConfig - ): String { - val context = LLModelContext( - logits_size = LibCAPI.size_t(generationConfig.logitsSize.toLong()), - tokens_size = LibCAPI.size_t(generationConfig.tokensSize.toLong()), - n_past = generationConfig.nPast, - n_ctx = generationConfig.nCtx, - n_predict = generationConfig.nPredict, - top_k = generationConfig.topK, - top_p = generationConfig.topP.toFloat(), - temp = generationConfig.temp.toFloat(), - n_batch = generationConfig.nBatch, - repeat_penalty = generationConfig.repeatPenalty.toFloat(), - repeat_last_n = generationConfig.repeatLastN, - context_erase = generationConfig.contextErase.toFloat() - ) - - return gpt4allModel.prompt(prompt, context) - } + private fun llmModelContext(generationConfig: GenerationConfig): LLModelContext = + with(generationConfig) { + LLModelContext( + logits_size = LibCAPI.size_t(logitsSize.toLong()), + tokens_size = LibCAPI.size_t(tokensSize.toLong()), + n_past = nPast, + n_ctx = nCtx, + n_predict = nPredict, + top_k = topK, + top_p = topP.toFloat(), + temp = temp.toFloat(), + n_batch = nBatch, + repeat_penalty = repeatPenalty.toFloat(), + repeat_last_n = repeatLastN, + context_erase = contextErase.toFloat() + ) + } } - } } - diff --git a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceLocalEmbeddings.kt b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceLocalEmbeddings.kt index ac4d1b0d3..f2b31980f 100644 --- a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceLocalEmbeddings.kt +++ b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceLocalEmbeddings.kt @@ -16,11 +16,10 @@ class HuggingFaceLocalEmbeddings(name: String, artifact: String) : com.xebia.fun override val name: String = HuggingFaceLocalEmbeddings::class.java.canonicalName override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult { - val embeddings = tokenizer.batchEncode(request.input) - return EmbeddingResult( - data = embedings.mapIndexed { n, em -> Embedding("embedding", em.ids.map { it.toFloat() }, n) }, - usage = Usage.ZERO - ) + val embeddings = tokenizer.batchEncode(request.input).mapIndexed { ix, em -> + Embedding("embedding", em.ids.map { it.toFloat() }, ix) + } + return EmbeddingResult(embeddings, Usage.ZERO) } override suspend fun embedDocuments( @@ -28,9 +27,7 @@ class HuggingFaceLocalEmbeddings(name: String, artifact: String) : com.xebia.fun chunkSize: Int?, requestConfig: RequestConfig ): List = - tokenizer.batchEncode(texts).mapIndexed { n, em -> - XefEmbedding(em.ids.map { it.toFloat() }) - } + tokenizer.batchEncode(texts).map { em -> XefEmbedding(em.ids.map { it.toFloat() }) } override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List = embedDocuments(listOf(text), null, requestConfig)