From 82c28b01e083afa529626576bdc356252193f52d Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 7 Dec 2024 12:03:37 +0100 Subject: [PATCH] fix lstm --- GNNLux/Project.toml | 4 ++-- GNNlib/Project.toml | 2 +- GNNlib/src/layers/pool.jl | 5 ++++- GraphNeuralNetworks/Project.toml | 4 ++-- GraphNeuralNetworks/src/layers/pool.jl | 21 ++++++++------------- GraphNeuralNetworks/test/layers/pool.jl | 2 +- 6 files changed, 18 insertions(+), 20 deletions(-) diff --git a/GNNLux/Project.toml b/GNNLux/Project.toml index b0ef019fd..1a4b9d5d4 100644 --- a/GNNLux/Project.toml +++ b/GNNLux/Project.toml @@ -1,7 +1,7 @@ name = "GNNLux" uuid = "e8545f4d-a905-48ac-a8c4-ca114b98986d" authors = ["Carlo Lucibello and contributors"] -version = "0.1.0" +version = "0.2.0-DEV" [deps] ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" @@ -18,7 +18,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ConcreteStructs = "0.2.3" GNNGraphs = "1.3" -GNNlib = "0.2.3" +GNNlib = "1" Lux = "1" LuxCore = "1" NNlib = "0.9.21" diff --git a/GNNlib/Project.toml b/GNNlib/Project.toml index 8a3d2179f..8f8ec9532 100644 --- a/GNNlib/Project.toml +++ b/GNNlib/Project.toml @@ -1,7 +1,7 @@ name = "GNNlib" uuid = "a6a84749-d869-43f8-aacc-be26a1996e48" authors = ["Carlo Lucibello and contributors"] -version = "0.2.3" +version = "1.0.0-DEV" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/GNNlib/src/layers/pool.jl b/GNNlib/src/layers/pool.jl index 4a6735a06..991e18465 100644 --- a/GNNlib/src/layers/pool.jl +++ b/GNNlib/src/layers/pool.jl @@ -29,8 +29,11 @@ topk_index(y::Adjoint, k::Int) = topk_index(y', k) function set2set_pool(l, g::GNNGraph, x::AbstractMatrix) n_in = size(x, 1) qstar = zeros_like(x, (2*n_in, g.num_graphs)) + h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2)) + c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2)) for t in 1:l.num_iters - q = l.lstm(qstar) # [n_in, n_graphs] + h, c = l.lstm(qstar, (h, c)) # [n_in, n_graphs] + q = h qn = broadcast_nodes(g, q) # [n_in, n_nodes] α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes] r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs] diff --git a/GraphNeuralNetworks/Project.toml b/GraphNeuralNetworks/Project.toml index 3184ed3cb..d7892bf36 100644 --- a/GraphNeuralNetworks/Project.toml +++ b/GraphNeuralNetworks/Project.toml @@ -19,8 +19,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" Flux = "0.15" -GNNGraphs = "1.0" -GNNlib = "0.2" +GNNGraphs = "1" +GNNlib = "1" LinearAlgebra = "1" MLUtils = "0.4" MacroTools = "0.5" diff --git a/GraphNeuralNetworks/src/layers/pool.jl b/GraphNeuralNetworks/src/layers/pool.jl index 59164e199..493ef6715 100644 --- a/GraphNeuralNetworks/src/layers/pool.jl +++ b/GraphNeuralNetworks/src/layers/pool.jl @@ -149,24 +149,19 @@ end Flux.@layer Set2Set function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1) - @assert n_layers >= 1 + @assert n_layers == 1 "multiple layers not implemented yet" #TODO n_out = 2 * n_in - - if n_layers == 1 - lstm = LSTM(n_out => n_in) - else - layers = [LSTM(n_out => n_in)] - for _ in 2:n_layers - push!(layers, LSTM(n_in => n_in)) - end - lstm = Chain(layers...) - end - + lstm = LSTMCell(n_out => n_in) return Set2Set(lstm, n_iters) end +function initialstates(cell::LSTMCell) + h = zeros_like(cell.Wh, size(cell.Wh, 2)) + c = zeros_like(cell.Wh, size(cell.Wh, 2)) + return h, c +end + function (l::Set2Set)(g, x) - Flux.reset!(l.lstm) return GNNlib.set2set_pool(l, g, x) end diff --git a/GraphNeuralNetworks/test/layers/pool.jl b/GraphNeuralNetworks/test/layers/pool.jl index 382a728ea..fa1475b20 100644 --- a/GraphNeuralNetworks/test/layers/pool.jl +++ b/GraphNeuralNetworks/test/layers/pool.jl @@ -76,7 +76,7 @@ end n_in = 3 n_iters = 2 - n_layers = 1 + n_layers = 1 #TODO test with more layers g = batch([rand_graph(10, 40, graph_type = GRAPH_T) for _ in 1:5]) g = GNNGraph(g, ndata = rand(Float32, n_in, g.num_nodes)) l = Set2Set(n_in, n_iters, n_layers)