Skip to content

Commit

Permalink
added sgconv lux
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky committed Aug 1, 2024
1 parent 4b4477e commit b55d31d
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 2 deletions.
4 changes: 2 additions & 2 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ export AGNNConv,
GCNConv,
# GINConv,
# GMMConv,
GraphConv
GraphConv,
# MEGNetConv,
# NNConv,
# ResGatedGraphConv,
# SAGEConv,
# SGConv,
SGConv
# TAGConv,
# TransformerConv

Expand Down
56 changes: 56 additions & 0 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,4 +515,60 @@ function Base.show(io::IO, l::GATv2Conv)
l.σ == identity || print(io, ", ", l.σ)
print(io, ", negative_slope=", l.negative_slope)
print(io, ")")
end

@concrete struct SGConv <: GNNLayer
in_dims::Int
out_dims::Int
k::Int
use_bias::Bool
add_self_loops::Bool
use_edge_weight::Bool
init_weight
init_bias
end

function SGConv(ch::Pair{Int, Int}, k = 1;
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true,
add_self_loops::Bool = true,
use_edge_weight::Bool = false)
in_dims, out_dims = ch
return SGConv(in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, k)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::SGConv)
weight = l.init_weight(rng, l.out_dims, l.in_dims)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
return (; weight, bias)
else
return (; weight)
end
end

LuxCore.parameterlength(l::SGConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims
LuxCore.statelength(d::SGConv) = 0
LuxCore.outputsize(d::SGConv) = (d.out_dims,)

function Base.show(io::IO, l::SGConv)
print(io, "SGConv(", l.in_dims, " => ", l.out_dims)
l.k || print(io, ", ", l.k)
l.use_bias || print(io, ", use_bias=false")
l.add_self_loops || print(io, ", add_self_loops=false")
!l.use_edge_weight || print(io, ", use_edge_weight=true")
print(io, ")")
end

(l::SGConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing) =
l(g, x, edge_weight, ps, st; conv_weight)

function (l::SGConv)(g, x, edge_weight, ps, st;
conv_weight=nothing, )

m = (; ps.weight, bias = _getbias(ps),
l.add_self_loops, l.use_edge_weight, l.σ)
y = GNNlib.sg_conv(m, g, x, edge_weight, conv_weight)
return y, st
end
5 changes: 5 additions & 0 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,10 @@

#TODO test edge
end

@testset "SGConv" begin
l = SGConv(in_dims => out_dims, relu)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end
end

54 changes: 54 additions & 0 deletions GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,57 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix)
end
return h .+ l.bias
end

####################### GCNConv ######################################

function sg_conv(l, g::AbstractGNNGraph, x, edge_weight::EW, conv_weight::CW) where
{EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F}
if edge_weight !== nothing
@assert length(edge_weight) == g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))"
end

if conv_weight === nothing
weight = l.weight
else
weight = conv_weight
if size(weight) != size(l.weight)
throw(ArgumentError("The weight matrix has the wrong size. Expected $(size(l.weight)) but got $(size(weight))"))
end
end

if l.add_self_loops
g = add_self_loops(g)
if edge_weight !== nothing
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
@assert length(edge_weight) == g.num_edges
end
end
Dout, Din = size(l.weight)
if Dout < Din
x = l.weight * x
end
d = degree(g, T; dir=:in, edge_weight)
c = 1 ./ sqrt.(d)
for iter in 1:l.k
x = x .* c'
if edge_weight !== nothing
x = propagate(e_mul_xj, g, +, xj=x, e=edge_weight)
elseif l.use_edge_weight
x = propagate(w_mul_xj, g, +, xj=x)
else
x = propagate(copy_xj, g, +, xj=x)
end
x = x .* c'
end
if Dout >= Din
x = l.weight * x
end
return (x .+ l.bias)
end

# when we also have edge_weight we need to convert the graph to COO
function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x, edge_weight::EW, conv_weight::CW) where
{EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F}
g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO
return gcn_conv(l, g, x, edge_weight, conv_weight)
end

0 comments on commit b55d31d

Please sign in to comment.