diff --git a/.gitignore b/.gitignore index 3d1804049..a2c3a9ddc 100644 --- a/.gitignore +++ b/.gitignore @@ -8,5 +8,4 @@ Manifest.toml /docs/build/ .vscode LocalPreferences.toml -.DS_Store -/test.jl +.DS_Store \ No newline at end of file diff --git a/src/GNNGraphs/GNNGraphs.jl b/src/GNNGraphs/GNNGraphs.jl index 2f989c9b1..1c98f3e02 100644 --- a/src/GNNGraphs/GNNGraphs.jl +++ b/src/GNNGraphs/GNNGraphs.jl @@ -72,6 +72,7 @@ export add_nodes, negative_sample, rand_edge_split, remove_self_loops, + remove_edges, remove_multi_edges, set_edge_weight, to_bidirected, diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index 42a2796f1..ee2d4b410 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -149,6 +149,57 @@ function remove_self_loops(g::GNNGraph{<:ADJMAT_T}) g.ndata, g.edata, g.gdata) end +""" + remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer}) + +Remove specified edges from a GNNGraph. + +# Arguments +- `g`: The input graph from which edges will be removed. +- `edges_to_remove`: Vector of edge indices to be removed. + +# Returns +A new GNNGraph with the specified edges removed. + +# Example +```julia +julia> using GraphNeuralNetworks + +# Construct a GNNGraph +julia> g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1]) +GNNGraph: + num_nodes: 3 + num_edges: 5 + +# Remove the second edge +julia> g_new = remove_edges(g, [2]); + +julia> g_new +GNNGraph: + num_nodes: 3 + num_edges: 4 +``` +""" +function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:Integer}) + s, t = edge_index(g) + w = get_edge_weight(g) + edata = g.edata + + mask_to_keep = trues(length(s)) + + mask_to_keep[edges_to_remove] .= false + + s = s[mask_to_keep] + t = t[mask_to_keep] + edata = getobs(edata, mask_to_keep) + w = isnothing(w) ? nothing : getobs(w, mask_to_keep) + + return GNNGraph((s, t, w), + g.num_nodes, length(s), g.num_graphs, + g.graph_indicator, + g.ndata, edata, g.gdata) +end + """ remove_multi_edges(g::GNNGraph; aggr=+) diff --git a/test/GNNGraphs/transform.jl b/test/GNNGraphs/transform.jl index 6edbb6511..20efdd2bf 100644 --- a/test/GNNGraphs/transform.jl +++ b/test/GNNGraphs/transform.jl @@ -101,6 +101,34 @@ end @test nodemap == 1:(g1.num_nodes) end +@testset "remove_edges" begin + if GRAPH_T == :coo + s = [1, 1, 2, 3] + t = [2, 3, 4, 5] + w = [0.1, 0.2, 0.3, 0.4] + edata = ['a', 'b', 'c', 'd'] + g = GNNGraph(s, t, w, edata = edata, graph_type = GRAPH_T) + + # single edge removal + gnew = remove_edges(g, [1]) + new_s, new_t = edge_index(gnew) + @test gnew.num_edges == 3 + @test new_s == s[2:end] + @test new_t == t[2:end] + + # multiple edge removal + gnew = remove_edges(g, [1,2,4]) + new_s, new_t = edge_index(gnew) + new_w = get_edge_weight(gnew) + new_edata = gnew.edata.e + @test gnew.num_edges == 1 + @test new_s == [2] + @test new_t == [4] + @test new_w == [0.3] + @test new_edata == ['c'] + end +end + @testset "add_edges" begin if GRAPH_T == :coo s = [1, 1, 2, 3]