diff --git a/interactive-computational-graph/src/components/GraphContainer.tsx b/interactive-computational-graph/src/components/GraphContainer.tsx index c89a70c..53653d9 100644 --- a/interactive-computational-graph/src/components/GraphContainer.tsx +++ b/interactive-computational-graph/src/components/GraphContainer.tsx @@ -15,8 +15,10 @@ import { type Node, type OnConnect, type OnEdgesChange, + type OnInit, type OnNodesChange, type OnSelectionChangeParams, + type ReactFlowInstance, type XYPosition, } from "reactflow"; import { TITLE_HEIGHT } from "../constants"; @@ -69,6 +71,7 @@ import type ExplainDerivativeData from "../features/ExplainDerivativeData"; import type FeatureNodeType from "../features/FeatureNodeType"; import type FeatureOperation from "../features/FeatureOperation"; import { findFeatureOperation } from "../features/FeatureOperationFinder"; +import type NodeData from "../features/NodeData"; import NodeNameBuilder from "../features/NodeNameBuilder"; import { addReactFlowNode, @@ -310,6 +313,8 @@ const GraphContainer: FunctionComponent = ({ >([]); // React Flow states + const [reactFlowInstance, setReactFlowInstance] = + useState(null); const [reactFlowNodes, setReactFlowNodes] = useState([]); const [reactFlowEdges, setReactFlowEdges] = useState([]); const [lastSelectedNodeId, setLastSelectedNodeId] = useState( @@ -522,17 +527,23 @@ const GraphContainer: FunctionComponent = ({ }, [featureOperations]); const handleSave = useCallback((): GraphContainerState => { + if (reactFlowInstance === null) { + throw new Error("React flow instance should not be null"); + } + const coreGraphAdapterState = coreGraphAdapter.save(); return { coreGraphAdapterState, isReverseMode, derivativeTarget, featureOperations: saveFeatureOperations(), + reactFlowState: reactFlowInstance.toObject(), }; }, [ coreGraphAdapter, derivativeTarget, isReverseMode, + reactFlowInstance, saveFeatureOperations, ]); @@ -559,6 +570,43 @@ const GraphContainer: FunctionComponent = ({ [], ); + const loadReactFlowNode = useCallback( + (nodes: Node[]) => { + return nodes.map((node) => { + const data = node.data as NodeData; + // Set the new data to notify React Flow about the change + const newData: NodeData = { + ...data, + // Set the callbacks because the json file doesn't have these + onNameChange: handleNameChange, + onInputChange: handleInputChange, + onBodyClick: handleBodyClick, + onDerivativeClick: handleDerivativeClick, + }; + + node.data = newData; + return node; + }); + }, + [ + handleBodyClick, + handleDerivativeClick, + handleInputChange, + handleNameChange, + ], + ); + + const loadReactFlow = useCallback( + (reactFlowState: any) => { + // Reference: https://reactflow.dev/examples/interaction/save-and-restore + const { x = 0, y = 0, zoom = 1 } = reactFlowState.viewport; + setReactFlowNodes(loadReactFlowNode(reactFlowState.nodes)); + setReactFlowEdges(reactFlowState.edges); + reactFlowInstance?.setViewport({ x, y, zoom }); + }, + [loadReactFlowNode, reactFlowInstance], + ); + const handleLoad = useCallback( (graphContainerState: GraphContainerState) => { const loadedFeatureOperations = loadFeatureOperations( @@ -572,8 +620,9 @@ const GraphContainer: FunctionComponent = ({ setReverseMode(graphContainerState.isReverseMode); setDerivativeTarget(graphContainerState.derivativeTarget); setFeatureOperations(loadedFeatureOperations); + loadReactFlow(graphContainerState.reactFlowState); }, - [coreGraphAdapter, loadFeatureOperations], + [coreGraphAdapter, loadFeatureOperations, loadReactFlow], ); const handleReverseModeChange = useCallback( @@ -596,6 +645,10 @@ const GraphContainer: FunctionComponent = ({ [coreGraphAdapter], ); + const handleReactFlowInit: OnInit = useCallback((reactFlowInstance) => { + setReactFlowInstance(reactFlowInstance); + }, []); + const handleNodesChange: OnNodesChange = useCallback( (changes) => { coreGraphAdapter.changeNodes(changes); @@ -832,6 +885,7 @@ const GraphContainer: FunctionComponent = ({ { ], "helpText": "Add two numbers, i.e., a + b" } - ] + ], + "reactFlowState": { + "nodes": [ + { + "width": 182, + "height": 96, + "id": "0", + "type": "custom", + "data": { + "name": "c_1", + "operationData": null, + "featureNodeType": { + "nodeType": "CONSTANT" + }, + "inputItems": [ + { + "id": "value", + "label": "=", + "showHandle": false, + "showInputField": true, + "value": "1" + } + ], + "outputItems": [], + "isDarkMode": false, + "isHighlighted": false + }, + "dragHandle": ".drag-handle", + "selected": false, + "position": { + "x": 10, + "y": 10 + }, + "positionAbsolute": { + "x": 10, + "y": 10 + } + }, + { + "width": 209, + "height": 153, + "id": "1", + "type": "custom", + "data": { + "name": "v_1", + "operationData": null, + "featureNodeType": { + "nodeType": "VARIABLE" + }, + "inputItems": [ + { + "id": "value", + "label": "=", + "showHandle": false, + "showInputField": true, + "value": "2" + } + ], + "outputItems": [ + { + "type": "DERIVATIVE", + "labelParts": [ + { + "type": "latexLink", + "id": "derivative", + "latex": "\\\\displaystyle \\\\frac{\\\\partial{?}}{\\\\partial{v_1}}", + "href": "1" + }, + { + "type": "latex", + "id": "equal", + "latex": "=" + } + ], + "value": "0" + } + ], + "isDarkMode": false, + "isHighlighted": false + }, + "dragHandle": ".drag-handle", + "selected": false, + "position": { + "x": 10, + "y": 150 + }, + "positionAbsolute": { + "x": 10, + "y": 150 + }, + "dragging": false + }, + { + "width": 209, + "height": 233, + "id": "2", + "type": "custom", + "data": { + "name": "a_1", + "operationData": { + "text": "Add", + "helpText": "Add two numbers, i.e., a + b" + }, + "featureNodeType": { + "nodeType": "OPERATION", + "operationId": "add" + }, + "inputItems": [ + { + "id": "a", + "label": "a", + "showHandle": true, + "showInputField": false, + "value": "0" + }, + { + "id": "b", + "label": "b", + "showHandle": true, + "showInputField": false, + "value": "0" + } + ], + "outputItems": [ + { + "type": "VALUE", + "labelParts": [ + { + "type": "latex", + "id": "value", + "latex": "=" + } + ], + "value": "3" + }, + { + "type": "DERIVATIVE", + "labelParts": [ + { + "type": "latexLink", + "id": "derivative", + "latex": "\\\\displaystyle \\\\frac{\\\\partial{?}}{\\\\partial{a_1}}", + "href": "2" + }, + { + "type": "latex", + "id": "equal", + "latex": "=" + } + ], + "value": "0" + } + ], + "isDarkMode": false, + "isHighlighted": false + }, + "dragHandle": ".drag-handle", + "selected": false, + "position": { + "x": 300, + "y": 0 + }, + "positionAbsolute": { + "x": 300, + "y": 0 + }, + "dragging": false + } + ], + "edges": [ + { + "source": "0", + "sourceHandle": "output", + "target": "2", + "targetHandle": "a", + "id": "reactflow__edge-0output-2a", + "animated": false + }, + { + "source": "1", + "sourceHandle": "output", + "target": "2", + "targetHandle": "b", + "id": "reactflow__edge-1output-2b", + "animated": false + } + ], + "viewport": { + "x": 150, + "y": 100, + "zoom": 1 + } + } } `; const file = new File([contents], "graph.json", { type: "text/plain" }); @@ -125,6 +317,8 @@ test("should trigger the event when clicking the load button", async () => { }); test("should show the error message when the file contents is not valid", async () => { + const mockConsole = jest.spyOn(console, "error").mockImplementation(); + const handleSave = jest.fn(); const handleLoad = jest.fn(); render(); @@ -139,4 +333,8 @@ test("should show the error message when the file contents is not valid", async expect(handleLoad).not.toHaveBeenCalled(); const successAlert = screen.queryByText(successMessage); expect(successAlert).not.toBeInTheDocument(); + + expect(mockConsole).toHaveBeenCalled(); + + jest.restoreAllMocks(); // restores the spy created with spyOn }); diff --git a/interactive-computational-graph/src/reactflow/InputItems.tsx b/interactive-computational-graph/src/reactflow/InputItems.tsx index e4c19cb..37e0906 100644 --- a/interactive-computational-graph/src/reactflow/InputItems.tsx +++ b/interactive-computational-graph/src/reactflow/InputItems.tsx @@ -119,7 +119,7 @@ const InputItems: FunctionComponent = ({ void; @@ -40,6 +41,7 @@ interface ReactFlowGraphProps { const ReactFlowGraph: FunctionComponent = ({ nodes, edges, + onInit, onNodesChange, onEdgesChange, onSelectionChange, @@ -62,9 +64,11 @@ const ReactFlowGraph: FunctionComponent = ({ const handleInit: OnInit = useCallback( (reactFlowInstance: ReactFlowInstance) => { + onInit(reactFlowInstance); + setReactFlowInstance(reactFlowInstance); }, - [], + [onInit], ); const handleDragOver = useCallback((event: DragEvent) => { diff --git a/interactive-computational-graph/src/states/GraphContainerState.ts b/interactive-computational-graph/src/states/GraphContainerState.ts index 4c09595..e859dab 100644 --- a/interactive-computational-graph/src/states/GraphContainerState.ts +++ b/interactive-computational-graph/src/states/GraphContainerState.ts @@ -19,7 +19,7 @@ interface GraphContainerState { // operationIdsAddedAtLeastOnce: Set; // React Flow states - // reactFlowState: object; + reactFlowState: object; } export default GraphContainerState;