Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/gradle/com.apptasticsoftware-rssr…
Browse files Browse the repository at this point in the history
…eader-3.6.0
  • Loading branch information
Montagon authored Apr 15, 2024
2 parents e9fdb69 + d33f606 commit 8de885a
Showing 1 changed file with 91 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import kotlin.jvm.JvmName
import kotlinx.coroutines.flow.*
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive

class AssistantThread(
val threadId: String,
Expand Down Expand Up @@ -81,77 +82,107 @@ class AssistantThread(
when (event) {
// submit tool outputs and join streams
is RunDelta.RunRequiresAction -> {
val steps =
runSteps(event.run.id).map { metric.assistantCreateRunStep(event.run.id) { it } }

steps.forEach { step ->
val calls = step.stepDetails.toolCalls()
takeRequiredAction(1, event, this@AssistantThread, assistant, this)
}
// previous to submitting tool outputs we let all events pass through the outer flow
else -> {
emit(event)
}
}
}
}

val run = getRun(event.run.id)
if (
run.status == RunObject.Status.requires_action &&
run.requiredAction?.type == RunObjectRequiredAction.Type.submit_tool_outputs
) {
val callsResult: List<Pair<String, Assistant.Companion.ToolOutput>> =
calls
.filterIsInstance<
RunStepDetailsToolCallsObjectToolCallsInner.CaseRunStepDetailsToolCallsFunctionObject
>()
.parMapNotNull { toolCall -> executeToolCall(toolCall, assistant) }
val results: Map<String, Assistant.Companion.ToolOutput> = callsResult.toMap()
val toolOutputsRequest =
SubmitToolOutputsRunRequest(
toolOutputs =
results.map { (toolCallId, result) ->
SubmitToolOutputsRunRequestToolOutputsInner(
toolCallId = toolCallId,
output =
Json.encodeToString(Assistant.Companion.ToolOutput.serializer(), result)
)
}
)
metric.assistantToolOutputsRun(event.run.id) {
api
.submitToolOuputsToRunStream(
threadId = threadId,
runId = event.run.id,
submitToolOutputsRunRequest = toolOutputsRequest,
configure = ::defaultConfig
)
.collect {
val delta = RunDelta.fromServerSentEvent(it)
if (delta is RunDelta.RunStepCompleted) {
emit(delta)
emit(RunDelta.RunSubmitToolOutputs(toolOutputsRequest))
} else {
emit(delta)
}
}
getRun(event.run.id)
}
private suspend fun takeRequiredAction(
depth: Int,
event: RunDelta.RunRequiresAction,
assistantThread: AssistantThread,
assistant: Assistant,
flowCollector: FlowCollector<RunDelta>
) {
if (
event.run.status == RunObject.Status.requires_action &&
event.run.requiredAction?.type == RunObjectRequiredAction.Type.submit_tool_outputs
) {
val calls = event.run.requiredAction?.submitToolOutputs?.toolCalls.orEmpty()
val callsResult: List<Pair<String, Assistant.Companion.ToolOutput>> =
calls.parMapNotNull { toolCall -> assistantThread.executeToolCall(toolCall, assistant) }
val results: Map<String, Assistant.Companion.ToolOutput> = callsResult.toMap()
val toolOutputsRequest =
SubmitToolOutputsRunRequest(
toolOutputs =
results.map { (toolCallId, result) ->
SubmitToolOutputsRunRequestToolOutputsInner(
toolCallId = toolCallId,
output = Json.encodeToString(Assistant.Companion.ToolOutput.serializer(), result)
)
}
)
val run =
metric.assistantToolOutputsRun(event.run.id) {
api
.submitToolOuputsToRunStream(
threadId = threadId,
runId = event.run.id,
submitToolOutputsRunRequest = toolOutputsRequest,
configure = ::defaultConfig
)
.collect {
val delta = RunDelta.fromServerSentEvent(it)
if (delta is RunDelta.RunStepCompleted) {
flowCollector.emit(RunDelta.RunSubmitToolOutputs(toolOutputsRequest))
}
flowCollector.emit(delta)
}
}
// previous to submitting tool outputs we let all events pass through the outer flow
else -> emit(event)
val run = getRun(event.run.id)
val finalEvent =
when (run.status) {
RunObject.Status.queued -> RunDelta.RunQueued(run)
RunObject.Status.in_progress -> RunDelta.RunInProgress(run)
RunObject.Status.requires_action -> RunDelta.RunRequiresAction(run)
RunObject.Status.cancelling -> RunDelta.RunCancelling(run)
RunObject.Status.cancelled -> RunDelta.RunCancelled(run)
RunObject.Status.failed -> RunDelta.RunFailed(run)
RunObject.Status.completed -> RunDelta.RunCompleted(run)
RunObject.Status.expired -> RunDelta.RunExpired(run)
}
flowCollector.emit(finalEvent)
run
}

if (run.status == RunObject.Status.requires_action) {
takeRequiredAction(
depth + 1,
RunDelta.RunRequiresAction(run),
assistantThread,
assistant,
flowCollector
)
}
}
}

private suspend fun executeToolCall(
toolCall: RunStepDetailsToolCallsObjectToolCallsInner.CaseRunStepDetailsToolCallsFunctionObject,
toolCall: RunToolCallObject,
assistant: Assistant
): Pair<String, Assistant.Companion.ToolOutput>? {
val function = toolCall.value.function
val functionName = function.name
val functionArguments = function.arguments
return if (functionName != null && functionArguments != null) {
val result = assistant.getToolRegistered(functionName, functionArguments)
val callId = toolCall.value.id
if (callId != null) {
callId to result
return try {
val function = toolCall.function
val functionName = function.name
val functionArguments = function.arguments
return if (functionName != null && functionArguments != null) {
val result = assistant.getToolRegistered(functionName, functionArguments)
val callId = toolCall.id
if (callId != null) {
callId to result
} else null
} else null
} else null
} catch (e: Throwable) {
toolCall.id to
Assistant.Companion.ToolOutput(
schema = JsonObject(emptyMap()),
result = JsonObject(mapOf("error" to JsonPrimitive(e.message ?: "Unknown error")))
)
}
}

suspend fun getRun(runId: String): RunObject =
Expand Down

0 comments on commit 8de885a

Please sign in to comment.