-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
884c2fa
commit fb83792
Showing
17 changed files
with
193 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
[deps] | ||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" | ||
DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" | ||
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" | ||
GNNLux = "e8545f4d-a905-48ac-a8c4-ca114b98986d" | ||
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48" | ||
LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589" | ||
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" | ||
Lux = "b2108857-7c20-44ae-9111-449ecde12c47" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,85 @@ | ||
using Documenter | ||
using DocumenterInterLinks | ||
using GNNlib | ||
using GNNLux | ||
using Lux, GNNGraphs, GNNlib, Graphs | ||
using DocumenterInterLinks | ||
|
||
DocMeta.setdocmeta!(GNNLux, :DocTestSetup, :(using GNNLux); recursive = true) | ||
|
||
mathengine = MathJax3(Dict(:loader => Dict("load" => ["[tex]/require", "[tex]/mathtools"]), | ||
:tex => Dict("inlineMath" => [["\$", "\$"], ["\\(", "\\)"]], | ||
"packages" => [ | ||
"base", | ||
"ams", | ||
"autoload", | ||
"mathtools", | ||
"require" | ||
]))) | ||
|
||
assets=[] | ||
prettyurls = get(ENV, "CI", nothing) == "true" | ||
mathengine = MathJax3() | ||
|
||
interlinks = InterLinks( | ||
"GNNGraphs" => ("https://carlolucibello.github.io/GraphNeuralNetworks.jl/GNNGraphs/", joinpath(dirname(dirname(@__DIR__)), "GNNGraphs", "docs", "build", "objects.inv")), | ||
"GNNlib" => ("https://carlolucibello.github.io/GraphNeuralNetworks.jl/GNNlib/", joinpath(dirname(dirname(@__DIR__)), "GNNlib", "docs", "build", "objects.inv"))) | ||
|
||
"NNlib" => "https://fluxml.ai/NNlib.jl/stable/", | ||
# "GNNGraphs" => ("https://carlolucibello.github.io/GraphNeuralNetworks.jl/GNNGraphs/", joinpath(dirname(dirname(@__DIR__)), "GNNGraphs", "docs", "build", "objects.inv")), | ||
# "GNNlib" => ("https://carlolucibello.github.io/GraphNeuralNetworks.jl/GNNlib/", joinpath(dirname(dirname(@__DIR__)), "GNNlib", "docs", "build", "objects.inv")) | ||
) | ||
|
||
# Copy the docs from GNNGraphs and GNNlib. Will be removed at the end of the script | ||
cp(joinpath(@__DIR__, "../../GNNGraphs/docs/src"), | ||
joinpath(@__DIR__, "src/GNNGraphs"), force=true) | ||
cp(joinpath(@__DIR__, "../../GNNlib/docs/src"), | ||
joinpath(@__DIR__, "src/GNNlib"), force=true) | ||
|
||
makedocs(; | ||
modules = [GNNLux], | ||
doctest = false, | ||
clean = true, | ||
modules = [GNNLux, GNNGraphs, GNNlib], | ||
doctest = false, # TODO: enable doctest | ||
plugins = [interlinks], | ||
format = Documenter.HTML(; mathengine, prettyurls, assets = assets, size_threshold=nothing), | ||
format = Documenter.HTML(; mathengine, | ||
prettyurls = get(ENV, "CI", nothing) == "true", | ||
assets = [], | ||
size_threshold=nothing, | ||
size_threshold_warn=2000000), | ||
sitename = "GNNLux.jl", | ||
pages = ["Home" => "index.md", | ||
"API Reference" => [ | ||
"Basic" => "api/basic.md", | ||
"Convolutional layers" => "api/conv.md", | ||
"Temporal Convolutional layers" => "api/temporalconv.md",], | ||
] | ||
) | ||
|
||
pages = [ | ||
|
||
"Home" => "index.md", | ||
|
||
"Guides" => [ | ||
"Graphs" => "GNNGraphs/guides/gnngraph.md", | ||
"Message Passing" => "GNNlib/guides/messagepassing.md", | ||
"Models" => "guides/models.md", | ||
"Datasets" => "GNNGraphs/guides/datasets.md", | ||
"Heterogeneous Graphs" => "GNNGraphs/guides/heterograph.md", | ||
"Temporal Graphs" => "GNNGraphs/guides/temporalgraph.md", | ||
], | ||
|
||
"API Reference" => [ | ||
"Graphs (GNNGraphs.jl)" => [ | ||
"GNNGraph" => "GNNGraphs/api/gnngraph.md", | ||
"GNNHeteroGraph" => "GNNGraphs/api/heterograph.md", | ||
"TemporalSnapshotsGNNGraph" => "GNNGraphs/api/temporalgraph.md", | ||
"Samplers" => "GNNGraphs/api/samplers.md", | ||
] | ||
|
||
"Message Passing (GNNlib.jl)" => [ | ||
"Message Passing" => "GNNlib/api/messagepassing.md", | ||
"Other Operators" => "GNNlib/api/utils.md", | ||
] | ||
|
||
"Layers" => [ | ||
"Basic layers" => "api/basic.md", | ||
"Convolutional layers" => "api/conv.md", | ||
# "Pooling layers" => "api/pool.md", | ||
"Temporal Convolutional layers" => "api/temporalconv.md", | ||
# "Hetero Convolutional layers" => "api/heteroconv.md", | ||
] | ||
], | ||
|
||
# "Developer guide" => "dev.md", | ||
], | ||
) | ||
|
||
rm(joinpath(@__DIR__, "src/GNNGraphs"), force=true, recursive=true) | ||
rm(joinpath(@__DIR__, "src/GNNlib"), force=true, recursive=true) | ||
|
||
deploydocs(;repo = "github.com/JuliaGraphs/GraphNeuralNetworks.jl.git", devbranch = "master", dirname = "GNNLux") | ||
deploydocs(repo = "github.com/JuliaGraphs/GraphNeuralNetworks.jl.git", | ||
devbranch = "master", | ||
dirname = "GNNLux") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Models | ||
|
||
GNNLux.jl provides common graph convolutional layers by which you can assemble arbitrarily deep or complex models. GNN layers are compatible with | ||
Lux.jl ones, therefore expert Lux users are promptly able to define and train | ||
their models. | ||
|
||
In what follows, we discuss two different styles for model creation: | ||
the *explicit modeling* style, more verbose but more flexible, | ||
and the *implicit modeling* style based on [`GNNLux.GNNChain`](@ref), more concise but less flexible. | ||
|
||
## Explicit modeling | ||
|
||
In the explicit modeling style, the model is created according to the following steps: | ||
|
||
1. Define a new type for your model (`GNN` in the example below). Refer to the | ||
[Lux Manual](https://lux.csail.mit.edu/dev/manual/interface#lux-interface) for the | ||
definition of the type. | ||
2. Define a convenience constructor for your model. | ||
4. Define the forward pass by implementing the call method for your type. | ||
5. Instantiate the model. | ||
|
||
Here is an example of this construction: | ||
```julia | ||
using Lux, GNNLux | ||
using Zygote | ||
using Random, Statistics | ||
|
||
struct GNN <: AbstractLuxContainerLayer{(:conv1, :bn, :conv2, :dropout, :dense)} # step 1 | ||
conv1 | ||
bn | ||
conv2 | ||
dropout | ||
dense | ||
end | ||
|
||
function GNN(din::Int, d::Int, dout::Int) # step 2 | ||
GNN(GraphConv(din => d), | ||
BatchNorm(d), | ||
GraphConv(d => d, relu), | ||
Dropout(0.5), | ||
Dense(d, dout)) | ||
end | ||
|
||
function (model::GNN)(g::GNNGraph, x, ps, st) # step 3 | ||
x, st_conv1 = model.conv1(g, x, ps.conv1, st.conv1) | ||
x, st_bn = model.bn(x, ps.bn, st.bn) | ||
x = relu.(x) | ||
x, st_conv2 = model.conv2(g, x, ps.conv2, st.conv2) | ||
x, st_drop = model.dropout(x, ps.dropout, st.dropout) | ||
x, st_dense = model.dense(x, ps.dense, st.dense) | ||
return x, (conv1=st_conv1, bn=st_bn, conv2=st_conv2, dropout=st_drop, dense=st_dense) | ||
end | ||
|
||
din, d, dout = 3, 4, 2 | ||
model = GNN(din, d, dout) # step 4 | ||
rng = Random.default_rng() | ||
ps, st = Lux.setup(rng, model) | ||
g = rand_graph(rng, 10, 30) | ||
X = randn(Float32, din, 10) | ||
|
||
st = Lux.testmode(st) | ||
y, st = model(g, X, ps, st) | ||
st = Lux.trainmode(st) | ||
grad = Zygote.gradient(ps -> mean(model(g, X, ps, st)[1]), ps)[1] | ||
``` | ||
|
||
## Implicit modeling with GNNChains | ||
|
||
While very flexible, the way in which we defined `GNN` model definition in last section is a bit verbose. | ||
In order to simplify things, we provide the [`GNNLux.GNNChain`](@ref) type. It is very similar | ||
to Lux's well known `Chain`. It allows to compose layers in a sequential fashion as Chain | ||
does, propagating the output of each layer to the next one. In addition, `GNNChain` | ||
propagates the input graph as well, providing it as a first argument | ||
to layers subtyping the [`GNNLux.GNNLayer`](@ref) abstract type. | ||
|
||
Using `GNNChain`, the model definition becomes more concise: | ||
|
||
```julia | ||
model = GNNChain(GraphConv(din => d), | ||
BatchNorm(d), | ||
x -> relu.(x), | ||
GraphConv(d => d, relu), | ||
Dropout(0.5), | ||
Dense(d, dout)) | ||
``` | ||
|
||
The `GNNChain` only propagates the graph and the node features. More complex scenarios, e.g. when also edge features are updated, have to be handled using the explicit definition of the forward pass. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters