-
Notifications
You must be signed in to change notification settings - Fork 112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation actions scafolding + initial list_files action #8646
Changes from all commits
c5270e4
b26fc95
8c3adf8
ee51e66
64c6d62
74d904e
0079271
fabbf9b
b5a4360
84b3b32
53892e2
1ad4f02
fe06c9c
17eea88
bbbfb4d
d771ae1
a6b2342
f7ae369
6eea7d7
37ce947
7ab6c5f
deb9b77
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import type { | ||
AgentMessageType, | ||
ConversationFileType, | ||
ConversationListFilesActionType, | ||
ConversationType, | ||
FunctionCallType, | ||
FunctionMessageTypeModel, | ||
ModelId, | ||
} from "@dust-tt/types"; | ||
import { | ||
BaseAction, | ||
getTablesQueryResultsFileTitle, | ||
isAgentMessageType, | ||
isContentFragmentType, | ||
isTablesQueryActionType, | ||
} from "@dust-tt/types"; | ||
|
||
interface ConversationListFilesActionBlob { | ||
agentMessageId: ModelId; | ||
functionCallId: string | null; | ||
functionCallName: string | null; | ||
files: ConversationFileType[]; | ||
} | ||
|
||
export class ConversationListFilesAction extends BaseAction { | ||
readonly agentMessageId: ModelId; | ||
readonly files: ConversationFileType[]; | ||
readonly functionCallId: string | null; | ||
readonly functionCallName: string | null; | ||
readonly step: number = -1; | ||
readonly type = "conversation_list_files_action"; | ||
|
||
constructor(blob: ConversationListFilesActionBlob) { | ||
super(-1, "conversation_list_files_action"); | ||
|
||
this.agentMessageId = blob.agentMessageId; | ||
this.files = blob.files; | ||
this.functionCallId = blob.functionCallId; | ||
this.functionCallName = blob.functionCallName; | ||
} | ||
|
||
renderForFunctionCall(): FunctionCallType { | ||
return { | ||
id: this.functionCallId ?? `call_${this.id.toString()}`, | ||
name: this.functionCallName ?? "list_conversation_files", | ||
arguments: JSON.stringify({}), | ||
}; | ||
} | ||
|
||
renderForMultiActionsModel(): FunctionMessageTypeModel { | ||
let content = "CONVERSATION FILES:\n"; | ||
for (const f of this.files) { | ||
content += `<file id="${f.fileId}" name="${f.title}" type="${f.contentType}" />\n`; | ||
} | ||
|
||
return { | ||
role: "function" as const, | ||
name: this.functionCallName ?? "list_conversation_files", | ||
function_call_id: this.functionCallId ?? `call_${this.id.toString()}`, | ||
content, | ||
}; | ||
} | ||
} | ||
|
||
export function makeConversationListFilesAction( | ||
agentMessage: AgentMessageType, | ||
conversation: ConversationType | ||
): ConversationListFilesActionType | null { | ||
const files: ConversationFileType[] = []; | ||
|
||
for (const m of conversation.content.flat(1)) { | ||
if (isContentFragmentType(m)) { | ||
if (m.fileId) { | ||
files.push({ | ||
fileId: m.fileId, | ||
title: m.title, | ||
contentType: m.contentType, | ||
}); | ||
} | ||
} else if (isAgentMessageType(m)) { | ||
for (const a of m.actions) { | ||
if (isTablesQueryActionType(a)) { | ||
if (a.resultsFileId && a.resultsFileSnippet) { | ||
files.push({ | ||
fileId: a.resultsFileId, | ||
contentType: "text/csv", | ||
title: getTablesQueryResultsFileTitle({ output: a.output }), | ||
}); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
if (files.length === 0) { | ||
return null; | ||
} | ||
|
||
return new ConversationListFilesAction({ | ||
functionCallId: "call_" + Math.random().toString(36).substring(7), | ||
functionCallName: "list_conversation_files", | ||
files, | ||
agentMessageId: agentMessage.agentMessageId, | ||
}); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ import type { | |
AgentActionSpecification, | ||
AgentActionSpecificEvent, | ||
AgentActionSuccessEvent, | ||
AgentActionType, | ||
AgentChainOfThoughtEvent, | ||
AgentConfigurationType, | ||
AgentContentEvent, | ||
|
@@ -29,9 +30,13 @@ import { | |
isWebsearchConfiguration, | ||
SUPPORTED_MODEL_CONFIGS, | ||
} from "@dust-tt/types"; | ||
import assert from "assert"; | ||
|
||
import { runActionStreamed } from "@app/lib/actions/server"; | ||
import { isJITActionsEnabled } from "@app/lib/api/assistant//jit_actions"; | ||
import { makeConversationListFilesAction } from "@app/lib/api/assistant/actions/conversation/list_files"; | ||
import { getRunnerForActionConfiguration } from "@app/lib/api/assistant/actions/runners"; | ||
import { getCitationsCount } from "@app/lib/api/assistant/actions/utils"; | ||
import { | ||
AgentMessageContentParser, | ||
getDelimitersConfiguration, | ||
|
@@ -49,8 +54,6 @@ import { AgentMessageContent } from "@app/lib/models/assistant/agent_message_con | |
import { cloneBaseConfig, DustProdActionRegistry } from "@app/lib/registry"; | ||
import logger from "@app/logger/logger"; | ||
|
||
import { getCitationsCount } from "./actions/utils"; | ||
|
||
const CANCELLATION_CHECK_INTERVAL = 500; | ||
const MAX_ACTIONS_PER_STEP = 16; | ||
|
||
|
@@ -292,6 +295,29 @@ async function* runMultiActionsAgentLoop( | |
} | ||
} | ||
|
||
async function getEmulatedAgentMessageActions( | ||
auth: Authenticator, | ||
{ | ||
agentMessage, | ||
conversation, | ||
}: { agentMessage: AgentMessageType; conversation: ConversationType } | ||
): Promise<AgentActionType[]> { | ||
const actions: AgentActionType[] = []; | ||
if (await isJITActionsEnabled(auth)) { | ||
const a = makeConversationListFilesAction(agentMessage, conversation); | ||
if (a) { | ||
actions.push(a); | ||
} | ||
} | ||
|
||
// We ensure that all emulated actions are injected with step -1. | ||
assert( | ||
actions.every((a) => a.step === -1), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can't this also be enforced in type system removing the need for assert ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would mix their type and the fact that they are emulated. One could imagine that a classic action is emulated. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair ! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: the -1 index is going to be limiting if we need to order multiple emulated actions, I suggest < 0 |
||
"Emulated actions must have step -1" | ||
); | ||
return actions; | ||
} | ||
|
||
// This method is used by the multi-actions execution loop to pick the next action to execute and | ||
// generate its inputs. | ||
async function* runMultiActionsAgent( | ||
|
@@ -362,6 +388,15 @@ async function* runMultiActionsAgent( | |
|
||
const MIN_GENERATION_TOKENS = 2048; | ||
|
||
const emulatedActions = await getEmulatedAgentMessageActions(auth, { | ||
agentMessage, | ||
conversation, | ||
}); | ||
|
||
// Prepend emulated actions to the current agent message before rendering the conversation for the | ||
// model. | ||
agentMessage.actions = emulatedActions.concat(agentMessage.actions); | ||
|
||
// Turn the conversation into a digest that can be presented to the model. | ||
const modelConversationRes = await renderConversationForModel(auth, { | ||
conversation, | ||
|
@@ -370,6 +405,11 @@ async function* runMultiActionsAgent( | |
allowedTokenCount: model.contextSize - MIN_GENERATION_TOKENS, | ||
}); | ||
|
||
// Scrub emulated actions from the agent message after rendering. | ||
agentMessage.actions = agentMessage.actions.filter( | ||
(a) => !emulatedActions.includes(a) | ||
); | ||
|
||
if (modelConversationRes.isErr()) { | ||
logger.error( | ||
{ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,12 +12,10 @@ import type { | |
import { | ||
assertNever, | ||
Err, | ||
getTablesQueryResultsFileAttachment, | ||
isAgentMessageType, | ||
isContentFragmentMessageTypeModel, | ||
isContentFragmentType, | ||
isDevelopment, | ||
isTablesQueryActionType, | ||
isTextContent, | ||
isUserMessageType, | ||
Ok, | ||
|
@@ -81,46 +79,55 @@ export async function renderConversationForModelJIT({ | |
const now = Date.now(); | ||
const messages: ModelMessageTypeMultiActions[] = []; | ||
|
||
// Render loop. | ||
// Render all messages and all actions. | ||
// Render loop: dender all messages and all actions. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: typo ( |
||
for (const versions of conversation.content) { | ||
const m = versions[versions.length - 1]; | ||
|
||
if (isAgentMessageType(m)) { | ||
const actions = removeNulls(m.actions); | ||
|
||
// This array is 2D, because we can have multiple calls per agent message (parallel calls). | ||
|
||
const steps = [] as Array<{ | ||
contents: string[]; | ||
actions: Array<{ | ||
call: FunctionCallType; | ||
result: FunctionMessageTypeModel; | ||
}>; | ||
}>; | ||
// This is a record of arrays, because we can have multiple calls per agent message (parallel | ||
// calls). Actions all have a step index which indicates how they should be grouped but some | ||
// actions injected by `getEmulatedAgentMessageActions` have a step index of `-1`. We | ||
// therefore group by index, then order and transform in a 2D array to present to the model. | ||
const stepByStepIndex = {} as Record< | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I'd rather use const aRecord:Record<string, {"k": string}> = {};
aRecord["ruc"].k |
||
string, | ||
{ | ||
contents: string[]; | ||
actions: Array<{ | ||
call: FunctionCallType; | ||
result: FunctionMessageTypeModel; | ||
}>; | ||
} | ||
>; | ||
|
||
const emptyStep = () => | ||
({ | ||
contents: [], | ||
actions: [], | ||
}) satisfies (typeof steps)[number]; | ||
}) satisfies (typeof stepByStepIndex)[number]; | ||
|
||
for (const action of actions) { | ||
const stepIndex = action.step; | ||
steps[stepIndex] = steps[stepIndex] || emptyStep(); | ||
steps[stepIndex].actions.push({ | ||
stepByStepIndex[stepIndex] = stepByStepIndex[stepIndex] || emptyStep(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: i know it's previous code but how can a step be empty ? (a comment would go a long way) |
||
stepByStepIndex[stepIndex].actions.push({ | ||
call: action.renderForFunctionCall(), | ||
result: action.renderForMultiActionsModel(), | ||
}); | ||
} | ||
|
||
for (const content of m.rawContents) { | ||
steps[content.step] = steps[content.step] || emptyStep(); | ||
stepByStepIndex[content.step] = | ||
stepByStepIndex[content.step] || emptyStep(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: shorter version |
||
if (content.content.trim()) { | ||
steps[content.step].contents.push(content.content); | ||
stepByStepIndex[content.step].contents.push(content.content); | ||
} | ||
} | ||
|
||
const steps = Object.entries(stepByStepIndex) | ||
.sort(([a], [b]) => Number(a) - Number(b)) | ||
.map(([, step]) => step); | ||
|
||
if (excludeActions) { | ||
// In Exclude Actions mode, we only render the last step that has content. | ||
const stepsWithContent = steps.filter((s) => s?.contents.length); | ||
|
@@ -222,44 +229,6 @@ export async function renderConversationForModelJIT({ | |
} | ||
} | ||
|
||
// If we have messages... | ||
if (messages.length > 0) { | ||
const { filesAsXML, hasFiles } = listConversationFiles({ | ||
conversation, | ||
}); | ||
|
||
// ... and files, we simulate a function call to list the files at the end of the conversation. | ||
if (hasFiles) { | ||
const randomCallId = "tool_" + Math.random().toString(36).substring(7); | ||
const functionName = "list_conversation_files"; | ||
|
||
const simulatedAgentMessages = [ | ||
// 1. We add a message from the agent, asking to use the files listing function | ||
{ | ||
role: "assistant", | ||
function_calls: [ | ||
{ | ||
id: randomCallId, | ||
name: functionName, | ||
arguments: "{}", | ||
}, | ||
], | ||
} as AssistantFunctionCallMessageTypeModel, | ||
|
||
// 2. We add a message with the resulting files listing | ||
{ | ||
function_call_id: randomCallId, | ||
role: "function", | ||
name: functionName, | ||
content: filesAsXML, | ||
} as FunctionMessageTypeModel, | ||
]; | ||
|
||
// Append the simulated messages to the end of the conversation. | ||
messages.push(...simulatedAgentMessages); | ||
} | ||
} | ||
|
||
// Compute in parallel the token count for each message and the prompt. | ||
const res = await tokenCountForTexts( | ||
[prompt, ...getTextRepresentationFromMessages(messages)], | ||
|
@@ -402,43 +371,3 @@ export async function renderConversationForModelJIT({ | |
tokensUsed, | ||
}); | ||
} | ||
|
||
function listConversationFiles({ | ||
conversation, | ||
}: { | ||
conversation: ConversationType; | ||
}) { | ||
const fileAttachments: string[] = []; | ||
for (const m of conversation.content.flat(1)) { | ||
if (isContentFragmentType(m)) { | ||
if (!m.fileId) { | ||
continue; | ||
} | ||
fileAttachments.push( | ||
`<file id="${m.fileId}" name="${m.title}" type="${m.contentType}" />` | ||
); | ||
} else if (isAgentMessageType(m)) { | ||
for (const a of m.actions) { | ||
if (isTablesQueryActionType(a)) { | ||
const attachment = getTablesQueryResultsFileAttachment({ | ||
resultsFileId: a.resultsFileId, | ||
resultsFileSnippet: a.resultsFileSnippet, | ||
output: a.output, | ||
includeSnippet: false, | ||
}); | ||
if (attachment) { | ||
fileAttachments.push(attachment); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
let filesAsXML = "<files>\n"; | ||
|
||
if (fileAttachments.length > 0) { | ||
filesAsXML += fileAttachments.join("\n"); | ||
} | ||
filesAsXML += "\n</files>"; | ||
|
||
return { filesAsXML, hasFiles: fileAttachments.length > 0 }; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated to this PR, but we need to rename this