Skip to content

Commit

Permalink
changes for Flux v0.15 (#550)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Dec 9, 2024
1 parent 6e5296a commit a6700c3
Show file tree
Hide file tree
Showing 25 changed files with 689 additions and 713 deletions.
4 changes: 2 additions & 2 deletions GNNGraphs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GNNGraphs"
uuid = "aed8fd31-079b-4b5a-b342-a13352159b8c"
authors = ["Carlo Lucibello and contributors"]
version = "1.3.1"
version = "1.4.0-DEV"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -31,7 +31,7 @@ GNNGraphsSimpleWeightedGraphsExt = "SimpleWeightedGraphs"
Adapt = "4"
CUDA = "5"
ChainRulesCore = "1"
Functors = "0.4.1, 0.5"
Functors = "0.5"
Graphs = "1.4"
KrylovKit = "0.8"
LinearAlgebra = "1"
Expand Down
2 changes: 0 additions & 2 deletions GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module GNNGraphs

using SparseArrays
using Functors: @functor
import Graphs
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree,
has_self_loops, is_directed, induced_subgraph, has_edge
Expand All @@ -13,7 +12,6 @@ using ChainRulesCore
using LinearAlgebra, Random, Statistics
import MLUtils
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch, rand_like
import Functors
using MLDataDevices: get_device, cpu_device, CPUDevice

include("chainrules.jl") # hacks for differentiability
Expand Down
2 changes: 0 additions & 2 deletions GNNGraphs/src/datastore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ struct DataStore
end
end

@functor DataStore

DataStore(data) = DataStore(-1, data)
DataStore(n::Int, data::NamedTuple) = DataStore(n, Dict{Symbol, Any}(pairs(data)))
DataStore(n::Int, data) = DataStore(n, Dict{Symbol, Any}(data))
Expand Down
2 changes: 0 additions & 2 deletions GNNGraphs/src/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ struct GNNGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T}
gdata::DataStore
end

@functor GNNGraph

function GNNGraph(data::D;
num_nodes = nothing,
graph_indicator = nothing,
Expand Down
2 changes: 0 additions & 2 deletions GNNGraphs/src/gnnheterograph/gnnheterograph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ struct GNNHeteroGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T}
etypes::Vector{EType}
end

@functor GNNHeteroGraph

GNNHeteroGraph(data; kws...) = GNNHeteroGraph(Dict(data); kws...)
GNNHeteroGraph(data::Pair...; kws...) = GNNHeteroGraph(Dict(data...); kws...)

Expand Down
2 changes: 0 additions & 2 deletions GNNGraphs/src/temporalsnapshotsgnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,5 +240,3 @@ function print_feature_t(io::IO, feature)
print(io, "no")
end
end

@functor TemporalSnapshotsGNNGraph
2 changes: 1 addition & 1 deletion GNNGraphs/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using CUDA, cuDNN
using GNNGraphs
using GNNGraphs: getn, getdata
using Functors
using Functors: Functors
using LinearAlgebra, Statistics, Random
using NNlib
import MLUtils
Expand Down
6 changes: 3 additions & 3 deletions GNNLux/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GNNLux"
uuid = "e8545f4d-a905-48ac-a8c4-ca114b98986d"
authors = ["Carlo Lucibello and contributors"]
version = "0.1.1"
version = "0.2.0-DEV"

[deps]
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Expand All @@ -17,8 +17,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ConcreteStructs = "0.2.3"
GNNGraphs = "1.3"
GNNlib = "0.2.3"
GNNGraphs = "1.4"
GNNlib = "1"
Lux = "1"
LuxCore = "1"
NNlib = "0.9.21"
Expand Down
8 changes: 5 additions & 3 deletions GNNLux/docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using Pkg
Pkg.activate(@__DIR__)
Pkg.develop(path=joinpath(@__DIR__, "..", "..", "GNNGraphs"))
Pkg.develop(path=joinpath(@__DIR__, "..", "..", "GNNlib"))
Pkg.develop(path=joinpath(@__DIR__, ".."))
Pkg.develop([
PackageSpec(path=joinpath(@__DIR__, "..", "..", "GNNGraphs")),
PackageSpec(path=joinpath(@__DIR__, "..", "..", "GNNlib")),
PackageSpec(path=joinpath(@__DIR__, "..")),
])
Pkg.instantiate()

using Documenter
Expand Down
2 changes: 1 addition & 1 deletion GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,7 @@ LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l

function (l::GatedGraphConv)(g, x, ps, st)
gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru))
fgru = (h, x) -> gru((x, (h,))) # make the forward compatible with Flux.GRUCell style
fgru = (x, h) -> gru((x, (h,)))[1] # make the forward compatible with Flux.GRUCell style
m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims)
return GNNlib.gated_graph_conv(m, g, x), st
end
Expand Down
8 changes: 8 additions & 0 deletions GNNLux/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
## The test environment is instantiated as follows:
# using Pkg
# Pkg.activate(@__DIR__)
# Pkg.develop(path=joinpath(@__DIR__, "..", "..", "GNNGraphs"))
# Pkg.develop(path=joinpath(@__DIR__, "..", "..", "GNNlib"))
# Pkg.develop(path=joinpath(@__DIR__, ".."))
# Pkg.instantiate()

using TestItemRunner

## See https://www.julia-vscode.org/docs/stable/userguide/testitems/
Expand Down
6 changes: 3 additions & 3 deletions GNNlib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GNNlib"
uuid = "a6a84749-d869-43f8-aacc-be26a1996e48"
authors = ["Carlo Lucibello and contributors"]
version = "0.2.5"
version = "1.0.0-DEV"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -28,10 +28,10 @@ GNNlibCUDAExt = "CUDA"

[compat]
AMDGPU = "1"
CUDA = "4, 5"
CUDA = "5"
ChainRulesCore = "1.24"
DataStructures = "0.18"
GNNGraphs = "1.0"
GNNGraphs = "1.4"
GPUArraysCore = "0.1"
LinearAlgebra = "1"
MLUtils = "0.4"
Expand Down
6 changes: 4 additions & 2 deletions GNNlib/docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using Pkg
Pkg.activate(@__DIR__)
Pkg.develop(path=joinpath(@__DIR__, "..", "..", "GNNGraphs"))
Pkg.develop(path=joinpath(@__DIR__, ".."))
Pkg.develop([
PackageSpec(path=joinpath(@__DIR__, "..", "..", "GNNGraphs")),
PackageSpec(path=joinpath(@__DIR__, "..")),
])
Pkg.instantiate()

using Documenter
Expand Down
2 changes: 1 addition & 1 deletion GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ function gated_graph_conv(l, g::GNNGraph, x::AbstractMatrix)
m = view(l.weight, :, :, i) * h
m = propagate(copy_xj, g, l.aggr; xj = m)
# in gru forward, hidden state is first argument, input is second
h, _ = l.gru(h, m)
h = l.gru(m, h)
end
return h
end
Expand Down
5 changes: 4 additions & 1 deletion GNNlib/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ topk_index(y::Adjoint, k::Int) = topk_index(y', k)
function set2set_pool(l, g::GNNGraph, x::AbstractMatrix)
n_in = size(x, 1)
qstar = zeros_like(x, (2*n_in, g.num_graphs))
h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
for t in 1:l.num_iters
q = l.lstm(qstar) # [n_in, n_graphs]
h, c = l.lstm(qstar, (h, c)) # [n_in, n_graphs]
q = h
qn = broadcast_nodes(g, q) # [n_in, n_nodes]
α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes]
r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs]
Expand Down
7 changes: 7 additions & 0 deletions GNNlib/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
## The test environment is instantiated as follows:
# using Pkg
# Pkg.activate(@__DIR__)
# Pkg.develop(path=joinpath(@__DIR__, "..", "..", "GNNGraphs"))
# Pkg.develop(path=joinpath(@__DIR__, ".."))
# Pkg.instantiate()

using TestItemRunner

## See https://www.julia-vscode.org/docs/stable/userguide/testitems/
Expand Down
8 changes: 4 additions & 4 deletions GraphNeuralNetworks/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GraphNeuralNetworks"
uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694"
authors = ["Carlo Lucibello and contributors"]
version = "0.6.23"
version = "1.0.0-DEV"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -18,9 +18,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1"
Flux = "0.14"
GNNGraphs = "1.0"
GNNlib = "0.2"
Flux = "0.15"
GNNGraphs = "1.4"
GNNlib = "1"
LinearAlgebra = "1"
MLUtils = "0.4"
MacroTools = "0.5"
Expand Down
8 changes: 5 additions & 3 deletions GraphNeuralNetworks/docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using Pkg
Pkg.activate(@__DIR__)
Pkg.develop(path=joinpath(@__DIR__, "..", "..", "GNNGraphs"))
Pkg.develop(path=joinpath(@__DIR__, "..", "..", "GNNlib"))
Pkg.develop(path=joinpath(@__DIR__, ".."))
Pkg.develop([
PackageSpec(path=joinpath(@__DIR__, "..", "..", "GNNGraphs")),
PackageSpec(path=joinpath(@__DIR__, "..", "..", "GNNlib")),
PackageSpec(path=joinpath(@__DIR__, "..")),
])
Pkg.instantiate()

using Documenter
Expand Down
21 changes: 8 additions & 13 deletions GraphNeuralNetworks/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,19 @@ end
Flux.@layer Set2Set

function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
@assert n_layers >= 1
@assert n_layers == 1 "multiple layers not implemented yet" #TODO
n_out = 2 * n_in

if n_layers == 1
lstm = LSTM(n_out => n_in)
else
layers = [LSTM(n_out => n_in)]
for _ in 2:n_layers
push!(layers, LSTM(n_in => n_in))
end
lstm = Chain(layers...)
end

lstm = LSTMCell(n_out => n_in)
return Set2Set(lstm, n_iters)
end

function initialstates(cell::LSTMCell)
h = zeros_like(cell.Wh, size(cell.Wh, 2))
c = zeros_like(cell.Wh, size(cell.Wh, 2))
return h, c
end

function (l::Set2Set)(g, x)
Flux.reset!(l.lstm)
return GNNlib.set2set_pool(l, g, x)
end

Expand Down
Loading

0 comments on commit a6700c3

Please sign in to comment.