Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation actions scafolding + initial list_files action #8646

Merged
merged 22 commits into from
Nov 15, 2024
4 changes: 4 additions & 0 deletions front/components/actions/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ const actionsSpecification: ActionSpecifications = {
detailsComponent: BrowseActionDetails,
runningLabel: ACTION_RUNNING_LABELS.browse_action,
},
conversation_list_files_action: {
detailsComponent: () => null,
runningLabel: ACTION_RUNNING_LABELS.conversation_list_files_action,
},
};

export function getActionSpecification<T extends ActionType>(
Expand Down
105 changes: 105 additions & 0 deletions front/lib/api/assistant/actions/conversation/list_files.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import type {
AgentMessageType,
ConversationFileType,
ConversationListFilesActionType,
ConversationType,
FunctionCallType,
FunctionMessageTypeModel,
ModelId,
} from "@dust-tt/types";
import {
BaseAction,
getTablesQueryResultsFileTitle,
isAgentMessageType,
isContentFragmentType,
isTablesQueryActionType,
} from "@dust-tt/types";

interface ConversationListFilesActionBlob {
agentMessageId: ModelId;
functionCallId: string | null;
functionCallName: string | null;
files: ConversationFileType[];
}

export class ConversationListFilesAction extends BaseAction {
readonly agentMessageId: ModelId;
readonly files: ConversationFileType[];
readonly functionCallId: string | null;
readonly functionCallName: string | null;
readonly step: number = -1;
readonly type = "conversation_list_files_action";

constructor(blob: ConversationListFilesActionBlob) {
super(-1, "conversation_list_files_action");

this.agentMessageId = blob.agentMessageId;
this.files = blob.files;
this.functionCallId = blob.functionCallId;
this.functionCallName = blob.functionCallName;
}

renderForFunctionCall(): FunctionCallType {
return {
id: this.functionCallId ?? `call_${this.id.toString()}`,
name: this.functionCallName ?? "list_conversation_files",
arguments: JSON.stringify({}),
};
}

renderForMultiActionsModel(): FunctionMessageTypeModel {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated to this PR, but we need to rename this

let content = "CONVERSATION FILES:\n";
for (const f of this.files) {
content += `<file id="${f.fileId}" name="${f.title}" type="${f.contentType}" />\n`;
}

return {
role: "function" as const,
name: this.functionCallName ?? "list_conversation_files",
function_call_id: this.functionCallId ?? `call_${this.id.toString()}`,
content,
};
}
}

export function makeConversationListFilesAction(
agentMessage: AgentMessageType,
conversation: ConversationType
): ConversationListFilesActionType | null {
const files: ConversationFileType[] = [];

for (const m of conversation.content.flat(1)) {
if (isContentFragmentType(m)) {
if (m.fileId) {
files.push({
fileId: m.fileId,
title: m.title,
contentType: m.contentType,
});
}
} else if (isAgentMessageType(m)) {
for (const a of m.actions) {
if (isTablesQueryActionType(a)) {
if (a.resultsFileId && a.resultsFileSnippet) {
files.push({
fileId: a.resultsFileId,
contentType: "text/csv",
title: getTablesQueryResultsFileTitle({ output: a.output }),
});
}
}
}
}
}

if (files.length === 0) {
return null;
}

return new ConversationListFilesAction({
functionCallId: "call_" + Math.random().toString(36).substring(7),
functionCallName: "list_conversation_files",
files,
agentMessageId: agentMessage.agentMessageId,
});
}
44 changes: 42 additions & 2 deletions front/lib/api/assistant/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import type {
AgentActionSpecification,
AgentActionSpecificEvent,
AgentActionSuccessEvent,
AgentActionType,
AgentChainOfThoughtEvent,
AgentConfigurationType,
AgentContentEvent,
Expand All @@ -29,9 +30,13 @@ import {
isWebsearchConfiguration,
SUPPORTED_MODEL_CONFIGS,
} from "@dust-tt/types";
import assert from "assert";

import { runActionStreamed } from "@app/lib/actions/server";
import { isJITActionsEnabled } from "@app/lib/api/assistant//jit_actions";
import { makeConversationListFilesAction } from "@app/lib/api/assistant/actions/conversation/list_files";
import { getRunnerForActionConfiguration } from "@app/lib/api/assistant/actions/runners";
import { getCitationsCount } from "@app/lib/api/assistant/actions/utils";
import {
AgentMessageContentParser,
getDelimitersConfiguration,
Expand All @@ -49,8 +54,6 @@ import { AgentMessageContent } from "@app/lib/models/assistant/agent_message_con
import { cloneBaseConfig, DustProdActionRegistry } from "@app/lib/registry";
import logger from "@app/logger/logger";

import { getCitationsCount } from "./actions/utils";

const CANCELLATION_CHECK_INTERVAL = 500;
const MAX_ACTIONS_PER_STEP = 16;

Expand Down Expand Up @@ -292,6 +295,29 @@ async function* runMultiActionsAgentLoop(
}
}

async function getEmulatedAgentMessageActions(
auth: Authenticator,
{
agentMessage,
conversation,
}: { agentMessage: AgentMessageType; conversation: ConversationType }
): Promise<AgentActionType[]> {
const actions: AgentActionType[] = [];
if (await isJITActionsEnabled(auth)) {
const a = makeConversationListFilesAction(agentMessage, conversation);
if (a) {
actions.push(a);
}
}

// We ensure that all emulated actions are injected with step -1.
assert(
actions.every((a) => a.step === -1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't this also be enforced in type system removing the need for assert ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would mix their type and the fact that they are emulated. One could imagine that a classic action is emulated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair !

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@spolu

nit: the -1 index is going to be limiting if we need to order multiple emulated actions, I suggest < 0

"Emulated actions must have step -1"
);
return actions;
}

// This method is used by the multi-actions execution loop to pick the next action to execute and
// generate its inputs.
async function* runMultiActionsAgent(
Expand Down Expand Up @@ -362,6 +388,15 @@ async function* runMultiActionsAgent(

const MIN_GENERATION_TOKENS = 2048;

const emulatedActions = await getEmulatedAgentMessageActions(auth, {
agentMessage,
conversation,
});

// Prepend emulated actions to the current agent message before rendering the conversation for the
// model.
agentMessage.actions = emulatedActions.concat(agentMessage.actions);

// Turn the conversation into a digest that can be presented to the model.
const modelConversationRes = await renderConversationForModel(auth, {
conversation,
Expand All @@ -370,6 +405,11 @@ async function* runMultiActionsAgent(
allowedTokenCount: model.contextSize - MIN_GENERATION_TOKENS,
});

// Scrub emulated actions from the agent message after rendering.
agentMessage.actions = agentMessage.actions.filter(
(a) => !emulatedActions.includes(a)
);

if (modelConversationRes.isErr()) {
logger.error(
{
Expand Down
121 changes: 25 additions & 96 deletions front/lib/api/assistant/jit_actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ import type {
import {
assertNever,
Err,
getTablesQueryResultsFileAttachment,
isAgentMessageType,
isContentFragmentMessageTypeModel,
isContentFragmentType,
isDevelopment,
isTablesQueryActionType,
isTextContent,
isUserMessageType,
Ok,
Expand Down Expand Up @@ -81,46 +79,55 @@ export async function renderConversationForModelJIT({
const now = Date.now();
const messages: ModelMessageTypeMultiActions[] = [];

// Render loop.
// Render all messages and all actions.
// Render loop: dender all messages and all actions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo (dender => render)

for (const versions of conversation.content) {
const m = versions[versions.length - 1];

if (isAgentMessageType(m)) {
const actions = removeNulls(m.actions);

// This array is 2D, because we can have multiple calls per agent message (parallel calls).

const steps = [] as Array<{
contents: string[];
actions: Array<{
call: FunctionCallType;
result: FunctionMessageTypeModel;
}>;
}>;
// This is a record of arrays, because we can have multiple calls per agent message (parallel
// calls). Actions all have a step index which indicates how they should be grouped but some
// actions injected by `getEmulatedAgentMessageActions` have a step index of `-1`. We
// therefore group by index, then order and transform in a 2D array to present to the model.
const stepByStepIndex = {} as Record<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd rather use Map() insead of Record, unless we are serializing as this is not catched by the typechecking system:

const aRecord:Record<string, {"k": string}> = {};

aRecord["ruc"].k

string,
{
contents: string[];
actions: Array<{
call: FunctionCallType;
result: FunctionMessageTypeModel;
}>;
}
>;

const emptyStep = () =>
({
contents: [],
actions: [],
}) satisfies (typeof steps)[number];
}) satisfies (typeof stepByStepIndex)[number];

for (const action of actions) {
const stepIndex = action.step;
steps[stepIndex] = steps[stepIndex] || emptyStep();
steps[stepIndex].actions.push({
stepByStepIndex[stepIndex] = stepByStepIndex[stepIndex] || emptyStep();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i know it's previous code but how can a step be empty ? (a comment would go a long way)

stepByStepIndex[stepIndex].actions.push({
call: action.renderForFunctionCall(),
result: action.renderForMultiActionsModel(),
});
}

for (const content of m.rawContents) {
steps[content.step] = steps[content.step] || emptyStep();
stepByStepIndex[content.step] =
stepByStepIndex[content.step] || emptyStep();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: shorter version stepByStepIndex[content.step] ??= emptyStep(); (same above)

if (content.content.trim()) {
steps[content.step].contents.push(content.content);
stepByStepIndex[content.step].contents.push(content.content);
}
}

const steps = Object.entries(stepByStepIndex)
.sort(([a], [b]) => Number(a) - Number(b))
.map(([, step]) => step);

if (excludeActions) {
// In Exclude Actions mode, we only render the last step that has content.
const stepsWithContent = steps.filter((s) => s?.contents.length);
Expand Down Expand Up @@ -222,44 +229,6 @@ export async function renderConversationForModelJIT({
}
}

// If we have messages...
if (messages.length > 0) {
const { filesAsXML, hasFiles } = listConversationFiles({
conversation,
});

// ... and files, we simulate a function call to list the files at the end of the conversation.
if (hasFiles) {
const randomCallId = "tool_" + Math.random().toString(36).substring(7);
const functionName = "list_conversation_files";

const simulatedAgentMessages = [
// 1. We add a message from the agent, asking to use the files listing function
{
role: "assistant",
function_calls: [
{
id: randomCallId,
name: functionName,
arguments: "{}",
},
],
} as AssistantFunctionCallMessageTypeModel,

// 2. We add a message with the resulting files listing
{
function_call_id: randomCallId,
role: "function",
name: functionName,
content: filesAsXML,
} as FunctionMessageTypeModel,
];

// Append the simulated messages to the end of the conversation.
messages.push(...simulatedAgentMessages);
}
}

// Compute in parallel the token count for each message and the prompt.
const res = await tokenCountForTexts(
[prompt, ...getTextRepresentationFromMessages(messages)],
Expand Down Expand Up @@ -402,43 +371,3 @@ export async function renderConversationForModelJIT({
tokensUsed,
});
}

function listConversationFiles({
conversation,
}: {
conversation: ConversationType;
}) {
const fileAttachments: string[] = [];
for (const m of conversation.content.flat(1)) {
if (isContentFragmentType(m)) {
if (!m.fileId) {
continue;
}
fileAttachments.push(
`<file id="${m.fileId}" name="${m.title}" type="${m.contentType}" />`
);
} else if (isAgentMessageType(m)) {
for (const a of m.actions) {
if (isTablesQueryActionType(a)) {
const attachment = getTablesQueryResultsFileAttachment({
resultsFileId: a.resultsFileId,
resultsFileSnippet: a.resultsFileSnippet,
output: a.output,
includeSnippet: false,
});
if (attachment) {
fileAttachments.push(attachment);
}
}
}
}
}
let filesAsXML = "<files>\n";

if (fileAttachments.length > 0) {
filesAsXML += fileAttachments.join("\n");
}
filesAsXML += "\n</files>";

return { filesAsXML, hasFiles: fileAttachments.length > 0 };
}
Loading
Loading