From 8a6802b7b510491ae72df8d44d6b029c11844756 Mon Sep 17 00:00:00 2001 From: Agata Skorupka <45850123+askorupka@users.noreply.github.com> Date: Tue, 30 Jan 2024 00:07:33 +0100 Subject: [PATCH] optimize test for heterograph (#370) --- test/layers/heteroconv.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index bd869564e..f1a07b7a7 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -1,6 +1,7 @@ @testset "HeteroGraphConv" begin d, n = 3, 5 g = rand_bipartite_heterograph(n, 2*n, 15) + hg = rand_bipartite_heterograph((2,3), 6) model = HeteroGraphConv([(:A,:to,:B) => GraphConv(d => d), (:B,:to,:A) => GraphConv(d => d)]) @@ -93,20 +94,18 @@ end @testset "CGConv" begin - g = rand_bipartite_heterograph((2,3), 6) x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, relu), (:B, :to, :A) => CGConv(4 => 2, relu)); - y = layers(g, x); + y = layers(hg, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end @testset "EdgeConv" begin - g = rand_bipartite_heterograph((2,3), 6) x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) layers = HeteroGraphConv( (:A, :to, :B) => EdgeConv(Dense(2 * 4, 2), aggr = +), (:B, :to, :A) => EdgeConv(Dense(2 * 4, 2), aggr = +)); - y = layers(g, x); + y = layers(hg, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end