Skip to content

Commit

Permalink
Refactor CFunction to work with String schemas instead of KotlinX (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
nomisRev authored Jun 27, 2023
1 parent d591d6b commit 837c13a
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 71 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,55 @@
package com.xebia.functional.xef.llm.openai.functions

import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.JsonEncoder
import kotlinx.serialization.json.JsonUnquotedLiteral

/*
"functions": [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
]
*/
@Serializable
data class CFunction(val name: String, val description: String, val parameters: JsonObject)
data class CFunction(val name: String, val description: String, val parameters: RawJsonString)

typealias RawJsonString = @Serializable(with = RawJsonStringSerializer::class) String

@OptIn(ExperimentalSerializationApi::class)
private object RawJsonStringSerializer : KSerializer<String> {
override val descriptor =
PrimitiveSerialDescriptor(
"com.xebia.functional.xef.llm.openai.functions.RawJsonString",
PrimitiveKind.STRING
)

override fun deserialize(decoder: Decoder): String = decoder.decodeString()

override fun serialize(encoder: Encoder, value: String) =
when (encoder) {
is JsonEncoder -> encoder.encodeJsonElement(JsonUnquotedLiteral(value))
else -> encoder.encodeString(value)
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package com.xebia.functional.xef.auto

import com.xebia.functional.xef.auto.serialization.functions.encodeFunctionSchema
import com.xebia.functional.xef.auto.serialization.encodeJsonSchema
import com.xebia.functional.xef.llm.openai.LLMModel
import com.xebia.functional.xef.llm.openai.functions.CFunction
import com.xebia.functional.xef.prompt.Prompt
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerializationException
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.json.Json
import kotlinx.serialization.serializer

Expand Down Expand Up @@ -96,7 +99,7 @@ suspend fun <A> AIScope.prompt(
bringFromContext: Int = this.docsInContext,
minResponseTokens: Int = this.minResponseTokens,
): A {
val functions = encodeFunctionSchema(serializer.descriptor)
val functions = generateCFunction(serializer.descriptor)
return prompt(
prompt,
functions,
Expand All @@ -111,3 +114,9 @@ suspend fun <A> AIScope.prompt(
minResponseTokens
)
}

@OptIn(ExperimentalSerializationApi::class)
private fun generateCFunction(descriptor: SerialDescriptor): List<CFunction> {
val fnName = descriptor.serialName.substringAfterLast(".")
return listOf(CFunction(fnName, "Generated function for $fnName", encodeJsonSchema(descriptor)))
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ fun encodeJsonSchema(descriptor: SerialDescriptor): String =
Json.encodeToString(JsonObject.serializer(), buildJsonSchema(descriptor))

/** Creates a Json Schema using the provided [descriptor] */
internal fun buildJsonSchema(descriptor: SerialDescriptor): JsonObject {
private fun buildJsonSchema(descriptor: SerialDescriptor): JsonObject {
val autoDefinitions = false
val prepend = mapOf("\$schema" to JsonPrimitive("http://json-schema.org/draft-07/schema"))
val definitions = JsonSchemaDefinitions(autoDefinitions)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import io.circe.Decoder
import io.circe.parser.parse
import com.xebia.functional.xef.auto.AIKt
import com.xebia.functional.xef.auto.AIRuntime
import com.xebia.functional.xef.auto.serialization.functions.FunctionSchemaKt
import com.xebia.functional.xef.auto.serialization.JsonSchemaKt
import com.xebia.functional.xef.pdf.PDFLoaderKt
import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.llm.openai._
Expand Down Expand Up @@ -54,7 +54,7 @@ def prompt[A: Decoder: SerialDescriptor](
LoomAdapter.apply((cont) =>
scope.kt.promptWithSerializer[A](
prompt,
FunctionSchemaKt.encodeFunctionSchema(SerialDescriptor[A].serialDescriptor),
generateCFunctions.asJava,
(json: String) => parse(json).flatMap(Decoder[A].decodeJson(_)).fold(throw _, identity),
maxAttempts,
llmModel,
Expand All @@ -68,6 +68,12 @@ def prompt[A: Decoder: SerialDescriptor](
)
)

private def generateCFunctions[A: SerialDescriptor]: List[CFunction] =
val descriptor = SerialDescriptor[A].serialDescriptor
val serialName = descriptor.getSerialName
val fnName = serialName.substring(serialName.lastIndexOf("."), serialName.length)
List(CFunction(fnName, "Generated function for $fnName", JsonSchemaKt.encodeJsonSchema(descriptor)))

def contextScope[A: Decoder: SerialDescriptor](docs: List[String])(block: AI[A])(using scope: AIScope): A =
LoomAdapter.apply(scope.kt.contextScopeWithDocs[A](docs.asJava, (_, _) => block, _))

Expand Down

0 comments on commit 837c13a

Please sign in to comment.