From b16c9cc490b0fef1f21c72ecca53955dba7ac67a Mon Sep 17 00:00:00 2001 From: Ethan McElroy Date: Tue, 3 Dec 2024 18:19:00 -0500 Subject: [PATCH] feat: use chat completions API --- src/components/App.tsx | 2 +- src/models/assistant-model.ts | 112 ++++++++++++++++++---------------- src/utils/openai-utils.ts | 4 +- 3 files changed, 61 insertions(+), 57 deletions(-) diff --git a/src/components/App.tsx b/src/components/App.tsx index 3075084..131aeac 100755 --- a/src/components/App.tsx +++ b/src/components/App.tsx @@ -60,7 +60,7 @@ export const App = observer(() => { const handleChatInputSubmit = async (messageText: string) => { transcriptStore.addMessage(USER_SPEAKER, messageText); - assistantStore.handleMessageSubmit(messageText); + await assistantStore.handleMessageSubmit(transcriptStore.messages); }; return ( diff --git a/src/models/assistant-model.ts b/src/models/assistant-model.ts index 3fb7243..5e7c439 100644 --- a/src/models/assistant-model.ts +++ b/src/models/assistant-model.ts @@ -1,7 +1,6 @@ import { types, flow } from "mobx-state-tree"; import { initLlmConnection, getTools } from "../utils/llm-utils"; import { transcriptStore } from "./chat-transcript-model"; -import { Message } from "openai/resources/beta/threads/messages"; import { getAttributeList, getDataContext } from "@concord-consortium/codap-plugin-api"; import { DAVAI_SPEAKER } from "../constants"; import { createGraph } from "../utils/codap-utils"; @@ -17,70 +16,60 @@ const AssistantModel = types const initialize = flow(function* () { try { const tools = getTools(); - const assistantInstructions = - "You are DAVAI, an Data Analysis through Voice and Artificial Intelligence partner. You are an intermediary for a user who is blind who wants to interact with data tables in a data analysis app named CODAP."; - const newAssistant = yield davai.beta.assistants.create({ - instructions: assistantInstructions, + const newAssistant = yield davai.chat.completions.create({ model: "gpt-4o-mini", + messages: [ + { + role: "system", + content: [{ + type: "text", + text: "You are DAVAI, an Data Analysis through Voice and Artificial Intelligence partner. You are an intermediary for a user who is blind who wants to interact with data tables in a data analysis app named CODAP." + }] + } + ], tools, }); self.assistant = newAssistant; - self.thread = yield davai.beta.threads.create(); } catch (err) { console.error("Failed to initialize assistant:", err); } }); - const handleMessageSubmit = flow(function* (messageText) { + const handleMessageSubmit = flow(function* (messages) { try { - yield davai.beta.threads.messages.create(self.thread.id, { - role: "user", - content: messageText, - }); - yield startRun(); + const tools = getTools(); - } catch (err) { - console.error("Failed to handle message submit:", err); - } - }); + // Transform messages for the API request. + const transformedMessages = messages.map((message: any) => ({ + role: message.speaker === "DAVAI" ? "assistant" : "user", + content: [{ type: "text", text: message.content }], + })); - const startRun = flow(function* () { - try { - const run = yield davai.beta.threads.runs.create(self.thread.id, { - assistant_id: self.assistant.id, + const response = yield davai.chat.completions.create({ + model: "gpt-4o-mini", + tools, + messages: transformedMessages, }); - // Wait for run completion and handle responses - let runState = yield davai.beta.threads.runs.retrieve(self.thread.id, run.id); - while (runState.status !== "completed" && runState.status !== "requires_action") { - runState = yield davai.beta.threads.runs.retrieve(self.thread.id, run.id); - } - - if (runState.status === "requires_action") { - yield handleRequiredAction(runState, run.id); + if (response?.choices[0]?.finish_reason === "tool_calls") { + yield handleRequiredAction(response?.choices[0]?.message.tool_calls, messages); + } else { + transcriptStore.addMessage( + DAVAI_SPEAKER, + response?.choices[0]?.message.content || "Error processing request." + ); } - - // Get the last assistant message from the messages array - const messages = yield davai.beta.threads.messages.list(self.thread.id); - const lastMessageForRun = messages.data.filter( - (msg: Message) => msg.run_id === run.id && msg.role === "assistant" - ).pop(); - - transcriptStore.addMessage( - DAVAI_SPEAKER, - lastMessageForRun?.content[0]?.text?.value || "Error processing request." - ); } catch (err) { - console.error("Failed to complete run:", err); + console.error("Failed to handle message submit:", err); } }); - const handleRequiredAction = flow(function* (runState, runId) { + const handleRequiredAction = flow(function* (tool_calls, messages) { try { - const toolOutputs = runState.required_action?.submit_tool_outputs.tool_calls + const toolOutputs = tool_calls ? yield Promise.all( - runState.required_action.submit_tool_outputs.tool_calls.map(async (toolCall: any) => { + tool_calls.map(async (toolCall: any) => { if (toolCall.function.name === "get_attributes") { const { dataset } = JSON.parse(toolCall.function.arguments); // getting the root collection won't always work. what if a user wants the attributes @@ -98,22 +87,37 @@ const AssistantModel = types : []; if (toolOutputs) { - davai.beta.threads.runs.submitToolOutputsStream( - self.thread.id, runId, { tool_outputs: toolOutputs } - ); + const tools = getTools(); - const threadMessageList = yield davai.beta.threads.messages.list(self.thread.id); - const threadMessages = threadMessageList.data.map((msg: any) => ({ - role: msg.role, - content: msg.content[0].text.value, + const newMessage = toolOutputs.map((toolOutput: any) => ({ + role: "system", + content: [ + { type: "text", text: toolOutput.output }, + ], })); - yield davai.chat.completions.create({ - model: "gpt-4o-mini", - messages: [ - ...threadMessages + // Transform existing messages for the API request. They need to be included in the request to maintain context. + const transformedMessages = messages.map((message: any) => ({ + role: message.speaker === "DAVAI" ? "assistant" : "user", + content: [ + { + type: "text", + text: message.content, + }, ], + })); + transformedMessages.push(...newMessage); + + const response = yield davai.chat.completions.create({ + model: "gpt-4o-mini", + messages: transformedMessages, + tools, }); + + transcriptStore.addMessage( + DAVAI_SPEAKER, + response?.choices[0]?.message.content || "Error processing request." + ); } } catch (err) { console.error(err); diff --git a/src/utils/openai-utils.ts b/src/utils/openai-utils.ts index 54576a6..a9ade6c 100644 --- a/src/utils/openai-utils.ts +++ b/src/utils/openai-utils.ts @@ -1,5 +1,5 @@ import { OpenAI } from "openai"; -import { AssistantTool } from "openai/resources/beta/assistants"; +import { ChatCompletionTool } from "openai/resources"; export const newOpenAI = () => { return new OpenAI({ @@ -10,7 +10,7 @@ export const newOpenAI = () => { }); }; -export const openAiTools: AssistantTool[] = [ +export const openAiTools: ChatCompletionTool[] = [ { type: "function", function: {