Skip to content

Commit

Permalink
broadcast, iterate, and setindex! for TemporalSnapshotsGNNGraphs (#563)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Dec 17, 2024
1 parent 27d13c8 commit 1a72242
Show file tree
Hide file tree
Showing 10 changed files with 341 additions and 1,830 deletions.
74 changes: 63 additions & 11 deletions GNNGraphs/docs/src/guides/temporalgraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ CurrentModule = GNNGraphs

# Temporal Graphs

Temporal Graphs are graphs with time varying topologies and features. In GNNGraphs.jl, temporal graphs with fixed number of nodes over time are supported by the [`TemporalSnapshotsGNNGraph`](@ref) type.
Temporal graphs are graphs with time-varying topologies and features. In GNNGraphs.jl, they are represented by the [`TemporalSnapshotsGNNGraph`](@ref) type.

## Creating a TemporalSnapshotsGNNGraph

Expand All @@ -13,7 +13,7 @@ A temporal graph can be created by passing a list of snapshots to the constructo
```jldoctest temporal
julia> using GNNGraphs
julia> snapshots = [rand_graph(10,20) for i in 1:5];
julia> snapshots = [rand_graph(10, 20) for i in 1:5];
julia> tg = TemporalSnapshotsGNNGraph(snapshots)
TemporalSnapshotsGNNGraph:
Expand Down Expand Up @@ -57,23 +57,77 @@ TemporalSnapshotsGNNGraph:
num_snapshots: 3
```

## Indexing

Snapshots in a temporal graph can be accessed using indexing:

```jldoctest temporal
julia> snapshots = [rand_graph(10, 20), rand_graph(10, 14), rand_graph(10, 22)];
julia> tg = TemporalSnapshotsGNNGraph(snapshots)
julia> tg[1] # first snapshot
GNNGraph:
num_nodes: 10
num_edges: 20
julia> tg[2:3] # snapshots 2 and 3
TemporalSnapshotsGNNGraph:
num_nodes: [10, 10]
num_edges: [14, 22]
num_snapshots: 2
```

A snapshot can be modified by assigning a new snapshot to the temporal graph:

```jldoctest temporal
julia> tg[1] = rand_graph(10, 16) # replace first snapshot
GNNGraph:
num_nodes: 10
num_edges: 16
```

## Iteration and Broadcasting

Iteration and broadcasting over a temporal graph is similar to that of a vector of snapshots:

```jldoctest temporal
julia> snapshots = [rand_graph(10, 20), rand_graph(10, 14), rand_graph(10, 22)];
julia> tg = TemporalSnapshotsGNNGraph(snapshots);
julia> [g for g in tg] # iterate over snapshots
3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}:
GNNGraph(10, 20) with no data
GNNGraph(10, 14) with no data
GNNGraph(10, 22) with no data
julia> f(g) = g isa GNNGraph;
julia> f.(tg) # broadcast over snapshots
3-element BitVector:
1
1
1
```

## Basic Queries

Basic queries are similar to those for [`GNNGraph`](@ref)s:
```jldoctest temporal
julia> snapshots = [rand_graph(10,20), rand_graph(10,14), rand_graph(10,22)];
julia> snapshots = [rand_graph(10,20), rand_graph(12,14), rand_graph(14,22)];
julia> tg = TemporalSnapshotsGNNGraph(snapshots)
TemporalSnapshotsGNNGraph:
num_nodes: [10, 10, 10]
num_nodes: [10, 12, 14]
num_edges: [20, 14, 22]
num_snapshots: 3
julia> tg.num_nodes # number of nodes in each snapshot
3-element Vector{Int64}:
10
10
10
12
14
julia> tg.num_edges # number of edges in each snapshot
3-element Vector{Int64}:
Expand All @@ -87,8 +141,8 @@ julia> tg.num_snapshots # number of snapshots
julia> tg.snapshots # list of snapshots
3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}:
GNNGraph(10, 20) with no data
GNNGraph(10, 14) with no data
GNNGraph(10, 22) with no data
GNNGraph(12, 14) with no data
GNNGraph(14, 22) with no data
julia> tg.snapshots[1] # first snapshot, same as tg[1]
GNNGraph:
Expand All @@ -97,7 +151,7 @@ GNNGraph:
```

## Data Features
A temporal graph can store global feature for the entire time series in the `tgdata` filed.
A temporal graph can store global feature for the entire time series in the `tgdata` field.
Also, each snapshot can store node, edge, and graph features in the `ndata`, `edata`, and `gdata` fields, respectively.

```jldoctest temporal
Expand Down Expand Up @@ -131,5 +185,3 @@ julia> [ds.x for ds in tg.ndata]; # vector containing the x feature of each snap
julia> [g.x for g in tg.snapshots]; # same vector as above, now accessing
# the x feature directly from the snapshots
```


138 changes: 70 additions & 68 deletions GNNGraphs/src/temporalsnapshotsgnngraph.jl
Original file line number Diff line number Diff line change
@@ -1,55 +1,73 @@
"""
TemporalSnapshotsGNNGraph(snapshots::AbstractVector{<:GNNGraph})
TemporalSnapshotsGNNGraph(snapshots)
A type representing a temporal graph as a sequence of snapshots. In this case a snapshot is a [`GNNGraph`](@ref).
A type representing a time-varying graph as a sequence of snapshots,
each snapshot being a [`GNNGraph`](@ref).
`TemporalSnapshotsGNNGraph` can store the feature array associated to the graph itself as a [`DataStore`](@ref) object,
and it uses the [`DataStore`](@ref) objects of each snapshot for the node and edge features.
The features can be passed at construction time or added later.
The argument `snapshots` is a collection of `GNNGraph`s with arbitrary
number of nodes and edges each.
# Constructor Arguments
Calling `tg` the temporal graph, `tg[t]` returns the `t`-th snapshot.
- `snapshot`: a vector of snapshots, where each snapshot must have the same number of nodes.
The snapshots can contain node/edge/graph features, while global features for the
whole temporal sequence can be stored in `tg.tgdata`.
# Examples
See [`add_snapshot`](@ref) and [`remove_snapshot`](@ref) for adding and removing snapshots.
```julia
julia> using GNNGraphs
# Examples
julia> snapshots = [rand_graph(10,20) for i in 1:5];
```jldoctest
julia> snapshots = [rand_graph(i , 2*i) for i in 10:10:50];
julia> tg = TemporalSnapshotsGNNGraph(snapshots)
TemporalSnapshotsGNNGraph:
num_nodes: [10, 10, 10, 10, 10]
num_edges: [20, 20, 20, 20, 20]
num_nodes: [10, 20, 30, 40, 50]
num_edges: [20, 40, 60, 80, 100]
num_snapshots: 5
julia> tg.tgdata.x = rand(4); # add temporal graph feature
julia> tg.num_snapshots
5
julia> tg.num_nodes
5-element Vector{Int64}:
10
20
30
40
50
julia> tg # show temporal graph with new feature
julia> tg[1]
GNNGraph:
num_nodes: 10
num_edges: 20
julia> tg[2:3]
TemporalSnapshotsGNNGraph:
num_nodes: [10, 10, 10, 10, 10]
num_edges: [20, 20, 20, 20, 20]
num_snapshots: 5
tgdata:
x = 4-element Vector{Float64}
num_nodes: [20, 30]
num_edges: [40, 60]
num_snapshots: 2
julia> tg[1] = rand_graph(10, 16)
GNNGraph:
num_nodes: 10
num_edges: 16
```
"""
struct TemporalSnapshotsGNNGraph
num_nodes::AbstractVector{Int}
num_edges::AbstractVector{Int}
struct TemporalSnapshotsGNNGraph{G<:GNNGraph, D<:DataStore}
num_nodes::Vector{Int}
num_edges::Vector{Int}
num_snapshots::Int
snapshots::AbstractVector{<:GNNGraph}
tgdata::DataStore
snapshots::Vector{G}
tgdata::D
end

function TemporalSnapshotsGNNGraph(snapshots::AbstractVector{<:GNNGraph})
@assert all([s.num_nodes == snapshots[1].num_nodes for s in snapshots]) "all snapshots must have the same number of nodes"
function TemporalSnapshotsGNNGraph(snapshots)
snapshots = collect(snapshots)
return TemporalSnapshotsGNNGraph(
[s.num_nodes for s in snapshots],
[s.num_edges for s in snapshots],
length(snapshots),
snapshots,
collect(snapshots),
DataStore()
)
end
Expand All @@ -67,7 +85,25 @@ function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::Int)
end

function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::AbstractVector)
return TemporalSnapshotsGNNGraph(tg.num_nodes[t], tg.num_edges[t], length(t), tg.snapshots[t], tg.tgdata)
return TemporalSnapshotsGNNGraph(tg.num_nodes[t], tg.num_edges[t],
length(t), tg.snapshots[t], tg.tgdata)
end

function Base.length(tg::TemporalSnapshotsGNNGraph)
return tg.num_snapshots
end

# Allow broadcasting over the temporal snapshots
Base.broadcastable(tg::TemporalSnapshotsGNNGraph) = tg.snapshots

Base.iterate(tg::TemporalSnapshotsGNNGraph) = Base.iterate(tg.snapshots)
Base.iterate(tg::TemporalSnapshotsGNNGraph, i) = Base.iterate(tg.snapshots, i)

function Base.setindex!(tg::TemporalSnapshotsGNNGraph, g::GNNGraph, t::Int)
tg.snapshots[t] = g
tg.num_nodes[t] = g.num_nodes
tg.num_edges[t] = g.num_edges
return tg
end

"""
Expand All @@ -78,8 +114,6 @@ Return a `TemporalSnapshotsGNNGraph` created starting from `tg` by adding the sn
# Examples
```jldoctest
julia> using GNNGraphs
julia> snapshots = [rand_graph(10, 20) for i in 1:5];
julia> tg = TemporalSnapshotsGNNGraph(snapshots)
Expand Down Expand Up @@ -185,58 +219,26 @@ end
function Base.getproperty(tg::TemporalSnapshotsGNNGraph, prop::Symbol)
if prop fieldnames(TemporalSnapshotsGNNGraph)
return getfield(tg, prop)
elseif prop == :ndata
return [s.ndata for s in tg.snapshots]
elseif prop == :edata
return [s.edata for s in tg.snapshots]
elseif prop == :gdata
return [s.gdata for s in tg.snapshots]
else
return [getproperty(s,prop) for s in tg.snapshots]
else
return [getproperty(s, prop) for s in tg.snapshots]
end
end

function Base.show(io::IO, tsg::TemporalSnapshotsGNNGraph)
print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ")
print_feature_t(io, tsg.tgdata)
print(io, " data")
print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots))")
end

function Base.show(io::IO, ::MIME"text/plain", tsg::TemporalSnapshotsGNNGraph)
if get(io, :compact, false)
print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ")
print_feature_t(io, tsg.tgdata)
print(io, " data")
print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots))")
else
print(io,
"TemporalSnapshotsGNNGraph:\n num_nodes: $(tsg.num_nodes)\n num_edges: $(tsg.num_edges)\n num_snapshots: $(tsg.num_snapshots)")
if !isempty(tsg.tgdata)
print(io, "\n tgdata:")
for k in keys(tsg.tgdata)
print(io, "\n\t$k = $(shortsummary(tsg.tgdata[k]))")
end
end
end
end

function print_feature_t(io::IO, feature)
if !isempty(feature)
if length(keys(feature)) == 1
k = first(keys(feature))
v = first(values(feature))
print(io, "$(k): $(dims2string(size(v)))")
else
print(io, "(")
for (i, (k, v)) in enumerate(pairs(feature))
print(io, "$k: $(dims2string(size(v)))")
if i == length(feature)
print(io, ")")
else
print(io, ", ")
end
print(io, "\n $k = $(shortsummary(tsg.tgdata[k]))")
end
end
else
print(io, "no")
end
end
Loading

0 comments on commit 1a72242

Please sign in to comment.