Skip to content

Commit

Permalink
optimize test for heterograph (#370)
Browse files Browse the repository at this point in the history
  • Loading branch information
askorupka authored Jan 29, 2024
1 parent 9e0ad4a commit 8a6802b
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions test/layers/heteroconv.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@testset "HeteroGraphConv" begin
d, n = 3, 5
g = rand_bipartite_heterograph(n, 2*n, 15)
hg = rand_bipartite_heterograph((2,3), 6)

model = HeteroGraphConv([(:A,:to,:B) => GraphConv(d => d),
(:B,:to,:A) => GraphConv(d => d)])
Expand Down Expand Up @@ -93,20 +94,18 @@
end

@testset "CGConv" begin
g = rand_bipartite_heterograph((2,3), 6)
x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3))
layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, relu),
(:B, :to, :A) => CGConv(4 => 2, relu));
y = layers(g, x);
y = layers(hg, x);
@test size(y.A) == (2,2) && size(y.B) == (2,3)
end

@testset "EdgeConv" begin
g = rand_bipartite_heterograph((2,3), 6)
x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3))
layers = HeteroGraphConv( (:A, :to, :B) => EdgeConv(Dense(2 * 4, 2), aggr = +),
(:B, :to, :A) => EdgeConv(Dense(2 * 4, 2), aggr = +));
y = layers(g, x);
y = layers(hg, x);
@test size(y.A) == (2,2) && size(y.B) == (2,3)
end

Expand Down

0 comments on commit 8a6802b

Please sign in to comment.