Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

broadcast, iterate, and setindex! for TemporalSnapshotsGNNGraph #563

Merged
merged 6 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 63 additions & 11 deletions GNNGraphs/docs/src/guides/temporalgraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ CurrentModule = GNNGraphs

# Temporal Graphs

Temporal Graphs are graphs with time varying topologies and features. In GNNGraphs.jl, temporal graphs with fixed number of nodes over time are supported by the [`TemporalSnapshotsGNNGraph`](@ref) type.
Temporal graphs are graphs with time-varying topologies and features. In GNNGraphs.jl, they are represented by the [`TemporalSnapshotsGNNGraph`](@ref) type.

## Creating a TemporalSnapshotsGNNGraph

Expand All @@ -13,7 +13,7 @@ A temporal graph can be created by passing a list of snapshots to the constructo
```jldoctest temporal
julia> using GNNGraphs

julia> snapshots = [rand_graph(10,20) for i in 1:5];
julia> snapshots = [rand_graph(10, 20) for i in 1:5];

julia> tg = TemporalSnapshotsGNNGraph(snapshots)
TemporalSnapshotsGNNGraph:
Expand Down Expand Up @@ -57,23 +57,77 @@ TemporalSnapshotsGNNGraph:
num_snapshots: 3
```

## Indexing

Snapshots in a temporal graph can be accessed using indexing:

```jldoctest temporal
julia> snapshots = [rand_graph(10, 20), rand_graph(10, 14), rand_graph(10, 22)];

julia> tg = TemporalSnapshotsGNNGraph(snapshots)

julia> tg[1] # first snapshot
GNNGraph:
num_nodes: 10
num_edges: 20

julia> tg[2:3] # snapshots 2 and 3
TemporalSnapshotsGNNGraph:
num_nodes: [10, 10]
num_edges: [14, 22]
num_snapshots: 2
```

A snapshot can be modified by assigning a new snapshot to the temporal graph:

```jldoctest temporal
julia> tg[1] = rand_graph(10, 16) # replace first snapshot
GNNGraph:
num_nodes: 10
num_edges: 16
```

## Iteration and Broadcasting

Iteration and broadcasting over a temporal graph is similar to that of a vector of snapshots:

```jldoctest temporal
julia> snapshots = [rand_graph(10, 20), rand_graph(10, 14), rand_graph(10, 22)];

julia> tg = TemporalSnapshotsGNNGraph(snapshots);

julia> [g for g in tg] # iterate over snapshots
3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}:
GNNGraph(10, 20) with no data
GNNGraph(10, 14) with no data
GNNGraph(10, 22) with no data

julia> f(g) = g isa GNNGraph;

julia> f.(tg) # broadcast over snapshots
3-element BitVector:
1
1
1
```

## Basic Queries

Basic queries are similar to those for [`GNNGraph`](@ref)s:
```jldoctest temporal
julia> snapshots = [rand_graph(10,20), rand_graph(10,14), rand_graph(10,22)];
julia> snapshots = [rand_graph(10,20), rand_graph(12,14), rand_graph(14,22)];

julia> tg = TemporalSnapshotsGNNGraph(snapshots)
TemporalSnapshotsGNNGraph:
num_nodes: [10, 10, 10]
num_nodes: [10, 12, 14]
num_edges: [20, 14, 22]
num_snapshots: 3

julia> tg.num_nodes # number of nodes in each snapshot
3-element Vector{Int64}:
10
10
10
12
14

julia> tg.num_edges # number of edges in each snapshot
3-element Vector{Int64}:
Expand All @@ -87,8 +141,8 @@ julia> tg.num_snapshots # number of snapshots
julia> tg.snapshots # list of snapshots
3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}:
GNNGraph(10, 20) with no data
GNNGraph(10, 14) with no data
GNNGraph(10, 22) with no data
GNNGraph(12, 14) with no data
GNNGraph(14, 22) with no data

julia> tg.snapshots[1] # first snapshot, same as tg[1]
GNNGraph:
Expand All @@ -97,7 +151,7 @@ GNNGraph:
```

## Data Features
A temporal graph can store global feature for the entire time series in the `tgdata` filed.
A temporal graph can store global feature for the entire time series in the `tgdata` field.
Also, each snapshot can store node, edge, and graph features in the `ndata`, `edata`, and `gdata` fields, respectively.

```jldoctest temporal
Expand Down Expand Up @@ -131,5 +185,3 @@ julia> [ds.x for ds in tg.ndata]; # vector containing the x feature of each snap
julia> [g.x for g in tg.snapshots]; # same vector as above, now accessing
# the x feature directly from the snapshots
```


138 changes: 70 additions & 68 deletions GNNGraphs/src/temporalsnapshotsgnngraph.jl
Original file line number Diff line number Diff line change
@@ -1,55 +1,73 @@
"""
TemporalSnapshotsGNNGraph(snapshots::AbstractVector{<:GNNGraph})
TemporalSnapshotsGNNGraph(snapshots)

A type representing a temporal graph as a sequence of snapshots. In this case a snapshot is a [`GNNGraph`](@ref).
A type representing a time-varying graph as a sequence of snapshots,
each snapshot being a [`GNNGraph`](@ref).

`TemporalSnapshotsGNNGraph` can store the feature array associated to the graph itself as a [`DataStore`](@ref) object,
and it uses the [`DataStore`](@ref) objects of each snapshot for the node and edge features.
The features can be passed at construction time or added later.
The argument `snapshots` is a collection of `GNNGraph`s with arbitrary
number of nodes and edges each.

# Constructor Arguments
Calling `tg` the temporal graph, `tg[t]` returns the `t`-th snapshot.

- `snapshot`: a vector of snapshots, where each snapshot must have the same number of nodes.
The snapshots can contain node/edge/graph features, while global features for the
whole temporal sequence can be stored in `tg.tgdata`.

# Examples
See [`add_snapshot`](@ref) and [`remove_snapshot`](@ref) for adding and removing snapshots.

```julia
julia> using GNNGraphs
# Examples

julia> snapshots = [rand_graph(10,20) for i in 1:5];
```jldoctest
julia> snapshots = [rand_graph(i , 2*i) for i in 10:10:50];

julia> tg = TemporalSnapshotsGNNGraph(snapshots)
TemporalSnapshotsGNNGraph:
num_nodes: [10, 10, 10, 10, 10]
num_edges: [20, 20, 20, 20, 20]
num_nodes: [10, 20, 30, 40, 50]
num_edges: [20, 40, 60, 80, 100]
num_snapshots: 5

julia> tg.tgdata.x = rand(4); # add temporal graph feature
julia> tg.num_snapshots
5

julia> tg.num_nodes
5-element Vector{Int64}:
10
20
30
40
50

julia> tg # show temporal graph with new feature
julia> tg[1]
GNNGraph:
num_nodes: 10
num_edges: 20

julia> tg[2:3]
TemporalSnapshotsGNNGraph:
num_nodes: [10, 10, 10, 10, 10]
num_edges: [20, 20, 20, 20, 20]
num_snapshots: 5
tgdata:
x = 4-element Vector{Float64}
num_nodes: [20, 30]
num_edges: [40, 60]
num_snapshots: 2

julia> tg[1] = rand_graph(10, 16)
GNNGraph:
num_nodes: 10
num_edges: 16
```
"""
struct TemporalSnapshotsGNNGraph
num_nodes::AbstractVector{Int}
num_edges::AbstractVector{Int}
struct TemporalSnapshotsGNNGraph{G<:GNNGraph, D<:DataStore}
num_nodes::Vector{Int}
num_edges::Vector{Int}
num_snapshots::Int
snapshots::AbstractVector{<:GNNGraph}
tgdata::DataStore
snapshots::Vector{G}
tgdata::D
end

function TemporalSnapshotsGNNGraph(snapshots::AbstractVector{<:GNNGraph})
@assert all([s.num_nodes == snapshots[1].num_nodes for s in snapshots]) "all snapshots must have the same number of nodes"
function TemporalSnapshotsGNNGraph(snapshots)
snapshots = collect(snapshots)
return TemporalSnapshotsGNNGraph(
[s.num_nodes for s in snapshots],
[s.num_edges for s in snapshots],
length(snapshots),
snapshots,
collect(snapshots),
DataStore()
)
end
Expand All @@ -67,7 +85,25 @@ function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::Int)
end

function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::AbstractVector)
return TemporalSnapshotsGNNGraph(tg.num_nodes[t], tg.num_edges[t], length(t), tg.snapshots[t], tg.tgdata)
return TemporalSnapshotsGNNGraph(tg.num_nodes[t], tg.num_edges[t],
length(t), tg.snapshots[t], tg.tgdata)
end

function Base.length(tg::TemporalSnapshotsGNNGraph)
return tg.num_snapshots
end

# Allow broadcasting over the temporal snapshots
Base.broadcastable(tg::TemporalSnapshotsGNNGraph) = tg.snapshots

Base.iterate(tg::TemporalSnapshotsGNNGraph) = Base.iterate(tg.snapshots)
Base.iterate(tg::TemporalSnapshotsGNNGraph, i) = Base.iterate(tg.snapshots, i)

function Base.setindex!(tg::TemporalSnapshotsGNNGraph, g::GNNGraph, t::Int)
tg.snapshots[t] = g
tg.num_nodes[t] = g.num_nodes
tg.num_edges[t] = g.num_edges
return tg
end

"""
Expand All @@ -78,8 +114,6 @@ Return a `TemporalSnapshotsGNNGraph` created starting from `tg` by adding the sn
# Examples

```jldoctest
julia> using GNNGraphs

julia> snapshots = [rand_graph(10, 20) for i in 1:5];

julia> tg = TemporalSnapshotsGNNGraph(snapshots)
Expand Down Expand Up @@ -185,58 +219,26 @@ end
function Base.getproperty(tg::TemporalSnapshotsGNNGraph, prop::Symbol)
if prop ∈ fieldnames(TemporalSnapshotsGNNGraph)
return getfield(tg, prop)
elseif prop == :ndata
return [s.ndata for s in tg.snapshots]
elseif prop == :edata
return [s.edata for s in tg.snapshots]
elseif prop == :gdata
return [s.gdata for s in tg.snapshots]
else
return [getproperty(s,prop) for s in tg.snapshots]
else
return [getproperty(s, prop) for s in tg.snapshots]
end
end

function Base.show(io::IO, tsg::TemporalSnapshotsGNNGraph)
print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ")
print_feature_t(io, tsg.tgdata)
print(io, " data")
print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots))")
end

function Base.show(io::IO, ::MIME"text/plain", tsg::TemporalSnapshotsGNNGraph)
if get(io, :compact, false)
print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ")
print_feature_t(io, tsg.tgdata)
print(io, " data")
print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots))")
else
print(io,
"TemporalSnapshotsGNNGraph:\n num_nodes: $(tsg.num_nodes)\n num_edges: $(tsg.num_edges)\n num_snapshots: $(tsg.num_snapshots)")
if !isempty(tsg.tgdata)
print(io, "\n tgdata:")
for k in keys(tsg.tgdata)
print(io, "\n\t$k = $(shortsummary(tsg.tgdata[k]))")
end
end
end
end

function print_feature_t(io::IO, feature)
if !isempty(feature)
if length(keys(feature)) == 1
k = first(keys(feature))
v = first(values(feature))
print(io, "$(k): $(dims2string(size(v)))")
else
print(io, "(")
for (i, (k, v)) in enumerate(pairs(feature))
print(io, "$k: $(dims2string(size(v)))")
if i == length(feature)
print(io, ")")
else
print(io, ", ")
end
print(io, "\n $k = $(shortsummary(tsg.tgdata[k]))")
end
end
else
print(io, "no")
end
end
Loading
Loading