Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow @Schema on Tool requests and read description from annotations when available #771

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.xebia.functional.xef.conversation

import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.SerialInfo

/** Schema for a tool request */
@OptIn(ExperimentalSerializationApi::class)
@SerialInfo
@Retention(AnnotationRetention.RUNTIME)
@Target(AnnotationTarget.CLASS)
annotation class Schema(val value: String)
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import com.xebia.functional.openai.generated.api.Chat
import com.xebia.functional.openai.generated.model.*
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.AIEvent
import com.xebia.functional.xef.Config
import com.xebia.functional.xef.Tool
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.conversation.Description
import com.xebia.functional.xef.conversation.Schema
import com.xebia.functional.xef.llm.models.functions.buildJsonSchema
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.PromptBuilder.Companion.tool
Expand All @@ -21,17 +24,31 @@ import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.json.*

@OptIn(ExperimentalSerializationApi::class)
fun chatFunction(descriptor: SerialDescriptor): FunctionObject {
val fnName = descriptor.serialName.substringAfterLast(".")
return chatFunction(fnName, buildJsonSchema(descriptor))
val functionName = functionName(descriptor)
return FunctionObject(
name = functionName,
description = functionDescription(descriptor, functionName),
parameters = functionSchema(descriptor)
)
}

fun chatFunctions(descriptors: List<SerialDescriptor>): List<FunctionObject> =
descriptors.map(::chatFunction)
@OptIn(ExperimentalSerializationApi::class)
fun functionSchema(descriptor: SerialDescriptor): JsonObject =
descriptor.annotations.filterIsInstance<Schema>().firstOrNull()?.value?.let {
Config.DEFAULT.json.decodeFromString(JsonObject.serializer(), it)
} ?: buildJsonSchema(descriptor)

fun chatFunction(fnName: String, schema: JsonObject): FunctionObject =
FunctionObject(fnName, "Generated function for $fnName", schema)
@OptIn(ExperimentalSerializationApi::class)
fun functionDescription(descriptor: SerialDescriptor, fnName: String): String =
(descriptor.annotations.filterIsInstance<Description>().firstOrNull()?.value
?: defaultFunctionDescription(fnName))

fun defaultFunctionDescription(fnName: String): String = "Generated function for $fnName"

@OptIn(ExperimentalSerializationApi::class)
fun functionName(descriptor: SerialDescriptor): String =
descriptor.serialName.substringAfterLast(".")

data class UsageTracker(
var llmCalls: Int = 0,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package com.xebia.functional.xef.functions

import com.xebia.functional.xef.conversation.Description
import com.xebia.functional.xef.conversation.Schema
import com.xebia.functional.xef.llm.chatFunction
import com.xebia.functional.xef.llm.defaultFunctionDescription
import com.xebia.functional.xef.llm.functionName
import com.xebia.functional.xef.llm.models.functions.buildJsonSchema
import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonObject

class FunctionSchemaTests :
StringSpec({
"Request has default description" {
val descriptor = Request.serializer().descriptor
val function = chatFunction(descriptor)
val fnName = functionName(descriptor)
function.description shouldBe defaultFunctionDescription(fnName)
}

"Description can be set on request" {
val descriptor = RequestWithDescription.serializer().descriptor
val function = chatFunction(descriptor)
function.description shouldBe "Request With Description"
}

"Schema can be generated on request" {
val descriptor = Request.serializer().descriptor
val function = chatFunction(descriptor)
function.parameters shouldBe buildJsonSchema(descriptor)
}

"Schema can be set on request" {
val descriptor = RequestWithSchema.serializer().descriptor
val function = chatFunction(descriptor)
function.parameters shouldBe JsonObject(emptyMap())
}
}) {

@Serializable data class Request(val input: String)

@Serializable
@Description("Request With Description")
data class RequestWithDescription(val input: String)

@Serializable
@Description("Request with schema")
@Schema("""
{
}
""")
data class RequestWithSchema(val input: String)
}
Loading