From c3bf05ca258ca31c1c5a56212341a5e08b2ad15f Mon Sep 17 00:00:00 2001 From: Shawn Chang Date: Thu, 21 Sep 2023 01:54:22 +0800 Subject: [PATCH] fix delayed update --- .../src/components/GraphContainer.test.tsx | 30 +++++++++---------- .../src/components/GraphContainer.tsx | 13 +++++--- .../src/features/CoreGraphAdapter.ts | 16 ++++++---- 3 files changed, 35 insertions(+), 24 deletions(-) diff --git a/interactive-computational-graph/src/components/GraphContainer.test.tsx b/interactive-computational-graph/src/components/GraphContainer.test.tsx index a143554..71b59db 100644 --- a/interactive-computational-graph/src/components/GraphContainer.test.tsx +++ b/interactive-computational-graph/src/components/GraphContainer.test.tsx @@ -301,11 +301,11 @@ it("outputs should change when derivative mode/target is changed", () => { expect(getOutputItemValue("5", "VALUE")).toBe("6"); // Check the derivative labels - expect(getOutputItemLabelText("1", "DERIVATIVE")).toBe("d(5)/d(1)"); - expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe("d(5)/d(2)"); - expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe("d(5)/d(3)"); - expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe("d(5)/d(4)"); - expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe("d(5)/d(5)"); + expect(getOutputItemLabelText("1", "DERIVATIVE")).toBe("d(5)/d(1) ="); + expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe("d(5)/d(2) ="); + expect(getOutputItemLabelText("3", "DERIVATIVE")).toBe("d(5)/d(3) ="); + expect(getOutputItemLabelText("4", "DERIVATIVE")).toBe("d(5)/d(4) ="); + expect(getOutputItemLabelText("5", "DERIVATIVE")).toBe("d(5)/d(5) ="); // Check the derivative values expect(getOutputItemValue("1", "DERIVATIVE")).toBe("2"); @@ -323,11 +323,11 @@ it("outputs should change when derivative mode/target is changed", () => { expect(getOutputItemValue("5", "VALUE")).toBe("6"); // Check the derivative labels - expect(getOutputItemLabelText("1", "DERIVATIVE")).toBe("d(1)/d(5)"); - expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe("d(2)/d(5)"); - expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe("d(3)/d(5)"); - expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe("d(4)/d(5)"); - expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe("d(5)/d(5)"); + expect(getOutputItemLabelText("1", "DERIVATIVE")).toBe("d(1)/d(5) ="); + expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe("d(2)/d(5) ="); + expect(getOutputItemLabelText("3", "DERIVATIVE")).toBe("d(3)/d(5) ="); + expect(getOutputItemLabelText("4", "DERIVATIVE")).toBe("d(4)/d(5) ="); + expect(getOutputItemLabelText("5", "DERIVATIVE")).toBe("d(5)/d(5) ="); // Check the derivative values expect(getOutputItemValue("1", "DERIVATIVE")).toBe("0"); @@ -345,11 +345,11 @@ it("outputs should change when derivative mode/target is changed", () => { expect(getOutputItemValue("5", "VALUE")).toBe("6"); // Check the derivative labels - expect(getOutputItemLabelText("1", "DERIVATIVE")).toBe("d(1)/d(2)"); - expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe("d(2)/d(2)"); - expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe("d(3)/d(2)"); - expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe("d(4)/d(2)"); - expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe("d(5)/d(2)"); + expect(getOutputItemLabelText("1", "DERIVATIVE")).toBe("d(1)/d(2) ="); + expect(getOutputItemLabelText("2", "DERIVATIVE")).toBe("d(2)/d(2) ="); + expect(getOutputItemLabelText("3", "DERIVATIVE")).toBe("d(3)/d(2) ="); + expect(getOutputItemLabelText("4", "DERIVATIVE")).toBe("d(4)/d(2) ="); + expect(getOutputItemLabelText("5", "DERIVATIVE")).toBe("d(5)/d(2) ="); // Check the derivative values expect(getOutputItemValue("1", "DERIVATIVE")).toBe("0"); diff --git a/interactive-computational-graph/src/components/GraphContainer.tsx b/interactive-computational-graph/src/components/GraphContainer.tsx index 3921b3c..5760971 100644 --- a/interactive-computational-graph/src/components/GraphContainer.tsx +++ b/interactive-computational-graph/src/components/GraphContainer.tsx @@ -20,6 +20,7 @@ import { type XYPosition, } from "reactflow"; import { TITLE_HEIGHT } from "../constants"; +import type DifferentiationMode from "../core/DifferentiationMode"; import Operation from "../core/Operation"; import Port from "../core/Port"; import type AddNodeData from "../features/AddNodeData"; @@ -413,17 +414,21 @@ const GraphContainer: FunctionComponent = ({ ); const handleDerivativeValuesUpdated = useCallback( - (nodeIdToDerivatives: Map) => { + ( + differentiationMode: DifferentiationMode, + targetNodeId: string | null, + nodeIdToDerivatives: Map, + ) => { setReactFlowNodes((nodes) => updateReactFlowNodeDerivatives( nodeIdToDerivatives, - isReverseMode, - derivativeTarget, + differentiationMode === "REVERSE", + targetNodeId, nodes, ), ); }, - [derivativeTarget, isReverseMode], + [], ); const handleExplainDerivativeDataUpdated = useCallback( diff --git a/interactive-computational-graph/src/features/CoreGraphAdapter.ts b/interactive-computational-graph/src/features/CoreGraphAdapter.ts index 6dc7fe5..f624c64 100644 --- a/interactive-computational-graph/src/features/CoreGraphAdapter.ts +++ b/interactive-computational-graph/src/features/CoreGraphAdapter.ts @@ -36,6 +36,8 @@ type HideInputFieldCallback = (nonEmptyPortConnection: Connection) => void; type FValuesUpdatedCallback = (nodeIdToFValues: Map) => void; type DerivativeValuesUpdatedCallback = ( + differentiationMode: DifferentiationMode, + targetNode: string | null, nodeIdToDerivatives: Map, ) => void; @@ -54,7 +56,7 @@ class CoreGraphAdapter { private showInputFieldsCallbacks: ShowInputFieldsCallback[] = []; private hideInputFieldCallbacks: HideInputFieldCallback[] = []; private fValuesUpdatedCallbacks: FValuesUpdatedCallback[] = []; - private derivativesUpdatedCallbacks: FValuesUpdatedCallback[] = []; + private derivativesUpdatedCallbacks: DerivativeValuesUpdatedCallback[] = []; private explainDerivativeDataUpdatedCallbacks: ExplainDerivativeDataUpdatedCallback[] = []; @@ -255,7 +257,7 @@ class CoreGraphAdapter { this.graph.setTargetNode(null); - this.emitTargetNodeChanged(null); + this.emitTargetNodeUpdated(); } private updateSelectedNodeIds(): void { @@ -507,9 +509,9 @@ class CoreGraphAdapter { }); } - private emitTargetNodeChanged(targetNodeId: string | null): void { + private emitTargetNodeUpdated(): void { this.targetNodeUpdatedCallbacks.forEach((callback) => { - callback(targetNodeId); + callback(this.graph.getTargetNode()); }); } @@ -535,7 +537,11 @@ class CoreGraphAdapter { nodeIdToDerivatives: Map, ): void { this.derivativesUpdatedCallbacks.forEach((callback) => { - callback(nodeIdToDerivatives); + callback( + this.graph.getDifferentiationMode(), + this.graph.getTargetNode(), + nodeIdToDerivatives, + ); }); }