From 23b89c26dcfcd9c4aae566ab3f36ad87a24f79f6 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Fri, 23 Aug 2024 20:18:48 +0530 Subject: [PATCH] Update conv.jl: reverted --- GNNLux/src/layers/conv.jl | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 64403e092..bca5eef7d 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -635,42 +635,45 @@ end in_dims::Int out_dims::Int use_bias::Bool + add_self_loops::Bool + use_edge_weight::Bool init_weight init_bias σ end + function NNConv(ch::Pair{Int, Int}, nn, σ = identity; aggr = +, init_bias = zeros32, use_bias::Bool = true, init_weight = glorot_uniform, + add_self_loops::Bool = true, + use_edge_weight::Bool = false, allow_fast_activation::Bool = true) in_dims, out_dims = ch σ = allow_fast_activation ? NNlib.fast_act(σ) : σ - return NNConv(nn, aggr, in_dims, out_dims, use_bias, init_weight, init_bias, σ) + return NNConv(nn, aggr, in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ) end function (l::NNConv)(g, x, edge_weight, ps, st) - nn = StatefulLuxLayer{true}(l.nn, ps.nn, st.nn) + nn = StatefulLuxLayer{true}(l.nn, ps, st) - m = (; nn, l.aggr, ps.weight, bias = ps.bias, l.σ) + m = (; nn, l.aggr, ps.weight, bias = _getbias(ps), + l.add_self_loops, l.use_edge_weight, l.σ) y = GNNlib.nn_conv(m, g, x, edge_weight) - stnew = (; nn = _getstate(nn)) + stnew = _getstate(nn) return y, stnew end - -function LuxCore.initialstates(rng::AbstractRNG, l::NNConv) - return (; nn = LuxCore.initialstates(rng, l.nn)) -end - -LuxCore.statelength(l::NNConv) = statelength(l.nn) LuxCore.outputsize(d::NNConv) = (d.out_dims,) function Base.show(io::IO, l::NNConv) - out, in = size(l.weight) - print(io, "NNConv($in => $out") - print(io, ", aggr=", l.aggr) + print(io, "NNConv($(l.nn)") + print(io, ", $(l.ϵ)") + l.σ == identity || print(io, ", ", l.σ) + l.use_bias || print(io, ", use_bias=false") + l.add_self_loops || print(io, ", add_self_loops=false") + !l.use_edge_weight || print(io, ", use_edge_weight=true") print(io, ")") end