diff --git a/front/Dockerfile b/front/Dockerfile index 9934dceb7709..fc472b9fd3b9 100644 --- a/front/Dockerfile +++ b/front/Dockerfile @@ -16,7 +16,10 @@ RUN npm ci COPY /front . ARG COMMIT_HASH -ENV NEXT_PUBLIC_COMMIT_HASH=${COMMIT_HASH} +ARG NEXT_PUBLIC_VIZ_URL + +ENV NEXT_PUBLIC_COMMIT_HASH=$COMMIT_HASH +ENV NEXT_PUBLIC_VIZ_URL=$NEXT_PUBLIC_VIZ_URL # fake database URIs are needed because Sequelize will throw if the `url` parameter # is undefined, and `next build` imports the `models.ts` file while "Collecting page data" diff --git a/front/components/assistant/conversation/AgentMessage.tsx b/front/components/assistant/conversation/AgentMessage.tsx index 244f3ce4c039..66048b422646 100644 --- a/front/components/assistant/conversation/AgentMessage.tsx +++ b/front/components/assistant/conversation/AgentMessage.tsx @@ -1,6 +1,5 @@ import { ArrowPathIcon, - BracesIcon, Button, Chip, Citation, @@ -22,7 +21,6 @@ import type { LightAgentConfigurationType, RetrievalActionType, UserType, - VisualizationActionType, WebsearchActionType, WebsearchResultType, WorkspaceType, @@ -39,6 +37,7 @@ import { isWebsearchActionType, removeNulls, } from "@dust-tt/types"; +import assert from "assert"; import Link from "next/link"; import { useRouter } from "next/router"; import { useCallback, useContext, useEffect, useRef, useState } from "react"; @@ -46,6 +45,7 @@ import { useCallback, useContext, useEffect, useRef, useState } from "react"; import { makeDocumentCitations } from "@app/components/actions/retrieval/utils"; import { AssistantDetailsDropdownMenu } from "@app/components/assistant/AssistantDetailsDropdownMenu"; import { AgentMessageActions } from "@app/components/assistant/conversation/actions/AgentMessageActions"; +import { VisualizationActionIframe } from "@app/components/assistant/conversation/actions/VisualizationActionIframe"; import type { MessageSizeType } from "@app/components/assistant/conversation/ConversationMessage"; import { ConversationMessage } from "@app/components/assistant/conversation/ConversationMessage"; import { GenerationContext } from "@app/components/assistant/conversation/GenerationContextProvider"; @@ -88,19 +88,9 @@ export function AgentMessage({ const [streamedAgentMessage, setStreamedAgentMessage] = useState(message); - const defaultVisualizations: VisualizationActionType[] = - message.actions.filter((a): a is VisualizationActionType => - isVisualizationActionType(a) - ) as VisualizationActionType[]; - const [streamedVisualizations, setStreamedVisualizations] = useState< { actionId: number; visualization: string }[] - >( - defaultVisualizations.map((v) => ({ - actionId: v.id, - visualization: v.generation ?? "", - })) - ); + >([]); const [isRetryHandlerProcessing, setIsRetryHandlerProcessing] = useState(false); @@ -213,6 +203,7 @@ export function AgentMessage({ ...event.message, }; }); + setStreamedVisualizations([]); break; } @@ -535,24 +526,6 @@ export function AgentMessage({ ) : null} - {/* This is where we will we plug Aric's work to render the graph in an iframe. */} - {streamedVisualizations.map(({ actionId, visualization }) => { - return ( -
-
- -
Visualization
-
-
- -
-
- ); - })} - {agentMessage.content !== null && (
{lastTokenClassification !== "chain_of_thought" && @@ -583,6 +556,27 @@ export function AgentMessage({ )}
)} + <> + {agentMessage.actions + .filter((a) => isVisualizationActionType(a)) + .map((a, i) => { + const streamingViz = streamedVisualizations.find( + (sv) => sv.actionId === a.id + ); + assert(isVisualizationActionType(a)); + return ( + retryHandler(agentMessage)} + owner={owner} + streamedCode={streamingViz?.visualization || null} + /> + ); + })} + {agentMessage.status === "cancelled" && ( { + target.postMessage( + { + command: "answer", + messageUniqueId: request.messageUniqueId, + actionId: request.actionId, + result: response, + }, + // TODO(2024-07-24 flav) Restrict origin. + { targetOrigin: "*" } + ); +}; + +// Custom hook to encapsulate the logic for handling visualization messages. +function useVisualizationDataHandler( + action: VisualizationActionType, + workspaceId: string, + onRetry: () => void +) { + const getFile = useCallback( + async (fileId: string) => { + const response = await fetch( + `/api/w/${workspaceId}/files/${fileId}?action=view` + ); + if (!response.ok) { + // TODO(2024-07-24 flav) Propagate the error to the iframe. + throw new Error(`Failed to fetch file ${fileId}`); + } + + const resBuffer = await response.arrayBuffer(); + return new File([resBuffer], fileId, { + type: response.headers.get("Content-Type") || undefined, + }); + }, + [workspaceId] + ); + + useEffect(() => { + const listener = async (event: MessageEvent) => { + const { data } = event; + + // TODO(2024-07-24 flav) Check origin. + if ( + !isVisualizationRPCRequest(data) || + !event.source || + data.actionId !== action.id + ) { + return; + } + + if (isGetFileRequest(data)) { + const file = await getFile(data.params.fileId); + + sendResponseToIframe(data, { file }, event.source); + } else if (isGetCodeToExecuteRequest(data)) { + const code = action.generation; + + sendResponseToIframe(data, { code }, event.source); + } else { + // TODO(2024-07-24 flav) Pass the error message to the host window. + onRetry(); + } + + // TODO: Types above are not accurate, as it can pass the first check but won't enter any if block. + }; + + window.addEventListener("message", listener); + return () => window.removeEventListener("message", listener); + }, [action.generation, action.id, onRetry, getFile]); + + return { getFile }; +} + +export function VisualizationActionIframe({ + owner, + action, + isStreaming, + streamedCode, + onRetry, +}: { + conversationId: string; + owner: WorkspaceType; + action: VisualizationActionType; + streamedCode: string | null; + isStreaming: boolean; + onRetry: () => void; +}) { + const [activeTab, setActiveTab] = useState<"code" | "runtime">("code"); + const [tabManuallyChanged, setTabManuallyChanged] = useState(false); + + const workspaceId = owner.sId; + + useVisualizationDataHandler(action, workspaceId, onRetry); + + useEffect(() => { + if (activeTab === "code" && action.generation && !tabManuallyChanged) { + setActiveTab("runtime"); + setTabManuallyChanged(true); + } + }, [action.generation, activeTab, tabManuallyChanged]); + + let extractedCode: string | null = null; + + if (action.generation) { + extractedCode = visualizationExtractCodeNonStreaming(action.generation); + } else { + extractedCode = visualizationExtractCodeStreaming(streamedCode || ""); + } + + return ( + <> + { + event.preventDefault(); + setActiveTab(tabId as "code" | "runtime"); + }} + /> + {activeTab === "code" && extractedCode && extractedCode.length > 0 && ( + + )} + {activeTab === "runtime" && ( +