diff --git a/docs/src/heterograph.md b/docs/src/heterograph.md index 889632004..c67f86fd6 100644 --- a/docs/src/heterograph.md +++ b/docs/src/heterograph.md @@ -12,8 +12,13 @@ the type [`GNNHeteroGraph`](@ref). ## Creating a Heterograph -A heterograph can be created by passing pairs `edge_type => data` to the constructor. +A heterograph can be created empty or by passing pairs `edge_type => data` to the constructor. ```jldoctest +julia> g = GNNHeteroGraph() +GNNHeteroGraph: + num_nodes: Dict() + num_edges: Dict() + julia> g = GNNHeteroGraph((:user, :like, :actor) => ([1,2,2,3], [1,3,2,9]), (:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7])) GNNHeteroGraph: diff --git a/src/GNNGraphs/convert.jl b/src/GNNGraphs/convert.jl index c9c072f54..1e103db8b 100644 --- a/src/GNNGraphs/convert.jl +++ b/src/GNNGraphs/convert.jl @@ -4,26 +4,28 @@ function to_coo(data::EDict; num_nodes = nothing, kws...) graph = EDict{COO_T}() _num_nodes = NDict{Int}() num_edges = EDict{Int}() - for k in keys(data) - d = data[k] - @assert d isa Tuple - if length(d) == 2 - d = (d..., nothing) + if !isempty(data) + for k in keys(data) + d = data[k] + @assert d isa Tuple + if length(d) == 2 + d = (d..., nothing) + end + if num_nodes !== nothing + n1 = get(num_nodes, k[1], nothing) + n2 = get(num_nodes, k[3], nothing) + else + n1 = nothing + n2 = nothing + end + g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...) + graph[k] = g + num_edges[k] = nedges + _num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1]) + _num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2]) end - if num_nodes !== nothing - n1 = get(num_nodes, k[1], nothing) - n2 = get(num_nodes, k[3], nothing) - else - n1 = nothing - n2 = nothing - end - g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...) - graph[k] = g - num_edges[k] = nedges - _num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1]) - _num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2]) + graph = Dict([k => v for (k, v) in pairs(graph)]...) # try to restrict the key/value types end - graph = Dict([k => v for (k, v) in pairs(graph)]...) # try to restrict the key/value types return graph, _num_nodes, num_edges end diff --git a/src/GNNGraphs/gnnheterograph.jl b/src/GNNGraphs/gnnheterograph.jl index 577c48faf..72d67b34b 100644 --- a/src/GNNGraphs/gnnheterograph.jl +++ b/src/GNNGraphs/gnnheterograph.jl @@ -100,6 +100,8 @@ end GNNHeteroGraph(data; kws...) = GNNHeteroGraph(Dict(data); kws...) GNNHeteroGraph(data::Pair...; kws...) = GNNHeteroGraph(Dict(data...); kws...) +GNNHeteroGraph() = GNNHeteroGraph(Dict{Tuple{Symbol,Symbol,Symbol}, Any}()) + function GNNHeteroGraph(data::Dict; kws...) all(k -> k isa EType, keys(data)) || throw(ArgumentError("Keys of data must be tuples of the form `(source_type, edge_type, target_type)`")) return GNNHeteroGraph(Dict([k => v for (k, v) in pairs(data)]...); kws...) @@ -135,10 +137,17 @@ function GNNHeteroGraph(data::EDict; num_graphs = !isnothing(graph_indicator) ? maximum([maximum(gi) for gi in values(graph_indicator)]) : 1 - ndata = normalize_heterographdata(ndata, default_name = :x, ns = num_nodes) - edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges, - duplicate_if_needed = true) - gdata = normalize_graphdata(gdata, default_name = :u, n = num_graphs) + + if length(keys(graph)) == 0 + ndata = Dict{Symbol, DataStore}() + edata = Dict{Tuple{Symbol, Symbol, Symbol}, DataStore}() + gdata = DataStore() + else + ndata = normalize_heterographdata(ndata, default_name = :x, ns = num_nodes) + edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges, + duplicate_if_needed = true) + gdata = normalize_graphdata(gdata, default_name = :u, n = num_graphs) + end return GNNHeteroGraph(graph, num_nodes, num_edges, num_graphs, diff --git a/test/GNNGraphs/gnnheterograph.jl b/test/GNNGraphs/gnnheterograph.jl index 08618fdc2..20a97dbc0 100644 --- a/test/GNNGraphs/gnnheterograph.jl +++ b/test/GNNGraphs/gnnheterograph.jl @@ -1,4 +1,14 @@ + +@testset "Empty constructor" begin + g = GNNHeteroGraph() + @test isempty(g.num_nodes) + g = add_edges(g, (:user, :like, :actor) => ([1,2,3,3,3], [3,5,1,9,4])) + @test g.num_nodes[:user] == 3 + @test g.num_nodes[:actor] == 9 + @test g.num_edges[(:user, :like, :actor)] == 5 +end + @testset "Constructor from pairs" begin hg = GNNHeteroGraph((:A, :e1, :B) => ([1,2,3,4], [3,2,1,5])) @test hg.num_nodes == Dict(:A => 4, :B => 5)