Skip to content

Commit

Permalink
Docs GCNConv
Browse files Browse the repository at this point in the history
  • Loading branch information
aurorarossi committed Nov 7, 2024
1 parent ebd7929 commit ae58c08
Showing 1 changed file with 76 additions and 3 deletions.
79 changes: 76 additions & 3 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,80 @@ _getstate(s::StatefulLuxLayer{Static.True}) = s.st
_getstate(s::StatefulLuxLayer{false}) = s.st_any
_getstate(s::StatefulLuxLayer{Static.False}) = s.st_any


@doc raw"""
GCNConv(in => out, σ=identity; [init_weight, init_bias, use_bias, add_self_loops, use_edge_weight])
Graph convolutional layer from paper [Semi-supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907).
Performs the operation
```math
\mathbf{x}'_i = \sum_{j\in N(i)} a_{ij} W \mathbf{x}_j
```
where ``a_{ij} = 1 / \sqrt{|N(i)||N(j)|}`` is a normalization factor computed from the node degrees.
If the input graph has weighted edges and `use_edge_weight=true`, than ``a_{ij}`` will be computed as
```math
a_{ij} = \frac{e_{j\to i}}{\sqrt{\sum_{j \in N(i)} e_{j\to i}} \sqrt{\sum_{i \in N(j)} e_{i\to j}}}
```
# Arguments
- `in`: Number of input features.
- `out`: Number of output features.
- `σ`: Activation function. Default `identity`.
- `init_weight`: Weights' initializer. Default `glorot_uniform`.
- `init_bias`: Bias initializer. Default `zeros32`.
- `use_bias`: Add learnable bias. Default `true`.
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available).
If `add_self_loops=true` the new weights will be set to 1.
This option is ignored if the `edge_weight` is explicitly provided in the forward pass.
Default `false`.
# Forward
(::GCNConv)(g, x, [edge_weight], ps, st; norm_fn = d -> 1 ./ sqrt.(d), conv_weight=nothing)
Takes as input a graph `g`, a node feature matrix `x` of size `[in, num_nodes]`, optionally an edge weight vector and the parameter and state of the layer. Returns a node feature matrix of size
`[out, num_nodes]`.
The `norm_fn` parameter allows for custom normalization of the graph convolution operation by passing a function as argument.
By default, it computes ``\frac{1}{\sqrt{d}}`` i.e the inverse square root of the degree (`d`) of each node in the graph.
If `conv_weight` is an `AbstractMatrix` of size `[out, in]`, then the convolution is performed using that weight matrix.
# Examples
```julia
using GNNLux, Lux, Random
# initialize random number generator
rng = Random.default_rng()
Random.seed!(rng, 0)
# create data
s = [1,1,2,3]
t = [2,3,1,1]
g = GNNGraph(s, t)
x = randn(Float32, 3, g.num_nodes)
# create layer
l = GCNConv(3 => 5)
# setup layer
ps, st = LuxCore.setup(rng, l)
# forward pass
y = l(g, x, ps, st) # size of the output first entry: 5 × num_nodes
# convolution with edge weights and custom normalization function
w = [1.1, 0.1, 2.3, 0.5]
custom_norm_fn(d) = 1 ./ sqrt.(d + 1) # Custom normalization function
y = l(g, x, w, ps, st; norm_fn = custom_norm_fn)
# Edge weights can also be embedded in the graph.
g = GNNGraph(s, t, w)
l = GCNConv(3 => 5, use_edge_weight=true)
y = l(g, x, ps, st) # same as l(g, x, w)
```
"""
@concrete struct GCNConv <: GNNLayer
in_dims::Int
out_dims::Int
Expand All @@ -18,7 +91,7 @@ _getstate(s::StatefulLuxLayer{Static.False}) = s.st_any
end

function GCNConv(ch::Pair{Int, Int}, σ = identity;
init_weight = glorot_uniform,
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true,
add_self_loops::Bool = true,
Expand Down Expand Up @@ -55,7 +128,7 @@ end

function (l::GCNConv)(g, x, edge_weight, ps, st;
norm_fn = d -> 1 ./ sqrt.(d),
conv_weight=nothing, )
conv_weight=nothing)

m = (; ps.weight, bias = _getbias(ps),
l.add_self_loops, l.use_edge_weight, l.σ)
Expand Down

0 comments on commit ae58c08

Please sign in to comment.