From 213c640b90973ad7369024642a17c4fba1cd74ec Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Sun, 22 Dec 2024 21:56:51 +0100 Subject: [PATCH 1/4] First draft node_classification tutorial --- .../docs/src_tutorials/node_classification.jl | 277 ++++++++++++++++++ 1 file changed, 277 insertions(+) create mode 100644 GNNLux/docs/src_tutorials/node_classification.jl diff --git a/GNNLux/docs/src_tutorials/node_classification.jl b/GNNLux/docs/src_tutorials/node_classification.jl new file mode 100644 index 000000000..e40e6aa8c --- /dev/null +++ b/GNNLux/docs/src_tutorials/node_classification.jl @@ -0,0 +1,277 @@ +# # Node Classification with Graph Neural Networks + +# In this tutorial, we will be learning how to use Graph Neural Networks (GNNs) for node classification. Given the ground-truth labels of only a small subset of nodes, and want to infer the labels for all the remaining nodes (transductive learning). + + +# ## Import +# Let us start off by importing some libraries. We will be using `Lux.jl` and `GNNLux.jl` for our tutorial. + +using Lux, GNNLux +using MLDatasets +using Plots, TSne +using Random, Statistics +using Zygote, Optimisers, OneHotArrays + + +ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation +rng = Random.seed!(17) # for reproducibility + +# ## Visualize +# We want to visualize the outputs of the results using t-distributed stochastic neighbor embedding (tsne) to embed our output embeddings onto a 2D plane. + +function visualize_tsne(out, targets) + z = tsne(out, 2) + scatter(z[:, 1], z[:, 2], color = Int.(targets[1:size(z, 1)]), leg = false) +end; + + +# ## Dataset: Cora + +# For our tutorial, we will be using the `Cora` dataset. `Cora` is a citation network of 2708 documents classified into one of seven classes and 5429 links. Each node represents articles/documents and the edges between them when they cite each other. + +# Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words. + +# This dataset was first introduced by [Yang et al. (2016)](https://arxiv.org/abs/1603.08861) as one of the datasets of the `Planetoid` benchmark suite. We will be using [MLDatasets.jl](https://juliaml.github.io/MLDatasets.jl/stable/) for an easy access to this dataset. + +dataset = Cora() + +# Datasets in MLDatasets.jl have `metadata` containing information about the dataset itself. + +dataset.metadata + +# The `graphs` variable GraphDataset contains the graph. The `Cora` dataset contains only 1 graph. + +dataset.graphs + + +# There is only one graph of the dataset. The `node_data` contains `features` indicating if certain words are present or not and `targets` indicating the class for each document. We convert the single-graph dataset to a `GNNGraph`. + +g = mldataset2gnngraph(dataset) + + +println("Number of nodes: $(g.num_nodes)") +println("Number of edges: $(g.num_edges)") +println("Average node degree: $(g.num_edges / g.num_nodes)") +println("Number of training nodes: $(sum(g.ndata.train_mask))") +println("Training node label rate: $(mean(g.ndata.train_mask))") +println("Has isolated nodes: $(has_isolated_nodes(g))") +println("Has self-loops: $(has_self_loops(g))") +println("Is undirected: $(is_bidirected(g))") + + + +# Overall, this dataset is quite similar to the previously used [`KarateClub`](https://juliaml.github.io/MLDatasets.jl/stable/datasets/graphs/#MLDatasets.KarateClub) network. +# We can see that the `Cora` network holds 2,708 nodes and 10,556 edges, resulting in an average node degree of 3.9. +# For training this dataset, we are given the ground-truth categories of 140 nodes (20 for each class). +# This results in a training node label rate of only 5%. + +# We can further see that this network is undirected, and that there exists no isolated nodes (each document has at least one citation). + +x = g.ndata.features # we onehot encode both the node labels (what we want to predict): +y = onehotbatch(g.ndata.targets, 1:7) +train_mask = g.ndata.train_mask; +num_features = size(x)[1]; +hidden_channels = 16; +drop_rate = 0.5; +num_classes = dataset.metadata["num_classes"]; + + +# ## Multi-layer Perception Network (MLP) + +# In theory, we should be able to infer the category of a document solely based on its content, *i.e.* its bag-of-words feature representation, without taking any relational information into account. + +# Let's verify that by constructing a simple MLP that solely operates on input node features (using shared weights across all nodes): + +MLP = Chain(Dense(num_features => hidden_channels, relu), + Dropout(drop_rate), + Dense(hidden_channels => num_classes)) + +ps, st = Lux.setup(rng, MLP) + +# ### Training a Multilayer Perceptron + +# Our MLP is defined by two linear layers and enhanced by [ReLU](https://lux.csail.mit.edu/stable/api/NN_Primitives/ActivationFunctions#NNlib.relu) non-linearity and [Dropout](https://lux.csail.mit.edu/stable/api/Lux/layers#Lux.Dropout). +# Here, we first reduce the 1433-dimensional feature vector to a low-dimensional embedding (`hidden_channels=16`), while the second linear layer acts as a classifier that should map each low-dimensional node embedding to one of the 7 classes. + +# Let's train our simple MLP by following a similar procedure as described in [the first part of this tutorial](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/tutorials/gnn_intro/). +# We again make use of the **cross entropy loss** and **Adam optimizer**. +# This time, we also define a **`accuracy` function** to evaluate how well our final model performs on the test node set (which labels have not been observed during training). + + +function custom_loss(model, ps, st, x) + logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) + ŷ, st = model(x, ps, st) + return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0 +end + +function train_model!(MLP, ps, st, x, epochs) + train_state = Lux.Training.TrainState(MLP, ps, st, Adam(1e-3)) + for iter in 1:epochs + _, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss, x, train_state) + + if iter % 100 == 0 + println("Epoch: $(iter) Loss: $(loss)") + end + end +end + +function accuracy(model, x, ps, st, y, mask) + st = Lux.testmode(st) + ŷ, st = model(x, ps, st) + mean(onecold(ŷ)[mask] .== onecold(y)[mask]) +end + +train_model!(MLP, ps, st, x, 2000) + + + +# After training the model, we can call the `accuracy` function to see how well our model performs on unseen labels. +# Here, we are interested in the accuracy of the model, *i.e.*, the ratio of correctly classified nodes: + +accuracy(MLP, x, ps, st, y, .!train_mask) + +# As one can see, our MLP performs rather bad with only about ~50% test accuracy. +# But why does the MLP do not perform better? +# The main reason for that is that this model suffers from heavy overfitting due to only having access to a **small amount of training nodes**, and therefore generalizes poorly to unseen node representations. + +# It also fails to incorporate an important bias into the model: **Cited papers are very likely related to the category of a document**. +# That is exactly where Graph Neural Networks come into play and can help to boost the performance of our model. + + +# ## Training a Graph Convolutional Neural Network (GNN) + +# Following-up on the first part of this tutorial, we replace the `Dense` linear layers by the [`GCNConv`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/api/conv/#GNNLux.GCNConv) module. +# To recap, the **GCN layer** ([Kipf et al. (2017)](https://arxiv.org/abs/1609.02907)) is defined as + +# ```math +# \mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)} +# ``` + +# where $\mathbf{W}^{(\ell + 1)}$ denotes a trainable weight matrix of shape `[num_output_features, num_input_features]` and $c_{w,v}$ refers to a fixed normalization coefficient for each edge. +# In contrast, a single `Linear` layer is defined as + +# ```math +# \mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \mathbf{x}_v^{(\ell)} +# ``` + +# which does not make use of neighboring node information. + +Lux.@concrete struct GCN <: GNNContainerLayer{(:conv1, :drop, :conv2)} + nf::Int + nc::Int + hd::Int + conv1 + conv2 + drop + use_bias::Bool + init_weight + init_bias +end + +function GCN(num_features, num_classes, hidden_channels, drop_rate; use_bias = true, init_weight = glorot_uniform, init_bias = zeros32) # constructor + conv1 = GCNConv(num_features => hidden_channels) + conv2 = GCNConv(hidden_channels => num_classes) + drop = Dropout(drop_rate) + return GCN(num_features, num_classes, hidden_channels, conv1, conv2, drop, use_bias, init_weight, init_bias) +end + +function (gcn::GCN)(g::GNNGraph, x, ps, st) # forward pass + x, stconv1 = gcn.conv1(g, x, ps.conv1, st.conv1) + x = relu.(x) + x, stdrop = gcn.drop(x, ps.drop, st.drop) + x, stconv2 = gcn.conv2(g, x, ps.conv2, st.conv2) + return x, (conv1 = stconv1, drop = stdrop, conv2 = stconv2) +end + + +# function LuxCore.initialparameters(rng::TaskLocalRNG, l::GCN) # initialize model parameters +# weight_c1 = l.init_weight(rng, l.hd, l.nf) +# weight_c2 = l.init_weight(rng, l.nc, l.hd) +# if l.use_bias +# bias_c1 = l.init_bias(rng, l.hd) +# bias_c2 = l.init_bias(rng, l.nc) +# return (; conv1 = ( weight = weight_c1, bias = bias_c1), drop= LuxCore.initialparameters(rng, l.drop), conv2 = ( weight = weight_c2, bias = bias_c2)) +# end +# return (; conv1 = ( weight = weight_c1), drop= LuxCore.initialparameters(rng, l.drop), conv2 = ( weight = weight_c2)) +# end + + +# Now let's visualize the node embeddings of our **untrained** GCN network. + + + +gcn = GCN(num_features, num_classes, hidden_channels, drop_rate) +ps, st = Lux.setup(rng, gcn) +h_untrained, st = gcn(g, x, ps, st) +h_untrained = h_untrained |> transpose +visualize_tsne(h_untrained, g.ndata.targets) + + +# We certainly can do better by training our model. +# The training and testing procedure is once again the same, but this time we make use of the node features `x` **and** the graph `g` as input to our GCN model. + + + +function custom_loss(gcn, ps, st, tuple) + g, x, y = tuple + logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) + ŷ, st = gcn(g, x, ps, st) + return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0 +end + +function train_model!(gcn, ps, st, g, x, y) + train_state = Lux.Training.TrainState(gcn, ps, st, Adam(1e-2)) + for iter in 1:2000 + _, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, x, y), train_state) + + if iter % 100 == 0 + println("Epoch: $(iter) Loss: $(loss)") + end + end + + return gcn, ps, st +end + +gcn, ps, st = train_model!(gcn, ps, st, g, x, y); + + + +# Now let's evaluate the loss of our trained GCN. + +function accuracy(model, g, x, ps, st, y, mask) + st = Lux.testmode(st) + ŷ, st = model(g, x, ps, st) + mean(onecold(ŷ)[mask] .== onecold(y)[mask]) +end + +train_accuracy = accuracy(gcn, g, g.ndata.features, ps, st, y, train_mask) +test_accuracy = accuracy(gcn, g, g.ndata.features, ps, st, y, .!train_mask) + +println("Train accuracy: $(train_accuracy)") +println("Test accuracy: $(test_accuracy)") + +# **There it is!** +# By simply swapping the linear layers with GNN layers, we can reach **76% of test accuracy**! +# This is in stark contrast to the 50% of test accuracy obtained by our MLP, indicating that relational information plays a crucial role in obtaining better performance. + +# We can also verify that once again by looking at the output embeddings of our trained model, which now produces a far better clustering of nodes of the same category. + + + +st = Lux.testmode(st) # inference mode + +out_trained, st = gcn(g, x, ps, st) +out_trained = out_trained|> transpose +visualize_tsne(out_trained, g.ndata.targets) + +# ## (Optional) Exercises + +# 1. To achieve better model performance and to avoid overfitting, it is usually a good idea to select the best model based on an additional validation set. The `Cora` dataset provides a validation node set as `g.ndata.val_mask`, but we haven't used it yet. Can you modify the code to select and test the model with the highest validation performance? This should bring test performance to **82% accuracy**. + +# 2. How does `GCN` behave when increasing the hidden feature dimensionality or the number of layers? Does increasing the number of layers help at all? + +# 3. You can try to use different GNN layers to see how model performance changes. What happens if you swap out all `GCNConv` instances with [`GATConv`](https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/api/conv/#GraphNeuralNetworks.GATConv) layers that make use of attention? Try to write a 2-layer `GAT` model that makes use of 8 attention heads in the first layer and 1 attention head in the second layer, uses a `dropout` ratio of `0.6` inside and outside each `GATConv` call, and uses a `hidden_channels` dimensions of `8` per head. + + +# ## Conclusion +# In this tutorial, we have seen how to apply GNNs to real-world problems, and, in particular, how they can effectively be used for boosting a model's performance. In the next tutorial, we will look into how GNNs can be used for the task of graph classification. From cac65067f510f51e11936edda6a98e8102671254 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Sun, 22 Dec 2024 22:33:57 +0100 Subject: [PATCH 2/4] Add tutorial --- GNNLux/docs/Project.toml | 1 + GNNLux/docs/make.jl | 1 + GNNLux/docs/make_tutorials.jl | 4 +- .../docs/src/tutorials/node_classification.md | 5906 +++++++++++++++++ .../docs/src_tutorials/node_classification.jl | 31 +- 5 files changed, 5923 insertions(+), 20 deletions(-) create mode 100644 GNNLux/docs/src/tutorials/node_classification.md diff --git a/GNNLux/docs/Project.toml b/GNNLux/docs/Project.toml index 63940977a..cfd4b89cd 100644 --- a/GNNLux/docs/Project.toml +++ b/GNNLux/docs/Project.toml @@ -12,4 +12,5 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +TSne = "24678dba-d5e9-5843-a4c6-250288b04835" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/GNNLux/docs/make.jl b/GNNLux/docs/make.jl index feae0f3c5..a4a990756 100644 --- a/GNNLux/docs/make.jl +++ b/GNNLux/docs/make.jl @@ -61,6 +61,7 @@ makedocs(; "Tutorials" => [ "Introductory tutorials" => [ "Hands on" => "tutorials/gnn_intro.md", + "Node Classification" => "tutorials/node_classification.md", ], ], diff --git a/GNNLux/docs/make_tutorials.jl b/GNNLux/docs/make_tutorials.jl index ebeaaff9d..a204d4b56 100644 --- a/GNNLux/docs/make_tutorials.jl +++ b/GNNLux/docs/make_tutorials.jl @@ -1,3 +1,5 @@ using Literate -Literate.markdown("src_tutorials/gnn_intro.jl", "src/tutorials/"; execute = true) \ No newline at end of file +Literate.markdown("src_tutorials/gnn_intro.jl", "src/tutorials/"; execute = true) + +Literate.markdown("src_tutorials/node_classification.jl", "src/tutorials/"; execute = true) \ No newline at end of file diff --git a/GNNLux/docs/src/tutorials/node_classification.md b/GNNLux/docs/src/tutorials/node_classification.md new file mode 100644 index 000000000..6b96d3a6b --- /dev/null +++ b/GNNLux/docs/src/tutorials/node_classification.md @@ -0,0 +1,5906 @@ +# Node Classification with Graph Neural Networks + +In this tutorial, we will be learning how to use Graph Neural Networks (GNNs) for node classification. Given the ground-truth labels of only a small subset of nodes, we want to infer the labels for all the remaining nodes (transductive learning). + +## Import +Let us start off by importing some libraries. We will be using `Lux.jl` and `GNNLux.jl` for our tutorial. + +````julia +using Lux, GNNLux +using MLDatasets +using Plots, TSne +using Random, Statistics +using Zygote, Optimisers, OneHotArrays + + +ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation +rng = Random.seed!(17); # for reproducibility +```` + +## Visualize +We want to visualize our results using t-distributed stochastic neighbor embedding (tsne) to project our output onto a 2D plane. + +````julia +function visualize_tsne(out, targets) + z = tsne(out, 2) + scatter(z[:, 1], z[:, 2], color = Int.(targets[1:size(z, 1)]), leg = false) +end; +```` + +## Dataset: Cora + +For our tutorial, we will be using the `Cora` dataset. `Cora` is a citation network of 2708 documents categorized into seven classes with 5,429 citation links. Each node represents an article or document, and edges between nodes indicate a citation relationship, where one cites the other. + +Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words. + +This dataset was first introduced by [Yang et al. (2016)](https://arxiv.org/abs/1603.08861) as one of the datasets of the `Planetoid` benchmark suite. We will be using [MLDatasets.jl](https://juliaml.github.io/MLDatasets.jl/stable/) for an easy access to this dataset. + +````julia +dataset = Cora() +```` + +```` +dataset Cora: + metadata => Dict{String, Any} with 3 entries + graphs => 1-element Vector{MLDatasets.Graph} +```` + +Datasets in MLDatasets.jl have `metadata` containing information about the dataset itself. + +````julia +dataset.metadata +```` + +```` +Dict{String, Any} with 3 entries: + "name" => "cora" + "classes" => [1, 2, 3, 4, 5, 6, 7] + "num_classes" => 7 +```` + +The `graphs` variable contains the graph. The `Cora` dataset contains only 1 graph. + +````julia +dataset.graphs +```` + +```` +1-element Vector{MLDatasets.Graph}: + Graph(2708, 10556) +```` + +There is only one graph of the dataset. The `node_data` contains `features` indicating if certain words are present or not and `targets` indicating the class for each document. We convert the single-graph dataset to a `GNNGraph`. + +````julia +g = mldataset2gnngraph(dataset) + + +println("Number of nodes: $(g.num_nodes)") +println("Number of edges: $(g.num_edges)") +println("Average node degree: $(g.num_edges / g.num_nodes)") +println("Number of training nodes: $(sum(g.ndata.train_mask))") +println("Training node label rate: $(mean(g.ndata.train_mask))") +println("Has isolated nodes: $(has_isolated_nodes(g))") +println("Has self-loops: $(has_self_loops(g))") +println("Is undirected: $(is_bidirected(g))") +```` + +```` +Number of nodes: 2708 +Number of edges: 10556 +Average node degree: 3.8980797636632203 +Number of training nodes: 140 +Training node label rate: 0.051698670605613 +Has isolated nodes: false +Has self-loops: false +Is undirected: true + +```` + +Overall, this dataset is quite similar to the previously used [`KarateClub`](https://juliaml.github.io/MLDatasets.jl/stable/datasets/graphs/#MLDatasets.KarateClub) network. +We can see that the `Cora` network holds 2,708 nodes and 10,556 edges, resulting in an average node degree of 3.9. +For training this dataset, we are given the ground-truth categories of 140 nodes (20 for each class). +This results in a training node label rate of only 5%. + +We can further see that this network is undirected, and that there exists no isolated nodes (each document has at least one citation). + +````julia +x = g.ndata.features # we onehot encode the node labels (what we want to predict): +y = onehotbatch(g.ndata.targets, 1:7) +train_mask = g.ndata.train_mask; +num_features = size(x)[1]; +hidden_channels = 16; +drop_rate = 0.5; +num_classes = dataset.metadata["num_classes"]; +```` + +## Multi-layer Perception Network (MLP) + +In theory, we should be able to infer the category of a document solely based on its content, *i.e.* its bag-of-words feature representation, without taking any relational information into account. + +Let's verify that by constructing a simple MLP that solely operates on input node features (using shared weights across all nodes): + +````julia +MLP = Chain(Dense(num_features => hidden_channels, relu), + Dropout(drop_rate), + Dense(hidden_channels => num_classes)) + +ps, st = Lux.setup(rng, MLP); +```` + +```` +┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`. +└ @ LuxCore ~/.julia/packages/LuxCore/SN4dl/src/LuxCore.jl:18 + +```` + +### Training a Multilayer Perceptron + +Our MLP is defined by two linear layers and enhanced by [ReLU](https://lux.csail.mit.edu/stable/api/NN_Primitives/ActivationFunctions#NNlib.relu) non-linearity and [Dropout](https://lux.csail.mit.edu/stable/api/Lux/layers#Lux.Dropout). +Here, we first reduce the 1433-dimensional feature vector to a low-dimensional embedding (`hidden_channels=16`), while the second linear layer acts as a classifier that should map each low-dimensional node embedding to one of the 7 classes. + +Let's train our simple MLP by following a similar procedure as described in [the first part of this tutorial](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/tutorials/gnn_intro/). +We again make use of the **cross entropy loss** and **Adam optimizer**. +This time, we also define a **`accuracy` function** to evaluate how well our final model performs on the test node set (which labels have not been observed during training). + +````julia +function custom_loss(model, ps, st, x) + logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) + ŷ, st = model(x, ps, st) + return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0 +end + +function train_model!(MLP, ps, st, x, epochs) + train_state = Lux.Training.TrainState(MLP, ps, st, Adam(1e-3)) + for iter in 1:epochs + _, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss, x, train_state) + + if iter % 100 == 0 + println("Epoch: $(iter) Loss: $(loss)") + end + end +end + +function accuracy(model, x, ps, st, y, mask) + st = Lux.testmode(st) + ŷ, st = model(x, ps, st) + mean(onecold(ŷ)[mask] .== onecold(y)[mask]) +end + +train_model!(MLP, ps, st, x, 2000) +```` + +```` +┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`. +└ @ LuxCore ~/.julia/packages/LuxCore/SN4dl/src/LuxCore.jl:18 +Epoch: 100 Loss: 0.810594 +Epoch: 200 Loss: 0.48982772 +Epoch: 300 Loss: 0.31716076 +Epoch: 400 Loss: 0.2397098 +Epoch: 500 Loss: 0.20041731 +Epoch: 600 Loss: 0.11589075 +Epoch: 700 Loss: 0.21093586 +Epoch: 800 Loss: 0.18869051 +Epoch: 900 Loss: 0.15322906 +Epoch: 1000 Loss: 0.12451931 +Epoch: 1100 Loss: 0.13396983 +Epoch: 1200 Loss: 0.111468166 +Epoch: 1300 Loss: 0.17113678 +Epoch: 1400 Loss: 0.18155631 +Epoch: 1500 Loss: 0.17731342 +Epoch: 1600 Loss: 0.11386197 +Epoch: 1700 Loss: 0.09408201 +Epoch: 1800 Loss: 0.15806198 +Epoch: 1900 Loss: 0.104388796 +Epoch: 2000 Loss: 0.18465123 + +```` + +After training the model, we can call the `accuracy` function to see how well our model performs on unseen labels. +Here, we are interested in the accuracy of the model, *i.e.*, the ratio of correctly classified nodes: + +````julia +accuracy(MLP, x, ps, st, y, .!train_mask) +```` + +```` +0.5089563862928349 +```` + +As one can see, our MLP performs rather bad with only about ~50% test accuracy. +But why does the MLP do not perform better? +The main reason for that is that this model suffers from heavy overfitting due to only having access to a **small amount of training nodes**, and therefore generalizes poorly to unseen node representations. + +It also fails to incorporate an important bias into the model: **Cited papers are very likely related to the category of a document**. +That is exactly where Graph Neural Networks come into play and can help to boost the performance of our model. + +## Training a Graph Convolutional Neural Network (GNN) + +Following-up on the first part of this tutorial, we replace the `Dense` linear layers by the [`GCNConv`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/api/conv/#GNNLux.GCNConv) module. +To recap, the **GCN layer** ([Kipf et al. (2017)](https://arxiv.org/abs/1609.02907)) is defined as + +```math +\mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)} +``` + +where $\mathbf{W}^{(\ell + 1)}$ denotes a trainable weight matrix of shape `[num_output_features, num_input_features]` and $c_{w,v}$ refers to a fixed normalization coefficient for each edge. +In contrast, a single `Linear` layer is defined as + +```math +\mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \mathbf{x}_v^{(\ell)} +``` + +which does not make use of neighboring node information. + +````julia +Lux.@concrete struct GCN <: GNNContainerLayer{(:conv1, :drop, :conv2)} + nf::Int + nc::Int + hd::Int + conv1 + conv2 + drop + use_bias::Bool + init_weight + init_bias +end; + +function GCN(num_features, num_classes, hidden_channels, drop_rate; use_bias = true, init_weight = glorot_uniform, init_bias = zeros32) # constructor + conv1 = GCNConv(num_features => hidden_channels) + conv2 = GCNConv(hidden_channels => num_classes) + drop = Dropout(drop_rate) + return GCN(num_features, num_classes, hidden_channels, conv1, conv2, drop, use_bias, init_weight, init_bias) +end; + +function (gcn::GCN)(g::GNNGraph, x, ps, st) # forward pass + x, stconv1 = gcn.conv1(g, x, ps.conv1, st.conv1) + x = relu.(x) + x, stdrop = gcn.drop(x, ps.drop, st.drop) + x, stconv2 = gcn.conv2(g, x, ps.conv2, st.conv2) + return x, (conv1 = stconv1, drop = stdrop, conv2 = stconv2) +end; +```` + +function LuxCore.initialparameters(rng::TaskLocalRNG, l::GCN) # initialize model parameters + weight_c1 = l.init_weight(rng, l.hd, l.nf) + weight_c2 = l.init_weight(rng, l.nc, l.hd) + if l.use_bias + bias_c1 = l.init_bias(rng, l.hd) + bias_c2 = l.init_bias(rng, l.nc) + return (; conv1 = ( weight = weight_c1, bias = bias_c1), drop= LuxCore.initialparameters(rng, l.drop), conv2 = ( weight = weight_c2, bias = bias_c2)) + end + return (; conv1 = ( weight = weight_c1), drop= LuxCore.initialparameters(rng, l.drop), conv2 = ( weight = weight_c2)) +end + +Now let's visualize the node embeddings of our **untrained** GCN network. + +````julia +gcn = GCN(num_features, num_classes, hidden_channels, drop_rate) +ps, st = Lux.setup(rng, gcn) +h_untrained, st = gcn(g, x, ps, st) +h_untrained = h_untrained |> transpose +visualize_tsne(h_untrained, g.ndata.targets) +```` + +```@raw html + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +``` + +We certainly can do better by training our model. +The training and testing procedure is once again the same, but this time we make use of the node features `x` **and** the graph `g` as input to our GCN model. + +````julia +function custom_loss(gcn, ps, st, tuple) + g, x, y = tuple + logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) + ŷ, st = gcn(g, x, ps, st) + return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0 +end + +function train_model!(gcn, ps, st, g, x, y) + train_state = Lux.Training.TrainState(gcn, ps, st, Adam(1e-2)) + for iter in 1:2000 + _, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, x, y), train_state) + + if iter % 100 == 0 + println("Epoch: $(iter) Loss: $(loss)") + end + end + + return gcn, ps, st +end + +gcn, ps, st = train_model!(gcn, ps, st, g, x, y); +```` + +```` +┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`. +└ @ LuxCore ~/.julia/packages/LuxCore/SN4dl/src/LuxCore.jl:18 +Epoch: 100 Loss: 0.019381031 +Epoch: 200 Loss: 0.017426146 +Epoch: 300 Loss: 0.006051709 +Epoch: 400 Loss: 0.0015434261 +Epoch: 500 Loss: 0.0052008606 +Epoch: 600 Loss: 0.025294377 +Epoch: 700 Loss: 0.0012917791 +Epoch: 800 Loss: 0.005089373 +Epoch: 900 Loss: 0.00912053 +Epoch: 1000 Loss: 0.002442247 +Epoch: 1100 Loss: 0.00024606875 +Epoch: 1200 Loss: 0.00046606906 +Epoch: 1300 Loss: 0.002437515 +Epoch: 1400 Loss: 0.00019191795 +Epoch: 1500 Loss: 0.0056298207 +Epoch: 1600 Loss: 0.00020503976 +Epoch: 1700 Loss: 0.0028860446 +Epoch: 1800 Loss: 0.02319943 +Epoch: 1900 Loss: 0.00030635786 +Epoch: 2000 Loss: 0.00013437525 + +```` + +Now let's evaluate the loss of our trained GCN. + +````julia +function accuracy(model, g, x, ps, st, y, mask) + st = Lux.testmode(st) + ŷ, st = model(g, x, ps, st) + mean(onecold(ŷ)[mask] .== onecold(y)[mask]) +end + +train_accuracy = accuracy(gcn, g, g.ndata.features, ps, st, y, train_mask) +test_accuracy = accuracy(gcn, g, g.ndata.features, ps, st, y, .!train_mask) + +println("Train accuracy: $(train_accuracy)") +println("Test accuracy: $(test_accuracy)") +```` + +```` +Train accuracy: 1.0 +Test accuracy: 0.7636292834890965 + +```` + +**There it is!** +By simply swapping the linear layers with GNN layers, we can reach **76% of test accuracy**! +This is in stark contrast to the 50% of test accuracy obtained by our MLP, indicating that relational information plays a crucial role in obtaining better performance. + +We can also verify that once again by looking at the output embeddings of our trained model, which now produces a far better clustering of nodes of the same category. + +````julia +st = Lux.testmode(st) # inference mode + +out_trained, st = gcn(g, x, ps, st) +out_trained = out_trained|> transpose +visualize_tsne(out_trained, g.ndata.targets) +```` + +```@raw html + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +``` + +## (Optional) Exercises + +1. To achieve better model performance and to avoid overfitting, it is usually a good idea to select the best model based on an additional validation set. The `Cora` dataset provides a validation node set as `g.ndata.val_mask`, but we haven't used it yet. Can you modify the code to select and test the model with the highest validation performance? This should bring test performance to **> 80% accuracy**. + +2. How does `GCN` behave when increasing the hidden feature dimensionality or the number of layers? Does increasing the number of layers help at all? + +3. You can try to use different GNN layers to see how model performance changes. What happens if you swap out all `GCNConv` instances with [`GATConv`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/api/conv/#GNNLux.GATConv) layers that make use of attention? Try to write a 2-layer `GAT` model that makes use of 8 attention heads in the first layer and 1 attention head in the second layer, uses a `dropout` ratio of `0.6` inside and outside each `GATConv` call, and uses a `hidden_channels` dimensions of `8` per head. + +## Conclusion +In this tutorial, we have seen how to apply GNNs to real-world problems, and, in particular, how they can effectively be used for boosting a model's performance. In the next tutorial, we will look into how GNNs can be used for the task of graph classification. + +--- + +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* + diff --git a/GNNLux/docs/src_tutorials/node_classification.jl b/GNNLux/docs/src_tutorials/node_classification.jl index e40e6aa8c..dac773175 100644 --- a/GNNLux/docs/src_tutorials/node_classification.jl +++ b/GNNLux/docs/src_tutorials/node_classification.jl @@ -1,6 +1,6 @@ # # Node Classification with Graph Neural Networks -# In this tutorial, we will be learning how to use Graph Neural Networks (GNNs) for node classification. Given the ground-truth labels of only a small subset of nodes, and want to infer the labels for all the remaining nodes (transductive learning). +# In this tutorial, we will be learning how to use Graph Neural Networks (GNNs) for node classification. Given the ground-truth labels of only a small subset of nodes, we want to infer the labels for all the remaining nodes (transductive learning). # ## Import @@ -14,10 +14,10 @@ using Zygote, Optimisers, OneHotArrays ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation -rng = Random.seed!(17) # for reproducibility +rng = Random.seed!(17); # for reproducibility # ## Visualize -# We want to visualize the outputs of the results using t-distributed stochastic neighbor embedding (tsne) to embed our output embeddings onto a 2D plane. +# We want to visualize our results using t-distributed stochastic neighbor embedding (tsne) to project our output onto a 2D plane. function visualize_tsne(out, targets) z = tsne(out, 2) @@ -27,7 +27,7 @@ end; # ## Dataset: Cora -# For our tutorial, we will be using the `Cora` dataset. `Cora` is a citation network of 2708 documents classified into one of seven classes and 5429 links. Each node represents articles/documents and the edges between them when they cite each other. +# For our tutorial, we will be using the `Cora` dataset. `Cora` is a citation network of 2708 documents categorized into seven classes with 5,429 citation links. Each node represents an article or document, and edges between nodes indicate a citation relationship, where one cites the other. # Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words. @@ -39,7 +39,7 @@ dataset = Cora() dataset.metadata -# The `graphs` variable GraphDataset contains the graph. The `Cora` dataset contains only 1 graph. +# The `graphs` variable contains the graph. The `Cora` dataset contains only 1 graph. dataset.graphs @@ -58,8 +58,6 @@ println("Has isolated nodes: $(has_isolated_nodes(g))") println("Has self-loops: $(has_self_loops(g))") println("Is undirected: $(is_bidirected(g))") - - # Overall, this dataset is quite similar to the previously used [`KarateClub`](https://juliaml.github.io/MLDatasets.jl/stable/datasets/graphs/#MLDatasets.KarateClub) network. # We can see that the `Cora` network holds 2,708 nodes and 10,556 edges, resulting in an average node degree of 3.9. # For training this dataset, we are given the ground-truth categories of 140 nodes (20 for each class). @@ -67,7 +65,7 @@ println("Is undirected: $(is_bidirected(g))") # We can further see that this network is undirected, and that there exists no isolated nodes (each document has at least one citation). -x = g.ndata.features # we onehot encode both the node labels (what we want to predict): +x = g.ndata.features # we onehot encode the node labels (what we want to predict): y = onehotbatch(g.ndata.targets, 1:7) train_mask = g.ndata.train_mask; num_features = size(x)[1]; @@ -86,7 +84,7 @@ MLP = Chain(Dense(num_features => hidden_channels, relu), Dropout(drop_rate), Dense(hidden_channels => num_classes)) -ps, st = Lux.setup(rng, MLP) +ps, st = Lux.setup(rng, MLP); # ### Training a Multilayer Perceptron @@ -123,8 +121,6 @@ end train_model!(MLP, ps, st, x, 2000) - - # After training the model, we can call the `accuracy` function to see how well our model performs on unseen labels. # Here, we are interested in the accuracy of the model, *i.e.*, the ratio of correctly classified nodes: @@ -166,14 +162,14 @@ Lux.@concrete struct GCN <: GNNContainerLayer{(:conv1, :drop, :conv2)} use_bias::Bool init_weight init_bias -end +end; function GCN(num_features, num_classes, hidden_channels, drop_rate; use_bias = true, init_weight = glorot_uniform, init_bias = zeros32) # constructor conv1 = GCNConv(num_features => hidden_channels) conv2 = GCNConv(hidden_channels => num_classes) drop = Dropout(drop_rate) return GCN(num_features, num_classes, hidden_channels, conv1, conv2, drop, use_bias, init_weight, init_bias) -end +end; function (gcn::GCN)(g::GNNGraph, x, ps, st) # forward pass x, stconv1 = gcn.conv1(g, x, ps.conv1, st.conv1) @@ -181,7 +177,7 @@ function (gcn::GCN)(g::GNNGraph, x, ps, st) # forward pass x, stdrop = gcn.drop(x, ps.drop, st.drop) x, stconv2 = gcn.conv2(g, x, ps.conv2, st.conv2) return x, (conv1 = stconv1, drop = stdrop, conv2 = stconv2) -end +end; # function LuxCore.initialparameters(rng::TaskLocalRNG, l::GCN) # initialize model parameters @@ -234,8 +230,6 @@ end gcn, ps, st = train_model!(gcn, ps, st, g, x, y); - - # Now let's evaluate the loss of our trained GCN. function accuracy(model, g, x, ps, st, y, mask) @@ -249,7 +243,6 @@ test_accuracy = accuracy(gcn, g, g.ndata.features, ps, st, y, .!train_mask) println("Train accuracy: $(train_accuracy)") println("Test accuracy: $(test_accuracy)") - # **There it is!** # By simply swapping the linear layers with GNN layers, we can reach **76% of test accuracy**! # This is in stark contrast to the 50% of test accuracy obtained by our MLP, indicating that relational information plays a crucial role in obtaining better performance. @@ -266,11 +259,11 @@ visualize_tsne(out_trained, g.ndata.targets) # ## (Optional) Exercises -# 1. To achieve better model performance and to avoid overfitting, it is usually a good idea to select the best model based on an additional validation set. The `Cora` dataset provides a validation node set as `g.ndata.val_mask`, but we haven't used it yet. Can you modify the code to select and test the model with the highest validation performance? This should bring test performance to **82% accuracy**. +# 1. To achieve better model performance and to avoid overfitting, it is usually a good idea to select the best model based on an additional validation set. The `Cora` dataset provides a validation node set as `g.ndata.val_mask`, but we haven't used it yet. Can you modify the code to select and test the model with the highest validation performance? This should bring test performance to **> 80% accuracy**. # 2. How does `GCN` behave when increasing the hidden feature dimensionality or the number of layers? Does increasing the number of layers help at all? -# 3. You can try to use different GNN layers to see how model performance changes. What happens if you swap out all `GCNConv` instances with [`GATConv`](https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/api/conv/#GraphNeuralNetworks.GATConv) layers that make use of attention? Try to write a 2-layer `GAT` model that makes use of 8 attention heads in the first layer and 1 attention head in the second layer, uses a `dropout` ratio of `0.6` inside and outside each `GATConv` call, and uses a `hidden_channels` dimensions of `8` per head. +# 3. You can try to use different GNN layers to see how model performance changes. What happens if you swap out all `GCNConv` instances with [`GATConv`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/api/conv/#GNNLux.GATConv) layers that make use of attention? Try to write a 2-layer `GAT` model that makes use of 8 attention heads in the first layer and 1 attention head in the second layer, uses a `dropout` ratio of `0.6` inside and outside each `GATConv` call, and uses a `hidden_channels` dimensions of `8` per head. # ## Conclusion From b62b53b75dc8b424273cebd52cb0b0a0df57b2dd Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Tue, 24 Dec 2024 22:53:36 +0100 Subject: [PATCH 3/4] Fix loss function name --- GNNLux/docs/src_tutorials/node_classification.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/GNNLux/docs/src_tutorials/node_classification.jl b/GNNLux/docs/src_tutorials/node_classification.jl index dac773175..16eafc0e8 100644 --- a/GNNLux/docs/src_tutorials/node_classification.jl +++ b/GNNLux/docs/src_tutorials/node_classification.jl @@ -96,7 +96,7 @@ ps, st = Lux.setup(rng, MLP); # This time, we also define a **`accuracy` function** to evaluate how well our final model performs on the test node set (which labels have not been observed during training). -function custom_loss(model, ps, st, x) +function loss(model, ps, st, x) logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) ŷ, st = model(x, ps, st) return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0 @@ -105,10 +105,10 @@ end function train_model!(MLP, ps, st, x, epochs) train_state = Lux.Training.TrainState(MLP, ps, st, Adam(1e-3)) for iter in 1:epochs - _, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss, x, train_state) + _, loss_value, _, train_state = Lux.Training.single_train_step!(AutoZygote(), loss, x, train_state) if iter % 100 == 0 - println("Epoch: $(iter) Loss: $(loss)") + println("Epoch: $(iter) Loss: $(loss_value)") end end end @@ -208,7 +208,7 @@ visualize_tsne(h_untrained, g.ndata.targets) -function custom_loss(gcn, ps, st, tuple) +function loss(gcn, ps, st, tuple) g, x, y = tuple logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) ŷ, st = gcn(g, x, ps, st) @@ -218,10 +218,10 @@ end function train_model!(gcn, ps, st, g, x, y) train_state = Lux.Training.TrainState(gcn, ps, st, Adam(1e-2)) for iter in 1:2000 - _, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, x, y), train_state) + _, loss_value, _, train_state = Lux.Training.single_train_step!(AutoZygote(), loss,(g, x, y), train_state) if iter % 100 == 0 - println("Epoch: $(iter) Loss: $(loss)") + println("Epoch: $(iter) Loss: $(loss_value)") end end From 5310bbff193999b9353d5bb08d2673d595d21427 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Tue, 24 Dec 2024 23:09:44 +0100 Subject: [PATCH 4/4] Fixes --- GNNLux/docs/Project.toml | 1 + .../docs/src/tutorials/node_classification.md | 10993 ++++++++-------- .../docs/src_tutorials/node_classification.jl | 19 +- 3 files changed, 5494 insertions(+), 5519 deletions(-) diff --git a/GNNLux/docs/Project.toml b/GNNLux/docs/Project.toml index cfd4b89cd..36822253f 100644 --- a/GNNLux/docs/Project.toml +++ b/GNNLux/docs/Project.toml @@ -1,5 +1,6 @@ [deps] CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" diff --git a/GNNLux/docs/src/tutorials/node_classification.md b/GNNLux/docs/src/tutorials/node_classification.md index 6b96d3a6b..8f332dac6 100644 --- a/GNNLux/docs/src/tutorials/node_classification.md +++ b/GNNLux/docs/src/tutorials/node_classification.md @@ -10,7 +10,7 @@ using Lux, GNNLux using MLDatasets using Plots, TSne using Random, Statistics -using Zygote, Optimisers, OneHotArrays +using Zygote, Optimisers, OneHotArrays, ConcreteStructs ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation @@ -130,7 +130,7 @@ ps, st = Lux.setup(rng, MLP); ```` ┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`. -└ @ LuxCore ~/.julia/packages/LuxCore/SN4dl/src/LuxCore.jl:18 +└ @ LuxCore ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:18 ```` @@ -144,7 +144,7 @@ We again make use of the **cross entropy loss** and **Adam optimizer**. This time, we also define a **`accuracy` function** to evaluate how well our final model performs on the test node set (which labels have not been observed during training). ````julia -function custom_loss(model, ps, st, x) +function loss(model, ps, st, x) logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) ŷ, st = model(x, ps, st) return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0 @@ -153,10 +153,10 @@ end function train_model!(MLP, ps, st, x, epochs) train_state = Lux.Training.TrainState(MLP, ps, st, Adam(1e-3)) for iter in 1:epochs - _, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss, x, train_state) + _, loss_value, _, train_state = Lux.Training.single_train_step!(AutoZygote(), loss, x, train_state) if iter % 100 == 0 - println("Epoch: $(iter) Loss: $(loss)") + println("Epoch: $(iter) Loss: $(loss_value)") end end end @@ -172,7 +172,7 @@ train_model!(MLP, ps, st, x, 2000) ```` ┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`. -└ @ LuxCore ~/.julia/packages/LuxCore/SN4dl/src/LuxCore.jl:18 +└ @ LuxCore ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:18 Epoch: 100 Loss: 0.810594 Epoch: 200 Loss: 0.48982772 Epoch: 300 Loss: 0.31716076 @@ -233,7 +233,7 @@ In contrast, a single `Linear` layer is defined as which does not make use of neighboring node information. ````julia -Lux.@concrete struct GCN <: GNNContainerLayer{(:conv1, :drop, :conv2)} +@concrete struct GCN <: GNNContainerLayer{(:conv1, :drop, :conv2)} nf::Int nc::Int hd::Int @@ -261,17 +261,6 @@ function (gcn::GCN)(g::GNNGraph, x, ps, st) # forward pass end; ```` -function LuxCore.initialparameters(rng::TaskLocalRNG, l::GCN) # initialize model parameters - weight_c1 = l.init_weight(rng, l.hd, l.nf) - weight_c2 = l.init_weight(rng, l.nc, l.hd) - if l.use_bias - bias_c1 = l.init_bias(rng, l.hd) - bias_c2 = l.init_bias(rng, l.nc) - return (; conv1 = ( weight = weight_c1, bias = bias_c1), drop= LuxCore.initialparameters(rng, l.drop), conv2 = ( weight = weight_c2, bias = bias_c2)) - end - return (; conv1 = ( weight = weight_c1), drop= LuxCore.initialparameters(rng, l.drop), conv2 = ( weight = weight_c2)) -end - Now let's visualize the node embeddings of our **untrained** GCN network. ````julia @@ -286,2752 +275,2752 @@ visualize_tsne(h_untrained, g.ndata.targets) - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ``` @@ -3040,7 +3029,7 @@ We certainly can do better by training our model. The training and testing procedure is once again the same, but this time we make use of the node features `x` **and** the graph `g` as input to our GCN model. ````julia -function custom_loss(gcn, ps, st, tuple) +function loss(gcn, ps, st, tuple) g, x, y = tuple logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) ŷ, st = gcn(g, x, ps, st) @@ -3050,10 +3039,10 @@ end function train_model!(gcn, ps, st, g, x, y) train_state = Lux.Training.TrainState(gcn, ps, st, Adam(1e-2)) for iter in 1:2000 - _, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, x, y), train_state) + _, loss_value, _, train_state = Lux.Training.single_train_step!(AutoZygote(), loss,(g, x, y), train_state) if iter % 100 == 0 - println("Epoch: $(iter) Loss: $(loss)") + println("Epoch: $(iter) Loss: $(loss_value)") end end @@ -3065,7 +3054,7 @@ gcn, ps, st = train_model!(gcn, ps, st, g, x, y); ```` ┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`. -└ @ LuxCore ~/.julia/packages/LuxCore/SN4dl/src/LuxCore.jl:18 +└ @ LuxCore ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:18 Epoch: 100 Loss: 0.019381031 Epoch: 200 Loss: 0.017426146 Epoch: 300 Loss: 0.006051709 @@ -3129,2762 +3118,2762 @@ visualize_tsne(out_trained, g.ndata.targets) - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ``` diff --git a/GNNLux/docs/src_tutorials/node_classification.jl b/GNNLux/docs/src_tutorials/node_classification.jl index 16eafc0e8..3ba82fb93 100644 --- a/GNNLux/docs/src_tutorials/node_classification.jl +++ b/GNNLux/docs/src_tutorials/node_classification.jl @@ -10,7 +10,7 @@ using Lux, GNNLux using MLDatasets using Plots, TSne using Random, Statistics -using Zygote, Optimisers, OneHotArrays +using Zygote, Optimisers, OneHotArrays, ConcreteStructs ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation @@ -152,7 +152,7 @@ accuracy(MLP, x, ps, st, y, .!train_mask) # which does not make use of neighboring node information. -Lux.@concrete struct GCN <: GNNContainerLayer{(:conv1, :drop, :conv2)} +@concrete struct GCN <: GNNContainerLayer{(:conv1, :drop, :conv2)} nf::Int nc::Int hd::Int @@ -179,23 +179,8 @@ function (gcn::GCN)(g::GNNGraph, x, ps, st) # forward pass return x, (conv1 = stconv1, drop = stdrop, conv2 = stconv2) end; - -# function LuxCore.initialparameters(rng::TaskLocalRNG, l::GCN) # initialize model parameters -# weight_c1 = l.init_weight(rng, l.hd, l.nf) -# weight_c2 = l.init_weight(rng, l.nc, l.hd) -# if l.use_bias -# bias_c1 = l.init_bias(rng, l.hd) -# bias_c2 = l.init_bias(rng, l.nc) -# return (; conv1 = ( weight = weight_c1, bias = bias_c1), drop= LuxCore.initialparameters(rng, l.drop), conv2 = ( weight = weight_c2, bias = bias_c2)) -# end -# return (; conv1 = ( weight = weight_c1), drop= LuxCore.initialparameters(rng, l.drop), conv2 = ( weight = weight_c2)) -# end - - # Now let's visualize the node embeddings of our **untrained** GCN network. - - gcn = GCN(num_features, num_classes, hidden_channels, drop_rate) ps, st = Lux.setup(rng, gcn) h_untrained, st = gcn(g, x, ps, st)