diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index ecac67b5a..ee868b6b6 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -1,8 +1,8 @@ module GNNLux using ConcreteStructs: @concrete -using NNlib: NNlib, sigmoid, relu +using NNlib: NNlib, sigmoid, relu, swish using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer -using Lux: Lux, Dense, glorot_uniform, zeros32, StatefulLuxLayer +using Lux: Lux, Chain, Dense, glorot_uniform, zeros32, StatefulLuxLayer using Reexport: @reexport using Random: AbstractRNG using GNNlib: GNNlib @@ -18,10 +18,10 @@ export AGNNConv, CGConv, ChebConv, EdgeConv, - # EGNNConv, - # DConv, - # GATConv, - # GATv2Conv, + EGNNConv, + DConv, + GATConv, + GATv2Conv, # GatedGraphConv, GCNConv, # GINConv, diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 2dc638d95..15b1bbf4b 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -255,3 +255,264 @@ function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st) end +@concrete struct EGNNConv <: GNNContainerLayer{(:ϕe, :ϕx, :ϕh)} + ϕe + ϕx + ϕh + num_features + residual::Bool +end + +function EGNNConv(ch::Pair{Int, Int}, hidden_size = 2 * ch[1]; residual = false) + return EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual) +end + +#Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py +function EGNNConv(ch::Pair{NTuple{2, Int}, Int}; hidden_size::Int = 2 * ch[1][1], + residual = false) + (in_size, edge_feat_size), out_size = ch + act_fn = swish + + # +1 for the radial feature: ||x_i - x_j||^2 + ϕe = Chain(Dense(in_size * 2 + edge_feat_size + 1 => hidden_size, act_fn), + Dense(hidden_size => hidden_size, act_fn)) + + ϕh = Chain(Dense(in_size + hidden_size => hidden_size, swish), + Dense(hidden_size => out_size)) + + ϕx = Chain(Dense(hidden_size => hidden_size, swish), + Dense(hidden_size => 1, use_bias = false)) + + num_features = (in = in_size, edge = edge_feat_size, out = out_size, + hidden = hidden_size) + if residual + @assert in_size==out_size "Residual connection only possible if in_size == out_size" + end + return EGNNConv(ϕe, ϕx, ϕh, num_features, residual) +end + +LuxCore.outputsize(l::EGNNConv) = (l.num_features.out,) + +(l::EGNNConv)(g, h, x, ps, st) = l(g, h, x, nothing, ps, st) + +function (l::EGNNConv)(g, h, x, e, ps, st) + ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe)) + ϕx = StatefulLuxLayer{true}(l.ϕx, ps.ϕx, _getstate(st, :ϕx)) + ϕh = StatefulLuxLayer{true}(l.ϕh, ps.ϕh, _getstate(st, :ϕh)) + m = (; ϕe, ϕx, ϕh, l.residual, l.num_features) + return GNNlib.egnn_conv(m, g, h, x, e), st +end + +function Base.show(io::IO, l::EGNNConv) + ne = l.num_features.edge + nin = l.num_features.in + nout = l.num_features.out + nh = l.num_features.hidden + print(io, "EGNNConv(($nin, $ne) => $nout; hidden_size=$nh") + if l.residual + print(io, ", residual=true") + end + print(io, ")") +end + +@concrete struct DConv <: GNNLayer + in_dims::Int + out_dims::Int + k::Int + init_weight + init_bias + use_bias::Bool +end + +function DConv(ch::Pair{Int, Int}, k::Int; + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias = true) + in, out = ch + return DConv(in, out, k, init_weight, init_bias, use_bias) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::DConv) + weights = l.init_weight(rng, 2, l.k, l.out_dims, l.in_dims) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + return (; weights, bias) + else + return (; weights) + end +end + +LuxCore.outputsize(l::DConv) = (l.out_dims,) +LuxCore.parameterlength(l::DConv) = l.use_bias ? 2 * l.in_dims * l.out_dims * l.k + l.out_dims : + 2 * l.in_dims * l.out_dims * l.k + +function (l::DConv)(g, x, ps, st) + m = (; ps.weights, bias = _getbias(ps), l.k) + return GNNlib.d_conv(m, g, x), st +end + +function Base.show(io::IO, l::DConv) + print(io, "DConv($(l.in_dims) => $(l.out_dims), k=$(l.k))") +end + +@concrete struct GATConv <: GNNLayer + dense_x + dense_e + init_weight + init_bias + use_bias::Bool + σ + negative_slope + channel::Pair{NTuple{2, Int}, Int} + heads::Int + concat::Bool + add_self_loops::Bool + dropout +end + + +GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...) + +function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity; + heads::Int = 1, concat::Bool = true, negative_slope = 0.2, + init_weight = glorot_uniform, init_bias = zeros32, + use_bias::Bool = true, + add_self_loops = true, dropout=0.0) + (in, ein), out = ch + if add_self_loops + @assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported." + end + + dense_x = Dense(in => out * heads, use_bias = false) + dense_e = ein > 0 ? Dense(ein => out * heads, use_bias = false) : nothing + negative_slope = convert(Float32, negative_slope) + return GATConv(dense_x, dense_e, init_weight, init_bias, use_bias, + σ, negative_slope, ch, heads, concat, add_self_loops, dropout) +end + +LuxCore.outputsize(l::GATConv) = (l.concat ? l.channel[2]*l.heads : l.channel[2],) +##TODO: parameterlength + +function LuxCore.initialparameters(rng::AbstractRNG, l::GATConv) + (in, ein), out = l.channel + dense_x = LuxCore.initialparameters(rng, l.dense_x) + a = l.init_weight(ein > 0 ? 3out : 2out, l.heads) + ps = (; dense_x, a) + if ein > 0 + ps = (ps..., dense_e = LuxCore.initialparameters(rng, l.dense_e)) + end + if l.use_bias + ps = (ps..., bias = l.init_bias(rng, l.concat ? out * l.heads : out)) + end + return ps +end + +(l::GATConv)(g, x, ps, st) = l(g, x, nothing, ps, st) + +function (l::GATConv)(g, x, e, ps, st) + dense_x = StatefulLuxLayer{true}(l.dense_x, ps.dense_x, _getstate(st, :dense_x)) + dense_e = l.dense_e === nothing ? nothing : + StatefulLuxLayer{true}(l.dense_e, ps.dense_e, _getstate(st, :dense_e)) + + m = (; l.add_self_loops, l.channel, l.heads, l.concat, l.dropout, l.σ, + ps.a, bias = _getbias(ps), dense_x, dense_e, l.negative_slope) + return GNNlib.gat_conv(m, g, x, e), st +end + +function Base.show(io::IO, l::GATConv) + (in, ein), out = l.channel + print(io, "GATConv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads) + l.σ == identity || print(io, ", ", l.σ) + print(io, ", negative_slope=", l.negative_slope) + print(io, ")") +end + +@concrete struct GATv2Conv <: GNNLayer + dense_i + dense_j + dense_e + init_weight + init_bias + use_bias::Bool + σ + negative_slope + channel::Pair{NTuple{2, Int}, Int} + heads::Int + concat::Bool + add_self_loops::Bool + dropout +end + +function GATv2Conv(ch::Pair{Int, Int}, args...; kws...) + GATv2Conv((ch[1], 0) => ch[2], args...; kws...) +end + +function GATv2Conv(ch::Pair{NTuple{2, Int}, Int}, + σ = identity; + heads::Int = 1, + concat::Bool = true, + negative_slope = 0.2, + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias::Bool = true, + add_self_loops = true, + dropout=0.0) + + (in, ein), out = ch + + if add_self_loops + @assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported." + end + + dense_i = Dense(in => out * heads; use_bias, init_weight, init_bias) + dense_j = Dense(in => out * heads; use_bias = false, init_weight) + if ein > 0 + dense_e = Dense(ein => out * heads; use_bias = false, init_weight) + else + dense_e = nothing + end + return GATv2Conv(dense_i, dense_j, dense_e, + init_weight, init_bias, use_bias, + σ, negative_slope, + ch, heads, concat, add_self_loops, dropout) +end + + +LuxCore.outputsize(l::GATv2Conv) = (l.concat ? l.channel[2]*l.heads : l.channel[2],) +##TODO: parameterlength + +function LuxCore.initialparameters(rng::AbstractRNG, l::GATv2Conv) + (in, ein), out = l.channel + dense_i = LuxCore.initialparameters(rng, l.dense_i) + dense_j = LuxCore.initialparameters(rng, l.dense_j) + a = l.init_weight(out, l.heads) + ps = (; dense_i, dense_j, a) + if ein > 0 + ps = (ps..., dense_e = LuxCore.initialparameters(rng, l.dense_e)) + end + if l.use_bias + ps = (ps..., bias = l.init_bias(rng, l.concat ? out * l.heads : out)) + end + return ps +end + +(l::GATv2Conv)(g, x, ps, st) = l(g, x, nothing, ps, st) + +function (l::GATv2Conv)(g, x, e, ps, st) + dense_i = StatefulLuxLayer{true}(l.dense_i, ps.dense_i, _getstate(st, :dense_i)) + dense_j = StatefulLuxLayer{true}(l.dense_j, ps.dense_j, _getstate(st, :dense_j)) + dense_e = l.dense_e === nothing ? nothing : + StatefulLuxLayer{true}(l.dense_e, ps.dense_e, _getstate(st, :dense_e)) + + m = (; l.add_self_loops, l.channel, l.heads, l.concat, l.dropout, l.σ, + ps.a, bias = _getbias(ps), dense_i, dense_j, dense_e, l.negative_slope) + return GNNlib.gatv2_conv(m, g, x, e), st +end + +function Base.show(io::IO, l::GATv2Conv) + (in, ein), out = l.channel + print(io, "GATv2Conv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads) + l.σ == identity || print(io, ", ", l.σ) + print(io, ", negative_slope=", l.negative_slope) + print(io, ")") +end \ No newline at end of file diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 520fcc570..b2e81173d 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -1,36 +1,81 @@ @testitem "layers/conv" setup=[SharedTestSetup] begin rng = StableRNG(1234) g = rand_graph(10, 40, seed=1234) - x = randn(rng, Float32, 3, 10) + in_dims = 3 + out_dims = 5 + x = randn(rng, Float32, in_dims, 10) @testset "GCNConv" begin - l = GCNConv(3 => 5, relu) - test_lux_layer(rng, l, g, x, outputsize=(5,)) + l = GCNConv(in_dims => out_dims, relu) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) end @testset "ChebConv" begin - l = ChebConv(3 => 5, 2) - test_lux_layer(rng, l, g, x, outputsize=(5,)) + l = ChebConv(in_dims => out_dims, 2) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) end @testset "GraphConv" begin - l = GraphConv(3 => 5, relu) - test_lux_layer(rng, l, g, x, outputsize=(5,)) + l = GraphConv(in_dims => out_dims, relu) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) end @testset "AGNNConv" begin l = AGNNConv(init_beta=1.0f0) - test_lux_layer(rng, l, g, x, sizey=(3,10)) + test_lux_layer(rng, l, g, x, sizey=(in_dims, 10)) end @testset "EdgeConv" begin - nn = Chain(Dense(6 => 5, relu), Dense(5 => 5)) + nn = Chain(Dense(2*in_dims => 5, relu), Dense(5 => out_dims)) l = EdgeConv(nn, aggr = +) - test_lux_layer(rng, l, g, x, sizey=(5,10), container=true) + test_lux_layer(rng, l, g, x, sizey=(out_dims,10), container=true) end @testset "CGConv" begin - l = CGConv(3 => 3, residual = true) - test_lux_layer(rng, l, g, x, outputsize=(3,), container=true) + l = CGConv(in_dims => in_dims, residual = true) + test_lux_layer(rng, l, g, x, outputsize=(in_dims,), container=true) + end + + @testset "DConv" begin + l = DConv(in_dims => out_dims, 2) + test_lux_layer(rng, l, g, x, outputsize=(5,)) + end + + @testset "EGNNConv" begin + hin = 6 + hout = 7 + hidden = 8 + l = EGNNConv(hin => hout, hidden) + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + h = randn(rng, Float32, hin, g.num_nodes) + (hnew, xnew), stnew = l(g, h, x, ps, st) + @test size(hnew) == (hout, g.num_nodes) + @test size(xnew) == (in_dims, g.num_nodes) + end + + @testset "GATConv" begin + x = randn(rng, Float32, 6, 10) + + l = GATConv(6 => 8, heads=2) + test_lux_layer(rng, l, g, x, outputsize=(16,)) + + l = GATConv(6 => 8, heads=2, concat=false, dropout=0.5) + test_lux_layer(rng, l, g, x, outputsize=(8,)) + + #TODO test edge + end + + @testset "GATv2Conv" begin + x = randn(rng, Float32, 6, 10) + + l = GATv2Conv(6 => 8, heads=2) + test_lux_layer(rng, l, g, x, outputsize=(16,)) + + l = GATv2Conv(6 => 8, heads=2, concat=false, dropout=0.5) + test_lux_layer(rng, l, g, x, outputsize=(8,)) + + #TODO test edge end end + diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index cd3606291..2fb5bc44f 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -161,7 +161,8 @@ function gat_message(l, Wxi, Wxj, e) Wxx = vcat(Wxi, Wxj, We) end aWW = sum(l.a .* Wxx, dims = 1) # 1 × nheads × nedges - logα = leakyrelu.(aWW, l.negative_slope) + slope = convert(eltype(aWW), l.negative_slope) + logα = leakyrelu.(aWW, slope) return (; logα, Wxj) end @@ -207,7 +208,8 @@ function gatv2_message(l, Wxi, Wxj, e) if e !== nothing Wx += reshape(l.dense_e(e), out, heads, :) end - logα = sum(l.a .* leakyrelu.(Wx, l.negative_slope), dims = 1) # 1 × heads × nedges + slope = convert(eltype(Wx), l.negative_slope) + logα = sum(l.a .* leakyrelu.(Wx, slope), dims = 1) # 1 × heads × nedges return (; logα, Wxj) end @@ -703,14 +705,14 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix) h = l.weights[1,1,:,:] * x .+ l.weights[2,1,:,:] * x T0 = x - if l.K > 1 + if l.k > 1 # T1_in = T0 * deg_in * A' #T1_out = T0 * deg_out' * A T1_out = propagate(w_mul_xj, g, +; xj = T0*deg_out') T1_in = propagate(w_mul_xj, gt, +; xj = T0*deg_in) h = h .+ l.weights[1,2,:,:] * T1_in .+ l.weights[2,2,:,:] * T1_out end - for i in 2:l.K + for i in 2:l.k T2_in = propagate(w_mul_xj, gt, +; xj = T1_in*deg_in) T2_in = 2 * T2_in - T0 T2_out = propagate(w_mul_xj, g ,+; xj = T1_out*deg_out') diff --git a/src/layers/conv.jl b/src/layers/conv.jl index ddfa4e945..ec9268bd0 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -451,8 +451,8 @@ function GATv2Conv(ch::Pair{NTuple{2, Int}, Int}, end b = bias ? Flux.create_bias(dense_i.weight, true, concat ? out * heads : out) : false a = init(out, heads) - negative_slope = convert(eltype(dense_i.weight), negative_slope) - GATv2Conv(dense_i, dense_j, dense_e, b, a, σ, negative_slope, ch, heads, concat, + return GATv2Conv(dense_i, dense_j, dense_e, + b, a, σ, negative_slope, ch, heads, concat, add_self_loops, dropout) end @@ -1536,14 +1536,14 @@ function Base.show(io::IO, l::TransformerConv) end """ - DConv(ch::Pair{Int, Int}, K::Int; init = glorot_uniform, bias = true) + DConv(ch::Pair{Int, Int}, k::Int; init = glorot_uniform, bias = true) Diffusion convolution layer from the paper [Diffusion Convolutional Recurrent Neural Networks: Data-Driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926). # Arguments - `ch`: Pair of input and output dimensions. -- `K`: Number of diffusion steps. +- `k`: Number of diffusion steps. - `init`: Weights' initializer. Default `glorot_uniform`. - `bias`: Add learnable bias. Default `true`. @@ -1552,7 +1552,7 @@ Diffusion convolution layer from the paper [Diffusion Convolutional Recurrent Ne julia> g = GNNGraph(rand(10, 10), ndata = rand(Float32, 2, 10)); julia> dconv = DConv(2 => 4, 4) -DConv(2 => 4, K=4) +DConv(2 => 4, 4) julia> y = dconv(g, g.ndata.x); @@ -1565,20 +1565,20 @@ struct DConv <: GNNLayer out::Int weights::AbstractArray bias::AbstractArray - K::Int + k::Int end @functor DConv -function DConv(ch::Pair{Int, Int}, K::Int; init = glorot_uniform, bias = true) +function DConv(ch::Pair{Int, Int}, k::Int; init = glorot_uniform, bias = true) in, out = ch - weights = init(2, K, out, in) + weights = init(2, k, out, in) b = bias ? Flux.create_bias(weights, true, out) : false - DConv(in, out, weights, b, K) + return DConv(in, out, weights, b, k) end (l::DConv)(g, x) = GNNlib.d_conv(l, g, x) function Base.show(io::IO, l::DConv) - print(io, "DConv($(l.in) => $(l.out), K=$(l.K))") -end \ No newline at end of file + print(io, "DConv($(l.in) => $(l.out), $(l.k))") +end