Skip to content

Commit

Permalink
Small suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
diesalbla committed Jul 3, 2023
1 parent 48dccec commit ef49ba7
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 87 deletions.
24 changes: 10 additions & 14 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -51,32 +51,28 @@ interface Chat : LLM {
return totalLeftTokens
}

fun buildChatRequest(): ChatCompletionRequest {
val messages: List<Message> = listOf(Message(Role.USER, promptWithContext, Role.USER.name))
return ChatCompletionRequest(
val userMessage = Message(Role.USER, promptWithContext, Role.USER.name)
fun buildChatRequest(): ChatCompletionRequest =
ChatCompletionRequest(
model = name,
user = promptConfiguration.user,
messages = messages,
messages = listOf(userMessage),
n = promptConfiguration.numberOfPredictions,
temperature = promptConfiguration.temperature,
maxTokens = checkTotalLeftChatTokens(messages)
maxTokens = checkTotalLeftChatTokens(listOf(userMessage))
)
}

fun chatWithFunctionsRequest(): ChatCompletionRequestWithFunctions {
val firstFnName: String? = functions.firstOrNull()?.name
val messages: List<Message> = listOf(Message(Role.USER, promptWithContext, Role.USER.name))
return ChatCompletionRequestWithFunctions(
fun chatWithFunctionsRequest(): ChatCompletionRequestWithFunctions =
ChatCompletionRequestWithFunctions(
model = name,
user = promptConfiguration.user,
messages = messages,
messages = listOf(userMessage),
n = promptConfiguration.numberOfPredictions,
temperature = promptConfiguration.temperature,
maxTokens = checkTotalLeftChatTokens(messages),
maxTokens = checkTotalLeftChatTokens(listOf(userMessage)),
functions = functions,
functionCall = mapOf("name" to (firstFnName ?: ""))
functionCall = mapOf("name" to (functions.firstOrNull()?.name ?: ""))
)
}

return when (this) {
is ChatWithFunctions ->
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.xebia.functional.gpt4all

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer
import com.xebia.functional.xef.embeddings.Embedding as XefEmbedding
import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
Expand All @@ -15,7 +16,7 @@ class HuggingFaceLocalEmbeddings(name: String, artifact: String) : com.xebia.fun
override val name: String = HuggingFaceLocalEmbeddings::class.java.canonicalName

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult {
val embedings = tokenizer.batchEncode(request.input)
val embeddings = tokenizer.batchEncode(request.input)
return EmbeddingResult(
data = embedings.mapIndexed { n, em -> Embedding("embedding", em.ids.map { it.toFloat() }, n) },
usage = Usage.ZERO
Expand All @@ -26,19 +27,12 @@ class HuggingFaceLocalEmbeddings(name: String, artifact: String) : com.xebia.fun
texts: List<String>,
chunkSize: Int?,
requestConfig: RequestConfig
): List<com.xebia.functional.xef.embeddings.Embedding> {
val encodings = tokenizer.batchEncode(texts)
return encodings.mapIndexed { n, em ->
com.xebia.functional.xef.embeddings.Embedding(
em.ids.map { it.toFloat() },
)
): List<XefEmbedding> =
tokenizer.batchEncode(texts).mapIndexed { n, em ->
XefEmbedding(em.ids.map { it.toFloat() })
}
}

override suspend fun embedQuery(
text: String,
requestConfig: RequestConfig
): List<com.xebia.functional.xef.embeddings.Embedding> =
override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<XefEmbedding> =
embedDocuments(listOf(text), null, requestConfig)

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,14 @@ suspend fun <A> CoreAIScope.prompt(
isLenient = true
},
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
): A {
val functions = generateCFunction(serializer.descriptor)
return model.prompt(
): A =
model.prompt(
prompt,
context,
functions,
generateCFunction(serializer.descriptor),
{ json.decodeFromString(serializer, it) },
promptConfiguration
)
}

@OptIn(ExperimentalSerializationApi::class)
private fun generateCFunction(descriptor: SerialDescriptor): List<CFunction> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import com.aallam.openai.api.completion.CompletionRequest as OpenAICompletionReq
import com.aallam.openai.api.completion.TextCompletion
import com.aallam.openai.api.completion.completionRequest
import com.aallam.openai.api.core.Usage as OpenAIUsage
import com.aallam.openai.api.embedding.EmbeddingRequest as OpenAIEmbeddingRequest
import com.aallam.openai.api.embedding.EmbeddingResponse
import com.aallam.openai.api.embedding.embeddingRequest
import com.aallam.openai.api.image.ImageCreation
Expand Down Expand Up @@ -52,27 +51,60 @@ class OpenAIModel(
request: ChatCompletionRequest
): ChatCompletionResponse {
val response = client.chatCompletion(toChatCompletionRequest(request))
return chatCompletionResult(response)
return ChatCompletionResponse(
id = response.id,
`object` = response.model.id,
created = response.created,
model = response.model.id,
choices = response.choices.map { chatCompletionChoice(it) },
usage = usage(response.usage)
)
}

@OptIn(BetaOpenAI::class)
override suspend fun createChatCompletionWithFunctions(
request: ChatCompletionRequestWithFunctions
): ChatCompletionResponseWithFunctions {
val response = client.chatCompletion(toChatCompletionRequestWithFunctions(request))
return chatCompletionResultWithFunctions(response)

fun chatCompletionChoiceWithFunctions(choice: ChatChoice): ChoiceWithFunctions =
ChoiceWithFunctions(
message =
choice.message?.let {
MessageWithFunctionCall(
role = it.role.role,
content = it.content,
name = it.name,
functionCall = it.functionCall?.let { FnCall(it.name, it.arguments) }
)
},
finishReason = choice.finishReason,
index = choice.index,
)

return ChatCompletionResponseWithFunctions(
id = response.id,
`object` = response.model.id,
created = response.created,
model = response.model.id,
choices = response.choices.map { chatCompletionChoiceWithFunctions(it) },
usage = usage(response.usage)
)
}

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult {
val response = client.embeddings(toEmbeddingRequest(request))
return embeddingResult(response)
val openAIRequest = embeddingRequest {
model = ModelId(request.model)
input = request.input
user = request.user
}

return embeddingResult(client.embeddings(openAIRequest))
}

@OptIn(BetaOpenAI::class)
override suspend fun createImages(request: ImagesGenerationRequest): ImagesGenerationResponse {
val response = client.imageURL(toImageCreationRequest(request))
return imageResult(response)
}
override suspend fun createImages(request: ImagesGenerationRequest): ImagesGenerationResponse =
imageResult(client.imageURL(toImageCreationRequest(request)))

private fun toCompletionRequest(request: CompletionRequest): OpenAICompletionRequest =
completionRequest {
Expand Down Expand Up @@ -118,46 +150,6 @@ class OpenAIModel(
totalTokens = usage?.totalTokens,
)

@OptIn(BetaOpenAI::class)
private fun chatCompletionResult(response: ChatCompletion): ChatCompletionResponse =
ChatCompletionResponse(
id = response.id,
`object` = response.model.id,
created = response.created,
model = response.model.id,
choices = response.choices.map { chatCompletionChoice(it) },
usage = usage(response.usage)
)

@OptIn(BetaOpenAI::class)
private fun chatCompletionResultWithFunctions(
response: ChatCompletion
): ChatCompletionResponseWithFunctions =
ChatCompletionResponseWithFunctions(
id = response.id,
`object` = response.model.id,
created = response.created,
model = response.model.id,
choices = response.choices.map { chatCompletionChoiceWithFunctions(it) },
usage = usage(response.usage)
)

@OptIn(BetaOpenAI::class)
private fun chatCompletionChoiceWithFunctions(choice: ChatChoice): ChoiceWithFunctions =
ChoiceWithFunctions(
message =
choice.message?.let {
MessageWithFunctionCall(
role = it.role.role,
content = it.content,
name = it.name,
functionCall = it.functionCall?.let { FnCall(it.name, it.arguments) }
)
},
finishReason = choice.finishReason,
index = choice.index,
)

@OptIn(BetaOpenAI::class)
private fun chatCompletionChoice(choice: ChatChoice): Choice =
Choice(
Expand Down Expand Up @@ -258,13 +250,6 @@ class OpenAIModel(
usage = usage(response.usage)
)

private fun toEmbeddingRequest(request: EmbeddingRequest): OpenAIEmbeddingRequest =
embeddingRequest {
model = ModelId(request.model)
input = request.input
user = request.user
}

@OptIn(BetaOpenAI::class)
private fun imageResult(response: List<ImageURL>): ImagesGenerationResponse =
ImagesGenerationResponse(data = response.map { ImageGenerationUrl(it.url) })
Expand Down

0 comments on commit ef49ba7

Please sign in to comment.