From 837c13ab9ad2de938f4161b1c37b8ca48357ebb6 Mon Sep 17 00:00:00 2001 From: Simon Vergauwen Date: Tue, 27 Jun 2023 11:04:36 +0200 Subject: [PATCH] Refactor CFunction to work with String schemas instead of KotlinX (#207) --- .../xef/llm/openai/functions/CFunction.kt | 52 ++++++++++++++- .../xef/auto/DeserializerLLMAgent.kt | 13 +++- .../xef/auto/serialization/JsonSchema.kt | 2 +- .../serialization/functions/FunctionSchema.kt | 64 ------------------- .../functional/xef/scala/auto/package.scala | 10 ++- 5 files changed, 70 insertions(+), 71 deletions(-) delete mode 100644 kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/serialization/functions/FunctionSchema.kt diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/openai/functions/CFunction.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/openai/functions/CFunction.kt index 3796088ba..beff263a5 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/openai/functions/CFunction.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/openai/functions/CFunction.kt @@ -1,7 +1,55 @@ package com.xebia.functional.xef.llm.openai.functions +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.KSerializer import kotlinx.serialization.Serializable -import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.descriptors.PrimitiveKind +import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder +import kotlinx.serialization.json.JsonEncoder +import kotlinx.serialization.json.JsonUnquotedLiteral +/* +"functions": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + ] + */ @Serializable -data class CFunction(val name: String, val description: String, val parameters: JsonObject) +data class CFunction(val name: String, val description: String, val parameters: RawJsonString) + +typealias RawJsonString = @Serializable(with = RawJsonStringSerializer::class) String + +@OptIn(ExperimentalSerializationApi::class) +private object RawJsonStringSerializer : KSerializer { + override val descriptor = + PrimitiveSerialDescriptor( + "com.xebia.functional.xef.llm.openai.functions.RawJsonString", + PrimitiveKind.STRING + ) + + override fun deserialize(decoder: Decoder): String = decoder.decodeString() + + override fun serialize(encoder: Encoder, value: String) = + when (encoder) { + is JsonEncoder -> encoder.encodeJsonElement(JsonUnquotedLiteral(value)) + else -> encoder.encodeString(value) + } +} diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt b/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt index 83793f0da..908482663 100644 --- a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt +++ b/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/DeserializerLLMAgent.kt @@ -1,10 +1,13 @@ package com.xebia.functional.xef.auto -import com.xebia.functional.xef.auto.serialization.functions.encodeFunctionSchema +import com.xebia.functional.xef.auto.serialization.encodeJsonSchema import com.xebia.functional.xef.llm.openai.LLMModel +import com.xebia.functional.xef.llm.openai.functions.CFunction import com.xebia.functional.xef.prompt.Prompt +import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.KSerializer import kotlinx.serialization.SerializationException +import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.json.Json import kotlinx.serialization.serializer @@ -96,7 +99,7 @@ suspend fun AIScope.prompt( bringFromContext: Int = this.docsInContext, minResponseTokens: Int = this.minResponseTokens, ): A { - val functions = encodeFunctionSchema(serializer.descriptor) + val functions = generateCFunction(serializer.descriptor) return prompt( prompt, functions, @@ -111,3 +114,9 @@ suspend fun AIScope.prompt( minResponseTokens ) } + +@OptIn(ExperimentalSerializationApi::class) +private fun generateCFunction(descriptor: SerialDescriptor): List { + val fnName = descriptor.serialName.substringAfterLast(".") + return listOf(CFunction(fnName, "Generated function for $fnName", encodeJsonSchema(descriptor))) +} diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/serialization/JsonSchema.kt b/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/serialization/JsonSchema.kt index 285923c7e..c7ddf8e5b 100644 --- a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/serialization/JsonSchema.kt +++ b/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/serialization/JsonSchema.kt @@ -117,7 +117,7 @@ fun encodeJsonSchema(descriptor: SerialDescriptor): String = Json.encodeToString(JsonObject.serializer(), buildJsonSchema(descriptor)) /** Creates a Json Schema using the provided [descriptor] */ -internal fun buildJsonSchema(descriptor: SerialDescriptor): JsonObject { +private fun buildJsonSchema(descriptor: SerialDescriptor): JsonObject { val autoDefinitions = false val prepend = mapOf("\$schema" to JsonPrimitive("http://json-schema.org/draft-07/schema")) val definitions = JsonSchemaDefinitions(autoDefinitions) diff --git a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/serialization/functions/FunctionSchema.kt b/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/serialization/functions/FunctionSchema.kt deleted file mode 100644 index 545f11c29..000000000 --- a/kotlin/src/commonMain/kotlin/com/xebia/functional/xef/auto/serialization/functions/FunctionSchema.kt +++ /dev/null @@ -1,64 +0,0 @@ -package com.xebia.functional.xef.auto.serialization.functions - -import com.xebia.functional.xef.auto.serialization.buildJsonSchema -import com.xebia.functional.xef.llm.openai.functions.CFunction -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.descriptors.* - -/* -"functions": [ - { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] - } - }, - "required": ["location"] - } - } - ] - */ -fun encodeFunctionSchema(serialDescriptor: SerialDescriptor): List { - return listOf(generateCFunction(serialDescriptor)) -} - -private fun generateCFunction(descriptor: SerialDescriptor): CFunction { - val parameters = buildJsonSchema(descriptor) - val fnName = functionName(descriptor) - return CFunction(fnName, "Generated function for $fnName", parameters) -} - -@OptIn(ExperimentalSerializationApi::class) -internal fun functionName(descriptor: SerialDescriptor): String = - descriptor.serialName.substringAfterLast(".") - -@OptIn(ExperimentalSerializationApi::class) -private fun typeName(it: SerialDescriptor): String = - when (it.kind) { - PolymorphicKind.OPEN -> "object" - PolymorphicKind.SEALED -> "object" - PrimitiveKind.BOOLEAN -> "boolean" - PrimitiveKind.BYTE -> "number" - PrimitiveKind.CHAR -> "character" - PrimitiveKind.DOUBLE -> "double" - PrimitiveKind.FLOAT -> "float" - PrimitiveKind.INT -> "number" - PrimitiveKind.LONG -> "number" - PrimitiveKind.SHORT -> "number" - PrimitiveKind.STRING -> "string" - SerialKind.CONTEXTUAL -> "object" - SerialKind.ENUM -> "enum" - StructureKind.CLASS -> "object" - StructureKind.LIST -> "array" - StructureKind.MAP -> "object" - StructureKind.OBJECT -> "object" - } diff --git a/scala/src/main/scala/com/xebia/functional/xef/scala/auto/package.scala b/scala/src/main/scala/com/xebia/functional/xef/scala/auto/package.scala index bac2315dc..7f0c42299 100644 --- a/scala/src/main/scala/com/xebia/functional/xef/scala/auto/package.scala +++ b/scala/src/main/scala/com/xebia/functional/xef/scala/auto/package.scala @@ -8,7 +8,7 @@ import io.circe.Decoder import io.circe.parser.parse import com.xebia.functional.xef.auto.AIKt import com.xebia.functional.xef.auto.AIRuntime -import com.xebia.functional.xef.auto.serialization.functions.FunctionSchemaKt +import com.xebia.functional.xef.auto.serialization.JsonSchemaKt import com.xebia.functional.xef.pdf.PDFLoaderKt import com.xebia.functional.tokenizer.ModelType import com.xebia.functional.xef.llm.openai._ @@ -54,7 +54,7 @@ def prompt[A: Decoder: SerialDescriptor]( LoomAdapter.apply((cont) => scope.kt.promptWithSerializer[A]( prompt, - FunctionSchemaKt.encodeFunctionSchema(SerialDescriptor[A].serialDescriptor), + generateCFunctions.asJava, (json: String) => parse(json).flatMap(Decoder[A].decodeJson(_)).fold(throw _, identity), maxAttempts, llmModel, @@ -68,6 +68,12 @@ def prompt[A: Decoder: SerialDescriptor]( ) ) +private def generateCFunctions[A: SerialDescriptor]: List[CFunction] = + val descriptor = SerialDescriptor[A].serialDescriptor + val serialName = descriptor.getSerialName + val fnName = serialName.substring(serialName.lastIndexOf("."), serialName.length) + List(CFunction(fnName, "Generated function for $fnName", JsonSchemaKt.encodeJsonSchema(descriptor))) + def contextScope[A: Decoder: SerialDescriptor](docs: List[String])(block: AI[A])(using scope: AIScope): A = LoomAdapter.apply(scope.kt.contextScopeWithDocs[A](docs.asJava, (_, _) => block, _))