From 87a33794ba9e756d1af62abb6a382842cc264330 Mon Sep 17 00:00:00 2001 From: Sebastian Krantz Date: Mon, 9 Sep 2024 01:03:12 +0200 Subject: [PATCH] Adding more robust linear interpolation method where spline does not work. --- Project.toml | 29 ++++++++++++--------- src/OptimalTransportNetworks.jl | 3 ++- src/main/helper.jl | 45 +++++++++++++++++++++++++++++++++ src/main/plot_graph.jl | 20 ++++++++++----- 4 files changed, 77 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index 8214b87..d2317ea 100644 --- a/Project.toml +++ b/Project.toml @@ -1,29 +1,34 @@ name = "OptimalTransportNetworks" uuid = "e2b46e68-897f-4e4e-ba36-a93c9789fd96" authors = ["Sebastian Krantz "] -version = "0.1.5" +version = "0.1.6" [deps] Dierckx = "39dd38d3-220a-591b-8e3c-4c3a8c710a94" Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9" JuMP = "4076af6c-e467-56ae-b986-b466b2749572" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" +NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -julia = "1.8.5" -Dierckx = "0.5.0" -Ipopt = "1.4.0" -JuMP = "1.20.0" -MathOptInterface = "1.0.0" -Plots = "1.19.0" -Random = "1.0.0" -SparseArrays = "1.0.0" -Statistics = "1.0.0" -LinearAlgebra = "1.0.0" +julia = "1.8.5, 2" +Ipopt = "1.4.0, 2" +JuMP = "1.20.0, 2" +LinearAlgebra = "1.0.0, 2" +NearestNeighbors = "0.4.10, 1" +Plots = "1.19.0, 2" +Dierckx = "0.5.0, 1" +Random = "1.0.0, 2" +SparseArrays = "1.0.0, 2" +StaticArrays = "1.1.0, 2" +Statistics = "1.0.0, 2" + + + diff --git a/src/OptimalTransportNetworks.jl b/src/OptimalTransportNetworks.jl index 98fa317..e7df504 100644 --- a/src/OptimalTransportNetworks.jl +++ b/src/OptimalTransportNetworks.jl @@ -73,10 +73,11 @@ plot_graph(graph, results_annealing[:Ijk]) module OptimalTransportNetworks using LinearAlgebra, JuMP, Plots -using SparseArrays: sparse using Statistics: mean # using MathOptInterface: Parameter using Dierckx: Spline2D, evaluate +using NearestNeighbors: KDTree, knn +# using Interpolations: cubic_spline_interpolation import Ipopt, Plots, Random #, MathOptSymbolicAD # import MathOptInterface as MOI diff --git a/src/main/helper.jl b/src/main/helper.jl index 7d5904a..13278f4 100644 --- a/src/main/helper.jl +++ b/src/main/helper.jl @@ -285,3 +285,48 @@ function rescale_network!(param, graph, I1, Il, Iu; max_iter = 100) return I1 end + +# KNN Version: More robust +function linear_interpolation_2d(vec_x, vec_y, vec_map, xmap, ymap) + + # Ensure input vectors are of the same length + @assert length(vec_x) == length(vec_y) == length(vec_map) "Input vectors must have the same length" + + # Ensure input vectors are Float64 + vec_x = convert(Vector{Float64}, vec_x) + vec_y = convert(Vector{Float64}, vec_y) + vec_map = convert(Vector{Float64}, vec_map) + xmap = convert(Vector{Float64}, xmap) + ymap = convert(Vector{Float64}, ymap) + + # Initialize the output array + fmap = zeros(length(xmap), length(ymap)) + + # Create a KDTree for efficient nearest neighbor search + points = hcat(vec_x, vec_y) + tree = KDTree(points'; leafsize = 5) + + # Determine the number of neighbors to use (k) + k = min(15, size(points, 1)) # Use 15 or the total number of points, whichever is smaller + + for (ix, x) in enumerate(xmap), (iy, y) in enumerate(ymap) + + # Find the 15 nearest neighbors + idxs, dists = knn(tree, [x, y], k, true) + + # If the point is exactly on a known point, use that value + if dists[1] ≈ 0 + fmap[ix, iy] = vec_map[idxs[1]] + continue + end + + # Weights + weights = 1 ./ dists.^2 + weights ./= sum(weights) + + # Interpolate + fmap[ix, iy] = sum(weights .* vec_map[idxs]) + end + + return fmap +end \ No newline at end of file diff --git a/src/main/plot_graph.jl b/src/main/plot_graph.jl index d63d970..c59a3f4 100644 --- a/src/main/plot_graph.jl +++ b/src/main/plot_graph.jl @@ -94,15 +94,21 @@ function plot_graph(graph, edges = nothing; kwargs...) vec_map = vec(op.map) end # Interpolate map onto grid - # itp = interpolate((vec_x, vec_y), vec_map, Gridded(Linear())) - spl = Spline2D(vec_x, vec_y, vec_map, s = 0.1) xmap = range(minimum(vec_x), stop=maximum(vec_x), length=2*length(vec_x)) ymap = range(minimum(vec_y), stop=maximum(vec_y), length=2*length(vec_y)) - Xmap, Ymap = xmap' .* ones(length(ymap)), ymap .* ones(length(xmap))' - Xmap, Ymap = Xmap[:], Ymap[:] - fmap = evaluate(spl, Xmap, Ymap) - # make fmap a matrix with same size as xmap and ymap - fmap = reshape(fmap, length(xmap), length(ymap)) + # itp = interpolate((vec_x, vec_y), vec_map, Gridded(Linear())) + fmap = zeros(length(xmap), length(ymap)) + try + spl = Spline2D(vec_x, vec_y, vec_map, s = 0.1) + Xmap, Ymap = xmap' .* ones(length(ymap)), ymap .* ones(length(xmap))' + Xmap, Ymap = Xmap[:], Ymap[:] + fmap_values = evaluate(spl, Xmap, Ymap) + fmap = reshape(fmap_values, length(xmap), length(ymap)) + catch + # println("Spline2D interpolation failed, falling back to linear interpolation") + # If Spline2D interpolation fails, fall back to linear interpolation + fmap = linear_interpolation_2d(vec_x, vec_y, vec_map, xmap, ymap) + end # Plot heatmap heatmap!(pl, xmap, ymap, fmap,