Skip to content

Commit

Permalink
add 2D forces
Browse files Browse the repository at this point in the history
  • Loading branch information
li3zhen1 committed Oct 10, 2023
1 parent 9b1a241 commit 18b4eed
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 33 deletions.
19 changes: 11 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@ https://github.com/li3zhen1/Grape/assets/45376537/5f57c223-0126-428a-a72d-d9a3ed

#### Features

| Feature | Status |
| --- | --- |
| LinkForce ||
| ManyBodyForce ||
| CenterForce ||
| CollideForce ||
| PositionForce | |
| RadialForce ||
| | Status (2D) | Status (3D) | Metal |
| --- | --- | --- | --- |
| **NdTree** || 🚧 | |
| **Simulation** || 🚧 | 🚧 |
|  LinkForce || | |
|  ManyBodyForce || | |
|  CenterForce || | |
|  CollideForce || | |
|  PositionForce || | |
|  RadialForce || | |
| **SwiftUI View** | 🚧 | | |


#### Usage
Expand Down
15 changes: 7 additions & 8 deletions Sources/ForceSimulation/forces/CollideForce.swift
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,13 @@ extension CollideForce: Force {
nodes: sim.simulationNodes.map { ($0, $0.position) },
getQuadDelegate: {
MaxRadiusQuadTreeDelegate {
// switch self.radius {
// case .constant(let r):
// return r
// case .varied(_):
// return self.calculatedRadius[$0, default: 0.0]
// }
return self.calculatedRadius[$0, default: 0.0]
// return self.calculatedRadius[$0]!
switch self.radius {
case .constant(let r):
return r
case .varied(_):
return self.calculatedRadius[$0, default: 0.0]
}
// return self.calculatedRadius[$0, default: 0.0]
}
}
)
Expand Down
1 change: 0 additions & 1 deletion Sources/ForceSimulation/forces/LinkForce.swift
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ final public class LinkForce<N>: Force where N: Identifiable {
var calculatedBias: [Float] = []
weak var simulation: Simulation<N>?


var iterations: Int

internal init(
Expand Down
4 changes: 1 addition & 3 deletions Sources/ForceSimulation/forces/ManyBodyForce.swift
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ final public class ManyBodyForce<N>: Force where N: Identifiable {
let quad = try QuadTree2(
nodes: sim.simulationNodes.map { ($0, $0.position) }
) {
// this switch is only called on root init
// but it significantly slows down the performance
//
// this switch is only called on root init
return switch self.mass {
case .constant(let m):
MassQuadTreeDelegate<SimulationNode<N.ID>> { _ in m }
Expand Down
82 changes: 74 additions & 8 deletions Sources/ForceSimulation/forces/PositionForce.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,88 @@
// Created by li3zhen1 on 10/1/23.
//


final public class PositionForce<N>: Force where N: Identifiable {
var x: Float
var y: Float
var strength: Float

init(x: Float, y: Float, strength: Float = 0.1) {
self.x = x
self.y = y
public enum Direction {
case x
case y
}
public enum TargetOnDirection {
case constant(Float)
case varied([N.ID: Float])
}
public enum Strength {
case constant(Float)
case varied([N.ID: Float])
}
public var strength: Strength
public var direction: Direction
public var calculatedStrength: [N.ID: Float] = [:]
public var targetOnDirection: TargetOnDirection
public var calculatedTargetOnDirection: [N.ID: Float] = [:]

internal init(direction: Direction, targetOnDirection: TargetOnDirection, strength: Strength = .constant(1.0)) {
self.strength = strength
self.direction = direction
self.targetOnDirection = targetOnDirection
}

weak var simulation: Simulation<N>?
weak var simulation: Simulation<N>? {
didSet {
guard let sim = self.simulation else { return }
self.calculatedStrength = strength.calculated(sim.simulationNodes)
self.calculatedTargetOnDirection = targetOnDirection.calculated(sim.simulationNodes)
}
}

public func apply(alpha: Float) {
guard let sim = self.simulation else { return }
let vectorIndex = self.direction == .x ? 0 : 1
for i in sim.simulationNodes.indices {
let nodeId = sim.simulationNodes[i].id
sim.simulationNodes[i].velocity += (
self.calculatedTargetOnDirection[nodeId, default: 0.0] - sim.simulationNodes[i].position[vectorIndex]
) * self.calculatedStrength[nodeId, default: 0.0] * alpha
}
}
}

extension PositionForce.Strength {
func calculated<SimNode>(_ nodes: [SimNode]) -> [N.ID: Float] where SimNode: Identifiable, SimNode.ID == N.ID {
switch self {
case .constant(let value):
return nodes.reduce(into: [:]) { $0[$1.id] = value }
case .varied(let dict):
return dict
}
}
}

extension PositionForce.TargetOnDirection {
func calculated<SimNode>(_ nodes: [SimNode]) -> [N.ID: Float] where SimNode: Identifiable, SimNode.ID == N.ID {
switch self {
case .constant(let value):
return nodes.reduce(into: [:]) { $0[$1.id] = value }
case .varied(let dict):
return dict
}
}
}


public extension Simulation {
func createPositionForce(
direction: PositionForce<N>.Direction,
targetOnDirection: PositionForce<N>.TargetOnDirection,
strength: PositionForce<N>.Strength = .constant(1.0)
) -> PositionForce<N> {
let force = PositionForce<N>(
direction: direction,
targetOnDirection: targetOnDirection,
strength: strength
)
force.simulation = self
self.forces.append(force)
return force
}
}
25 changes: 25 additions & 0 deletions Sources/ForceSimulation/mpsforces/CenterForce.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include <metal_stdlib>
using namespace metal;

struct Node {
float2 position;
float2 velocity;
float2 fixation;
};

kernel void applyCenterForce(
device Node* nodes [[ buffer(0) ]],
constant float2& center [[ buffer(1) ]],
constant float& strength [[ buffer(2) ]],
uint id [[ thread_position_in_grid ]],
uint nodeCount [[ threads_per_grid ]])
{
float2 meanPosition = float2(0.0, 0.0);
for (int i = 0; i < nodeCount; ++i) {
meanPosition += nodes[i].position;
}
meanPosition /= float(nodeCount);

float2 delta = (meanPosition - center) * strength;
nodes[id].position -= delta;
}
28 changes: 28 additions & 0 deletions Sources/ForceSimulation/mpsforces/MPSSimulation.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import Foundation
import Metal
import simd
import MetalPerformanceShaders


final public class MPSSimulation {
init() {
guard let device = MTLCreateSystemDefaultDevice() else {
fatalError("")
}

guard let commandQueue = device.makeCommandQueue() else {
fatalError("")
}

let library = device.makeDefaultLibrary()

let function = library?.makeFunction(name: "")

var pipelineState: MTLComputePipelineState
do {
pipelineState = try device.makeComputePipelineState(function: function!)
} catch {
fatalError("")
}
}
}
10 changes: 5 additions & 5 deletions Sources/QuadTree/NdTree.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ public protocol NdDirection: RawRepresentable {
static var entryCount: Int { get }
}

public extension SIMD<Float> {
@inlinable func direction<Direction>(originalPoint point: Self) -> Direction where Direction: NdDirection, Direction.Coordinate == Self {

}
}
//public extension SIMD<Float> {
// @inlinable func direction<Direction>(originalPoint point: Self) -> Direction where Direction: NdDirection, Direction.Coordinate == Self {
//
// }
//}

struct OctDirection: NdDirection {
typealias Coordinate = SIMD3<Float>
Expand Down
18 changes: 18 additions & 0 deletions Tests/QuadTreeTests/NdTreeTests.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,24 @@
import XCTest
@testable import QuadTree

extension SIMD3<Float>: ComponentComparable {
public static func < (lhs: SIMD3<Scalar>, rhs: SIMD3<Scalar>) -> Bool {
return lhs.x < rhs.x && lhs.y < rhs.y && lhs.z < rhs.z
}

public static func > (lhs: SIMD3<Scalar>, rhs: SIMD3<Scalar>) -> Bool {
return lhs.x > rhs.x && lhs.y > rhs.y && lhs.z > rhs.z
}

public static func <= (lhs: SIMD3<Scalar>, rhs: SIMD3<Scalar>) -> Bool {
return lhs.x <= rhs.x && lhs.y <= rhs.y && lhs.z <= rhs.z
}

public static func >= (lhs: SIMD3<Scalar>, rhs: SIMD3<Scalar>) -> Bool {
return lhs.x >= rhs.x && lhs.y >= rhs.y && lhs.z >= rhs.z
}

}

final class NdTreeTests: XCTestCase {

Expand Down

0 comments on commit 18b4eed

Please sign in to comment.