Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
aurorarossi committed Aug 15, 2024
1 parent c5d90ee commit c06a465
Showing 1 changed file with 25 additions and 20 deletions.
45 changes: 25 additions & 20 deletions GNNLux/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,25 @@
cell <: Union{<:Lux.AbstractRecurrentCell, <:GNNContainerLayer}
end

function initialstates(rng::AbstractRNG, r::StatefulRecurrentCell)
return (cell=initialstates(rng, r.cell), carry=nothing)
function LuxCore.initialstates(rng::AbstractRNG, r::GNNLux.StatefulRecurrentCell)
return (cell=LuxCore.initialstates(rng, r.cell), carry=nothing)
end

function (r::StatefulRecurrentCell)(g, x, ps, st::NamedTuple)
(out, carry), st_ = applyrecurrentcell(r.cell, g, x, ps, st.cell, st.carry)
return out, (; cell=st_, carry)
function (r::StatefulRecurrentCell)(g, x::AbstractMatrix, ps, st::NamedTuple)
(out, carry), st = applyrecurrentcell(r.cell, g, x, ps, st.cell, st.carry)
return out, (; cell=st, carry)
end

function applyrecurrentcell(l, g, x, ps, st, carry)
return Lux.apply(l, g, (x, carry), ps, st)
function (r::StatefulRecurrentCell)(g, x::AbstractVector, ps, st::NamedTuple)
(out, carry), st = applyrecurrentcell(r.cell, g, first(x), ps, st.cell, st.carry)
for xᵢ in x[(begin + 1):end]
(out, carry), st = applyrecurrentcell(r.cell, g, xᵢ, ps, st, carry)
end
return out, (; cell=st, carry)
end

function applyrecurrentcell(l, g, x, ps, st, ::Nothing)
return Lux.apply(l, g, x, ps, st)
function applyrecurrentcell(l, g, x, ps, st, carry)
return Lux.apply(l, g, (x, carry), ps, st)
end

LuxCore.apply(m::GNNContainerLayer, g, x, ps, st) = m(g, x, ps, st)
Expand All @@ -26,29 +30,30 @@ LuxCore.apply(m::GNNContainerLayer, g, x, ps, st) = m(g, x, ps, st)
out_dims::Int
conv
gru
init_state::Function
end

function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
in_dims, out_dims = ch
conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight, allow_fast_activation= true)
gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state)
return TGCNCell(in_dims, out_dims, conv, gru)
return TGCNCell(in_dims, out_dims, conv, gru, init_state)
end

LuxCore.outputsize(l::TGCNCell) = (l.out_dims,)

function (l::TGCNCell)(g, x, h, ps, st)
conv = StatefulLuxLayer{true}(l.conv, ps.conv, _getstate(st, :conv))
gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru))
#m = (; conv, gru)

function (l::TGCNCell)(g, (x, h), ps, st)
if h === nothing
h = l.init_state(l.out_dims, 1)
end
x̃, stconv = l.conv(g, x, ps.conv, st.conv)
(h, (h,)), st = l.gru((x̃,(h,)), ps.gru,st.gru)
return (h, (h,)), st
(h, (h,)), stgru = l.gru((x̃,(h,)), ps.gru,st.gru)
return (h, h), (conv=stconv, gru=stgru)
end

LuxCore.outputsize(l::TGCNCell) = (l.out_dims,)
LuxCore.outputsize(l::GNNLux.StatefulRecurrentCell) = (l.cell.out_dims,)

function Base.show(io::IO, tgcn::TGCNCell)
print(io, "TGCNCell($(tgcn.in_dims) => $(tgcn.out_dims))")
end

tgcn = StatefulRecurrentCell(TGCNCell(1 =>3))
TGCN(ch::Pair{Int, Int}; kwargs...) = GNNLux.StatefulRecurrentCell(TGCNCell(ch; kwargs...))

0 comments on commit c06a465

Please sign in to comment.