Skip to content

Commit

Permalink
Add GConvLSTM temporal layer (#437)
Browse files Browse the repository at this point in the history
* Add first draft GConvLSTM

* Fix spaces

* Add show method

* Add `GConvLSTM` export

* Fix `GConvLSTM`

* Add `GConvLSTM` tests

* Add `GCLSTM` docstring

* Add temporal feat example

* Fix missing end
  • Loading branch information
aurorarossi authored Jun 16, 2024
1 parent 942fe91 commit 36e8373
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ export
# layers/temporalconv
TGCN,
A3TGCN,
GConvLSTM,
GConvGRU,

# layers/pool
Expand Down
122 changes: 122 additions & 0 deletions src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,128 @@ Flux.Recur(ggru::GConvGRUCell) = Flux.Recur(ggru, ggru.state0)
_applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph, x) = l(g, x)
_applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph) = l(g)

struct GConvLSTMCell <: GNNLayer
conv_x_i::ChebConv
conv_h_i::ChebConv
w_i
b_i
conv_x_f::ChebConv
conv_h_f::ChebConv
w_f
b_f
conv_x_c::ChebConv
conv_h_c::ChebConv
w_c
b_c
conv_x_o::ChebConv
conv_h_o::ChebConv
w_o
b_o
k::Int
state0
in::Int
out::Int
end

Flux.@functor GConvLSTMCell

function GConvLSTMCell(ch::Pair{Int, Int}, k::Int, n::Int;
bias::Bool = true,
init = Flux.glorot_uniform,
init_state = Flux.zeros32)
in, out = ch
# input gate
conv_x_i = ChebConv(in => out, k; bias, init)
conv_h_i = ChebConv(out => out, k; bias, init)
w_i = init(out, 1)
b_i = bias ? Flux.create_bias(w_i, true, out) : false
# forget gate
conv_x_f = ChebConv(in => out, k; bias, init)
conv_h_f = ChebConv(out => out, k; bias, init)
w_f = init(out, 1)
b_f = bias ? Flux.create_bias(w_f, true, out) : false
# cell state
conv_x_c = ChebConv(in => out, k; bias, init)
conv_h_c = ChebConv(out => out, k; bias, init)
w_c = init(out, 1)
b_c = bias ? Flux.create_bias(w_c, true, out) : false
# output gate
conv_x_o = ChebConv(in => out, k; bias, init)
conv_h_o = ChebConv(out => out, k; bias, init)
w_o = init(out, 1)
b_o = bias ? Flux.create_bias(w_o, true, out) : false
state0 = (init_state(out, n), init_state(out, n))
return GConvLSTMCell(conv_x_i, conv_h_i, w_i, b_i,
conv_x_f, conv_h_f, w_f, b_f,
conv_x_c, conv_h_c, w_c, b_c,
conv_x_o, conv_h_o, w_o, b_o,
k, state0, in, out)
end

function (gclstm::GConvLSTMCell)((h, c), g::GNNGraph, x)
# input gate
i = gclstm.conv_x_i(g, x) .+ gclstm.conv_h_i(g, h) .+ gclstm.w_i .* c .+ gclstm.b_i
i = Flux.sigmoid_fast(i)
# forget gate
f = gclstm.conv_x_f(g, x) .+ gclstm.conv_h_f(g, h) .+ gclstm.w_f .* c .+ gclstm.b_f
f = Flux.sigmoid_fast(f)
# cell state
c = f .* c .+ i .* Flux.tanh_fast(gclstm.conv_x_c(g, x) .+ gclstm.conv_h_c(g, h) .+ gclstm.w_c .* c .+ gclstm.b_c)
# output gate
o = gclstm.conv_x_o(g, x) .+ gclstm.conv_h_o(g, h) .+ gclstm.w_o .* c .+ gclstm.b_o
o = Flux.sigmoid_fast(o)
h = o .* Flux.tanh_fast(c)
return (h,c), h
end

function Base.show(io::IO, gclstm::GConvLSTMCell)
print(io, "GConvLSTMCell($(gclstm.in) => $(gclstm.out))")
end

"""
GConvLSTM(in => out, k, n; [bias, init, init_state])
Graph Convolutional Long Short-Term Memory (GConvLSTM) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659).
Performs a layer of ChebConv to model spatial dependencies, followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies.
# Arguments
- `in`: Number of input features.
- `out`: Number of output features.
- `k`: Chebyshev polynomial order.
- `n`: Number of nodes in the graph.
- `bias`: Add learnable bias. Default `true`.
- `init`: Weights' initializer. Default `glorot_uniform`.
- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`.
# Examples
```jldoctest
julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5);
julia> gclstm = GConvLSTM(2 => 5, 2, g1.num_nodes);
julia> y = gclstm(g1, x1);
julia> size(y)
(5, 5)
julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30);
julia> z = gclstm(g2, x2);
julia> size(z)
(5, 5, 30)
```
"""
GConvLSTM(ch, k, n; kwargs...) = Flux.Recur(GConvLSTMCell(ch, k, n; kwargs...))
Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0)

(l::Flux.Recur{GConvLSTMCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
_applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x)
_applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g)

function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
return l.(tg.snapshots, x)
end
Expand Down
14 changes: 13 additions & 1 deletion test/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ end
@test model(g1) isa GNNGraph
end

@testset "GConvLSTMCell" begin
gconvlstm = GraphNeuralNetworks.GConvLSTMCell(in_channel => out_channel, 2, g1.num_nodes)
(h, c), h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x)
@test size(h) == (out_channel, N)
@test size(c) == (out_channel, N)
end

@testset "GConvLSTM" begin
gconvlstm = GConvLSTM(in_channel => out_channel, 2, g1.num_nodes)
@test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N)
model = GNNChain(GConvLSTM(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1))
end

@testset "GConvGRUCell" begin
gconvlstm = GraphNeuralNetworks.GConvGRUCell(in_channel => out_channel, 2, g1.num_nodes)
h, h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x)
Expand All @@ -55,7 +68,6 @@ end
@test length(Flux.gradient(x ->sum(sum(ginconv(tg, x))), tg.ndata.x)[1]) == S
end


@testset "ChebConv" begin
chebconv = ChebConv(in_channel => out_channel, 5)
@test length(chebconv(tg, tg.ndata.x)) == S
Expand Down

0 comments on commit 36e8373

Please sign in to comment.