diff --git a/src/state/store.ts b/src/state/store.ts index 97bd8552..ae7ffbb1 100644 --- a/src/state/store.ts +++ b/src/state/store.ts @@ -284,44 +284,45 @@ export const createTreeStore = (id?: string, initialTree?: TreeState) => { nextBranch: () => set( produce((state) => { - if (state.position.length === 0) return state; - - let branchIndex = state.position[state.position.length - 1]; - let branchCount = getNodeAtPath(state.root, state.position).children - .length; - - while (state.position.length > 0 && branchCount <= 1) { - branchIndex = state.position[state.position.length - 1]; - state.position = state.position.slice(0, -1); - branchCount = getNodeAtPath(state.root, state.position).children - .length; + if (state.position.length === 0) return; + + let parent = getNodeAtPath(state.root, state.position.slice(0, -1)); + let branchCount = parent.children.length; + + if (branchCount >= 2) { + // Go to next branch + let branchIndex = state.position[state.position.length - 1]; + state.position = [...state.position.slice(0, -1), (branchIndex + 1) % branchCount]; + } else { + // Go to next branching point + while (branchCount === 1) { + state.position.push(0); + parent = parent.children[0]; + branchCount = parent.children.length; + } } - - const currentNode = getNodeAtPath(state.root, state.position); - state.position.push((branchIndex + 1) % currentNode.children.length); }), ), previousBranch: () => set( produce((state) => { - if (state.position.length === 0) return state; - - let branchIndex = state.position[state.position.length - 1]; - let branchCount = getNodeAtPath(state.root, state.position).children - .length; - - while (state.position.length > 0 && branchCount <= 1) { - branchIndex = state.position[state.position.length - 1]; - state.position = state.position.slice(0, -1); - branchCount = getNodeAtPath(state.root, state.position).children - .length; + if (state.position.length === 0) return; + + let parent = getNodeAtPath(state.root, state.position.slice(0, -1)); + let branchCount = parent.children.length; + + if (branchCount >= 2) { + // Go to previous branch + let branchIndex = state.position[state.position.length - 1]; + state.position = [...state.position.slice(0, -1), (branchIndex + branchCount - 1) % branchCount]; + } else { + // Go to previous branching point + while (branchCount === 1 && state.position.length > 0) { + state.position = state.position.slice(0, -1); + parent = getNodeAtPath(state.root, state.position.slice(0, -1)); + branchCount = parent.children.length; + } } - - const currentNode = getNodeAtPath(state.root, state.position); - state.position.push( - (branchIndex + currentNode.children.length - 1) % - currentNode.children.length, - ); }), ), diff --git a/src/utils/tests/store.test.ts b/src/utils/tests/store.test.ts index 82c61855..9fbbc046 100644 --- a/src/utils/tests/store.test.ts +++ b/src/utils/tests/store.test.ts @@ -11,6 +11,7 @@ beforeEach(() => { const e4 = parseUci("e2e4")!; const d5 = parseUci("d7d5")!; +const e5 = parseUci("e7e5")!; const treeE4D5: () => TreeState = () => ({ ...defaultTree(), position: [0, 0], @@ -114,6 +115,77 @@ const treeE4D5Nf3: () => TreeState = () => ({ }, }); +const treeE4D5E5Nf3: () => TreeState = () => ({ + ...defaultTree(), + position: [0, 0], + root: { + fen: "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", + move: null, + san: null, + children: [ + { + fen: "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1", + move: e4, + san: "e4", + children: [ + { + fen: "rnbqkbnr/ppp1pppp/8/3p4/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2", + move: d5, + san: "d5", + clock: undefined, + children: [], + score: null, + depth: null, + halfMoves: 2, + shapes: [], + annotations: [], + comment: "", + }, + { + fen: "rnbqkbnr/ppp1pppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2", + move: e5, + san: "e5", + clock: undefined, + children: [], + score: null, + depth: null, + halfMoves: 2, + shapes: [], + annotations: [], + comment: "", + }, + ], + clock: undefined, + score: null, + depth: null, + halfMoves: 1, + shapes: [], + annotations: [], + comment: "", + }, + { + fen: "rnbqkbnr/pppppppp/8/8/8/5N2/PPPPPPPP/RNBQKB1R b KQkq - 1 1", + move: parseUci("g1f3")!, + san: "Nf3", + children: [], + clock: undefined, + score: null, + depth: null, + halfMoves: 1, + shapes: [], + annotations: [], + comment: "", + }, + ], + score: null, + depth: null, + halfMoves: 0, + shapes: [], + annotations: [], + comment: "", + }, +}) + const getNewState = () => { const s = store.getState(); return { @@ -249,28 +321,43 @@ test("should handle goToPrevious", () => { }); test("should handle nextBranch", () => { - store.setState({ ...treeE4D5Nf3(), position: [] }); + store.setState({ ...treeE4D5E5Nf3(), position: [] }); store.getState().nextBranch(); expect(getNewState()).toStrictEqual({ - ...treeE4D5Nf3(), + ...treeE4D5E5Nf3(), position: [], }); - store.setState({ ...treeE4D5Nf3(), position: [0, 0] }); + store.setState({ ...treeE4D5E5Nf3(), position: [0] }); store.getState().nextBranch(); expect(getNewState()).toStrictEqual({ - ...treeE4D5Nf3(), + ...treeE4D5E5Nf3(), position: [1], }); store.getState().nextBranch(); expect(getNewState()).toStrictEqual({ - ...treeE4D5Nf3(), + ...treeE4D5E5Nf3(), position: [0], }); + + store.setState({ ...treeE4D5E5Nf3(), position: [0, 0] }); + + store.getState().nextBranch(); + expect(getNewState()).toStrictEqual({ + ...treeE4D5E5Nf3(), + position: [0, 1], + }); + + store.getState().nextBranch(); + expect(getNewState()).toStrictEqual({ + ...treeE4D5E5Nf3(), + position: [0, 0], + }); + }); test("should handle previousBranch", () => { @@ -282,7 +369,8 @@ test("should handle previousBranch", () => { position: [], }); - store.setState({ ...treeE4D5Nf3(), position: [0, 0] }); + store.setState({ ...treeE4D5Nf3(), position: [0] }); + store.getState().previousBranch(); expect(getNewState()).toStrictEqual({ ...treeE4D5Nf3(), @@ -294,6 +382,20 @@ test("should handle previousBranch", () => { ...treeE4D5Nf3(), position: [0], }); + + store.setState({ ...treeE4D5Nf3(), position: [0, 1] }); + + store.getState().previousBranch(); + expect(getNewState()).toStrictEqual({ + ...treeE4D5Nf3(), + position: [0, 0], + }); + + store.getState().previousBranch(); + expect(getNewState()).toStrictEqual({ + ...treeE4D5Nf3(), + position: [0, 1], + }); }); test("should handle goToBranchEnd", () => {