Skip to content

Commit

Permalink
Small Suggestions - Rewrites.
Browse files Browse the repository at this point in the history
  • Loading branch information
diesalbla committed Jul 4, 2023
1 parent 8796424 commit 6c967c0
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 222 deletions.
43 changes: 20 additions & 23 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ interface Chat : LLM {
minResponseTokens = promptConfiguration.minResponseTokens
)

fun checkTotalLeftChatTokens(messages: List<Message>): Int {
val messages: List<Message> = listOf(Message(Role.USER, promptWithContext, Role.USER.name))

fun checkTotalLeftChatTokens(): Int {
val maxContextLength: Int = modelType.maxContextLength
val messagesTokens: Int = tokensFromMessages(messages)
val totalLeftTokens: Int = maxContextLength - messagesTokens
Expand All @@ -55,21 +57,19 @@ interface Chat : LLM {
return totalLeftTokens
}

fun buildChatRequest(): ChatCompletionRequest {
val messages: List<Message> = listOf(Message(Role.USER, promptWithContext, Role.USER.name))
return ChatCompletionRequest(
val request: ChatCompletionRequest =
ChatCompletionRequest(
model = name,
user = promptConfiguration.user,
messages = messages,
n = promptConfiguration.numberOfPredictions,
temperature = promptConfiguration.temperature,
maxTokens = checkTotalLeftChatTokens(messages),
maxTokens = checkTotalLeftChatTokens(),
streamToStandardOut = true
)
}

return flow {
createChatCompletions(buildChatRequest()).collect {
createChatCompletions(request).collect {
emit(it.choices.mapNotNull { it.delta?.content }.joinToString(""))
}
}
Expand Down Expand Up @@ -99,7 +99,9 @@ interface Chat : LLM {
minResponseTokens = promptConfiguration.minResponseTokens
)

fun checkTotalLeftChatTokens(messages: List<Message>): Int {
val messages: List<Message> = listOf(Message(Role.USER, promptWithContext, Role.USER.name))

fun checkTotalLeftChatTokens(): Int {
val maxContextLength: Int = modelType.maxContextLength
val messagesTokens: Int = tokensFromMessages(messages)
val totalLeftTokens: Int = maxContextLength - messagesTokens
Expand All @@ -109,45 +111,40 @@ interface Chat : LLM {
return totalLeftTokens
}

fun buildChatRequest(): ChatCompletionRequest {
val messages: List<Message> = listOf(Message(Role.USER, promptWithContext, Role.USER.name))
return ChatCompletionRequest(
fun chatRequest(): ChatCompletionRequest =
ChatCompletionRequest(
model = name,
user = promptConfiguration.user,
messages = messages,
n = promptConfiguration.numberOfPredictions,
temperature = promptConfiguration.temperature,
maxTokens = checkTotalLeftChatTokens(messages),
maxTokens = checkTotalLeftChatTokens(),
streamToStandardOut = promptConfiguration.streamToStandardOut
)
}

fun chatWithFunctionsRequest(): ChatCompletionRequestWithFunctions {
val firstFnName: String? = functions.firstOrNull()?.name
val messages: List<Message> = listOf(Message(Role.USER, promptWithContext, Role.USER.name))
return ChatCompletionRequestWithFunctions(
fun withFunctionsRequest(): ChatCompletionRequestWithFunctions =
ChatCompletionRequestWithFunctions(
model = name,
user = promptConfiguration.user,
messages = messages,
n = promptConfiguration.numberOfPredictions,
temperature = promptConfiguration.temperature,
maxTokens = checkTotalLeftChatTokens(messages),
maxTokens = checkTotalLeftChatTokens(),
functions = functions,
functionCall = mapOf("name" to (firstFnName ?: ""))
functionCall = mapOf("name" to (functions.firstOrNull()?.name ?: ""))
)
}

return when (this) {
is ChatWithFunctions ->
// we only support functions for now with GPT_3_5_TURBO_FUNCTIONS
if (modelType == ModelType.GPT_3_5_TURBO_FUNCTIONS) {
createChatCompletionWithFunctions(chatWithFunctionsRequest()).choices.mapNotNull {
createChatCompletionWithFunctions(withFunctionsRequest()).choices.mapNotNull {
it.message?.functionCall?.arguments
}
} else {
createChatCompletion(buildChatRequest()).choices.mapNotNull { it.message?.content }
createChatCompletion(chatRequest()).choices.mapNotNull { it.message?.content }
}
else -> createChatCompletion(buildChatRequest()).choices.mapNotNull { it.message?.content }
else -> createChatCompletion(chatRequest()).choices.mapNotNull { it.message?.content }
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,24 +111,24 @@ interface GPT4All : AutoCloseable, Chat, Completion {
val channel = Channel<String>(capacity = UNLIMITED)

val outputStream = object : OutputStream() {
override fun write(b: Int) {
override fun write(b: Int) {
val c = b.toChar()
channel.trySend(c.toString())
}
}

val printStream = PrintStream(outputStream, true, StandardCharsets.UTF_8)

val flow = channel.consumeAsFlow()
.map { text ->
ChatCompletionChunk(
UUID.randomUUID().toString(),
System.currentTimeMillis().toInt(),
path.name,
listOf(ChatChunk(delta = ChatDelta(Role.ASSISTANT, text))),
Usage.ZERO,
)
}
fun toChunk(text: String?): ChatCompletionChunk =
ChatCompletionChunk(
UUID.randomUUID().toString(),
System.currentTimeMillis().toInt(),
path.name,
listOf(ChatChunk(delta = ChatDelta(Role.ASSISTANT, text))),
Usage.ZERO,
)

val flow = channel.consumeAsFlow().map { toChunk(it) }

launch(Dispatchers.IO) {
System.setOut(printStream) // Set the standard output to the print stream
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.xebia.functional.gpt4all

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer
import com.xebia.functional.xef.embeddings.Embedding as XefEmbedding
import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
Expand All @@ -26,19 +27,10 @@ class HuggingFaceLocalEmbeddings(name: String, artifact: String) : com.xebia.fun
texts: List<String>,
chunkSize: Int?,
requestConfig: RequestConfig
): List<com.xebia.functional.xef.embeddings.Embedding> {
val encodings = tokenizer.batchEncode(texts)
return encodings.mapIndexed { n, em ->
com.xebia.functional.xef.embeddings.Embedding(
em.ids.map { it.toFloat() },
)
}
}
): List<XefEmbedding> =
tokenizer.batchEncode(texts).map { em -> XefEmbedding(em.ids.map { it.toFloat() }) }

override suspend fun embedQuery(
text: String,
requestConfig: RequestConfig
): List<com.xebia.functional.xef.embeddings.Embedding> =
override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<XefEmbedding> =
embedDocuments(listOf(text), null, requestConfig)

companion object {
Expand Down
Loading

0 comments on commit 6c967c0

Please sign in to comment.