From b55d31de8fd55b1c05ae5bfd7d37ad47c408c7bb Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 1 Aug 2024 18:13:44 +0530 Subject: [PATCH] added sgconv lux --- GNNLux/src/GNNLux.jl | 4 +-- GNNLux/src/layers/conv.jl | 56 ++++++++++++++++++++++++++++++++ GNNLux/test/layers/conv_tests.jl | 5 +++ GNNlib/src/layers/conv.jl | 54 ++++++++++++++++++++++++++++++ 4 files changed, 117 insertions(+), 2 deletions(-) diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index ee868b6b6..e4d1c09aa 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -26,12 +26,12 @@ export AGNNConv, GCNConv, # GINConv, # GMMConv, - GraphConv + GraphConv, # MEGNetConv, # NNConv, # ResGatedGraphConv, # SAGEConv, - # SGConv, + SGConv # TAGConv, # TransformerConv diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 15b1bbf4b..26af1cf7e 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -515,4 +515,60 @@ function Base.show(io::IO, l::GATv2Conv) l.σ == identity || print(io, ", ", l.σ) print(io, ", negative_slope=", l.negative_slope) print(io, ")") +end + +@concrete struct SGConv <: GNNLayer + in_dims::Int + out_dims::Int + k::Int + use_bias::Bool + add_self_loops::Bool + use_edge_weight::Bool + init_weight + init_bias +end + +function SGConv(ch::Pair{Int, Int}, k = 1; + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias::Bool = true, + add_self_loops::Bool = true, + use_edge_weight::Bool = false) + in_dims, out_dims = ch + return SGConv(in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, k) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::SGConv) + weight = l.init_weight(rng, l.out_dims, l.in_dims) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + return (; weight, bias) + else + return (; weight) + end +end + +LuxCore.parameterlength(l::SGConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims +LuxCore.statelength(d::SGConv) = 0 +LuxCore.outputsize(d::SGConv) = (d.out_dims,) + +function Base.show(io::IO, l::SGConv) + print(io, "SGConv(", l.in_dims, " => ", l.out_dims) + l.k || print(io, ", ", l.k) + l.use_bias || print(io, ", use_bias=false") + l.add_self_loops || print(io, ", add_self_loops=false") + !l.use_edge_weight || print(io, ", use_edge_weight=true") + print(io, ")") +end + +(l::SGConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing) = + l(g, x, edge_weight, ps, st; conv_weight) + +function (l::SGConv)(g, x, edge_weight, ps, st; + conv_weight=nothing, ) + + m = (; ps.weight, bias = _getbias(ps), + l.add_self_loops, l.use_edge_weight, l.σ) + y = GNNlib.sg_conv(m, g, x, edge_weight, conv_weight) + return y, st end \ No newline at end of file diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index b2e81173d..a4e2fdf79 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -77,5 +77,10 @@ #TODO test edge end + + @testset "SGConv" begin + l = SGConv(in_dims => out_dims, relu) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) + end end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 2fb5bc44f..3e97fae75 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -723,3 +723,57 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix) end return h .+ l.bias end + +####################### GCNConv ###################################### + +function sg_conv(l, g::AbstractGNNGraph, x, edge_weight::EW, conv_weight::CW) where + {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} + if edge_weight !== nothing + @assert length(edge_weight) == g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))" + end + + if conv_weight === nothing + weight = l.weight + else + weight = conv_weight + if size(weight) != size(l.weight) + throw(ArgumentError("The weight matrix has the wrong size. Expected $(size(l.weight)) but got $(size(weight))")) + end + end + + if l.add_self_loops + g = add_self_loops(g) + if edge_weight !== nothing + edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)] + @assert length(edge_weight) == g.num_edges + end + end + Dout, Din = size(l.weight) + if Dout < Din + x = l.weight * x + end + d = degree(g, T; dir=:in, edge_weight) + c = 1 ./ sqrt.(d) + for iter in 1:l.k + x = x .* c' + if edge_weight !== nothing + x = propagate(e_mul_xj, g, +, xj=x, e=edge_weight) + elseif l.use_edge_weight + x = propagate(w_mul_xj, g, +, xj=x) + else + x = propagate(copy_xj, g, +, xj=x) + end + x = x .* c' + end + if Dout >= Din + x = l.weight * x + end + return (x .+ l.bias) +end + +# when we also have edge_weight we need to convert the graph to COO +function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x, edge_weight::EW, conv_weight::CW) where + {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} + g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO + return gcn_conv(l, g, x, edge_weight, conv_weight) +end \ No newline at end of file