diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index bd869564e..f1a07b7a7 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -1,6 +1,7 @@ @testset "HeteroGraphConv" begin d, n = 3, 5 g = rand_bipartite_heterograph(n, 2*n, 15) + hg = rand_bipartite_heterograph((2,3), 6) model = HeteroGraphConv([(:A,:to,:B) => GraphConv(d => d), (:B,:to,:A) => GraphConv(d => d)]) @@ -93,20 +94,18 @@ end @testset "CGConv" begin - g = rand_bipartite_heterograph((2,3), 6) x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, relu), (:B, :to, :A) => CGConv(4 => 2, relu)); - y = layers(g, x); + y = layers(hg, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end @testset "EdgeConv" begin - g = rand_bipartite_heterograph((2,3), 6) x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) layers = HeteroGraphConv( (:A, :to, :B) => EdgeConv(Dense(2 * 4, 2), aggr = +), (:B, :to, :A) => EdgeConv(Dense(2 * 4, 2), aggr = +)); - y = layers(g, x); + y = layers(hg, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end