Skip to content

Commit

Permalink
Adding more robust linear interpolation method where spline does not …
Browse files Browse the repository at this point in the history
…work.
  • Loading branch information
SebKrantz committed Sep 8, 2024
1 parent 877dd24 commit 87a3379
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 20 deletions.
29 changes: 17 additions & 12 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
name = "OptimalTransportNetworks"
uuid = "e2b46e68-897f-4e4e-ba36-a93c9789fd96"
authors = ["Sebastian Krantz <sebastian.krantz@graduateinstitute.ch>"]
version = "0.1.5"
version = "0.1.6"

[deps]
Dierckx = "39dd38d3-220a-591b-8e3c-4c3a8c710a94"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
julia = "1.8.5"
Dierckx = "0.5.0"
Ipopt = "1.4.0"
JuMP = "1.20.0"
MathOptInterface = "1.0.0"
Plots = "1.19.0"
Random = "1.0.0"
SparseArrays = "1.0.0"
Statistics = "1.0.0"
LinearAlgebra = "1.0.0"
julia = "1.8.5, 2"
Ipopt = "1.4.0, 2"
JuMP = "1.20.0, 2"
LinearAlgebra = "1.0.0, 2"
NearestNeighbors = "0.4.10, 1"
Plots = "1.19.0, 2"
Dierckx = "0.5.0, 1"
Random = "1.0.0, 2"
SparseArrays = "1.0.0, 2"
StaticArrays = "1.1.0, 2"
Statistics = "1.0.0, 2"





3 changes: 2 additions & 1 deletion src/OptimalTransportNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,11 @@ plot_graph(graph, results_annealing[:Ijk])
module OptimalTransportNetworks

using LinearAlgebra, JuMP, Plots
using SparseArrays: sparse
using Statistics: mean
# using MathOptInterface: Parameter
using Dierckx: Spline2D, evaluate
using NearestNeighbors: KDTree, knn
# using Interpolations: cubic_spline_interpolation
import Ipopt, Plots, Random #, MathOptSymbolicAD
# import MathOptInterface as MOI

Expand Down
45 changes: 45 additions & 0 deletions src/main/helper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,48 @@ function rescale_network!(param, graph, I1, Il, Iu; max_iter = 100)

return I1
end

# KNN Version: More robust
function linear_interpolation_2d(vec_x, vec_y, vec_map, xmap, ymap)

# Ensure input vectors are of the same length
@assert length(vec_x) == length(vec_y) == length(vec_map) "Input vectors must have the same length"

# Ensure input vectors are Float64
vec_x = convert(Vector{Float64}, vec_x)
vec_y = convert(Vector{Float64}, vec_y)
vec_map = convert(Vector{Float64}, vec_map)
xmap = convert(Vector{Float64}, xmap)
ymap = convert(Vector{Float64}, ymap)

# Initialize the output array
fmap = zeros(length(xmap), length(ymap))

# Create a KDTree for efficient nearest neighbor search
points = hcat(vec_x, vec_y)
tree = KDTree(points'; leafsize = 5)

# Determine the number of neighbors to use (k)
k = min(15, size(points, 1)) # Use 15 or the total number of points, whichever is smaller

for (ix, x) in enumerate(xmap), (iy, y) in enumerate(ymap)

# Find the 15 nearest neighbors
idxs, dists = knn(tree, [x, y], k, true)

# If the point is exactly on a known point, use that value
if dists[1] 0
fmap[ix, iy] = vec_map[idxs[1]]
continue
end

# Weights
weights = 1 ./ dists.^2
weights ./= sum(weights)

# Interpolate
fmap[ix, iy] = sum(weights .* vec_map[idxs])
end

return fmap
end
20 changes: 13 additions & 7 deletions src/main/plot_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,21 @@ function plot_graph(graph, edges = nothing; kwargs...)
vec_map = vec(op.map)
end
# Interpolate map onto grid
# itp = interpolate((vec_x, vec_y), vec_map, Gridded(Linear()))
spl = Spline2D(vec_x, vec_y, vec_map, s = 0.1)
xmap = range(minimum(vec_x), stop=maximum(vec_x), length=2*length(vec_x))
ymap = range(minimum(vec_y), stop=maximum(vec_y), length=2*length(vec_y))
Xmap, Ymap = xmap' .* ones(length(ymap)), ymap .* ones(length(xmap))'
Xmap, Ymap = Xmap[:], Ymap[:]
fmap = evaluate(spl, Xmap, Ymap)
# make fmap a matrix with same size as xmap and ymap
fmap = reshape(fmap, length(xmap), length(ymap))
# itp = interpolate((vec_x, vec_y), vec_map, Gridded(Linear()))
fmap = zeros(length(xmap), length(ymap))
try
spl = Spline2D(vec_x, vec_y, vec_map, s = 0.1)
Xmap, Ymap = xmap' .* ones(length(ymap)), ymap .* ones(length(xmap))'
Xmap, Ymap = Xmap[:], Ymap[:]
fmap_values = evaluate(spl, Xmap, Ymap)
fmap = reshape(fmap_values, length(xmap), length(ymap))
catch
# println("Spline2D interpolation failed, falling back to linear interpolation")
# If Spline2D interpolation fails, fall back to linear interpolation
fmap = linear_interpolation_2d(vec_x, vec_y, vec_map, xmap, ymap)
end

# Plot heatmap
heatmap!(pl, xmap, ymap, fmap,
Expand Down

0 comments on commit 87a3379

Please sign in to comment.