From 68386d4bf07a891ea5242fb8608fb3c72fcca9ef Mon Sep 17 00:00:00 2001 From: Flavien David Date: Thu, 25 Jul 2024 19:07:23 +0200 Subject: [PATCH] Enforce origin on cross-document message in Visualization (#6523) * Enforce origin on cross-document message in Visualization * :scissors: * Add viz-secrets --- k8s/deployments/viz-deployment.yaml | 2 + viz/.eslintrc.json | 3 +- viz/app/components/ErrorBoundary.tsx | 92 +++++++ viz/app/components/VisualizationWrapper.tsx | 282 ++++++++------------ viz/app/content/page.tsx | 18 +- 5 files changed, 218 insertions(+), 179 deletions(-) create mode 100644 viz/app/components/ErrorBoundary.tsx diff --git a/k8s/deployments/viz-deployment.yaml b/k8s/deployments/viz-deployment.yaml index 82563c74c6a6..f5c72f2b3e76 100644 --- a/k8s/deployments/viz-deployment.yaml +++ b/k8s/deployments/viz-deployment.yaml @@ -33,6 +33,8 @@ spec: envFrom: - configMapRef: name: viz-config + - secretRef: + name: viz-secrets env: - name: DD_AGENT_HOST diff --git a/viz/.eslintrc.json b/viz/.eslintrc.json index 3faa07b9a705..de1e7e4bb822 100644 --- a/viz/.eslintrc.json +++ b/viz/.eslintrc.json @@ -6,7 +6,8 @@ ], "rules": { "eqeqeq": "error", - "no-unused-vars": "error", + "no-unused-vars": "off", + "@typescript-eslint/no-unused-vars": "error", "@typescript-eslint/no-explicit-any": "error" } } diff --git a/viz/app/components/ErrorBoundary.tsx b/viz/app/components/ErrorBoundary.tsx new file mode 100644 index 000000000000..5bd8392f853b --- /dev/null +++ b/viz/app/components/ErrorBoundary.tsx @@ -0,0 +1,92 @@ +import { Button, ErrorMessage } from "@viz/app/components/Components"; +import React, { useState } from "react"; + +interface ErrorBoundaryState { + error: unknown; + hasError: boolean; +} + +interface ErrorBoundaryProps { + children: React.ReactNode; + errorMessage: string; + onRetryClick: (errorMessage: string) => void; +} + +export class ErrorBoundary extends React.Component< + ErrorBoundaryProps, + ErrorBoundaryState +> { + constructor(props: ErrorBoundaryProps) { + super(props); + this.state = { hasError: false, error: null }; + } + + static getDerivedStateFromError() { + // Update state so the next render will show the fallback UI. + return { hasError: true }; + } + + componentDidCatch(error: unknown) { + this.setState({ hasError: true, error }); + } + + render() { + if (this.state.hasError) { + let error: Error; + if (this.state.error instanceof Error) { + error = this.state.error; + } else { + error = new Error("Unknown error."); + } + + return ( + + ); + } + + return <>{this.props.children}; + } +} + +// This is the component to render when an error occurs. +export function RenderError({ + error, + message, + onRetryClick, +}: { + error: Error; + message: string; + onRetryClick: (errorMessage: string) => void; +}) { + const [showDetails, setShowDetails] = useState(false); + + return ( +
+ + <> + {message} +
+ + {showDetails && ( +
+ Error message: {error.message} +
+ )} +
+ +
+
+
+
+ ); +} diff --git a/viz/app/components/VisualizationWrapper.tsx b/viz/app/components/VisualizationWrapper.tsx index 673a6a671e3d..8f1a5f73e582 100644 --- a/viz/app/components/VisualizationWrapper.tsx +++ b/viz/app/components/VisualizationWrapper.tsx @@ -5,25 +5,24 @@ import type { VisualizationRPCCommand, VisualizationRPCRequestMap, } from "@dust-tt/types"; -import { Button, ErrorMessage, Spinner } from "@viz/app/components/Components"; +import { Spinner } from "@viz/app/components/Components"; import * as papaparseAll from "papaparse"; import * as reactAll from "react"; -import React, { useCallback } from "react"; +import React, { useCallback, useMemo } from "react"; import { useEffect, useState } from "react"; import { importCode, Runner } from "react-runner"; import * as rechartsAll from "recharts"; import { useResizeDetector } from "react-resize-detector"; +import { ErrorBoundary } from "@viz/app/components/ErrorBoundary"; -export function useVisualizationAPI(actionId: number) { +export function useVisualizationAPI( + sendCrossDocumentMessage: ReturnType +) { const [error, setError] = useState(null); const fetchCode = useCallback(async (): Promise => { - const getCode = makeIframeMessagePassingFunction( - "getCodeToExecute", - actionId - ); try { - const result = await getCode(null); + const result = await sendCrossDocumentMessage("getCodeToExecute", null); const { code } = result; if (!code) { @@ -42,12 +41,11 @@ export function useVisualizationAPI(actionId: number) { return null; } - }, [actionId]); + }, [sendCrossDocumentMessage]); const fetchFile = useCallback( async (fileId: string): Promise => { - const getFile = makeIframeMessagePassingFunction("getFile", actionId); - const res = await getFile({ fileId }); + const res = await sendCrossDocumentMessage("getFile", { fileId }); const { fileBlob: blob } = res; @@ -60,59 +58,40 @@ export function useVisualizationAPI(actionId: number) { return file; }, - [actionId] + [sendCrossDocumentMessage] ); // This retry function sends a command to the host window requesting a retry of a previous // operation, typically if the generated code fails. const retry = useCallback( async (errorMessage: string): Promise => { - const sendRetry = makeIframeMessagePassingFunction("retry", actionId); - await sendRetry({ errorMessage }); + await sendCrossDocumentMessage("retry", { errorMessage }); }, - [actionId] + [sendCrossDocumentMessage] ); - return { fetchCode, fetchFile, error, retry }; -} + const sendHeightToParent = useCallback( + async ({ height }: { height: number | null }) => { + if (height === null) { + return; + } -// This function creates a function that sends a command to the host window with templated Input and Output types. -function makeIframeMessagePassingFunction( - methodName: T, - actionId: number -) { - return (params: VisualizationRPCRequestMap[T]) => { - return new Promise((resolve, reject) => { - const messageUniqueId = Math.random().toString(); - const listener = (event: MessageEvent) => { - if (event.data.messageUniqueId === messageUniqueId) { - if (event.data.error) { - reject(event.data.error); - } else { - resolve(event.data.result); - } - window.removeEventListener("message", listener); - } - }; - window.addEventListener("message", listener); - window.top?.postMessage( - { - command: methodName, - messageUniqueId, - actionId, - params, - }, - "*" - ); - }); - }; + await sendCrossDocumentMessage("setContentHeight", { + height, + }); + }, + [sendCrossDocumentMessage] + ); + + return { fetchCode, fetchFile, error, retry, sendHeightToParent }; } -const useFile = (actionId: number, fileId: string) => { +const useFile = ( + fileId: string, + fetchFile: (fileId: string) => Promise +) => { const [file, setFile] = useState(null); - const { fetchFile } = useVisualizationAPI(actionId); // Adjust the import based on your project structure - useEffect(() => { const fetch = async () => { try { @@ -131,20 +110,34 @@ const useFile = (actionId: number, fileId: string) => { return file; }; +interface RunnerParams { + code: string; + scope: Record; +} + // This component renders the generated code. // It gets the generated code via message passing to the host window. -export function VisualizationWrapper({ actionId }: { actionId: string }) { - type RunnerParams = { - code: string; - scope: Record; - }; - +export function VisualizationWrapper({ + actionId, + allowedVisualizationOrigin, +}: { + actionId: number; + allowedVisualizationOrigin: string | undefined; +}) { const [runnerParams, setRunnerParams] = useState(null); const [errored, setErrored] = useState(null); - const actionIdParsed = parseInt(actionId, 10); + const sendCrossDocumentMessage = useMemo( + () => + makeSendCrossDocumentMessage({ + actionId, + allowedVisualizationOrigin, + }), + [actionId, allowedVisualizationOrigin] + ); - const { fetchCode, error, retry } = useVisualizationAPI(actionIdParsed); + const { fetchCode, fetchFile, error, retry, sendHeightToParent } = + useVisualizationAPI(sendCrossDocumentMessage); useEffect(() => { const loadCode = async () => { @@ -165,8 +158,7 @@ export function VisualizationWrapper({ actionId }: { actionId: string }) { react: reactAll, papaparse: papaparseAll, "@dust/react-hooks": { - useFile: (fileId: string) => - useFile(actionIdParsed, fileId), + useFile: (fileId: string) => useFile(fileId, fetchFile), }, }, }), @@ -180,21 +172,7 @@ export function VisualizationWrapper({ actionId }: { actionId: string }) { }; loadCode(); - }, [fetchCode, actionIdParsed]); - - const sendHeightToParent = useCallback( - ({ height }: { height: number | null }) => { - if (height === null) { - return; - } - const sendHeight = makeIframeMessagePassingFunction<"setContentHeight">( - "setContentHeight", - actionIdParsed - ); - sendHeight({ height }); - }, - [actionIdParsed] - ); + }, [fetchCode, fetchFile]); const { ref } = useResizeDetector({ handleHeight: true, @@ -210,12 +188,8 @@ export function VisualizationWrapper({ actionId }: { actionId: string }) { }, [error]); if (errored) { - return ( - retry(errored.message)} - /> - ); + // Throw the error to the ErrorBoundary. + throw errored; } if (!runnerParams) { @@ -223,103 +197,67 @@ export function VisualizationWrapper({ actionId }: { actionId: string }) { } return ( -
- { - if (error) { - setErrored(error); - } - }} - /> -
+ +
+ { + if (error) { + setErrored(error); + } + }} + /> +
+
); } -// This is the component to render when an error occurs. -function VisualizationError({ - error, - retry, +export function makeSendCrossDocumentMessage({ + actionId, + allowedVisualizationOrigin, }: { - error: Error; - retry: () => void; + actionId: number; + allowedVisualizationOrigin: string | undefined; }) { - const [showDetails, setShowDetails] = useState(false); - - return ( -
- - <> - We encountered an error while running the code generated above. You - can try again by clicking the button below. -
- - {showDetails && ( -
- Error message: {error.message} -
- )} -
- -
-
-
-
- ); -} - -type ErrorBoundaryProps = { - actionId: string; -}; - -type ErrorBoundaryState = { - hasError: boolean; - error: unknown; -}; - -// This is the error boundary component that wraps the VisualizationWrapper component. -// It needs to be a class component for error handling to work. -export class VisualizationWrapperWithErrorHandling extends React.Component< - ErrorBoundaryProps, - ErrorBoundaryState -> { - constructor(props: ErrorBoundaryProps) { - super(props); - this.state = { hasError: false, error: null }; - } - - static getDerivedStateFromError() { - // Update state so the next render will show the fallback UI. - return { hasError: true }; - } - - componentDidCatch(error: unknown) { - this.setState({ hasError: true, error }); - } + return ( + command: T, + params: VisualizationRPCRequestMap[T] + ) => { + return new Promise((resolve, reject) => { + const messageUniqueId = Math.random().toString(); + const listener = (event: MessageEvent) => { + if (event.origin !== allowedVisualizationOrigin) { + console.log( + `Ignored message from unauthorized origin: ${event.origin}` + ); - render() { - if (this.state.hasError) { - let error: Error; - if (this.state.error instanceof Error) { - error = this.state.error; - } else { - error = new Error("Unknown error."); - } + // Simply ignore messages from unauthorized origins. + return; + } - const retry = makeIframeMessagePassingFunction( - "retry", - parseInt(this.props.actionId, 10) + if (event.data.messageUniqueId === messageUniqueId) { + if (event.data.error) { + reject(event.data.error); + } else { + resolve(event.data.result); + } + window.removeEventListener("message", listener); + } + }; + window.addEventListener("message", listener); + window.top?.postMessage( + { + command, + messageUniqueId, + actionId, + params, + }, + "*" ); - return retry} />; - } - - return ; - } + }); + }; } diff --git a/viz/app/content/page.tsx b/viz/app/content/page.tsx index e01902cc6aa9..f2a974c81dd8 100644 --- a/viz/app/content/page.tsx +++ b/viz/app/content/page.tsx @@ -1,14 +1,20 @@ -import { VisualizationWrapperWithErrorHandling } from "@viz/app/components/VisualizationWrapper"; +import { VisualizationWrapper } from "@viz/app/components/VisualizationWrapper"; -type IframeProps = { - wId: string; +type RenderVisualizationSearchParams = { aId: string; }; -export default function Iframe({ +const { ALLOWED_VISUALIZATION_ORIGIN } = process.env; + +export default function RenderVisualization({ searchParams, }: { - searchParams: IframeProps; + searchParams: RenderVisualizationSearchParams; }) { - return ; + return ( + + ); }