Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenApi and Swagger to Xef Server routes #735

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ package com.xebia.functional.xef.llm

import com.xebia.functional.openai.generated.model.CreateChatCompletionResponse
import com.xebia.functional.openai.generated.model.RunObject
import com.xebia.functional.openai.generated.model.RunStepObject
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.assistants.RunDelta
import com.xebia.functional.xef.metrics.Metric
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.completionRole
Expand All @@ -22,18 +20,6 @@ suspend fun CreateChatCompletionResponse.addMetrics(
)
conversation.metric.parameter("openai.chat_completion.token.count", "${it.totalTokens}")
}
choices.forEach { choice ->
choice.message.content?.let {
conversation.metric.parameter("openai.chat_completion.choice.${choice.index}.content", it)
}
choice.message.toolCalls?.zip(choice.message.toolCalls!!.indices)?.forEach {
(toolCall, toolCallIndex) ->
conversation.metric.parameter(
"openai.chat_completion.choice.${choice.index}.tool_call.$toolCallIndex",
toolCall.function.arguments
)
}
}
return this
}

Expand Down Expand Up @@ -64,30 +50,3 @@ suspend fun RunObject.addMetrics(metric: Metric): RunObject {
metric.assistantCreateRun(this)
return this
}

suspend fun RunStepObject.addMetrics(metric: Metric): RunStepObject {
metric.assistantCreateRunStep(this)
return this
}

suspend fun RunDelta.addMetrics(metric: Metric): RunDelta {
when (this) {
is RunDelta.RunCancelled -> run.addMetrics(metric)
is RunDelta.RunCancelling -> run.addMetrics(metric)
is RunDelta.RunCompleted -> run.addMetrics(metric)
is RunDelta.RunCreated -> run.addMetrics(metric)
is RunDelta.RunExpired -> run.addMetrics(metric)
is RunDelta.RunFailed -> run.addMetrics(metric)
is RunDelta.RunInProgress -> run.addMetrics(metric)
is RunDelta.RunQueued -> run.addMetrics(metric)
is RunDelta.RunRequiresAction -> run.addMetrics(metric)
is RunDelta.RunStepCancelled -> runStep.addMetrics(metric)
is RunDelta.RunStepCompleted -> runStep.addMetrics(metric)
is RunDelta.RunStepCreated -> runStep.addMetrics(metric)
is RunDelta.RunStepExpired -> runStep.addMetrics(metric)
is RunDelta.RunStepFailed -> runStep.addMetrics(metric)
is RunDelta.RunStepInProgress -> runStep.addMetrics(metric)
else -> {} // ignore other cases
}
return this
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,13 @@ class AssistantThread(
)
.data

suspend fun createRun(assistant: Assistant): RunObject =
createRun(CreateRunRequest(assistantId = assistant.assistantId))

suspend fun createRun(request: CreateRunRequest): RunObject =
api.createRun(threadId, request, configure = ::defaultConfig).addMetrics(metric)

fun createRunStream(assistant: Assistant, request: CreateRunRequest): Flow<RunDelta> = flow {
api
.createRunStream(threadId, request, configure = ::defaultConfig)
.map { RunDelta.fromServerSentEvent(it) }
.map { it.addMetrics(metric) }
.collect { event ->
when (event) {
// submit tool outputs and join streams
Expand Down Expand Up @@ -188,6 +184,9 @@ class AssistantThread(
suspend fun getRun(runId: String): RunObject =
api.getRun(threadId, runId, configure = ::defaultConfig)

suspend fun createRun(assistant: Assistant): RunObject =
createRun(CreateRunRequest(assistantId = assistant.assistantId))

fun run(assistant: Assistant): Flow<RunDelta> =
createRunStream(assistant, CreateRunRequest(assistantId = assistant.assistantId))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,21 +67,6 @@ class LogsMetric(private val level: Level = Level.INFO) : Metric {
return output
}

override suspend fun assistantCreateRunStep(runObject: RunStepObject) {
logger.at(level) {
this.message = "${writeIndent(numberOfBlocks.get())}|-- AssistantId: ${runObject.assistantId}"
}
logger.at(level) {
this.message = "${writeIndent(numberOfBlocks.get())}|-- ThreadId: ${runObject.threadId}"
}
logger.at(level) {
this.message = "${writeIndent(numberOfBlocks.get())}|-- RunId: ${runObject.id}"
}
logger.at(level) {
this.message = "${writeIndent(numberOfBlocks.get())}|-- Status: ${runObject.status.name}"
}
}

override suspend fun assistantCreatedMessage(
runId: String,
block: suspend Metric.() -> List<MessageObject>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ interface Metric {

suspend fun assistantCreateRun(runId: String, block: suspend Metric.() -> RunObject): RunObject

suspend fun assistantCreateRunStep(runObject: RunStepObject)

suspend fun assistantCreatedMessage(
runId: String,
block: suspend Metric.() -> List<MessageObject>
Expand Down Expand Up @@ -53,8 +51,6 @@ interface Metric {
block: suspend Metric.() -> RunObject
): RunObject = block()

override suspend fun assistantCreateRunStep(runObject: RunStepObject) {}

override suspend fun assistantCreatedMessage(
runId: String,
block: suspend Metric.() -> List<MessageObject>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.xebia.functional.xef.assistants

import com.xebia.functional.openai.generated.model.*
import com.xebia.functional.xef.OpenAI
import com.xebia.functional.xef.llm.assistants.Assistant
import com.xebia.functional.xef.llm.assistants.AssistantThread
Expand Down
17 changes: 12 additions & 5 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[versions]
arrow = "1.2.3"
arrow = "1.2.4"
arrowGradle = "0.12.0-rc.24"
exposed = "0.49.0"
kotlin = "1.9.23"
Expand All @@ -13,7 +13,7 @@ kotest = "5.8.1"
kotest-testcontainers = "2.0.2"
kotest-arrow = "1.4.0"
klogging = "6.0.9"
uuid = "0.0.22"
uuid = "0.0.25"
postgresql = "42.7.3"
testcontainers = "1.19.5"
hikari = "5.1.0"
Expand All @@ -26,19 +26,20 @@ junit = "5.10.2"
pdfbox = "3.0.2"
mysql = "8.0.33"
semverGradle = "0.5.0-rc.6"
jackson = "2.16.1"
jackson = "2.17.1"
jsonschema = "4.35.0"
jakarta = "3.0.2"
suspendApp = "0.4.0"
flyway = "9.22.3"
resources-kmp = "0.4.0"
detekt = "1.23.6"
opentelemetry="1.36.0"
opentelemetry-semconv="1.30.1-alpha"
opentelemetry-semconv="1.31.0-alpha"
progressbar = "0.10.0"
jmf = "2.1.1e"
mp3-wav-converter = "1.0.4"
yamlkt="0.13.0"
tegral = "0.0.4"


[libraries]
Expand Down Expand Up @@ -115,7 +116,13 @@ opentelemetry-extension-kotlin = { module = "io.opentelemetry:opentelemetry-exte
progressbar = { module = "me.tongfei:progressbar", version.ref = "progressbar" }
jmf = { module = "javax.media:jmf", version.ref = "jmf" }
mp3-wav-converter = { module = "com.sipgate:mp3-wav", version.ref = "mp3-wav-converter" }

tegral-catalog = { module = "guru.zoroark.tegral:tegral-catalog", version.ref = "tegral" }
tegral-core = { module = "guru.zoroark.tegral:tegral-core", version.ref = "tegral" }
tegral-openapi-dsl = { module = "guru.zoroark.tegral:tegral-openapi-dsl", version.ref = "tegral" }
tegral-openapi-scriptdef = { module = "guru.zoroark.tegral:tegral-openapi-scriptdef", version.ref = "tegral" }
tegral-openapi-ktor = { module = "guru.zoroark.tegral:tegral-openapi-ktor", version.ref = "tegral" }
tegral-openapi-ktor-resources = { module = "guru.zoroark.tegral:tegral-openapi-ktor-resources", version.ref = "tegral" }
tegral-openapi-ktorui = { module = "guru.zoroark.tegral:tegral-openapi-ktorui", version.ref = "tegral" }


[bundles]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,6 @@ class OpenTelemetryAssistantState(private val tracer: Tracer) {
}
}

fun runStepSpan(runObject: RunStepObject) {

val parentOrRoot: Context = runObject.runId.getOrCreateContext()

val currentSpan =
tracer
.spanBuilder("step ${runObject.status.name} ${runObject.id}")
.setParent(parentOrRoot)
.setSpanKind(SpanKind.CLIENT)
.startSpan()

try {
currentSpan.makeCurrent().use { runObject.setParameters(currentSpan) }
} finally {
currentSpan.end()
}
}

suspend fun runStepSpan(runId: String, block: suspend () -> RunStepObject): RunStepObject {

val parentOrRoot: Context = runId.getOrCreateContext()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ class OpenTelemetryMetric(
block: suspend Metric.() -> RunObject
): RunObject = assistantState.runSpan(runId) { block() }

override suspend fun assistantCreateRunStep(runObject: RunStepObject) =
assistantState.runStepSpan(runObject)

override suspend fun assistantCreatedMessage(
runId: String,
block: suspend Metric.() -> List<MessageObject>
Expand Down
6 changes: 6 additions & 0 deletions server/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ dependencies {
implementation(libs.ktor.server.status.pages)
implementation(libs.suspendApp.core)
implementation(libs.suspendApp.ktor)
implementation(libs.tegral.core)
implementation(libs.tegral.openapi.dsl)
implementation(libs.tegral.openapi.scriptdef)
implementation(libs.tegral.openapi.ktor)
implementation(libs.tegral.openapi.ktor.resources)
implementation(libs.tegral.openapi.ktorui)
implementation(libs.uuid)
implementation(projects.xefCore)
implementation(projects.xefPostgresql)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ import com.xebia.functional.xef.server.http.routes.xefRoutes
import com.xebia.functional.xef.server.services.hikariDataSource
import com.xebia.functional.xef.server.services.vectorStoreService
import com.xebia.functional.xef.store.migrations.runDatabaseMigrations
import guru.zoroark.tegral.openapi.ktor.TegralOpenApiKtor
import guru.zoroark.tegral.openapi.ktor.openApiEndpoint
import guru.zoroark.tegral.openapi.ktorui.TegralSwaggerUiKtor
import guru.zoroark.tegral.openapi.ktorui.swaggerUiEndpoint
import io.github.oshai.kotlinlogging.KotlinLogging
import io.ktor.client.HttpClient
import io.ktor.client.engine.cio.CIO
Expand Down Expand Up @@ -77,10 +81,14 @@ object Server {
authenticate { tokenCredential -> UserIdPrincipal(tokenCredential.token) }
}
}
install(TegralOpenApiKtor) { title = "Xef Server" }
install(TegralSwaggerUiKtor)
exceptionsHandler()
routing {
xefRoutes(logger)
aiRoutes(ktorClient)
openApiEndpoint("/openapi")
swaggerUiEndpoint(path = "/docs", openApiPath = "/openapi")
}
}
awaitCancellation()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package com.xebia.functional.xef.server.http.routes

import com.xebia.functional.openai.generated.model.ChatCompletionRequestMessage
import com.xebia.functional.xef.server.models.Token
import com.xebia.functional.xef.server.models.exceptions.XefExceptions
import guru.zoroark.tegral.openapi.dsl.schema
import guru.zoroark.tegral.openapi.ktor.resources.ResourceDescription
import guru.zoroark.tegral.openapi.ktor.resources.describeResource
import guru.zoroark.tegral.openapi.ktor.resources.postD
import io.ktor.client.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.resources.*
import io.ktor.server.application.*
import io.ktor.server.auth.*
import io.ktor.server.request.*
Expand All @@ -22,7 +27,7 @@ fun Routing.aiRoutes(client: HttpClient) {
val openAiUrl = "https://api.openai.com/v1"

authenticate("auth-bearer") {
post("/chat/completions") {
postD<ChatCompletionRoutes> {
val token = call.getToken()
val byteArrayBody = call.receiveChannel().toByteArray()
val body = byteArrayBody.toString(Charsets.UTF_8)
Expand All @@ -37,14 +42,48 @@ fun Routing.aiRoutes(client: HttpClient) {
}
}

post("/embeddings") {
postD<EmbeddingsRoutes> {
val token = call.getToken()
val context = call.receiveChannel().toByteArray()
client.makeRequest(call, "$openAiUrl/embeddings", context, token)
}
}
}

@Resource("/chat/completions")
class ChatCompletionRoutes {
companion object :
ResourceDescription by describeResource({
tags += "AI"
post {
description = "Create chat completions"
body {
description = "The chat details"
required = true
json { schema<ChatCompletionRequestMessage>() }
}
HttpStatusCode.OK.value response { description = "Chat completions" }
}
})
}

@Resource("/embeddings")
class EmbeddingsRoutes {
companion object :
ResourceDescription by describeResource({
tags += "AI"
post {
description = "Create embeddings"
body {
description = "The context"
required = true
json { schema<ByteArray>() }
}
HttpStatusCode.OK.value response { description = "Embeddings" }
}
})
}

private val conflictingRequestHeaders =
listOf("Host", "Content-Type", "Content-Length", "Accept", "Accept-Encoding")
private val conflictingResponseHeaders = listOf("Content-Length")
Expand Down Expand Up @@ -97,12 +136,3 @@ internal fun HeadersBuilder.copyFrom(headers: Headers) =
headers
.filter { key, _ -> !conflictingRequestHeaders.any { it.equals(key, true) } }
.forEach { key, values -> appendMissing(key, values) }

fun ApplicationCall.getToken(): Token =
principal<UserIdPrincipal>()?.name?.let { Token(it) }
?: throw XefExceptions.AuthorizationException("No token found")

fun ApplicationCall.getId(): Int = getInt("id")

fun ApplicationCall.getInt(field: String): Int =
this.parameters[field]?.toInt() ?: throw XefExceptions.ValidationException("Invalid $field")
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.xebia.functional.xef.server.http.routes

import arrow.core.raise.catch
import com.xebia.functional.xef.server.models.Token
import com.xebia.functional.xef.server.models.exceptions.XefExceptions
import io.ktor.server.application.*
import io.ktor.server.auth.*
import io.ktor.server.request.*
import kotlinx.serialization.json.Json

fun ApplicationCall.getToken(): Token =
principal<UserIdPrincipal>()?.name?.let { Token(it) }
?: throw XefExceptions.AuthorizationException("No token found")

fun ApplicationCall.getId(): Int = getInt("id")

fun ApplicationCall.getInt(field: String): Int =
this.parameters[field]?.toInt() ?: throw XefExceptions.ValidationException("Invalid $field")

suspend inline fun <reified T> ApplicationCall.decodeFromStringRequest(): T =
catch({ Json.decodeFromString<T>(this.receive<String>()) }) {
throw XefExceptions.ValidationException("Invalid ${T::class.simpleName}")
}
Loading
Loading