Skip to content

Commit

Permalink
save/load react flow states
Browse files Browse the repository at this point in the history
  • Loading branch information
sc420 committed Feb 20, 2024
1 parent 6481845 commit ca67bb5
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ import {
type Node,
type OnConnect,
type OnEdgesChange,
type OnInit,
type OnNodesChange,
type OnSelectionChangeParams,
type ReactFlowInstance,
type XYPosition,
} from "reactflow";
import { TITLE_HEIGHT } from "../constants";
Expand Down Expand Up @@ -69,6 +71,7 @@ import type ExplainDerivativeData from "../features/ExplainDerivativeData";
import type FeatureNodeType from "../features/FeatureNodeType";
import type FeatureOperation from "../features/FeatureOperation";
import { findFeatureOperation } from "../features/FeatureOperationFinder";
import type NodeData from "../features/NodeData";
import NodeNameBuilder from "../features/NodeNameBuilder";
import {
addReactFlowNode,
Expand Down Expand Up @@ -310,6 +313,8 @@ const GraphContainer: FunctionComponent<GraphContainerProps> = ({
>([]);

// React Flow states
const [reactFlowInstance, setReactFlowInstance] =
useState<ReactFlowInstance | null>(null);
const [reactFlowNodes, setReactFlowNodes] = useState<Node[]>([]);
const [reactFlowEdges, setReactFlowEdges] = useState<Edge[]>([]);
const [lastSelectedNodeId, setLastSelectedNodeId] = useState<string | null>(
Expand Down Expand Up @@ -522,17 +527,23 @@ const GraphContainer: FunctionComponent<GraphContainerProps> = ({
}, [featureOperations]);

const handleSave = useCallback((): GraphContainerState => {
if (reactFlowInstance === null) {
throw new Error("React flow instance should not be null");
}

const coreGraphAdapterState = coreGraphAdapter.save();
return {
coreGraphAdapterState,
isReverseMode,
derivativeTarget,
featureOperations: saveFeatureOperations(),
reactFlowState: reactFlowInstance.toObject(),
};
}, [
coreGraphAdapter,
derivativeTarget,
isReverseMode,
reactFlowInstance,
saveFeatureOperations,
]);

Expand All @@ -559,6 +570,43 @@ const GraphContainer: FunctionComponent<GraphContainerProps> = ({
[],
);

const loadReactFlowNode = useCallback(
(nodes: Node[]) => {
return nodes.map((node) => {
const data = node.data as NodeData;
// Set the new data to notify React Flow about the change
const newData: NodeData = {
...data,
// Set the callbacks because the json file doesn't have these
onNameChange: handleNameChange,
onInputChange: handleInputChange,
onBodyClick: handleBodyClick,
onDerivativeClick: handleDerivativeClick,
};

node.data = newData;
return node;
});
},
[
handleBodyClick,
handleDerivativeClick,
handleInputChange,
handleNameChange,
],
);

const loadReactFlow = useCallback(
(reactFlowState: any) => {
// Reference: https://reactflow.dev/examples/interaction/save-and-restore
const { x = 0, y = 0, zoom = 1 } = reactFlowState.viewport;
setReactFlowNodes(loadReactFlowNode(reactFlowState.nodes));
setReactFlowEdges(reactFlowState.edges);
reactFlowInstance?.setViewport({ x, y, zoom });
},
[loadReactFlowNode, reactFlowInstance],
);

const handleLoad = useCallback(
(graphContainerState: GraphContainerState) => {
const loadedFeatureOperations = loadFeatureOperations(
Expand All @@ -572,8 +620,9 @@ const GraphContainer: FunctionComponent<GraphContainerProps> = ({
setReverseMode(graphContainerState.isReverseMode);
setDerivativeTarget(graphContainerState.derivativeTarget);
setFeatureOperations(loadedFeatureOperations);
loadReactFlow(graphContainerState.reactFlowState);
},
[coreGraphAdapter, loadFeatureOperations],
[coreGraphAdapter, loadFeatureOperations, loadReactFlow],
);

const handleReverseModeChange = useCallback(
Expand All @@ -596,6 +645,10 @@ const GraphContainer: FunctionComponent<GraphContainerProps> = ({
[coreGraphAdapter],
);

const handleReactFlowInit: OnInit = useCallback((reactFlowInstance) => {
setReactFlowInstance(reactFlowInstance);
}, []);

const handleNodesChange: OnNodesChange = useCallback(
(changes) => {
coreGraphAdapter.changeNodes(changes);
Expand Down Expand Up @@ -832,6 +885,7 @@ const GraphContainer: FunctionComponent<GraphContainerProps> = ({
<ReactFlowGraph
nodes={reactFlowNodes}
edges={reactFlowEdges}
onInit={handleReactFlowInit}
onNodesChange={handleNodesChange}
onEdgesChange={handleEdgesChange}
onSelectionChange={handleSelectionChange}
Expand Down
200 changes: 199 additions & 1 deletion interactive-computational-graph/src/components/SaveLoadPanel.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,199 @@ test("should trigger the event when clicking the load button", async () => {
],
"helpText": "Add two numbers, i.e., a + b"
}
]
],
"reactFlowState": {
"nodes": [
{
"width": 182,
"height": 96,
"id": "0",
"type": "custom",
"data": {
"name": "c_1",
"operationData": null,
"featureNodeType": {
"nodeType": "CONSTANT"
},
"inputItems": [
{
"id": "value",
"label": "=",
"showHandle": false,
"showInputField": true,
"value": "1"
}
],
"outputItems": [],
"isDarkMode": false,
"isHighlighted": false
},
"dragHandle": ".drag-handle",
"selected": false,
"position": {
"x": 10,
"y": 10
},
"positionAbsolute": {
"x": 10,
"y": 10
}
},
{
"width": 209,
"height": 153,
"id": "1",
"type": "custom",
"data": {
"name": "v_1",
"operationData": null,
"featureNodeType": {
"nodeType": "VARIABLE"
},
"inputItems": [
{
"id": "value",
"label": "=",
"showHandle": false,
"showInputField": true,
"value": "2"
}
],
"outputItems": [
{
"type": "DERIVATIVE",
"labelParts": [
{
"type": "latexLink",
"id": "derivative",
"latex": "\\\\displaystyle \\\\frac{\\\\partial{?}}{\\\\partial{v_1}}",
"href": "1"
},
{
"type": "latex",
"id": "equal",
"latex": "="
}
],
"value": "0"
}
],
"isDarkMode": false,
"isHighlighted": false
},
"dragHandle": ".drag-handle",
"selected": false,
"position": {
"x": 10,
"y": 150
},
"positionAbsolute": {
"x": 10,
"y": 150
},
"dragging": false
},
{
"width": 209,
"height": 233,
"id": "2",
"type": "custom",
"data": {
"name": "a_1",
"operationData": {
"text": "Add",
"helpText": "Add two numbers, i.e., a + b"
},
"featureNodeType": {
"nodeType": "OPERATION",
"operationId": "add"
},
"inputItems": [
{
"id": "a",
"label": "a",
"showHandle": true,
"showInputField": false,
"value": "0"
},
{
"id": "b",
"label": "b",
"showHandle": true,
"showInputField": false,
"value": "0"
}
],
"outputItems": [
{
"type": "VALUE",
"labelParts": [
{
"type": "latex",
"id": "value",
"latex": "="
}
],
"value": "3"
},
{
"type": "DERIVATIVE",
"labelParts": [
{
"type": "latexLink",
"id": "derivative",
"latex": "\\\\displaystyle \\\\frac{\\\\partial{?}}{\\\\partial{a_1}}",
"href": "2"
},
{
"type": "latex",
"id": "equal",
"latex": "="
}
],
"value": "0"
}
],
"isDarkMode": false,
"isHighlighted": false
},
"dragHandle": ".drag-handle",
"selected": false,
"position": {
"x": 300,
"y": 0
},
"positionAbsolute": {
"x": 300,
"y": 0
},
"dragging": false
}
],
"edges": [
{
"source": "0",
"sourceHandle": "output",
"target": "2",
"targetHandle": "a",
"id": "reactflow__edge-0output-2a",
"animated": false
},
{
"source": "1",
"sourceHandle": "output",
"target": "2",
"targetHandle": "b",
"id": "reactflow__edge-1output-2b",
"animated": false
}
],
"viewport": {
"x": 150,
"y": 100,
"zoom": 1
}
}
}
`;
const file = new File([contents], "graph.json", { type: "text/plain" });
Expand All @@ -125,6 +317,8 @@ test("should trigger the event when clicking the load button", async () => {
});

test("should show the error message when the file contents is not valid", async () => {
const mockConsole = jest.spyOn(console, "error").mockImplementation();

const handleSave = jest.fn();
const handleLoad = jest.fn();
render(<SaveLoadPanel onSave={handleSave} onLoad={handleLoad} />);
Expand All @@ -139,4 +333,8 @@ test("should show the error message when the file contents is not valid", async
expect(handleLoad).not.toHaveBeenCalled();
const successAlert = screen.queryByText(successMessage);
expect(successAlert).not.toBeInTheDocument();

expect(mockConsole).toHaveBeenCalled();

jest.restoreAllMocks(); // restores the spy created with spyOn
});
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ const InputItems: FunctionComponent<InputItemProps> = ({
<OutlinedInput
id={getInputId(item.id)}
data-testid={getInputId(item.id)}
defaultValue={item.value}
value={item.value}
size="small"
inputProps={{
style: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import "./ReactFlowGraph.css";
interface ReactFlowGraphProps {
nodes: Node[];
edges: Edge[];
onInit: OnInit;
onNodesChange: OnNodesChange;
onEdgesChange: OnEdgesChange;
onSelectionChange: (params: OnSelectionChangeParams) => void;
Expand All @@ -40,6 +41,7 @@ interface ReactFlowGraphProps {
const ReactFlowGraph: FunctionComponent<ReactFlowGraphProps> = ({
nodes,
edges,
onInit,
onNodesChange,
onEdgesChange,
onSelectionChange,
Expand All @@ -62,9 +64,11 @@ const ReactFlowGraph: FunctionComponent<ReactFlowGraphProps> = ({

const handleInit: OnInit = useCallback(
(reactFlowInstance: ReactFlowInstance) => {
onInit(reactFlowInstance);

setReactFlowInstance(reactFlowInstance);
},
[],
[onInit],
);

const handleDragOver = useCallback((event: DragEvent) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ interface GraphContainerState {
// operationIdsAddedAtLeastOnce: Set<string>;

// React Flow states
// reactFlowState: object;
reactFlowState: object;
}

export default GraphContainerState;

0 comments on commit ca67bb5

Please sign in to comment.