Skip to content

Commit

Permalink
edgeconv working
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jul 28, 2024
1 parent 222a45d commit 4cb8112
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 22 deletions.
18 changes: 17 additions & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module GNNLux
using ConcreteStructs: @concrete
using NNlib: NNlib, sigmoid, relu
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using Lux: Lux, Dense, glorot_uniform, zeros32
using Lux: Lux, Dense, glorot_uniform, zeros32, StatefulLuxLayer
using Reexport: @reexport
using Random: AbstractRNG
using GNNlib: GNNlib
Expand All @@ -17,8 +17,24 @@ include("layers/conv.jl")
export AGNNConv,
CGConv,
ChebConv,
EdgeConv,
# EGNNConv,
# DConv,
# GATConv,
# GATv2Conv,
# GatedGraphConv,
GCNConv,
# GINConv,
# GMMConv,
GraphConv
# MEGNetConv,
# NNConv,
# ResGatedGraphConv,
# SAGEConv,
# SGConv,
# TAGConv,
# TransformerConv


end #module

52 changes: 32 additions & 20 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,8 @@
# Missing Layers

# | Layer |Sparse Ops|Edge Weight|Edge Features| Heterograph | TemporalSnapshotsGNNGraphs |
# | :-------- | :---: |:---: |:---: | :---: | :---: |
# | [`EGNNConv`](@ref) | | | ✓ | | |
# | [`EdgeConv`](@ref) | | | | ✓ | |
# | [`GATConv`](@ref) | | | ✓ | ✓ | ✓ |
# | [`GATv2Conv`](@ref) | | | ✓ | ✓ | ✓ |
# | [`GatedGraphConv`](@ref) | ✓ | | | | ✓ |
# | [`GINConv`](@ref) | ✓ | | | ✓ | ✓ |
# | [`GMMConv`](@ref) | | | ✓ | | |
# | [`MEGNetConv`](@ref) | | | ✓ | | |
# | [`NNConv`](@ref) | | | ✓ | | |
# | [`ResGatedGraphConv`](@ref) | | | | ✓ | ✓ |
# | [`SAGEConv`](@ref) | ✓ | | | ✓ | ✓ |
# | [`SGConv`](@ref) | ✓ | | | | ✓ |
# | [`TransformerConv`](@ref) | | | ✓ | | |

_getbias(ps) = hasproperty(ps, :bias) ? getproperty(ps, :bias) : false
_getstate(st, name) = hasproperty(st, name) ? getproperty(st, name) : NamedTuple()
_getstate(s::StatefulLuxLayer{true}) = s.st
_getstate(s::StatefulLuxLayer{false}) = s.st_any


@concrete struct GCNConv <: GNNLayer
in_dims::Int
Expand Down Expand Up @@ -235,11 +221,37 @@ function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false,
return CGConv((nin, ein), out, dense_f, dense_s, residual, init_weight, init_bias)
end

LuxCore.outputsize(l::CGConv) = (l.out_dims,)

(l::CGConv)(g, x, ps, st) = l(g, x, nothing, ps, st)

function (l::CGConv)(g, x, e, ps, st)
dense_f = StatefulLuxLayer(l.dense_f, ps.dense_f)
dense_s = StatefulLuxLayer(l.dense_s, ps.dense_s)
dense_f = StatefulLuxLayer{true}(l.dense_f, ps.dense_f, _getstate(st, :dense_f))
dense_s = StatefulLuxLayer{true}(l.dense_s, ps.dense_s, _getstate(st, :dense_s))
m = (; dense_f, dense_s, l.residual)
return GNNlib.cg_conv(m, g, x, e), st
end

@concrete struct EdgeConv <: GNNContainerLayer{(:nn,)}
nn <: AbstractExplicitLayer
aggr
end

EdgeConv(nn; aggr = max) = EdgeConv(nn, aggr)

function Base.show(io::IO, l::EdgeConv)
print(io, "EdgeConv(", l.nn)
print(io, ", aggr=", l.aggr)
print(io, ")")
end


function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st)
nn = StatefulLuxLayer{true}(l.nn, ps, st)
m = (; nn, l.aggr)
y = GNNlib.edge_conv(m, g, x)
stnew = _getstate(nn)
return y, stnew
end


43 changes: 43 additions & 0 deletions GNNLux/test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
using Lux
using Lux: AbstractExplicitContainerLayer, StatefulLuxLayer
using Random

struct A <: AbstractExplicitContainerLayer{(:x,)}
x
y
end


a = A(Dense(3, 5), true)
rng = Random.default_rng()
ps = Lux.initialparameters(rng, a)
ps.x #ERROR, no field named x
ps.weight # OK

struct B <: AbstractExplicitContainerLayer{(:x,:y)}
x
y
end

b = B(Dense(3, 5), Dense(5, 5))
rng = Random.default_rng()
ps = Lux.initialparameters(rng, b)
ps.x #OK
ps.y #OK

rng = Random.default_rng()
x = rand(rng, Float32, 2, 3)
model = Chain(Dense(2 => 5, relu), Dense(5 => 5))
ps = Lux.initialparameters(rng, model)
st = Lux.initialstates(rng, model)
y, _ = model(x, ps, st)

model2 = StatefulLuxLayer{true}(model, ps, st)
y2 = model2(x)


a = A(model, true)
ps = Lux.initialparameters(rng, a)
st = Lux.initialstates(rng, a)
model3 = StatefulLuxLayer(a.x, ps, st)
y3 = model3(x)
28 changes: 28 additions & 0 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,32 @@
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
end

@testset "EdgeConv" begin
nn = Chain(Dense(6 => 5, relu), Dense(5 => 5))
l = EdgeConv(nn, aggr = +)
@test l isa GNNContainerLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)
y, st′ = l(g, x, ps, st)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
end

@testset "CGConv" begin
l = CGConv(3 => 5, residual = true)
@test l isa GNNContainerLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)
y, st′ = l(g, x, ps, st)
@test size(y) == (5, 10)
@test Lux.outputsize(l) == (5,)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
end
end
2 changes: 1 addition & 1 deletion GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ function edge_conv(l, g::AbstractGNNGraph, x)
xj, xi = expand_srcdst(g, x)

message = Fix1(edge_conv_message, l)
x = propagate(message, g, l.aggr, xi = xi, xj = xj, e = nothing)
x = propagate(message, g, l.aggr; xi, xj, e = nothing)
return x
end

Expand Down

0 comments on commit 4cb8112

Please sign in to comment.