diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 2a867209d..ca1ed68d6 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -103,7 +103,7 @@ (x_new, e_new), st_new = l(g, x, ps, st) - #@test size(x_new) == (out_dims, g.num_nodes) - #@test size(e_new) == (out_dims, g.num_edges) + @test size(x_new) == (out_dims, g.num_nodes) + @test size(e_new) == (out_dims, g.num_edges) end end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index ebd9ed94a..2c13e62fc 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -357,12 +357,7 @@ end function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::Union{AbstractMatrix, Nothing}=nothing) check_num_nodes(g, x) - - if isnothing(e) - num_edges = g.num_edges - e = zeros(eltype(x), 0, num_edges) - end - + ē = apply_edges(g, xi = x, xj = x, e = e) do xi, xj, e l.ϕe(vcat(xi, xj, e)) end