Skip to content

Commit

Permalink
feat: Add empty constructor for GNNHeteroGraph (#358)
Browse files Browse the repository at this point in the history
* add empty heterograph constructor

* update docs

* Update src/GNNGraphs/convert.jl

Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>

* Update docs/src/heterograph.md

Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>

* add tests

---------

Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
  • Loading branch information
askorupka and CarloLucibello authored Jan 13, 2024
1 parent 2c11b95 commit 32ca15a
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 23 deletions.
7 changes: 6 additions & 1 deletion docs/src/heterograph.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@ the type [`GNNHeteroGraph`](@ref).

## Creating a Heterograph

A heterograph can be created by passing pairs `edge_type => data` to the constructor.
A heterograph can be created empty or by passing pairs `edge_type => data` to the constructor.
```jldoctest
julia> g = GNNHeteroGraph()
GNNHeteroGraph:
num_nodes: Dict()
num_edges: Dict()
julia> g = GNNHeteroGraph((:user, :like, :actor) => ([1,2,2,3], [1,3,2,9]),
(:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7]))
GNNHeteroGraph:
Expand Down
38 changes: 20 additions & 18 deletions src/GNNGraphs/convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,28 @@ function to_coo(data::EDict; num_nodes = nothing, kws...)
graph = EDict{COO_T}()
_num_nodes = NDict{Int}()
num_edges = EDict{Int}()
for k in keys(data)
d = data[k]
@assert d isa Tuple
if length(d) == 2
d = (d..., nothing)
if !isempty(data)
for k in keys(data)
d = data[k]
@assert d isa Tuple
if length(d) == 2
d = (d..., nothing)
end
if num_nodes !== nothing
n1 = get(num_nodes, k[1], nothing)
n2 = get(num_nodes, k[3], nothing)
else
n1 = nothing
n2 = nothing
end
g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...)
graph[k] = g
num_edges[k] = nedges
_num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1])
_num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2])
end
if num_nodes !== nothing
n1 = get(num_nodes, k[1], nothing)
n2 = get(num_nodes, k[3], nothing)
else
n1 = nothing
n2 = nothing
end
g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...)
graph[k] = g
num_edges[k] = nedges
_num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1])
_num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2])
graph = Dict([k => v for (k, v) in pairs(graph)]...) # try to restrict the key/value types
end
graph = Dict([k => v for (k, v) in pairs(graph)]...) # try to restrict the key/value types
return graph, _num_nodes, num_edges
end

Expand Down
17 changes: 13 additions & 4 deletions src/GNNGraphs/gnnheterograph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ end
GNNHeteroGraph(data; kws...) = GNNHeteroGraph(Dict(data); kws...)
GNNHeteroGraph(data::Pair...; kws...) = GNNHeteroGraph(Dict(data...); kws...)

GNNHeteroGraph() = GNNHeteroGraph(Dict{Tuple{Symbol,Symbol,Symbol}, Any}())

function GNNHeteroGraph(data::Dict; kws...)
all(k -> k isa EType, keys(data)) || throw(ArgumentError("Keys of data must be tuples of the form `(source_type, edge_type, target_type)`"))
return GNNHeteroGraph(Dict([k => v for (k, v) in pairs(data)]...); kws...)
Expand Down Expand Up @@ -135,10 +137,17 @@ function GNNHeteroGraph(data::EDict;
num_graphs = !isnothing(graph_indicator) ?
maximum([maximum(gi) for gi in values(graph_indicator)]) : 1

ndata = normalize_heterographdata(ndata, default_name = :x, ns = num_nodes)
edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges,
duplicate_if_needed = true)
gdata = normalize_graphdata(gdata, default_name = :u, n = num_graphs)

if length(keys(graph)) == 0
ndata = Dict{Symbol, DataStore}()
edata = Dict{Tuple{Symbol, Symbol, Symbol}, DataStore}()
gdata = DataStore()
else
ndata = normalize_heterographdata(ndata, default_name = :x, ns = num_nodes)
edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges,
duplicate_if_needed = true)
gdata = normalize_graphdata(gdata, default_name = :u, n = num_graphs)
end

return GNNHeteroGraph(graph,
num_nodes, num_edges, num_graphs,
Expand Down
10 changes: 10 additions & 0 deletions test/GNNGraphs/gnnheterograph.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@


@testset "Empty constructor" begin
g = GNNHeteroGraph()
@test isempty(g.num_nodes)
g = add_edges(g, (:user, :like, :actor) => ([1,2,3,3,3], [3,5,1,9,4]))
@test g.num_nodes[:user] == 3
@test g.num_nodes[:actor] == 9
@test g.num_edges[(:user, :like, :actor)] == 5
end

@testset "Constructor from pairs" begin
hg = GNNHeteroGraph((:A, :e1, :B) => ([1,2,3,4], [3,2,1,5]))
@test hg.num_nodes == Dict(:A => 4, :B => 5)
Expand Down

0 comments on commit 32ca15a

Please sign in to comment.