diff --git a/front/components/assistant/conversation/actions/VisualizationActionIframe.tsx b/front/components/assistant/conversation/actions/VisualizationActionIframe.tsx index 13c1a9ebf09b..4b1a07e295f6 100644 --- a/front/components/assistant/conversation/actions/VisualizationActionIframe.tsx +++ b/front/components/assistant/conversation/actions/VisualizationActionIframe.tsx @@ -43,6 +43,16 @@ const sendResponseToIframe = ( ); }; +const getExtensionFromBlob = (blob: Blob): string => { + const mimeToExt: Record = { + "image/png": "png", + "image/jpeg": "jpg", + "text/csv": "csv", + }; + + return mimeToExt[blob.type] || "txt"; // Default to 'txt' if mime type is unknown. +}; + // Custom hook to encapsulate the logic for handling visualization messages. function useVisualizationDataHandler({ visualization, @@ -77,12 +87,19 @@ function useVisualizationDataHandler({ [workspaceId] ); - const downloadScreenshotFromBlob = useCallback( - (blob: Blob) => { + const downloadFileFromBlob = useCallback( + (blob: Blob, filename?: string) => { const url = URL.createObjectURL(blob); const link = document.createElement("a"); link.href = url; - link.download = `visualization-${visualization.identifier}.png`; + + if (filename) { + link.download = filename; + } else { + const ext = getExtensionFromBlob(blob); + link.download = `visualization-${visualization.identifier}.${ext}`; + } + link.click(); URL.revokeObjectURL(url); }, @@ -126,8 +143,8 @@ function useVisualizationDataHandler({ setErrorMessage(data.params.errorMessage); break; - case "sendScreenshotBlob": - downloadScreenshotFromBlob(data.params.blob); + case "downloadFileRequest": + downloadFileFromBlob(data.params.blob, data.params.filename); break; default: @@ -139,7 +156,7 @@ function useVisualizationDataHandler({ return () => window.removeEventListener("message", listener); }, [ code, - downloadScreenshotFromBlob, + downloadFileFromBlob, getFileBlob, setContentHeight, setErrorMessage, diff --git a/front/lib/api/assistant/visualization.ts b/front/lib/api/assistant/visualization.ts index aa0e9c3a32a6..e35645427be3 100644 --- a/front/lib/api/assistant/visualization.ts +++ b/front/lib/api/assistant/visualization.ts @@ -153,6 +153,7 @@ Guidelines using the :::visualization tag: - Files from the conversation can be accessed using the \`useFile()\` hook. - Once/if the file is available, \`useFile()\` will return a non-null \`File\` object. The \`File\` object is a browser File object. Examples of using \`useFile\` are available below. - Always use \`papaparse\` to parse CSV files. + - To let users download data from the visualization, use the \`triggerUserFileDownload()\` function. - Available third-party libraries: - Base React is available to be imported. In order to use hooks, they have to be imported at the top of the script, e.g. \`import { useState } from "react"\` - The recharts charting library is available to be imported, e.g. \`import { LineChart, XAxis, ... } from "recharts"\` & \` ...\`. @@ -167,6 +168,7 @@ Guidelines using the :::visualization tag: Example using the \`useFile\` hook: \`\`\` +// Reading files from conversation import { useFile } from "@dust/react-hooks"; const file = useFile(fileId); if (file) { @@ -176,9 +178,23 @@ if (file) { // for binary file: const arrayBuffer = await file.arrayBuffer(); } -\`\`\` \`fileId\` can be extracted from the \`\` tags in the conversation history. +\`\`\` + +Example using the \`triggerUserFileDownload\` hook: + +\`\`\` +// Adding download capability +import { triggerUserFileDownload } from "@dust/react-hooks"; + + +\`\`\` General example of a visualization component: diff --git a/types/src/front/assistant/visualization.ts b/types/src/front/assistant/visualization.ts index a2f23675de62..5ae2ab0065d0 100644 --- a/types/src/front/assistant/visualization.ts +++ b/types/src/front/assistant/visualization.ts @@ -16,8 +16,9 @@ interface SetContentHeightParams { height: number; } -interface SendScreenshotBlobParams { +interface DownloadFileRequestParams { blob: Blob; + filename?: string; } interface setErrorMessageParams { @@ -30,7 +31,7 @@ export type VisualizationRPCRequestMap = { getCodeToExecute: null; setContentHeight: SetContentHeightParams; setErrorMessage: setErrorMessageParams; - sendScreenshotBlob: SendScreenshotBlobParams; + downloadFileRequest: DownloadFileRequestParams; }; // Derive the command type from the keys of the request map @@ -56,7 +57,7 @@ export const validCommands: VisualizationRPCCommand[] = [ export interface CommandResultMap { getCodeToExecute: { code: string }; getFile: { fileBlob: Blob | null }; - sendScreenshotBlob: { blob: Blob }; + downloadFileRequest: { blob: Blob; filename?: string }; setContentHeight: void; setErrorMessage: void; } @@ -147,11 +148,11 @@ export function isSetErrorMessageRequest( ); } -export function isSendScreenshotBlobRequest( +export function isDownloadFileRequest( value: unknown ): value is VisualizationRPCRequest & { - command: "sendScreenshotBlob"; - params: SendScreenshotBlobParams; + command: "downloadFileRequest"; + params: DownloadFileRequestParams; } { if (typeof value !== "object" || value === null) { return false; @@ -160,12 +161,12 @@ export function isSendScreenshotBlobRequest( const v = value as Partial; return ( - v.command === "sendScreenshotBlob" && + v.command === "downloadFileRequest" && typeof v.identifier === "string" && typeof v.messageUniqueId === "string" && typeof v.params === "object" && v.params !== null && - (v.params as SendScreenshotBlobParams).blob instanceof Blob + (v.params as DownloadFileRequestParams).blob instanceof Blob ); } @@ -179,7 +180,7 @@ export function isVisualizationRPCRequest( return ( isGetCodeToExecuteRequest(value) || isGetFileRequest(value) || - isSendScreenshotBlobRequest(value) || + isDownloadFileRequest(value) || isSetContentHeightRequest(value) || isSetErrorMessageRequest(value) ); diff --git a/viz/app/components/VisualizationWrapper.tsx b/viz/app/components/VisualizationWrapper.tsx index 51d06a5656d7..57f6d5a04412 100644 --- a/viz/app/components/VisualizationWrapper.tsx +++ b/viz/app/components/VisualizationWrapper.tsx @@ -76,9 +76,9 @@ export function useVisualizationAPI( [sendCrossDocumentMessage] ); - const sendScreenshotBlob = useCallback( - async (blob: Blob) => { - await sendCrossDocumentMessage("sendScreenshotBlob", { blob }); + const downloadFile = useCallback( + async (blob: Blob, filename?: string) => { + await sendCrossDocumentMessage("downloadFileRequest", { blob, filename }); }, [sendCrossDocumentMessage] ); @@ -88,7 +88,7 @@ export function useVisualizationAPI( fetchCode, fetchFile, sendHeightToParent, - sendScreenshotBlob, + downloadFile, }; } @@ -116,6 +116,24 @@ const useFile = ( return file; }; +function useDownloadFileCallback( + downloadFile: (blob: Blob, filename?: string) => Promise +) { + return useCallback( + async ({ + content, + filename, + }: { + content: string | Blob; + filename?: string; + }) => { + const blob = typeof content === "string" ? new Blob([content]) : content; + await downloadFile(blob, filename); + }, + [downloadFile] + ); +} + interface RunnerParams { code: string; scope: Record; @@ -146,7 +164,7 @@ export function VisualizationWrapperWithErrorBoundary({ }); }} > - + ); } @@ -155,20 +173,18 @@ export function VisualizationWrapperWithErrorBoundary({ // It gets the generated code via message passing to the host window. export function VisualizationWrapper({ api, + identifier, }: { api: ReturnType; + identifier: string; }) { const [runnerParams, setRunnerParams] = useState(null); const [errored, setErrorMessage] = useState(null); - const { - fetchCode, - fetchFile, - error, - sendHeightToParent, - sendScreenshotBlob, - } = api; + const { fetchCode, fetchFile, error, sendHeightToParent, downloadFile } = api; + + const memoizedDownloadFile = useDownloadFileCallback(downloadFile); useEffect(() => { const loadCode = async () => { @@ -190,6 +206,7 @@ export function VisualizationWrapper({ papaparse: papaparseAll, "@dust/react-hooks": { useFile: (fileId: string) => useFile(fileId, fetchFile), + triggerUserFileDownload: memoizedDownloadFile, }, }, }), @@ -216,7 +233,7 @@ export function VisualizationWrapper({ onResize: sendHeightToParent, }); - const handleDownload = useCallback(async () => { + const handleScreenshotDownload = useCallback(async () => { if (ref.current) { try { const blob = await toBlob(ref.current, { @@ -224,13 +241,13 @@ export function VisualizationWrapper({ skipFonts: true, }); if (blob) { - await sendScreenshotBlob(blob); + await downloadFile(blob, `visualization-${identifier}.png`); } } catch (err) { console.error("Failed to convert to Blob", err); } } - }, [ref, sendScreenshotBlob]); + }, [ref, downloadFile]); useEffect(() => { if (error) { @@ -250,7 +267,7 @@ export function VisualizationWrapper({ return (