Skip to content

Commit

Permalink
More small suggestions
Browse files Browse the repository at this point in the history
- Replace blocks with equals / single expressions
  • Loading branch information
diesalbla committed Jul 3, 2023
1 parent ef49ba7 commit b9b6f05
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,14 @@ class CoreAIScope(
functions: List<CFunction>,
serializer: (json: String) -> A,
promptConfiguration: PromptConfiguration,
): A {
return prompt(
): A =
prompt(
prompt = Prompt(prompt),
context = context,
functions = functions,
serializer = serializer,
promptConfiguration = promptConfiguration,
)
}

@AiDsl
suspend fun Chat.promptMessage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,15 @@ interface ChatWithFunctions : Chat {
functions: List<CFunction>,
serializer: (json: String) -> A,
promptConfiguration: PromptConfiguration,
): A {
return tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) {
): A =
tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) {
promptMessage(
prompt = Prompt(prompt),
context = context,
functions = functions,
promptConfiguration
)
}
}

@AiDsl
suspend fun <A> prompt(
Expand All @@ -42,11 +41,10 @@ interface ChatWithFunctions : Chat {
functions: List<CFunction>,
serializer: (json: String) -> A,
promptConfiguration: PromptConfiguration,
): A {
return tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) {
): A =
tryDeserialize(serializer, promptConfiguration.maxDeserializationAttempts) {
promptMessage(prompt = prompt, context = context, functions = functions, promptConfiguration)
}
}

private suspend fun <A> tryDeserialize(
serializer: (json: String) -> A,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ interface GPT4All : AutoCloseable, Chat, Completion {

override suspend fun createCompletion(request: CompletionRequest): CompletionResult =
with(request) {
val response: String = generateCompletion(prompt, generationConfig)
val response: String = gpt4allModel.prompt(prompt, llmModelContext(generationConfig))
return CompletionResult(
UUID.randomUUID().toString(),
path.name,
Expand All @@ -59,8 +59,8 @@ interface GPT4All : AutoCloseable, Chat, Completion {

override suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse =
with(request) {
val prompt: String = messages.buildPrompt()
val response: String = generateCompletion(prompt, generationConfig)
val response: String =
gpt4allModel.prompt(messages.buildPrompt(), llmModelContext(generationConfig))
return ChatCompletionResponse(
UUID.randomUUID().toString(),
path.name,
Expand All @@ -71,9 +71,7 @@ interface GPT4All : AutoCloseable, Chat, Completion {
)
}

override fun tokensFromMessages(messages: List<Message>): Int {
return 0
}
override fun tokensFromMessages(messages: List<Message>): Int = 0

override val name: String = path.name

Expand All @@ -92,31 +90,25 @@ interface GPT4All : AutoCloseable, Chat, Completion {
return "$messages\n### Response:"
}

private fun generateCompletion(
prompt: String,
generationConfig: GenerationConfig
): String {
val context = LLModelContext(
logits_size = LibCAPI.size_t(generationConfig.logitsSize.toLong()),
tokens_size = LibCAPI.size_t(generationConfig.tokensSize.toLong()),
n_past = generationConfig.nPast,
n_ctx = generationConfig.nCtx,
n_predict = generationConfig.nPredict,
top_k = generationConfig.topK,
top_p = generationConfig.topP.toFloat(),
temp = generationConfig.temp.toFloat(),
n_batch = generationConfig.nBatch,
repeat_penalty = generationConfig.repeatPenalty.toFloat(),
repeat_last_n = generationConfig.repeatLastN,
context_erase = generationConfig.contextErase.toFloat()
)

return gpt4allModel.prompt(prompt, context)
}
private fun llmModelContext(generationConfig: GenerationConfig): LLModelContext =
with(generationConfig) {
LLModelContext(
logits_size = LibCAPI.size_t(logitsSize.toLong()),
tokens_size = LibCAPI.size_t(tokensSize.toLong()),
n_past = nPast,
n_ctx = nCtx,
n_predict = nPredict,
top_k = topK,
top_p = topP.toFloat(),
temp = temp.toFloat(),
n_batch = nBatch,
repeat_penalty = repeatPenalty.toFloat(),
repeat_last_n = repeatLastN,
context_erase = contextErase.toFloat()
)
}
}


}
}


Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,18 @@ class HuggingFaceLocalEmbeddings(name: String, artifact: String) : com.xebia.fun
override val name: String = HuggingFaceLocalEmbeddings::class.java.canonicalName

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult {
val embeddings = tokenizer.batchEncode(request.input)
return EmbeddingResult(
data = embedings.mapIndexed { n, em -> Embedding("embedding", em.ids.map { it.toFloat() }, n) },
usage = Usage.ZERO
)
val embeddings = tokenizer.batchEncode(request.input).mapIndexed { ix, em ->
Embedding("embedding", em.ids.map { it.toFloat() }, ix)
}
return EmbeddingResult(embeddings, Usage.ZERO)
}

override suspend fun embedDocuments(
texts: List<String>,
chunkSize: Int?,
requestConfig: RequestConfig
): List<XefEmbedding> =
tokenizer.batchEncode(texts).mapIndexed { n, em ->
XefEmbedding(em.ids.map { it.toFloat() })
}
tokenizer.batchEncode(texts).map { em -> XefEmbedding(em.ids.map { it.toFloat() }) }

override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<XefEmbedding> =
embedDocuments(listOf(text), null, requestConfig)
Expand Down

0 comments on commit b9b6f05

Please sign in to comment.