diff --git a/.gitignore b/.gitignore index 7181205b6..91820619c 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ Manifest.toml .vscode LocalPreferences.toml .DS_Store -docs/src/democards/gridtheme.css \ No newline at end of file +docs/src/democards/gridtheme.css +test.jl \ No newline at end of file diff --git a/GNNGraphs/src/abstracttypes.jl b/GNNGraphs/src/abstracttypes.jl index b8959b807..73146160f 100644 --- a/GNNGraphs/src/abstracttypes.jl +++ b/GNNGraphs/src/abstracttypes.jl @@ -1,5 +1,5 @@ -const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V} +const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V <: Union{Nothing, AbstractVector}} const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}} const ADJMAT_T = AbstractMatrix const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T diff --git a/GNNGraphs/src/convert.jl b/GNNGraphs/src/convert.jl index 1e103db8b..3789309cb 100644 --- a/GNNGraphs/src/convert.jl +++ b/GNNGraphs/src/convert.jl @@ -4,27 +4,24 @@ function to_coo(data::EDict; num_nodes = nothing, kws...) graph = EDict{COO_T}() _num_nodes = NDict{Int}() num_edges = EDict{Int}() - if !isempty(data) - for k in keys(data) - d = data[k] - @assert d isa Tuple - if length(d) == 2 - d = (d..., nothing) - end - if num_nodes !== nothing - n1 = get(num_nodes, k[1], nothing) - n2 = get(num_nodes, k[3], nothing) - else - n1 = nothing - n2 = nothing - end - g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...) - graph[k] = g - num_edges[k] = nedges - _num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1]) - _num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2]) + for k in keys(data) + d = data[k] + @assert d isa Tuple + if length(d) == 2 + d = (d..., nothing) end - graph = Dict([k => v for (k, v) in pairs(graph)]...) # try to restrict the key/value types + if num_nodes !== nothing + n1 = get(num_nodes, k[1], nothing) + n2 = get(num_nodes, k[3], nothing) + else + n1 = nothing + n2 = nothing + end + g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...) + graph[k] = g + num_edges[k] = nedges + _num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1]) + _num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2]) end return graph, _num_nodes, num_edges end diff --git a/GNNGraphs/src/generate.jl b/GNNGraphs/src/generate.jl index 4e6738279..6005ac023 100644 --- a/GNNGraphs/src/generate.jl +++ b/GNNGraphs/src/generate.jl @@ -1,5 +1,5 @@ """ - rand_graph(n, m; bidirected=true, seed=-1, edge_weight = nothing, kws...) + rand_graph([rng,] n, m; bidirected=true, edge_weight = nothing, kws...) Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes and `m` edges. @@ -10,7 +10,7 @@ In any case, the output graph will contain no self-loops or multi-edges. A vector can be passed as `edge_weight`. Its length has to be equal to `m` in the directed case, and `m÷2` in the bidirected one. -Use a `seed > 0` for reproducibility. +Pass a random number generator as the first argument to make the generation reproducible. Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor. @@ -36,25 +36,42 @@ GNNGraph: # Each edge has a reverse julia> edge_index(g) ([1, 3, 3, 4], [3, 4, 1, 3]) - ``` """ -function rand_graph(n::Integer, m::Integer; bidirected = true, seed = -1, edge_weight = nothing, kws...) +function rand_graph(n::Integer, m::Integer; seed=-1, kws...) + if seed != -1 + Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_graph) + rng = MersenneTwister(seed) + else + rng = Random.default_rng() + end + return rand_graph(rng, n, m; kws...) +end + +function rand_graph(rng::AbstractRNG, n::Integer, m::Integer; + bidirected::Bool = true, + edge_weight::Union{AbstractVector, Nothing} = nothing, kws...) if bidirected - @assert iseven(m) "Need even number of edges for bidirected graphs, given m=$m." + @assert iseven(m) lazy"Need even number of edges for bidirected graphs, given m=$m." + s, t, _ = _rand_edges(rng, n, m ÷ 2; directed=false, self_loops=false) + s, t = vcat(s, t), vcat(t, s) + if edge_weight !== nothing + edge_weight = vcat(edge_weight, edge_weight) + end + else + s, t, _ = _rand_edges(rng, n, m; directed=true, self_loops=false) end - m2 = bidirected ? m ÷ 2 : m - return GNNGraph(Graphs.erdos_renyi(n, m2; is_directed = !bidirected, seed); edge_weight, kws...) + return GNNGraph((s, t, edge_weight); num_nodes=n, kws...) end """ - rand_heterograph(n, m; seed=-1, bidirected=false, kws...) + rand_heterograph([rng,] n, m; bidirected=false, kws...) -Construct an [`GNNHeteroGraph`](@ref) with number of nodes and edges +Construct an [`GNNHeteroGraph`](@ref) with random edges and with number of nodes and edges specified by `n` and `m` respectively. `n` and `m` can be any iterable of pairs specifing node/edge types and their numbers. -Use a `seed > 0` for reproducibility. +Pass a random number generator as a first argument to make the generation reproducible. Setting `bidirected=true` will generate a bidirected graph, i.e. each edge will have a reverse edge. Therefore, for each edge type `(:A, :rel, :B)` a corresponding reverse edge type `(:B, :rel, :A)` @@ -76,9 +93,19 @@ function rand_heterograph end # for generic iterators of pairs rand_heterograph(n, m; kws...) = rand_heterograph(Dict(n), Dict(m); kws...) +rand_heterograph(rng::AbstractRNG, n, m; kws...) = rand_heterograph(rng, Dict(n), Dict(m); kws...) -function rand_heterograph(n::NDict, m::EDict; bidirected = false, seed = -1, kws...) - rng = seed > 0 ? MersenneTwister(seed) : Random.GLOBAL_RNG +function rand_heterograph(n::NDict, m::EDict; seed=-1, kws...) + if seed != -1 + Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_heterograph) + rng = MersenneTwister(seed) + else + rng = Random.default_rng() + end + return rand_heterograph(rng, n, m; kws...) +end + +function rand_heterograph(rng::AbstractRNG, n::NDict, m::EDict; bidirected::Bool = false, kws...) if bidirected return _rand_bidirected_heterograph(rng, n, m; kws...) end @@ -86,7 +113,7 @@ function rand_heterograph(n::NDict, m::EDict; bidirected = false, seed = -1, kws return GNNHeteroGraph(graphs; num_nodes = n, kws...) end -function _rand_bidirected_heterograph(rng, n::NDict, m::EDict; kws...) +function _rand_bidirected_heterograph(rng::AbstractRNG, n::NDict, m::EDict; kws...) for k in keys(m) if reverse(k) ∈ keys(m) @assert m[k] == m[reverse(k)] "Number of edges must be the same in reverse edge types for bidirected graphs." @@ -104,43 +131,60 @@ function _rand_bidirected_heterograph(rng, n::NDict, m::EDict; kws...) return GNNHeteroGraph(graphs; num_nodes = n, kws...) end -function _rand_edges(rng, (n1, n2), m) - idx = StatsBase.sample(rng, 1:(n1 * n2), m, replace = false) - s, t = edge_decoding(idx, n1, n2) - val = nothing - return s, t, val -end """ - rand_bipartite_heterograph(n1, n2, m; [bidirected, seed, node_t, edge_t, kws...]) - rand_bipartite_heterograph((n1, n2), m; ...) - rand_bipartite_heterograph((n1, n2), (m1, m2); ...) + rand_bipartite_heterograph([rng,] + (n1, n2), (m12, m21); + bidirected = true, + node_t = (:A, :B), + edge_t = :to, + kws...) -Construct an [`GNNHeteroGraph`](@ref) with number of nodes and edges -specified by `n1`, `n2` and `m1` and `m2` respectively. +Construct an [`GNNHeteroGraph`](@ref) with random edges representing a bipartite graph. +The graph will have two types of nodes, and edges will only connect nodes of different types. -See [`rand_heterograph`](@ref) for a more general version. +The first argument is a tuple `(n1, n2)` specifying the number of nodes of each type. +The second argument is a tuple `(m12, m21)` specifying the number of edges connecting nodes of type `1` to nodes of type `2` +and vice versa. -# Keyword arguments +The type of nodes and edges can be specified with the `node_t` and `edge_t` keyword arguments, +which default to `(:A, :B)` and `:to` respectively. -- `bidirected`: whether to generate a bidirected graph. Default is `true`. -- `seed`: random seed. Default is `-1` (no seed). -- `node_t`: node types. If `bipartite=true`, this should be a tuple of two node types, otherwise it should be a single node type. -- `edge_t`: edge types. If `bipartite=true`, this should be a tuple of two edge types, otherwise it should be a single edge type. -""" -function rand_bipartite_heterograph end +If `bidirected=true` (default), the reverse edge of each edge will be present. In this case +`m12 == m21` is required. + +A random number generator can be passed as the first argument to make the generation reproducible. + +Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor. + +See [`rand_heterograph`](@ref) for a more general version. + +# Examples -rand_bipartite_heterograph(n1::Int, n2::Int, m::Int; kws...) = rand_bipartite_heterograph((n1, n2), (m, m); kws...) +```julia-repl +julia> g = rand_bipartite_heterograph((10, 15), 20) +GNNHeteroGraph: + num_nodes: (:A => 10, :B => 15) + num_edges: ((:A, :to, :B) => 20, (:B, :to, :A) => 20) -rand_bipartite_heterograph((n1, n2)::NTuple{2,Int}, m::Int; kws...) = rand_bipartite_heterograph((n1, n2), (m, m); kws...) +julia> g = rand_bipartite_heterograph((10, 15), (20, 0), node_t=(:user, :item), edge_t=:-, bidirected=false) +GNNHeteroGraph: + num_nodes: Dict(:item => 15, :user => 10) + num_edges: Dict((:item, :-, :user) => 0, (:user, :-, :item) => 20) +``` +""" +rand_bipartite_heterograph(n, m; kws...) = rand_bipartite_heterograph(Random.default_rng(), n, m; kws...) -function rand_bipartite_heterograph((n1, n2)::NTuple{2,Int}, (m1, m2)::NTuple{2,Int}; bidirected=true, - node_t = (:A, :B), edge_t = :to, kws...) - if edge_t isa Symbol - edge_t = (edge_t, edge_t) +function rand_bipartite_heterograph(rng::AbstractRNG, (n1, n2)::NTuple{2,Int}, m; bidirected=true, + node_t = (:A, :B), edge_t::Symbol = :to, kws...) + if m isa Integer + m12 = m21 = m + else + m12, m21 = m end - return rand_heterograph(Dict(node_t[1] => n1, node_t[2] => n2), - Dict((node_t[1], edge_t[1], node_t[2]) => m1, (node_t[2], edge_t[2], node_t[1]) => m2); + + return rand_heterograph(rng, Dict(node_t[1] => n1, node_t[2] => n2), + Dict((node_t[1], edge_t, node_t[2]) => m12, (node_t[2], edge_t, node_t[1]) => m21); bidirected, kws...) end diff --git a/GNNGraphs/src/gnngraph.jl b/GNNGraphs/src/gnngraph.jl index 64fd32aad..a9af576e2 100644 --- a/GNNGraphs/src/gnngraph.jl +++ b/GNNGraphs/src/gnngraph.jl @@ -209,10 +209,10 @@ function GNNGraph(g::GNNGraph; ndata = g.ndata, edata = g.edata, gdata = g.gdata else graph = g.graph end - GNNGraph(graph, - g.num_nodes, g.num_edges, g.num_graphs, - g.graph_indicator, - ndata, edata, gdata) + return GNNGraph(graph, + g.num_nodes, g.num_edges, g.num_graphs, + g.graph_indicator, + ndata, edata, gdata) end """ diff --git a/GNNGraphs/src/transform.jl b/GNNGraphs/src/transform.jl index 8e8c98d13..325a20f5c 100644 --- a/GNNGraphs/src/transform.jl +++ b/GNNGraphs/src/transform.jl @@ -57,7 +57,8 @@ then all new self loops will have no weight. If `edge_t` is not passed as argument, for the entire graph self-loop is added to each node for every edge type in the graph where the source and destination node types are the same. This iterates over all edge types present in the graph, applying the self-loop addition logic to each applicable edge type. """ -function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where {T <: AbstractVector{<:Integer}, V} +function add_self_loops(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) + function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) get(g.graph, edge_t, (nothing, nothing, nothing))[3] end @@ -69,13 +70,17 @@ function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where n = get(g.num_nodes, src_t, 0) if haskey(g.graph, edge_t) - x = g.graph[edge_t] - s, t = x[1:2] + s, t = g.graph[edge_t][1:2] nodes = convert(typeof(s), [1:n;]) s = [s; nodes] t = [t; nodes] else - nodes = convert(T, [1:n;]) + if !isempty(g.graph) + T = typeof(first(values(g.graph))[1]) + nodes = convert(T, [1:n;]) + else + nodes = [1:n;] + end s = nodes t = nodes end @@ -518,7 +523,6 @@ end Return a new graph obtained from `g` by adding random edges, based on a specified `perturb_ratio`. The `perturb_ratio` determines the fraction of new edges to add relative to the current number of edges in the graph. These new edges are added without creating self-loops. -Optionally, a random `seed` can be provided to ensure reproducible perturbations. The function returns a new `GNNGraph` instance that shares some of the underlying data with `g` but includes the additional edges. The nodes for the new edges are selected randomly, and no edge data (`edata`) or weights (`w`) are assigned to these new edges. diff --git a/GNNGraphs/src/utils.jl b/GNNGraphs/src/utils.jl index 4bba304ef..7cdc3e543 100644 --- a/GNNGraphs/src/utils.jl +++ b/GNNGraphs/src/utils.jl @@ -205,17 +205,13 @@ end numnonzeros(a::AbstractSparseMatrix) = nnz(a) numnonzeros(a::AbstractMatrix) = count(!=(0), a) -# each edge is represented by a number in -# 1:N^2 -function edge_encoding(s, t, n; directed = true) - if directed - # directed edges and self-loops allowed - idx = (s .- 1) .* n .+ t +## Map edges into a contiguous range of integers +function edge_encoding(s, t, n; directed = true, self_loops = true) + if directed && self_loops maxid = n^2 - else - # Undirected edges and self-loops allowed + idx = (s .- 1) .* n .+ t + elseif !directed && self_loops maxid = n * (n + 1) ÷ 2 - mask = s .> t snew = copy(s) tnew = copy(t) @@ -228,18 +224,34 @@ function edge_encoding(s, t, n; directed = true) # = ∑_{i',i' s) + elseif !directed && !self_loops + @assert all(s .!= t) + maxid = n * (n - 1) ÷ 2 + mask = s .> t + snew = copy(s) + tnew = copy(t) + snew[mask] .= t[mask] + tnew[mask] .= s[mask] + s, t = snew, tnew + + # idx(s,t) = ∑_{s',1<= s'= s) + elseif !directed && !self_loops + # Considering t = s + 1 in + # idx = @. (s - 1) * n - s * (s - 1) ÷ 2 + (t - s) + # and inverting for s we have + s = @. floor(Int, 1/2 + n - 1/2 * sqrt(9 - 4n + 4n^2 - 8*idx)) + # now we can find t + t = @. idx - (s - 1) * n + s * (s - 1) ÷ 2 + s end return s, t end -# each edge is represented by a number in -# 1:n1*n2 +# for bipartite graphs function edge_decoding(idx, n1, n2) @assert all(1 .<= idx .<= n1 * n2) s = (idx .- 1) .÷ n2 .+ 1 @@ -265,6 +287,29 @@ function edge_decoding(idx, n1, n2) return s, t end +function _rand_edges(rng, n::Int, m::Int; directed = true, self_loops = true) + idmax = if directed && self_loops + n^2 + elseif !directed && self_loops + n * (n + 1) ÷ 2 + elseif directed && !self_loops + n * (n - 1) + elseif !directed && !self_loops + n * (n - 1) ÷ 2 + end + idx = StatsBase.sample(rng, 1:idmax, m, replace = false) + s, t = edge_decoding(idx, n; directed, self_loops) + val = nothing + return s, t, val +end + +function _rand_edges(rng, (n1, n2), m) + idx = StatsBase.sample(rng, 1:(n1 * n2), m, replace = false) + s, t = edge_decoding(idx, n1, n2) + val = nothing + return s, t, val +end + binarize(x) = map(>(0), x) @non_differentiable binarize(x...) diff --git a/GNNGraphs/test/generate.jl b/GNNGraphs/test/generate.jl index 867fec399..c26b651c3 100644 --- a/GNNGraphs/test/generate.jl +++ b/GNNGraphs/test/generate.jl @@ -16,19 +16,23 @@ @test g.edata.e[:, (m2 + 1):end] == e end - g = rand_graph(n, m, bidirected = false, seed = 17, graph_type = GRAPH_T) + rng = MersenneTwister(17) + g = rand_graph(rng, n, m, bidirected = false, graph_type = GRAPH_T) @test g.num_nodes == n @test g.num_edges == m - g2 = rand_graph(n, m, bidirected = false, seed = 17, graph_type = GRAPH_T) + rng = MersenneTwister(17) + g2 = rand_graph(rng, n, m, bidirected = false, graph_type = GRAPH_T) @test edge_index(g2) == edge_index(g) ew = rand(m2) - g = rand_graph(n, m, bidirected = true, seed = 17, graph_type = GRAPH_T, edge_weight = ew) + rng = MersenneTwister(17) + g = rand_graph(rng, n, m, bidirected = true, graph_type = GRAPH_T, edge_weight = ew) @test get_edge_weight(g) == [ew; ew] broken=(GRAPH_T != :coo) ew = rand(m) - g = rand_graph(n, m, bidirected = false, seed = 17, graph_type = GRAPH_T, edge_weight = ew) + rng = MersenneTwister(17) + g = rand_graph(n, m, bidirected = false, graph_type = GRAPH_T, edge_weight = ew) @test get_edge_weight(g) == ew broken=(GRAPH_T != :coo) end @@ -77,7 +81,7 @@ end end @testset "rand_bipartite_heterograph" begin - g = rand_bipartite_heterograph(10, 15, 20) + g = rand_bipartite_heterograph((10, 15), (20, 20)) @test g.num_nodes == Dict(:A => 10, :B => 15) @test g.num_edges == Dict((:A, :to, :B) => 20, (:B, :to, :A) => 20) sA, tB = edge_index(g, (:A, :to, :B)) diff --git a/GNNGraphs/test/gnnheterograph.jl b/GNNGraphs/test/gnnheterograph.jl index 6764b7814..f3c29b80f 100644 --- a/GNNGraphs/test/gnnheterograph.jl +++ b/GNNGraphs/test/gnnheterograph.jl @@ -123,7 +123,7 @@ end @testset "get/set node features" begin d, n = 3, 5 - g = rand_bipartite_heterograph(n, 2*n, 15) + g = rand_bipartite_heterograph((n, 2*n), 15) g[:A].x = rand(Float32, d, n) g[:B].y = rand(Float32, d, 2*n) @@ -133,7 +133,7 @@ end @testset "add_edges" begin d, n = 3, 5 - g = rand_bipartite_heterograph(n, 2 * n, 15) + g = rand_bipartite_heterograph((n, 2 * n), 15) s, t = [1, 2, 3], [3, 2, 1] ## Keep the same ntypes - construct with args g1 = add_edges(g, (:A, :rel1, :B), s, t) diff --git a/GNNGraphs/test/utils.jl b/GNNGraphs/test/utils.jl index db65b6357..31a1c7373 100644 --- a/GNNGraphs/test/utils.jl +++ b/GNNGraphs/test/utils.jl @@ -47,10 +47,55 @@ tnew[mask] .= s1[mask] @test sdec == snew @test tdec == tnew + + @testset "directed=false, self_loops=false" begin + n = 5 + edges = [(1,2), (3,1), (1,4), (1,5), (2,3), (2,4), (2,5), (3,4), (3,5), (4,5)] + s = [e[1] for e in edges] + t = [e[2] for e in edges] + g = GNNGraph(s, t) + idx, idxmax = GNNGraphs.edge_encoding(s, t, n, directed=false, self_loops=false) + @test idxmax == n * (n - 1) ÷ 2 + @test idx == 1:idxmax + + snew, tnew = GNNGraphs.edge_decoding(idx, n, directed=false, self_loops=false) + @test snew == [1, 1, 1, 1, 2, 2, 2, 3, 3, 4] + @test tnew == [2, 3, 4, 5, 3, 4, 5, 4, 5, 5] + end + + @testset "directed=false, self_loops=false" begin + n = 5 + edges = [(1,2), (3,1), (1,4), (1,5), (2,3), (2,4), (2,5), (3,4), (3,5), (4,5)] + s = [e[1] for e in edges] + t = [e[2] for e in edges] + + idx, idxmax = GNNGraphs.edge_encoding(s, t, n, directed=false, self_loops=false) + @test idxmax == n * (n - 1) ÷ 2 + @test idx == 1:idxmax + + snew, tnew = GNNGraphs.edge_decoding(idx, n, directed=false, self_loops=false) + @test snew == [1, 1, 1, 1, 2, 2, 2, 3, 3, 4] + @test tnew == [2, 3, 4, 5, 3, 4, 5, 4, 5, 5] + end + + @testset "directed=true, self_loops=false" begin + n = 5 + edges = [(1,2), (3,1), (1,4), (1,5), (2,3), (2,4), (2,5), (3,4), (3,5), (4,5)] + s = [e[1] for e in edges] + t = [e[2] for e in edges] + + idx, idxmax = GNNGraphs.edge_encoding(s, t, n, directed=true, self_loops=false) + @test idxmax == n^2 - n + @test idx == [1, 9, 3, 4, 6, 7, 8, 11, 12, 16] + snew, tnew = GNNGraphs.edge_decoding(idx, n, directed=true, self_loops=false) + @test snew == s + @test tnew == t + end end -@testset "color_refinment" begin - g = rand_graph(10, 20, seed=17, graph_type = GRAPH_T) +@testset "color_refinement" begin + rng = MersenneTwister(17) + g = rand_graph(rng, 10, 20, graph_type = GRAPH_T) x0 = ones(Int, 10) x, ncolors, niters = color_refinement(g, x0) @test ncolors == 8 @@ -59,4 +104,4 @@ end x2, _, _ = color_refinement(g) @test x2 == x -end \ No newline at end of file +end \ No newline at end of file diff --git a/GNNLux/test/layers/basic_tests.jl b/GNNLux/test/layers/basic_tests.jl index 9f59f3b10..ac937d128 100644 --- a/GNNLux/test/layers/basic_tests.jl +++ b/GNNLux/test/layers/basic_tests.jl @@ -1,6 +1,6 @@ @testitem "layers/basic" setup=[SharedTestSetup] begin rng = StableRNG(17) - g = rand_graph(10, 40, seed=17) + g = rand_graph(rng, 10, 40) x = randn(rng, Float32, 3, 10) @testset "GNNLayer" begin diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 8db856803..142435074 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -1,12 +1,12 @@ @testitem "layers/conv" setup=[SharedTestSetup] begin rng = StableRNG(1234) - g = rand_graph(10, 40, seed=1234) + g = rand_graph(rng, 10, 40) in_dims = 3 out_dims = 5 x = randn(rng, Float32, in_dims, 10) @testset "GCNConv" begin - l = GCNConv(in_dims => out_dims, relu) + l = GCNConv(in_dims => out_dims, tanh) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) end @@ -16,7 +16,7 @@ end @testset "GraphConv" begin - l = GraphConv(in_dims => out_dims, relu) + l = GraphConv(in_dims => out_dims, tanh) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) end @@ -26,7 +26,7 @@ end @testset "EdgeConv" begin - nn = Chain(Dense(2*in_dims => 5, relu), Dense(5 => out_dims)) + nn = Chain(Dense(2*in_dims => 2, tanh), Dense(2 => out_dims)) l = EdgeConv(nn, aggr = +) test_lux_layer(rng, l, g, x, sizey=(out_dims,10), container=true) end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 50b5b34aa..3a5c543a1 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -74,7 +74,7 @@ end # when we also have edge_weight we need to convert the graph to COO function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x, edge_weight::EW, norm_fn::F, conv_weight::CW) where {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} - g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO + g = GNNGraph(g, graph_type = :coo) return gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight) end @@ -449,9 +449,10 @@ function sgc_conv(l, g::GNNGraph, x::AbstractMatrix{T}, return (x .+ l.bias) end +# when we also have edge_weight we need to convert the graph to COO function sgc_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, edge_weight::AbstractVector) - g = GNNGraph(edge_index(g)...; g.num_nodes) + g = GNNGraph(g; graph_type=:coo) return sgc_conv(l, g, x, edge_weight) end @@ -542,9 +543,10 @@ function sg_conv(l, g::GNNGraph, x::AbstractMatrix{T}, return (x .+ l.bias) end +# when we also have edge_weight we need to convert the graph to COO function sg_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, edge_weight::AbstractVector) - g = GNNGraph(edge_index(g)...; g.num_nodes) + g = GNNGraph(g; graph_type=:coo) return sg_conv(l, g, x, edge_weight) end @@ -684,9 +686,10 @@ function tag_conv(l, g::GNNGraph, x::AbstractMatrix{T}, return (sum_total .+ l.bias) end +# when we also have edge_weight we need to convert the graph to COO function tag_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, edge_weight::AbstractVector) - g = GNNGraph(edge_index(g)...; g.num_nodes) + g = GNNGraph(g; graph_type = :coo) return l(g, x, edge_weight) end diff --git a/GNNlib/src/msgpass.jl b/GNNlib/src/msgpass.jl index acab02217..1aa17437a 100644 --- a/GNNlib/src/msgpass.jl +++ b/GNNlib/src/msgpass.jl @@ -45,7 +45,7 @@ struct GNNConv <: GNNLayer σ end -Flux.@functor GNNConv +Flux.@layer GNNConv function GNNConv(ch::Pair{Int,Int}, σ=identity) in, out = ch diff --git a/docs/src/messagepassing.md b/docs/src/messagepassing.md index 9db0062e6..f59ad6561 100644 --- a/docs/src/messagepassing.md +++ b/docs/src/messagepassing.md @@ -109,7 +109,7 @@ struct GCN{A<:AbstractMatrix, B, F} <: GNNLayer σ::F end -Flux.@functor GCN # allow gpu movement, select trainable params etc... +Flux.@layer GCN # allow gpu movement, select trainable params etc... function GCN(ch::Pair{Int,Int}, σ=identity) in, out = ch diff --git a/docs/src/models.md b/docs/src/models.md index d0265964e..96e49055a 100644 --- a/docs/src/models.md +++ b/docs/src/models.md @@ -13,7 +13,7 @@ and the *implicit modeling* style based on [`GNNChain`](@ref), more concise but In the explicit modeling style, the model is created according to the following steps: 1. Define a new type for your model (`GNN` in the example below). Layers and submodels are fields. -2. Apply `Flux.@functor` to the new type to make it Flux's compatible (parameters' collection, gpu movement, etc...) +2. Apply `Flux.@layer` to the new type to make it Flux's compatible (parameters' collection, gpu movement, etc...) 3. Optionally define a convenience constructor for your model. 4. Define the forward pass by implementing the call method for your type. 5. Instantiate the model. @@ -30,7 +30,7 @@ struct GNN # step 1 dense end -Flux.@functor GNN # step 2 +Flux.@layer GNN # step 2 function GNN(din::Int, d::Int, dout::Int) # step 3 GNN(GCNConv(din => d), diff --git a/docs/tutorials/introductory_tutorials/gnn_intro_pluto.jl b/docs/tutorials/introductory_tutorials/gnn_intro_pluto.jl index 76e2e870e..977f621ce 100644 --- a/docs/tutorials/introductory_tutorials/gnn_intro_pluto.jl +++ b/docs/tutorials/introductory_tutorials/gnn_intro_pluto.jl @@ -182,7 +182,7 @@ begin layers::NamedTuple end - Flux.@functor GCN # provides parameter collection, gpu movement and more + Flux.@layer GCN # provides parameter collection, gpu movement and more function GCN(num_features, num_classes) layers = (conv1 = GCNConv(num_features => 4), diff --git a/docs/tutorials/introductory_tutorials/node_classification_pluto.jl b/docs/tutorials/introductory_tutorials/node_classification_pluto.jl index 9b3876b20..edf73d4fc 100644 --- a/docs/tutorials/introductory_tutorials/node_classification_pluto.jl +++ b/docs/tutorials/introductory_tutorials/node_classification_pluto.jl @@ -138,7 +138,7 @@ begin layers::NamedTuple end - Flux.@functor MLP + Flux.@layer :expand MLP function MLP(num_features, num_classes, hidden_channels; drop_rate = 0.5) layers = (hidden = Dense(num_features => hidden_channels), @@ -235,7 +235,7 @@ begin layers::NamedTuple end - Flux.@functor GCN # provides parameter collection, gpu movement and more + Flux.@layer GCN # provides parameter collection, gpu movement and more function GCN(num_features, num_classes, hidden_channels; drop_rate = 0.5) layers = (conv1 = GCNConv(num_features => hidden_channels), diff --git a/docs/tutorials_broken/temporal_graph_classification_pluto.jl b/docs/tutorials_broken/temporal_graph_classification_pluto.jl index b5460c1ec..6afd988c3 100644 --- a/docs/tutorials_broken/temporal_graph_classification_pluto.jl +++ b/docs/tutorials_broken/temporal_graph_classification_pluto.jl @@ -117,7 +117,7 @@ begin dense::Dense end - Flux.@functor GenderPredictionModel + Flux.@layer GenderPredictionModel function GenderPredictionModel(; nfeatures = 103, nhidden = 128, activation = relu) mlp = Chain(Dense(nfeatures, nhidden, activation), Dense(nhidden, nhidden, activation)) diff --git a/examples/graph_classification_temporalbrains.jl b/examples/graph_classification_temporalbrains.jl index 2aac2a3e8..e25e9c1f0 100644 --- a/examples/graph_classification_temporalbrains.jl +++ b/examples/graph_classification_temporalbrains.jl @@ -62,7 +62,7 @@ struct GenderPredictionModel dense::Dense end -Flux.@functor GenderPredictionModel +Flux.@layer GenderPredictionModel function GenderPredictionModel(; nfeatures = 103, nhidden = 128, activation = relu) mlp = Chain(Dense(nfeatures, nhidden, activation), Dense(nhidden, nhidden, activation)) diff --git a/notebooks/gnn_intro.ipynb b/notebooks/gnn_intro.ipynb index 3f9748f93..db3721ea5 100644 --- a/notebooks/gnn_intro.ipynb +++ b/notebooks/gnn_intro.ipynb @@ -354,7 +354,7 @@ " layers::NamedTuple\n", "end\n", "\n", - "Flux.@functor GCN # provides parameter collection, gpu movement and more\n", + "Flux.@layer :expand GCN # provides parameter collection, gpu movement and more\n", "\n", "function GCN(num_features, num_classes)\n", " layers = (conv1 = GCNConv(num_features => 4),\n", diff --git a/notebooks/graph_classification_solved.ipynb b/notebooks/graph_classification_solved.ipynb index af2e6bf38..a54c5b359 100644 --- a/notebooks/graph_classification_solved.ipynb +++ b/notebooks/graph_classification_solved.ipynb @@ -857,7 +857,7 @@ "\tact::F\n", "end\n", "\n", - "Flux.@functor MyConv\n", + "Flux.@layer MyConv\n", "\n", "function MyConv((nin, nout)::Pair, act=identity)\n", "\tW1 = Flux.glorot_uniform(nout, nin)\n", diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 021d4d8b2..bf6991155 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -3,7 +3,7 @@ module GraphNeuralNetworks using Statistics: mean using LinearAlgebra, Random using Flux -using Flux: glorot_uniform, leakyrelu, GRUCell, @functor, batch +using Flux: glorot_uniform, leakyrelu, GRUCell, batch using MacroTools: @forward using NNlib using NNlib: scatter, gather diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 22fd029f9..4f99ddba4 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -45,7 +45,7 @@ end WithGraph(model, g::GNNGraph; traingraph = false) = WithGraph(model, g, traingraph) -@functor WithGraph +Flux.@layer :expand WithGraph Flux.trainable(l::WithGraph) = l.traingraph ? (; l.model, l.g) : (; l.model) (l::WithGraph)(g::GNNGraph, x...; kws...) = l.model(g, x...; kws...) @@ -107,7 +107,7 @@ struct GNNChain{T <: Union{Tuple, NamedTuple, AbstractVector}} <: GNNLayer layers::T end -@functor GNNChain +Flux.@layer :expand GNNChain GNNChain(xs...) = GNNChain(xs) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 4a9f31783..8c3565dce 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -76,7 +76,7 @@ struct GCNConv{W <: AbstractMatrix, B, F} <: GNNLayer use_edge_weight::Bool end -@functor GCNConv +Flux.@layer GCNConv function GCNConv(ch::Pair{Int, Int}, σ = identity; init = glorot_uniform, @@ -167,7 +167,7 @@ function ChebConv(ch::Pair{Int, Int}, k::Int; ChebConv(W, b, k) end -@functor ChebConv +Flux.@layer ChebConv (l::ChebConv)(g, x) = GNNlib.cheb_conv(l, g, x) @@ -225,7 +225,7 @@ struct GraphConv{W <: AbstractMatrix, B, F, A} <: GNNLayer aggr::A end -@functor GraphConv +Flux.@layer GraphConv function GraphConv(ch::Pair{Int, Int}, σ = identity; aggr = +, init = glorot_uniform, bias::Bool = true) @@ -300,8 +300,7 @@ l = GATConv(in_channel => out_channel, add_self_loops = false, bias = false; hea y = l(g, x) ``` """ -struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, DV, T, A <: AbstractMatrix, F, B} <: - GNNLayer +struct GATConv{DX<:Dense,DE<:Union{Dense, Nothing},DV,T,A<:AbstractMatrix,F,B} <: GNNLayer dense_x::DX dense_e::DE bias::B @@ -315,8 +314,8 @@ struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, DV, T, A <: AbstractMat dropout::DV end -@functor GATConv -Flux.trainable(l::GATConv) = (dense_x = l.dense_x, dense_e = l.dense_e, bias = l.bias, a = l.a) +Flux.@layer GATConv +Flux.trainable(l::GATConv) = (; l.dense_x, l.dense_e, l.bias, l.a) GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...) @@ -420,7 +419,7 @@ struct GATv2Conv{T, A1, A2, A3, DV, B, C <: AbstractMatrix, F} <: GNNLayer dropout::DV end -@functor GATv2Conv +Flux.@layer GATv2Conv Flux.trainable(l::GATv2Conv) = (dense_i = l.dense_i, dense_j = l.dense_j, dense_e = l.dense_e, bias = l.bias, a = l.a) function GATv2Conv(ch::Pair{Int, Int}, args...; kws...) @@ -515,7 +514,7 @@ struct GatedGraphConv{W <: AbstractArray{<:Number, 3}, R, A} <: GNNLayer aggr::A end -@functor GatedGraphConv +Flux.@layer GatedGraphConv function GatedGraphConv(dims::Int, num_layers::Int; aggr = +, init = glorot_uniform) @@ -572,7 +571,7 @@ struct EdgeConv{NN, A} <: GNNLayer aggr::A end -@functor EdgeConv +Flux.@layer :expand EdgeConv EdgeConv(nn; aggr = max) = EdgeConv(nn, aggr) @@ -626,7 +625,7 @@ struct GINConv{R <: Real, NN, A} <: GNNLayer aggr::A end -@functor GINConv +Flux.@layer :expand GINConv Flux.trainable(l::GINConv) = (nn = l.nn,) GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr) @@ -680,7 +679,7 @@ edim = 10 g = GNNGraph(s, t) # create dense layer -nn = Dense(edim, out_channel * in_channel) +nn = Dense(edim => out_channel * in_channel) # create layer l = NNConv(in_channel => out_channel, nn, tanh, bias = true, aggr = +) @@ -697,7 +696,7 @@ struct NNConv{W, B, NN, F, A} <: GNNLayer aggr::A end -@functor NNConv +Flux.@layer :expand NNConv function NNConv(ch::Pair{Int, Int}, nn, σ = identity; aggr = +, bias = true, init = glorot_uniform) @@ -763,7 +762,7 @@ struct SAGEConv{W <: AbstractMatrix, B, F, A} <: GNNLayer aggr::A end -@functor SAGEConv +Flux.@layer SAGEConv function SAGEConv(ch::Pair{Int, Int}, σ = identity; aggr = mean, init = glorot_uniform, bias::Bool = true) @@ -833,7 +832,7 @@ struct ResGatedGraphConv{W, B, F} <: GNNLayer σ::F end -@functor ResGatedGraphConv +Flux.@layer ResGatedGraphConv function ResGatedGraphConv(ch::Pair{Int, Int}, σ = identity; init = glorot_uniform, bias::Bool = true) @@ -907,7 +906,7 @@ struct CGConv{D1, D2} <: GNNLayer residual::Bool end -@functor CGConv +Flux.@layer CGConv CGConv(ch::Pair{Int, Int}, args...; kws...) = CGConv((ch[1], 0) => ch[2], args...; kws...) @@ -980,7 +979,7 @@ struct AGNNConv{A <: AbstractVector} <: GNNLayer trainable::Bool end -@functor AGNNConv +Flux.@layer AGNNConv Flux.trainable(l::AGNNConv) = l.trainable ? (; l.β) : (;) @@ -1027,7 +1026,7 @@ struct MEGNetConv{TE, TV, A} <: GNNLayer aggr::A end -@functor MEGNetConv +Flux.@layer :expand MEGNetConv MEGNetConv(ϕe, ϕv; aggr = mean) = MEGNetConv(ϕe, ϕv, aggr) @@ -1108,7 +1107,7 @@ struct GMMConv{A <: AbstractMatrix, B, F} <: GNNLayer residual::Bool end -@functor GMMConv +Flux.@layer GMMConv function GMMConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity; @@ -1191,7 +1190,7 @@ struct SGConv{A <: AbstractMatrix, B} <: GNNLayer use_edge_weight::Bool end -@functor SGConv +Flux.@layer SGConv function SGConv(ch::Pair{Int, Int}, k = 1; init = glorot_uniform, @@ -1259,7 +1258,7 @@ struct TAGConv{A <: AbstractMatrix, B} <: GNNLayer use_edge_weight::Bool end -@functor TAGConv +Flux.@layer TAGConv function TAGConv(ch::Pair{Int, Int}, k = 3; init = glorot_uniform, @@ -1269,7 +1268,7 @@ function TAGConv(ch::Pair{Int, Int}, k = 3; in, out = ch W = init(out, in) b = bias ? Flux.create_bias(W, true, out) : false - TAGConv(W, b, k, add_self_loops, use_edge_weight) + return TAGConv(W, b, k, add_self_loops, use_edge_weight) end (l::TAGConv)(g, x, edge_weight = nothing) = GNNlib.tag_conv(l, g, x, edge_weight) @@ -1343,10 +1342,10 @@ struct EGNNConv{TE, TX, TH, NF} <: GNNLayer residual::Bool end -@functor EGNNConv +Flux.@layer EGNNConv function EGNNConv(ch::Pair{Int, Int}, hidden_size = 2 * ch[1]; residual = false) - EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual) + return EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual) end #Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py @@ -1477,7 +1476,7 @@ struct TransformerConv{TW1, TW2, TW3, TW4, TW5, TW6, TFF, TBN1, TBN2} <: GNNLaye sqrt_out::Float32 end -@functor TransformerConv +Flux.@layer TransformerConv function Flux.trainable(l::TransformerConv) (; l.W1, l.W2, l.W3, l.W4, l.W5, l.W6, l.FF, l.BN1, l.BN2) @@ -1568,7 +1567,7 @@ struct DConv <: GNNLayer k::Int end -@functor DConv +Flux.@layer DConv function DConv(ch::Pair{Int, Int}, k::Int; init = glorot_uniform, bias = true) in, out = ch diff --git a/src/layers/heteroconv.jl b/src/layers/heteroconv.jl index b2603e455..a10ebb0c7 100644 --- a/src/layers/heteroconv.jl +++ b/src/layers/heteroconv.jl @@ -43,7 +43,7 @@ struct HeteroGraphConv aggr::Function end -Flux.@functor HeteroGraphConv +Flux.@layer HeteroGraphConv HeteroGraphConv(itr::Dict; aggr = +) = HeteroGraphConv(pairs(itr); aggr) HeteroGraphConv(itr::Pair...; aggr = +) = HeteroGraphConv(itr; aggr) diff --git a/src/layers/pool.jl b/src/layers/pool.jl index ed2f7eca6..59164e199 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -90,7 +90,7 @@ struct GlobalAttentionPool{G, F} ffeat::F end -@functor GlobalAttentionPool +Flux.@layer GlobalAttentionPool GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity) @@ -146,7 +146,7 @@ struct Set2Set{L} <: GNNLayer num_iters::Int end -@functor Set2Set +Flux.@layer Set2Set function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1) @assert n_layers >= 1 diff --git a/src/layers/temporalconv.jl b/src/layers/temporalconv.jl index 44688cea4..443ef2a3a 100644 --- a/src/layers/temporalconv.jl +++ b/src/layers/temporalconv.jl @@ -18,7 +18,7 @@ struct TGCNCell <: GNNLayer out::Int end -Flux.@functor TGCNCell +Flux.@layer TGCNCell function TGCNCell(ch::Pair{Int, Int}; bias::Bool = true, @@ -156,7 +156,7 @@ struct A3TGCN <: GNNLayer out::Int end -Flux.@functor A3TGCN +Flux.@layer A3TGCN function A3TGCN(ch::Pair{Int, Int}, bias::Bool = true, @@ -200,7 +200,7 @@ struct GConvGRUCell <: GNNLayer out::Int end -Flux.@functor GConvGRUCell +Flux.@layer GConvGRUCell function GConvGRUCell(ch::Pair{Int, Int}, k::Int, n::Int; bias::Bool = true, @@ -302,7 +302,7 @@ struct GConvLSTMCell <: GNNLayer out::Int end -Flux.@functor GConvLSTMCell +Flux.@layer GConvLSTMCell function GConvLSTMCell(ch::Pair{Int, Int}, k::Int, n::Int; bias::Bool = true, @@ -411,7 +411,7 @@ struct DCGRUCell dconv_c::DConv end -Flux.@functor DCGRUCell +Flux.@layer DCGRUCell function DCGRUCell(ch::Pair{Int,Int}, k::Int, n::Int; bias = true, init = glorot_uniform, init_state = Flux.zeros32) in, out = ch diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 9a3b6ee9f..2428865ae 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -21,7 +21,7 @@ @testset "constructor with names" begin m = GNNChain(GCNConv(din => d), LayerNorm(d), - x -> relu.(x), + x -> tanh.(x), Dense(d, dout)) m2 = GNNChain(enc = m, @@ -34,7 +34,7 @@ @testset "constructor with vector" begin m = GNNChain(GCNConv(din => d), LayerNorm(d), - x -> relu.(x), + x -> tanh.(x), Dense(d, dout)) m2 = GNNChain([m.layers...]) @test m2(g, x) == m(g, x) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 224b98697..b96baa880 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -84,12 +84,8 @@ end @test l.k == k for g in test_graphs g = add_self_loops(g) - test_layer(l, g, rtol = RTOL_HIGH, test_gpu = false, + test_layer(l, g, rtol = RTOL_HIGH, test_gpu = TEST_GPU, outsize = (out_channel, g.num_nodes)) - if TEST_GPU - @test_broken test_layer(l, g, rtol = RTOL_HIGH, test_gpu = true, - outsize = (out_channel, g.num_nodes)) - end end @testset "bias=false" begin @@ -104,7 +100,7 @@ end test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) end - l = GraphConv(in_channel => out_channel, relu, bias = false, aggr = mean) + l = GraphConv(in_channel => out_channel, tanh, bias = false, aggr = mean) for g in test_graphs test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) end diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index 29f36ba63..d9eaf0c7f 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -1,6 +1,6 @@ @testset "HeteroGraphConv" begin d, n = 3, 5 - g = rand_bipartite_heterograph(n, 2*n, 15) + g = rand_bipartite_heterograph((n, 2*n), 15) hg = rand_bipartite_heterograph((2,3), 6) model = HeteroGraphConv([(:A,:to,:B) => GraphConv(d => d), @@ -30,8 +30,8 @@ end @testset "Constructor from pairs" begin - layer = HeteroGraphConv((:A, :to, :B) => GraphConv(64 => 32, relu), - (:B, :to, :A) => GraphConv(64 => 32, relu)); + layer = HeteroGraphConv((:A, :to, :B) => GraphConv(64 => 32, tanh), + (:B, :to, :A) => GraphConv(64 => 32, tanh)); @test length(layer.etypes) == 2 end @@ -95,8 +95,8 @@ @testset "CGConv" begin x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, relu), - (:B, :to, :A) => CGConv(4 => 2, relu)); + layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, tanh), + (:B, :to, :A) => CGConv(4 => 2, tanh)); y = layers(hg, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end @@ -111,8 +111,8 @@ @testset "SAGEConv" begin x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, relu, bias = false, aggr = +), - (:B, :to, :A) => SAGEConv(4 => 2, relu, bias = false, aggr = +)); + layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, tanh, bias = false, aggr = +), + (:B, :to, :A) => SAGEConv(4 => 2, tanh, bias = false, aggr = +)); y = layers(hg, x); @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end @@ -152,8 +152,8 @@ @testset "GCNConv" begin g = rand_bipartite_heterograph((2,3), 6) x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, relu), - (:B, :to, :A) => GCNConv(4 => 2, relu)); + layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, tanh), + (:B, :to, :A) => GCNConv(4 => 2, tanh)); y = layers(g, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end diff --git a/test/layers/temporalconv.jl b/test/layers/temporalconv.jl index b55aff808..45c8acf04 100644 --- a/test/layers/temporalconv.jl +++ b/test/layers/temporalconv.jl @@ -133,7 +133,7 @@ end end @testset "ResGatedGraphConv" begin - resgatedconv = ResGatedGraphConv(in_channel => out_channel, relu) + resgatedconv = ResGatedGraphConv(in_channel => out_channel, tanh) @test length(resgatedconv(tg, tg.ndata.x)) == S @test size(resgatedconv(tg, tg.ndata.x)[1]) == (out_channel, N) @test length(Flux.gradient(x ->sum(sum(resgatedconv(tg, x))), tg.ndata.x)[1]) == S @@ -147,7 +147,7 @@ end end @testset "GraphConv" begin - graphconv = GraphConv(in_channel => out_channel,relu) + graphconv = GraphConv(in_channel => out_channel, tanh) @test length(graphconv(tg, tg.ndata.x)) == S @test size(graphconv(tg, tg.ndata.x)[1]) == (out_channel, N) @test length(Flux.gradient(x ->sum(sum(graphconv(tg, x))), tg.ndata.x)[1]) == S diff --git a/test/runtests.jl b/test/runtests.jl index e41c7c1ae..05cb6fd5f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using GNNGraphs: sort_edge_index using GNNGraphs: getn, getdata using Functors using Flux -using Flux: gpu, @functor +using Flux: gpu using LinearAlgebra, Statistics, Random using NNlib import MLUtils @@ -36,11 +36,10 @@ tests = [ # @testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo, :dense, :sparse) for graph_type in (:coo, :dense, :sparse) + @info "Testing graph format :$graph_type" global GRAPH_T = graph_type global TEST_GPU = CUDA.functional() && (GRAPH_T != :sparse) - # global GRAPH_T = :sparse - # global TEST_GPU = false @testset "$t" for t in tests startswith(t, "examples") && GRAPH_T == :dense && continue # not testing :dense since causes OutOfMememory on github's CI