diff --git a/src/GNNGraphs/GNNGraphs.jl b/src/GNNGraphs/GNNGraphs.jl index 294806215..1e28fc9c8 100644 --- a/src/GNNGraphs/GNNGraphs.jl +++ b/src/GNNGraphs/GNNGraphs.jl @@ -79,6 +79,7 @@ export add_nodes, to_unidirected, random_walk_pe, remove_nodes, + ppr_diffusion, drop_nodes, # from Flux batch, diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index da43aba7e..f05d14a12 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -1168,3 +1168,49 @@ ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci @non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule @non_differentiable dense_zeros_like(x...) +""" + ppr_diffusion(g::GNNGraph{<:COO_T}, alpha =0.85f0) -> GNNGraph + +Calculates the Personalized PageRank (PPR) diffusion based on the edge weight matrix of a GNNGraph and updates the graph with new edge weights derived from the PPR matrix. +References paper: [The pagerank citation ranking: Bringing order to the web](http://ilpubs.stanford.edu:8090/422) + + +The function performs the following steps: +1. Constructs a modified adjacency matrix `A` using the graph's edge weights, where `A` is adjusted by `(α - 1) * A + I`, with `α` being the damping factor (`alpha_f32`) and `I` the identity matrix. +2. Normalizes `A` to ensure each column sums to 1, representing transition probabilities. +3. Applies the PPR formula `α * (I + (α - 1) * A)^-1` to compute the diffusion matrix. +4. Updates the original edge weights of the graph based on the PPR diffusion matrix, assigning new weights for each edge from the PPR matrix. + +# Arguments +- `g::GNNGraph`: The input graph for which PPR diffusion is to be calculated. It should have edge weights available. +- `alpha_f32::Float32`: The damping factor used in PPR calculation, controlling the teleport probability in the random walk. Defaults to `0.85f0`. + +# Returns +- A new `GNNGraph` instance with the same structure as `g` but with updated edge weights according to the PPR diffusion calculation. +""" +function ppr_diffusion(g::GNNGraph{<:COO_T}; alpha = 0.85f0) + s, t = edge_index(g) + w = get_edge_weight(g) + if isnothing(w) + w = ones(Float32, g.num_edges) + end + + N = g.num_nodes + + initial_A = sparse(t, s, w, N, N) + scaled_A = (Float32(alpha) - 1) * initial_A + + I_sparse = sparse(Diagonal(ones(Float32, N))) + A_sparse = I_sparse + scaled_A + + A_dense = Matrix(A_sparse) + + PPR = alpha * inv(A_dense) + + new_w = [PPR[dst, src] for (src, dst) in zip(s, t)] + + return GNNGraph((s, t, new_w), + g.num_nodes, length(s), g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) +end diff --git a/test/GNNGraphs/transform.jl b/test/GNNGraphs/transform.jl index c9f413064..af414bbd1 100644 --- a/test/GNNGraphs/transform.jl +++ b/test/GNNGraphs/transform.jl @@ -595,4 +595,24 @@ end @test g.graph[(:A, :to1, :A)][3] == vcat([2, 2, 2], fill(1, n)) end +end + +@testset "ppr_diffusion" begin + if GRAPH_T == :coo + s = [1, 1, 2, 3] + t = [2, 3, 4, 5] + eweights = [0.1, 0.2, 0.3, 0.4] + + g = GNNGraph(s, t, eweights) + + g_new = ppr_diffusion(g) + w_new = get_edge_weight(g_new) + + check_ew = Float32[0.012749999 + 0.025499998 + 0.038249996 + 0.050999995] + + @test w_new ≈ check_ew + end end \ No newline at end of file