From d2ab349876c749c442d9a334a19307ca2d0afe70 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 1 Aug 2024 15:05:06 +0530 Subject: [PATCH 1/4] new drop edge --- GNNGraphs/src/transform.jl | 88 ++++++++++++++++++++++++++----------- GNNGraphs/test/transform.jl | 7 +++ 2 files changed, 69 insertions(+), 26 deletions(-) diff --git a/GNNGraphs/src/transform.jl b/GNNGraphs/src/transform.jl index 8df726752..54fcfc14f 100644 --- a/GNNGraphs/src/transform.jl +++ b/GNNGraphs/src/transform.jl @@ -201,40 +201,76 @@ function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:In end """ - remove_multi_edges(g::GNNGraph; aggr=+) + remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer}) + remove_edges(g::GNNGraph, p::Float64=0.5) + +Remove specified edges from a GNNGraph, either by specifying edge indices or by randomly removing edges with a given probability. + +# Arguments +- `g`: The input graph from which edges will be removed. +- `edges_to_remove`: Vector of edge indices to be removed. This argument is only required for the first method. +- `p`: Probability of removing each edge. This argument is only required for the second method and defaults to 0.5. -Remove multiple edges (also called parallel edges or repeated edges) from graph `g`. -Possible edge features are aggregated according to `aggr`, that can take value -`+`,`min`, `max` or `mean`. +# Returns +A new GNNGraph with the specified edges removed. -See also [`remove_self_loops`](@ref), [`has_multi_edges`](@ref), and [`to_bidirected`](@ref). +# Example +```julia +julia> using GraphNeuralNetworks + +# Construct a GNNGraph +julia> g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1]) +GNNGraph: + num_nodes: 3 + num_edges: 5 + +# Remove the second edge +julia> g_new = remove_edges(g, [2]); + +julia> g_new +GNNGraph: + num_nodes: 3 + num_edges: 4 + +# Remove edges with a probability of 0.5 +julia> g_new = remove_edges(g, 0.5); + +julia> g_new +GNNGraph: + num_nodes: 3 + num_edges: 2 +``` """ -function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr = +) +function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:Integer}) s, t = edge_index(g) w = get_edge_weight(g) edata = g.edata - num_edges = g.num_edges - idxs, idxmax = edge_encoding(s, t, g.num_nodes) - - perm = sortperm(idxs) - idxs = idxs[perm] - s, t = s[perm], t[perm] - edata = getobs(edata, perm) - w = isnothing(w) ? nothing : getobs(w, perm) - idxs = [-1; idxs] - mask = idxs[2:end] .> idxs[1:(end - 1)] - if !all(mask) - s, t = s[mask], t[mask] - idxs = similar(s, num_edges) - idxs .= 1:num_edges - idxs .= idxs .- cumsum(.!mask) - num_edges = length(s) - w = _scatter(aggr, w, idxs, num_edges) - edata = _scatter(aggr, edata, idxs, num_edges) - end + + mask_to_keep = trues(length(s)) + + mask_to_keep[edges_to_remove] .= false + + s = s[mask_to_keep] + t = t[mask_to_keep] + edata = getobs(edata, mask_to_keep) + w = isnothing(w) ? nothing : getobs(w, mask_to_keep) return GNNGraph((s, t, w), - g.num_nodes, num_edges, g.num_graphs, + g.num_nodes, length(s), g.num_graphs, + g.graph_indicator, + g.ndata, edata, g.gdata) +end + + +function remove_edges(g::GNNGraph{<:COO_T}, p = 0.5) + num_edges = g.num_edges + edges_to_remove = filter(_ -> rand() < p, 1:num_edges) + g = remove_edges(g, edges_to_remove) + s, t = edge_index(g) + w = get_edge_weight(g) + edata = g.edata + return GNNGraph((s, t, w), + g.num_nodes, length(s), g.num_graphs, g.graph_indicator, g.ndata, edata, g.gdata) end diff --git a/GNNGraphs/test/transform.jl b/GNNGraphs/test/transform.jl index 993ac714a..05413fd4f 100644 --- a/GNNGraphs/test/transform.jl +++ b/GNNGraphs/test/transform.jl @@ -126,6 +126,13 @@ end @test new_t == [4] @test new_w == [0.3] @test new_edata == ['c'] + + # drop with probability + gnew = remove_edges(g, Float32(1.0)) + @test gnew.num_edges == 0 + + gnew = remove_edges(g, Float32(0.0)) + @test gnew.num_edges == g.num_edges end end From 96eb264a1489fdc605bc6dedbf49bb95f6f32495 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 1 Aug 2024 15:06:38 +0530 Subject: [PATCH 2/4] fix --- GNNGraphs/src/transform.jl | 90 +++++++++++++++++--------------------- 1 file changed, 39 insertions(+), 51 deletions(-) diff --git a/GNNGraphs/src/transform.jl b/GNNGraphs/src/transform.jl index 54fcfc14f..1e7603897 100644 --- a/GNNGraphs/src/transform.jl +++ b/GNNGraphs/src/transform.jl @@ -149,57 +149,6 @@ function remove_self_loops(g::GNNGraph{<:ADJMAT_T}) g.ndata, g.edata, g.gdata) end -""" - remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer}) - -Remove specified edges from a GNNGraph. - -# Arguments -- `g`: The input graph from which edges will be removed. -- `edges_to_remove`: Vector of edge indices to be removed. - -# Returns -A new GNNGraph with the specified edges removed. - -# Example -```julia -julia> using GraphNeuralNetworks - -# Construct a GNNGraph -julia> g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1]) -GNNGraph: - num_nodes: 3 - num_edges: 5 - -# Remove the second edge -julia> g_new = remove_edges(g, [2]); - -julia> g_new -GNNGraph: - num_nodes: 3 - num_edges: 4 -``` -""" -function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:Integer}) - s, t = edge_index(g) - w = get_edge_weight(g) - edata = g.edata - - mask_to_keep = trues(length(s)) - - mask_to_keep[edges_to_remove] .= false - - s = s[mask_to_keep] - t = t[mask_to_keep] - edata = getobs(edata, mask_to_keep) - w = isnothing(w) ? nothing : getobs(w, mask_to_keep) - - return GNNGraph((s, t, w), - g.num_nodes, length(s), g.num_graphs, - g.graph_indicator, - g.ndata, edata, g.gdata) -end - """ remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer}) remove_edges(g::GNNGraph, p::Float64=0.5) @@ -275,6 +224,45 @@ function remove_edges(g::GNNGraph{<:COO_T}, p = 0.5) g.ndata, edata, g.gdata) end +""" + remove_multi_edges(g::GNNGraph; aggr=+) + +Remove multiple edges (also called parallel edges or repeated edges) from graph `g`. +Possible edge features are aggregated according to `aggr`, that can take value +`+`,`min`, `max` or `mean`. + +See also [`remove_self_loops`](@ref), [`has_multi_edges`](@ref), and [`to_bidirected`](@ref). +""" +function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr = +) + s, t = edge_index(g) + w = get_edge_weight(g) + edata = g.edata + num_edges = g.num_edges + idxs, idxmax = edge_encoding(s, t, g.num_nodes) + + perm = sortperm(idxs) + idxs = idxs[perm] + s, t = s[perm], t[perm] + edata = getobs(edata, perm) + w = isnothing(w) ? nothing : getobs(w, perm) + idxs = [-1; idxs] + mask = idxs[2:end] .> idxs[1:(end - 1)] + if !all(mask) + s, t = s[mask], t[mask] + idxs = similar(s, num_edges) + idxs .= 1:num_edges + idxs .= idxs .- cumsum(.!mask) + num_edges = length(s) + w = _scatter(aggr, w, idxs, num_edges) + edata = _scatter(aggr, edata, idxs, num_edges) + end + + return GNNGraph((s, t, w), + g.num_nodes, num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, edata, g.gdata) +end + """ remove_nodes(g::GNNGraph, nodes_to_remove::AbstractVector) From b86bf2ef88b46fb2abef921d773878db9b649d0f Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Thu, 1 Aug 2024 15:32:40 +0530 Subject: [PATCH 3/4] Update GNNGraphs/src/transform.jl Co-authored-by: Carlo Lucibello --- GNNGraphs/src/transform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GNNGraphs/src/transform.jl b/GNNGraphs/src/transform.jl index 1e7603897..d7177a0e8 100644 --- a/GNNGraphs/src/transform.jl +++ b/GNNGraphs/src/transform.jl @@ -151,7 +151,7 @@ end """ remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer}) - remove_edges(g::GNNGraph, p::Float64=0.5) + remove_edges(g::GNNGraph, p=0.5) Remove specified edges from a GNNGraph, either by specifying edge indices or by randomly removing edges with a given probability. From 28712e354bfefe3099b0c5eb5cdfb8cae144d0fd Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Thu, 1 Aug 2024 15:33:13 +0530 Subject: [PATCH 4/4] Update GNNGraphs/src/transform.jl Co-authored-by: Carlo Lucibello --- GNNGraphs/src/transform.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/GNNGraphs/src/transform.jl b/GNNGraphs/src/transform.jl index d7177a0e8..8e8c98d13 100644 --- a/GNNGraphs/src/transform.jl +++ b/GNNGraphs/src/transform.jl @@ -214,14 +214,7 @@ end function remove_edges(g::GNNGraph{<:COO_T}, p = 0.5) num_edges = g.num_edges edges_to_remove = filter(_ -> rand() < p, 1:num_edges) - g = remove_edges(g, edges_to_remove) - s, t = edge_index(g) - w = get_edge_weight(g) - edata = g.edata - return GNNGraph((s, t, w), - g.num_nodes, length(s), g.num_graphs, - g.graph_indicator, - g.ndata, edata, g.gdata) + return remove_edges(g, edges_to_remove) end """