From bac80713ace7edbac43f2ee3c3c246d39ee9e10a Mon Sep 17 00:00:00 2001 From: li3zhen1 Date: Mon, 9 Oct 2023 17:53:01 -0400 Subject: [PATCH] add quad tree delegate --- Sources/ForceSimulation/Force.swift | 1 - .../ForceSimulation/forces/CenterForce.swift | 4 - .../ForceSimulation/forces/CollideForce.swift | 44 +- .../ForceSimulation/forces/LinkForce.swift | 8 - .../forces/ManyBodyForce.swift | 88 +++- .../forces/PositionForce.swift | 5 +- .../ForceSimulation/forces/RadialForce.swift | 6 +- Sources/QuadTree/QuadTree.swift | 8 +- Sources/QuadTree/QuadTree2.swift | 490 ++++++++++++++++++ .../MiserableGraphTest.swift | 4 +- 10 files changed, 607 insertions(+), 51 deletions(-) create mode 100644 Sources/QuadTree/QuadTree2.swift diff --git a/Sources/ForceSimulation/Force.swift b/Sources/ForceSimulation/Force.swift index 4f5e01c..c211723 100644 --- a/Sources/ForceSimulation/Force.swift +++ b/Sources/ForceSimulation/Force.swift @@ -69,7 +69,6 @@ public protocol Force { associatedtype N: Identifiable func apply(alpha: Float) - func initialize() } diff --git a/Sources/ForceSimulation/forces/CenterForce.swift b/Sources/ForceSimulation/forces/CenterForce.swift index 6e2cd7e..e7e2006 100644 --- a/Sources/ForceSimulation/forces/CenterForce.swift +++ b/Sources/ForceSimulation/forces/CenterForce.swift @@ -46,10 +46,6 @@ public class CenterForce : Force where N : Identifiable { } } - public func initialize() { - - } - } diff --git a/Sources/ForceSimulation/forces/CollideForce.swift b/Sources/ForceSimulation/forces/CollideForce.swift index 434717b..abce1f1 100644 --- a/Sources/ForceSimulation/forces/CollideForce.swift +++ b/Sources/ForceSimulation/forces/CollideForce.swift @@ -1,32 +1,52 @@ // -// File.swift +// CollideForce.swift // // // Created by li3zhen1 on 10/1/23. // -import Foundation +import QuadTree +enum CollideForceError: Error { + case applyBeforeSimulationInitialized +} + + +public class CollideForce where N : Identifiable { + + let radius: CollideRadius + let iterationsPerTick: Int + + + weak var simulation: Simulation? -public class CollideForce : Force where N : Identifiable { + internal init( + radius: CollideRadius, + iterationsPerTick: Int = 1 + ) { + self.radius = radius + self.iterationsPerTick = iterationsPerTick + } +} - public enum CollideRadius { + +public extension CollideForce { + enum CollideRadius{ case constant(Float) case varied( (N.ID) -> Float ) case polarCoordinatesOnRad( (Float, N.ID) -> Float ) } +} - weak var simulation: Simulation? +extension CollideForce: Force { public func apply(alpha: Float) { - - } + guard let sim = self.simulation else { return } + + for _ in 0.. : Force where N : Identifiable { } } - public func initialize() { - guard let sim = self.simulation else { return } - for link in self.links { - - } - - } - public let defaultStiffness: LinkStiffness = .varied { link, lookup in 1 / Float( diff --git a/Sources/ForceSimulation/forces/ManyBodyForce.swift b/Sources/ForceSimulation/forces/ManyBodyForce.swift index 0c0b8c7..7206957 100644 --- a/Sources/ForceSimulation/forces/ManyBodyForce.swift +++ b/Sources/ForceSimulation/forces/ManyBodyForce.swift @@ -9,12 +9,77 @@ import QuadTree import simd -enum ManyBodyForceError: Error { - case buildQuadTreeBeforeSimulationInitialized +final class MassQuadTreeDelegate: QuadDelegate where N : Identifiable { + + typealias Node = N + typealias Property = Float + typealias MassProvider = [N.ID: Float] + + public var accumulatedProperty: Float = 0.0 + public var accumulatedCount = 0 + public var weightedAccumulatedNodePositions: Vector2f = .zero + + let massProvider: [N.ID: Float] + + init( + massProvider: MassProvider + ) { + self.massProvider = massProvider + } + + + internal init( + initialAccumulatedProperty: Float, + initialAccumulatedCount: Int, + initialWeightedAccumulatedNodePositions: Vector2f, + massProvider: MassProvider + ) { + self.accumulatedProperty = initialAccumulatedProperty + self.accumulatedCount = initialAccumulatedCount + self.weightedAccumulatedNodePositions = initialWeightedAccumulatedNodePositions + self.massProvider = massProvider + } + + func didAddNode(_ node: N, at position: Vector2f) { + let p = massProvider[node.id, default: 0] + accumulatedCount += 1 + accumulatedProperty += p + weightedAccumulatedNodePositions += p * position + } + + func didRemoveNode(_ node: N, at position: Vector2f) { + let p = massProvider[node.id, default: 0] + accumulatedCount -= 1 + accumulatedProperty -= p + weightedAccumulatedNodePositions -= p * position + + // TODO: parent removal? + } + + + func createForExpanded(towards _: Quadrant, from _: Quad, to _: Quad) -> Self { + return Self( + initialAccumulatedProperty: self.accumulatedProperty, + initialAccumulatedCount: self.accumulatedCount, + initialWeightedAccumulatedNodePositions: self.weightedAccumulatedNodePositions, + massProvider: self.massProvider + ) + } + + func stem() -> Self { + return Self(massProvider: self.massProvider) + } + + + var centroid : Vector2f? { + guard accumulatedCount > 0 else { return nil } + return weightedAccumulatedNodePositions / accumulatedProperty + } } -extension SimulationNode: HasMassLikeProperty { - public var property: Float {1.0} + +enum ManyBodyForceError: Error { + case buildQuadTreeBeforeSimulationInitialized } public class ManyBodyForce : Force where N : Identifiable { @@ -50,9 +115,7 @@ public class ManyBodyForce : Force where N : Identifiable { } } - public func initialize() { - - } + func calculateForce(alpha: Float) throws -> [Vector2f] { @@ -60,13 +123,15 @@ public class ManyBodyForce : Force where N : Identifiable { throw ManyBodyForceError.buildQuadTreeBeforeSimulationInitialized } - let quad = try QuadTree(nodes: sim.simulationNodes.map { ($0, $0.position) }) + let quad = try QuadTree2(nodes: sim.simulationNodes.map { ($0, $0.position) }) { + MassQuadTreeDelegate(massProvider: sim.simulationNodes.reduce(into: [N.ID: Float]()) { $0[$1.id] = 1.0 }) + } var forces = Array(repeating: .zero, count: sim.simulationNodes.count) for i in sim.simulationNodes.indices { quad.visit { quadNode in - if let centroid = quadNode.centroid { + if let centroid = quadNode.quadDelegate.centroid { let vec = centroid - sim.simulationNodes[i].position var distanceSquared = vec.jiggled() @@ -82,10 +147,9 @@ public class ManyBodyForce : Force where N : Identifiable { distanceSquared = sqrt(self.distanceMin2 * distanceSquared) } - - if quadNode.isLeaf || distanceSquared * self.theta2 > quadNode.quad.area { + if (quadNode.isLeaf || (distanceSquared * self.theta2 > quadNode.quad.area)) { - forces[i] += self.strength * alpha * quadNode.accumulatedProperty * vec / pow(distanceSquared, 1.5) + forces[i] += self.strength * alpha * quadNode.quadDelegate.accumulatedProperty * vec / pow(distanceSquared, 1.5) return false } diff --git a/Sources/ForceSimulation/forces/PositionForce.swift b/Sources/ForceSimulation/forces/PositionForce.swift index 749c4ca..28f22b1 100644 --- a/Sources/ForceSimulation/forces/PositionForce.swift +++ b/Sources/ForceSimulation/forces/PositionForce.swift @@ -23,10 +23,7 @@ public class PositionForce : Force where N: Identifiable { public func apply(alpha: Float) { } - - public func initialize() { - - } + diff --git a/Sources/ForceSimulation/forces/RadialForce.swift b/Sources/ForceSimulation/forces/RadialForce.swift index 20d8365..0c6187b 100644 --- a/Sources/ForceSimulation/forces/RadialForce.swift +++ b/Sources/ForceSimulation/forces/RadialForce.swift @@ -11,10 +11,6 @@ public class RadialForce : Force where N : Identifiable { public func apply(alpha: Float) { } - - public func initialize() { - - } - + } diff --git a/Sources/QuadTree/QuadTree.swift b/Sources/QuadTree/QuadTree.swift index 11204ee..9e6e01a 100644 --- a/Sources/QuadTree/QuadTree.swift +++ b/Sources/QuadTree/QuadTree.swift @@ -9,13 +9,14 @@ import simd // TODO: https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=&ved=2ahUKEwjoh_vKttuBAxUunokEHdchDZAQFnoECBkQAQ&url=https%3A%2F%2Fosf.io%2Fdu6gq%2Fdownload%2F%3Fversion%3D1%26displayName%3Dgove-2018-updating-tree-approximations-2018-06-13T02%253A16%253A17.463Z.pdf&usg=AOvVaw3KFAE5U8cnhTDMN_qrzV6a&opi=89978449 +@available(*, deprecated) public class QuadTreeNode where N: Identifiable, N: HasMassLikeProperty { - + + public private(set) var quad: Quad public var nodes: [N.ID: Vector2f] = [:] // TODO: merge nodes if close enough -// public var allNodes: [NodeID: Vector2f] = [:] public var accumulatedProperty: Float = 0.0 public var accumulatedCount = 0 public var weightedAccumulatedNodePositions: Vector2f = .zero @@ -28,7 +29,6 @@ public class QuadTreeNode where N: Identifiable, N: HasMassLikeProperty { } } - final public class Children { public private(set) var northWest: QuadTreeNode @@ -274,6 +274,8 @@ enum QuadTreeError: Error { case noNodeProvidedError } + +@available(*, deprecated) final public class QuadTree where N: HasMassLikeProperty { public private(set) var root: QuadTreeNode private var nodeIds: Set = [] diff --git a/Sources/QuadTree/QuadTree2.swift b/Sources/QuadTree/QuadTree2.swift new file mode 100644 index 0000000..205ed95 --- /dev/null +++ b/Sources/QuadTree/QuadTree2.swift @@ -0,0 +1,490 @@ +// The Swift Programming Language +// https://docs.swift.org/swift-book + +// #if arch(wasm32) +// import SimdPolyfill +// #else +import simd +// #endif + + +// TODO: https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=&ved=2ahUKEwjoh_vKttuBAxUunokEHdchDZAQFnoECBkQAQ&url=https%3A%2F%2Fosf.io%2Fdu6gq%2Fdownload%2F%3Fversion%3D1%26displayName%3Dgove-2018-updating-tree-approximations-2018-06-13T02%253A16%253A17.463Z.pdf&usg=AOvVaw3KFAE5U8cnhTDMN_qrzV6a&opi=89978449 +public class QuadTreeNode2 where N: Identifiable, QD: QuadDelegate, QD.Node == N { + + public private(set) var quad: Quad + + public var nodes: [N.ID: Vector2f] = [:] // TODO: merge nodes if close enough + + final public class Children { + public private(set) var northWest: QuadTreeNode2 + public private(set) var northEast: QuadTreeNode2 + public private(set) var southWest: QuadTreeNode2 + public private(set) var southEast: QuadTreeNode2 + internal init( + _ northWest: QuadTreeNode2, + _ northEast: QuadTreeNode2, + _ southWest: QuadTreeNode2, + _ southEast: QuadTreeNode2 + ) { + self.northWest = northWest + self.northEast = northEast + self.southWest = southWest + self.southEast = southEast + } + + + } + + public private(set) var children: Children? + + public let clusterDistance: Float + + public var quadDelegate: QD + + internal init( + quad: Quad, + clusterDistance: Float, + rootQuadDelegate: QD + ) { + self.quad = quad + self.clusterDistance = clusterDistance + self.quadDelegate = rootQuadDelegate.stem() + } + + public func add(_ node: N, at point: Vector2f) { + cover(point) + + // accumulatedCount += 1 + + quadDelegate.didAddNode(node, at: point) + + // accumulatedProperty += QD.getPropertyFor(node) + // weightedAccumulatedNodePositions += node.quadDelegate * point + + guard let children = self.children else { + if nodes.isEmpty { + // no children, not occupied => take this point + nodes[node.id] = point + return + } + else if nodes.first!.value.distanceTo(point) < clusterDistance { + // no children, close enough => take this point + nodes[node.id] = point + return + } + else { + // no children, not close enough => divide & add to children + let divided = QuadTreeNode2.divide(quad: quad, clusterDistance: clusterDistance, rootQuadDelegate: quadDelegate) + + if !nodes.isEmpty { + let direction = quad.quadrantOf(nodes.first!.value) + divided[at: direction].nodes = self.nodes + } + self.nodes = [:] + + let direction = quad.quadrantOf(point) + divided[at: direction].add(node, at: point) + + self.children = divided + return + } + } + // has children => add to children + let direction = quad.quadrantOf(point) + children[at: direction].add(node, at: point) + } + + + public func addAll(_ nodesAndPoints: [(N, Vector2f)]) { + for entry in nodesAndPoints { + add(entry.0, at: entry.1) + } + } + + @discardableResult + public func remove(_ node: N.ID) -> Bool { + if nodes.removeValue(forKey: node) != nil { + return true + } + else { + guard let children = self.children else { + return false + } + return children.anyMutating { child, _ in + child.remove(node) + } + } + } + + public func removeAll() { + nodes.removeAll() + children?.forEachMutating { n, _ in + n.removeAll() + } + } + + /// Expand the quad by exponential of 2 to cover the point, does nothing if the point is already covered + /// Does not add point to the tree + /// - Parameter point: + internal func cover(_ point: Vector2f) { + if quad.contains(point) { return } + + repeat { + + /** + * (0, 0) + * | point: .northEast + * | + * ---- quad.x0y0 ---- + * | + * | + * + */ + let quadrant: Quadrant = switch (point.y < quad.y0, point.x < quad.x0) { + case (false, false): .southEast + case (false, true): .southWest + case (true, true): .northWest + case (true, false): .northEast + } + + expand(towards: quadrant) + + } while !quad.contains(point) + + } + + + private func expand(towards quadrant: Quadrant) { + let nailedQuadrant = quadrant.reversed + let nailedCorner = quad.getCorner(of: nailedQuadrant) + let expandedCorner = quad.getCorner(of: quadrant) * 2 - nailedCorner + + let newRootQuad = Quad(corner: nailedCorner, oppositeCorner: expandedCorner) + let copiedCurrentNode = shallowCopy() + let divided = QuadTreeNode2.divide(quad: newRootQuad, clusterDistance: clusterDistance, rootQuadDelegate:quadDelegate) + divided[at: nailedQuadrant] = copiedCurrentNode + + self.quad = newRootQuad + self.children = divided + self.nodes = [:] + self.quadDelegate = quadDelegate.createForExpanded(towards: quadrant, from: copiedCurrentNode.quad, to: newRootQuad) + } + + private static func divide(quad: Quad, clusterDistance: Float, rootQuadDelegate: QD) -> Children { + let divided = quad.divide() + let northWest = QuadTreeNode2(quad: divided.northWest, clusterDistance: clusterDistance, rootQuadDelegate: rootQuadDelegate) + let northEast = QuadTreeNode2(quad: divided.northEast, clusterDistance: clusterDistance,rootQuadDelegate:rootQuadDelegate) + let southWest = QuadTreeNode2(quad: divided.southWest, clusterDistance: clusterDistance,rootQuadDelegate:rootQuadDelegate) + let southEast = QuadTreeNode2(quad: divided.southEast, clusterDistance: clusterDistance,rootQuadDelegate:rootQuadDelegate) + return Children(northWest, northEast, southWest, southEast) + } + + /** + * Copy object while holding the same reference to children + */ + private func shallowCopy() -> QuadTreeNode2 { + let copy = QuadTreeNode2(quad: quad, clusterDistance: clusterDistance, rootQuadDelegate: quadDelegate) + copy.nodes = nodes + copy.children = children + copy.quadDelegate = quadDelegate + return copy + } + + public var isLeaf: Bool { + return children == nil + } +} + + +enum QuadTree2Error: Error { + case noNodeProvidedError +} + +final public class QuadTree2 where N: Identifiable, QD: QuadDelegate, QD.Node == N { + public private(set) var root: QuadTreeNode2 + private var nodeIds: Set = [] + + public let clusterDistance: Float + + public init( + quad: Quad, + clusterDistance: Float = 1e-6, + getQuadDelegate: @escaping() -> QD + ) { + self.clusterDistance = clusterDistance + self.root = QuadTreeNode2( + quad: quad, + clusterDistance: clusterDistance, + rootQuadDelegate: getQuadDelegate() + ) + } + + public init( + nodes: [(N, Vector2f)], + clusterDistance: Float = 1e-6, + getQuadDelegate: @escaping() -> QD + ) throws { + guard let firstEntry = nodes.first else { + throw QuadTreeError.noNodeProvidedError + } + self.clusterDistance = clusterDistance + self.root = QuadTreeNode2( + quad: Quad.cover(firstEntry.1), + clusterDistance: clusterDistance, + rootQuadDelegate: getQuadDelegate() + ) + self.addAll(nodes) + } + + public func add(_ node: N, at point: Vector2f) { + root.add(node, at: point) + nodeIds.insert(node.id) + } + + public func add(_ node: N, at point: (Float, Float)) { + root.add(node, at: Vector2f(point.0, point.1)) + nodeIds.insert(node.id) + } + + public func addAll(_ nodes: [(N, Vector2f)]) { + for (node, position) in nodes { + add(node, at: position) + } + } + + public func remove(_ nodeID: N.ID) { + root.remove(nodeID) + nodeIds.remove(nodeID) + // self.nodeLookup.removeValue(forKey: nodeID) + } + + public func removeAll() { + root.removeAll() + nodeIds = [] + } + + // public var centroid : Vector2f? { + // get { + // return root.centroid + // } + // } + + static public func create( + startingWith node: N, + at point: Vector2f, + clusterDistance: Float = 1e-6, + getQuadDelegate: @escaping() -> QD + ) -> QuadTree2 where N: Identifiable { + let tree = QuadTree2( + quad: Quad.cover(point), + clusterDistance: clusterDistance, + getQuadDelegate: getQuadDelegate + ) + tree.add(node, at: point) + return tree + } + + + public var quad: Quad { return root.quad } +} + + + +extension QuadTreeNode2.Children { + + public subscript(at quadrant: Quadrant) -> QuadTreeNode2 { + get { + switch quadrant { + case .northWest: + return northWest + case .northEast: + return northEast + case .southWest: + return southWest + case .southEast: + return southEast + } + } + set { + switch quadrant { + case .northWest: + northWest = newValue + case .northEast: + northEast = newValue + case .southWest: + southWest = newValue + case .southEast: + southEast = newValue + } + } + } + + public func forEach(_ body: @escaping (QuadTreeNode2, Quadrant) -> Void) { + body(northWest, .northWest) + body(northEast, .northEast) + body(southWest, .southWest) + body(southEast, .southEast) + } + + public func forEachMutating(_ body: @escaping (inout QuadTreeNode2, Quadrant) -> Void) { + body(&northWest, .northWest) + body(&northEast, .northEast) + body(&southWest, .southWest) + body(&southEast, .southEast) + } + + @discardableResult + public func anyMutating(_ predicate: @escaping (inout QuadTreeNode2, Quadrant) -> Bool) -> Bool { + return predicate(&northWest, .northWest) + || predicate(&northEast, .northEast) + || predicate(&southWest, .southWest) + || predicate(&southEast, .southEast) + } + + public func any(_ predicate: @escaping (QuadTreeNode2, Quadrant) -> Bool) -> Bool { + return predicate(northWest, .northWest) + || predicate(northEast, .northEast) + || predicate(southWest, .southWest) + || predicate(southEast, .southEast) + } +} + +public protocol QuadDelegate { + associatedtype Node + associatedtype Property + + mutating func didAddNode(_ node: Node, at position: Vector2f) + mutating func didRemoveNode(_ node: Node, at position: Vector2f) + func createForExpanded(towards: Quadrant, from oldQuad: Quad, to newQuad: Quad) -> Self + func stem() -> Self +} + + + + + + + + + + + + +extension QuadTreeNode2 { + + public func visitAfter( + _ action: @escaping( + T, QuadTreeNode2 + )->T + ) -> T where T: AdditiveArithmetic { + + if let children { + + let nw = children.northWest.visitAfter(action) + let ne = children.northEast.visitAfter(action) + let sw = children.southWest.visitAfter(action) + let se = children.southEast.visitAfter(action) + + return action(nw+ne+sw+se, self) + } + else if !isLeaf { + return action(.zero, self) + } + return .zero + } + + + public func visit( + _ decideWhetherToVisitChildrenAfterAction: @escaping( + QuadTreeNode2 + ) -> Bool + ) { + + if decideWhetherToVisitChildrenAfterAction(self), let children { + // this is an internal node + children.northWest.visit(decideWhetherToVisitChildrenAfterAction) + children.northEast.visit(decideWhetherToVisitChildrenAfterAction) + children.southWest.visit(decideWhetherToVisitChildrenAfterAction) + children.southEast.visit(decideWhetherToVisitChildrenAfterAction) + } + } + + + public func visitAfter( + _ action: @escaping( + QuadTreeNode2 + )->Void + ) { + + if let children { + + children.northWest.visitAfter(action) + children.northEast.visitAfter(action) + children.southWest.visitAfter(action) + children.southEast.visitAfter(action) + + action(self) + } + else if !isLeaf { + action(self) + } + } + + public func visitAfter( + onInternal: @escaping( + Quad + )->Void, + onFilledLeaf: @escaping( + [N.ID: Vector2f], + Quad + )->Void + ) { + + if let children { + onInternal(quad) + children.northWest.visitAfter(onInternal: onInternal, onFilledLeaf: onFilledLeaf) + children.northEast.visitAfter(onInternal: onInternal, onFilledLeaf: onFilledLeaf) + children.southWest.visitAfter(onInternal: onInternal, onFilledLeaf: onFilledLeaf) + children.southEast.visitAfter(onInternal: onInternal, onFilledLeaf: onFilledLeaf) + } + else if !isLeaf { + onFilledLeaf(nodes, quad) + } + } +} + + + + + + + +extension QuadTree2 { + + @discardableResult + public func visitAfter( + withResult action: @escaping( + T, QuadTreeNode2 + )->T + ) -> T where T: AdditiveArithmetic { + return root.visitAfter(action) + } + + + public func visitAfter( + _ action: @escaping( + QuadTreeNode2 + )->Void + ) { + return root.visitAfter(action) + } + + public func visit( + _ decideWhetherToVisitChildrenAfterAction: @escaping( + QuadTreeNode2 + ) -> Bool + ) { + root.visit(decideWhetherToVisitChildrenAfterAction) + } +} + diff --git a/Tests/ForceSimulationTests/MiserableGraphTest.swift b/Tests/ForceSimulationTests/MiserableGraphTest.swift index 965301f..821ccb9 100644 --- a/Tests/ForceSimulationTests/MiserableGraphTest.swift +++ b/Tests/ForceSimulationTests/MiserableGraphTest.swift @@ -26,11 +26,11 @@ final class MiserableGraphTest: XCTestCase { // sim.tick() -// measure { + measure { for _ in 0..<60{ sim.tick() } -// } + } sim.tick() print(sim.simulationNodes)