diff --git a/GNNGraphs/docs/src/guides/temporalgraph.md b/GNNGraphs/docs/src/guides/temporalgraph.md index 74202041f..631e3b944 100644 --- a/GNNGraphs/docs/src/guides/temporalgraph.md +++ b/GNNGraphs/docs/src/guides/temporalgraph.md @@ -4,7 +4,7 @@ CurrentModule = GNNGraphs # Temporal Graphs -Temporal Graphs are graphs with time varying topologies and features. In GNNGraphs.jl, temporal graphs with fixed number of nodes over time are supported by the [`TemporalSnapshotsGNNGraph`](@ref) type. +Temporal graphs are graphs with time-varying topologies and features. In GNNGraphs.jl, they are represented by the [`TemporalSnapshotsGNNGraph`](@ref) type. ## Creating a TemporalSnapshotsGNNGraph @@ -13,7 +13,7 @@ A temporal graph can be created by passing a list of snapshots to the constructo ```jldoctest temporal julia> using GNNGraphs -julia> snapshots = [rand_graph(10,20) for i in 1:5]; +julia> snapshots = [rand_graph(10, 20) for i in 1:5]; julia> tg = TemporalSnapshotsGNNGraph(snapshots) TemporalSnapshotsGNNGraph: @@ -57,23 +57,77 @@ TemporalSnapshotsGNNGraph: num_snapshots: 3 ``` +## Indexing + +Snapshots in a temporal graph can be accessed using indexing: + +```jldoctest temporal +julia> snapshots = [rand_graph(10, 20), rand_graph(10, 14), rand_graph(10, 22)]; + +julia> tg = TemporalSnapshotsGNNGraph(snapshots) + +julia> tg[1] # first snapshot +GNNGraph: + num_nodes: 10 + num_edges: 20 + +julia> tg[2:3] # snapshots 2 and 3 +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10] + num_edges: [14, 22] + num_snapshots: 2 +``` + +A snapshot can be modified by assigning a new snapshot to the temporal graph: + +```jldoctest temporal +julia> tg[1] = rand_graph(10, 16) # replace first snapshot +GNNGraph: + num_nodes: 10 + num_edges: 16 +``` + +## Iteration and Broadcasting + +Iteration and broadcasting over a temporal graph is similar to that of a vector of snapshots: + +```jldoctest temporal +julia> snapshots = [rand_graph(10, 20), rand_graph(10, 14), rand_graph(10, 22)]; + +julia> tg = TemporalSnapshotsGNNGraph(snapshots); + +julia> [g for g in tg] # iterate over snapshots +3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}: + GNNGraph(10, 20) with no data + GNNGraph(10, 14) with no data + GNNGraph(10, 22) with no data + +julia> f(g) = g isa GNNGraph; + +julia> f.(tg) # broadcast over snapshots +3-element BitVector: + 1 + 1 + 1 +``` + ## Basic Queries Basic queries are similar to those for [`GNNGraph`](@ref)s: ```jldoctest temporal -julia> snapshots = [rand_graph(10,20), rand_graph(10,14), rand_graph(10,22)]; +julia> snapshots = [rand_graph(10,20), rand_graph(12,14), rand_graph(14,22)]; julia> tg = TemporalSnapshotsGNNGraph(snapshots) TemporalSnapshotsGNNGraph: - num_nodes: [10, 10, 10] + num_nodes: [10, 12, 14] num_edges: [20, 14, 22] num_snapshots: 3 julia> tg.num_nodes # number of nodes in each snapshot 3-element Vector{Int64}: 10 - 10 - 10 + 12 + 14 julia> tg.num_edges # number of edges in each snapshot 3-element Vector{Int64}: @@ -87,8 +141,8 @@ julia> tg.num_snapshots # number of snapshots julia> tg.snapshots # list of snapshots 3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}: GNNGraph(10, 20) with no data - GNNGraph(10, 14) with no data - GNNGraph(10, 22) with no data + GNNGraph(12, 14) with no data + GNNGraph(14, 22) with no data julia> tg.snapshots[1] # first snapshot, same as tg[1] GNNGraph: @@ -97,7 +151,7 @@ GNNGraph: ``` ## Data Features -A temporal graph can store global feature for the entire time series in the `tgdata` filed. +A temporal graph can store global feature for the entire time series in the `tgdata` field. Also, each snapshot can store node, edge, and graph features in the `ndata`, `edata`, and `gdata` fields, respectively. ```jldoctest temporal @@ -131,5 +185,3 @@ julia> [ds.x for ds in tg.ndata]; # vector containing the x feature of each snap julia> [g.x for g in tg.snapshots]; # same vector as above, now accessing # the x feature directly from the snapshots ``` - - diff --git a/GNNGraphs/src/temporalsnapshotsgnngraph.jl b/GNNGraphs/src/temporalsnapshotsgnngraph.jl index 53983e4c2..641162e58 100644 --- a/GNNGraphs/src/temporalsnapshotsgnngraph.jl +++ b/GNNGraphs/src/temporalsnapshotsgnngraph.jl @@ -1,55 +1,73 @@ """ - TemporalSnapshotsGNNGraph(snapshots::AbstractVector{<:GNNGraph}) + TemporalSnapshotsGNNGraph(snapshots) -A type representing a temporal graph as a sequence of snapshots. In this case a snapshot is a [`GNNGraph`](@ref). +A type representing a time-varying graph as a sequence of snapshots, +each snapshot being a [`GNNGraph`](@ref). -`TemporalSnapshotsGNNGraph` can store the feature array associated to the graph itself as a [`DataStore`](@ref) object, -and it uses the [`DataStore`](@ref) objects of each snapshot for the node and edge features. -The features can be passed at construction time or added later. +The argument `snapshots` is a collection of `GNNGraph`s with arbitrary +number of nodes and edges each. -# Constructor Arguments +Calling `tg` the temporal graph, `tg[t]` returns the `t`-th snapshot. -- `snapshot`: a vector of snapshots, where each snapshot must have the same number of nodes. +The snapshots can contain node/edge/graph features, while global features for the +whole temporal sequence can be stored in `tg.tgdata`. -# Examples +See [`add_snapshot`](@ref) and [`remove_snapshot`](@ref) for adding and removing snapshots. -```julia -julia> using GNNGraphs +# Examples -julia> snapshots = [rand_graph(10,20) for i in 1:5]; +```jldoctest +julia> snapshots = [rand_graph(i , 2*i) for i in 10:10:50]; julia> tg = TemporalSnapshotsGNNGraph(snapshots) TemporalSnapshotsGNNGraph: - num_nodes: [10, 10, 10, 10, 10] - num_edges: [20, 20, 20, 20, 20] + num_nodes: [10, 20, 30, 40, 50] + num_edges: [20, 40, 60, 80, 100] num_snapshots: 5 -julia> tg.tgdata.x = rand(4); # add temporal graph feature +julia> tg.num_snapshots +5 + +julia> tg.num_nodes +5-element Vector{Int64}: + 10 + 20 + 30 + 40 + 50 -julia> tg # show temporal graph with new feature +julia> tg[1] +GNNGraph: + num_nodes: 10 + num_edges: 20 + +julia> tg[2:3] TemporalSnapshotsGNNGraph: - num_nodes: [10, 10, 10, 10, 10] - num_edges: [20, 20, 20, 20, 20] - num_snapshots: 5 - tgdata: - x = 4-element Vector{Float64} + num_nodes: [20, 30] + num_edges: [40, 60] + num_snapshots: 2 + +julia> tg[1] = rand_graph(10, 16) +GNNGraph: + num_nodes: 10 + num_edges: 16 ``` """ -struct TemporalSnapshotsGNNGraph - num_nodes::AbstractVector{Int} - num_edges::AbstractVector{Int} +struct TemporalSnapshotsGNNGraph{G<:GNNGraph, D<:DataStore} + num_nodes::Vector{Int} + num_edges::Vector{Int} num_snapshots::Int - snapshots::AbstractVector{<:GNNGraph} - tgdata::DataStore + snapshots::Vector{G} + tgdata::D end -function TemporalSnapshotsGNNGraph(snapshots::AbstractVector{<:GNNGraph}) - @assert all([s.num_nodes == snapshots[1].num_nodes for s in snapshots]) "all snapshots must have the same number of nodes" +function TemporalSnapshotsGNNGraph(snapshots) + snapshots = collect(snapshots) return TemporalSnapshotsGNNGraph( [s.num_nodes for s in snapshots], [s.num_edges for s in snapshots], length(snapshots), - snapshots, + collect(snapshots), DataStore() ) end @@ -67,7 +85,25 @@ function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::Int) end function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::AbstractVector) - return TemporalSnapshotsGNNGraph(tg.num_nodes[t], tg.num_edges[t], length(t), tg.snapshots[t], tg.tgdata) + return TemporalSnapshotsGNNGraph(tg.num_nodes[t], tg.num_edges[t], + length(t), tg.snapshots[t], tg.tgdata) +end + +function Base.length(tg::TemporalSnapshotsGNNGraph) + return tg.num_snapshots +end + +# Allow broadcasting over the temporal snapshots +Base.broadcastable(tg::TemporalSnapshotsGNNGraph) = tg.snapshots + +Base.iterate(tg::TemporalSnapshotsGNNGraph) = Base.iterate(tg.snapshots) +Base.iterate(tg::TemporalSnapshotsGNNGraph, i) = Base.iterate(tg.snapshots, i) + +function Base.setindex!(tg::TemporalSnapshotsGNNGraph, g::GNNGraph, t::Int) + tg.snapshots[t] = g + tg.num_nodes[t] = g.num_nodes + tg.num_edges[t] = g.num_edges + return tg end """ @@ -78,8 +114,6 @@ Return a `TemporalSnapshotsGNNGraph` created starting from `tg` by adding the sn # Examples ```jldoctest -julia> using GNNGraphs - julia> snapshots = [rand_graph(10, 20) for i in 1:5]; julia> tg = TemporalSnapshotsGNNGraph(snapshots) @@ -185,58 +219,26 @@ end function Base.getproperty(tg::TemporalSnapshotsGNNGraph, prop::Symbol) if prop ∈ fieldnames(TemporalSnapshotsGNNGraph) return getfield(tg, prop) - elseif prop == :ndata - return [s.ndata for s in tg.snapshots] - elseif prop == :edata - return [s.edata for s in tg.snapshots] - elseif prop == :gdata - return [s.gdata for s in tg.snapshots] - else - return [getproperty(s,prop) for s in tg.snapshots] + else + return [getproperty(s, prop) for s in tg.snapshots] end end function Base.show(io::IO, tsg::TemporalSnapshotsGNNGraph) - print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ") - print_feature_t(io, tsg.tgdata) - print(io, " data") + print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots))") end function Base.show(io::IO, ::MIME"text/plain", tsg::TemporalSnapshotsGNNGraph) if get(io, :compact, false) - print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ") - print_feature_t(io, tsg.tgdata) - print(io, " data") + print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots))") else print(io, "TemporalSnapshotsGNNGraph:\n num_nodes: $(tsg.num_nodes)\n num_edges: $(tsg.num_edges)\n num_snapshots: $(tsg.num_snapshots)") if !isempty(tsg.tgdata) print(io, "\n tgdata:") for k in keys(tsg.tgdata) - print(io, "\n\t$k = $(shortsummary(tsg.tgdata[k]))") - end - end - end -end - -function print_feature_t(io::IO, feature) - if !isempty(feature) - if length(keys(feature)) == 1 - k = first(keys(feature)) - v = first(values(feature)) - print(io, "$(k): $(dims2string(size(v)))") - else - print(io, "(") - for (i, (k, v)) in enumerate(pairs(feature)) - print(io, "$k: $(dims2string(size(v)))") - if i == length(feature) - print(io, ")") - else - print(io, ", ") - end + print(io, "\n $k = $(shortsummary(tsg.tgdata[k]))") end end - else - print(io, "no") end end diff --git a/GNNGraphs/test/temporalsnapshotsgnngraph.jl b/GNNGraphs/test/temporalsnapshotsgnngraph.jl index bb4c061ae..352dbbedd 100644 --- a/GNNGraphs/test/temporalsnapshotsgnngraph.jl +++ b/GNNGraphs/test/temporalsnapshotsgnngraph.jl @@ -1,12 +1,20 @@ +#TODO add graph_type = GRAPH_TYPE to all constructor calls + @testset "Constructor array TemporalSnapshotsGNNGraph" begin snapshots = [rand_graph(10, 20) for i in 1:5] - tsg = TemporalSnapshotsGNNGraph(snapshots) - @test tsg.num_nodes == [10 for i in 1:5] - @test tsg.num_edges == [20 for i in 1:5] - wrsnapshots = [rand_graph(10,20), rand_graph(12,22)] - @test_throws AssertionError TemporalSnapshotsGNNGraph(wrsnapshots) + tg = TemporalSnapshotsGNNGraph(snapshots) + @test tg.num_nodes == [10 for i in 1:5] + @test tg.num_edges == [20 for i in 1:5] + @test tg.num_snapshots == 5 + + snapshots = [rand_graph(i, 2*i) for i in 10:10:50] + tg = TemporalSnapshotsGNNGraph(snapshots) + @test tg.num_nodes == [i for i in 10:10:50] + @test tg.num_edges == [2*i for i in 10:10:50] + @test tg.num_snapshots == 5 end + @testset "==" begin snapshots = [rand_graph(10, 20) for i in 1:5] tsg1 = TemporalSnapshotsGNNGraph(snapshots) @@ -24,8 +32,19 @@ end @test tsg[[1,2]] == TemporalSnapshotsGNNGraph([10,10], [20,20], 2, snapshots[1:2], tsg.tgdata) end +@testset "setindex!" begin + snapshots = [rand_graph(10, 20) for i in 1:5] + tsg = TemporalSnapshotsGNNGraph(snapshots) + g = rand_graph(20, 40) + tsg[3] = g + @test tsg.snapshots[3] === g + @test tsg.num_nodes == [10, 10, 20, 10, 10] + @test tsg.num_edges == [20, 20, 40, 20, 20] + @test_throws MethodError tsg[3:4] = g +end + @testset "getproperty" begin - x = rand(10) + x = rand(Float32, 10) snapshots = [rand_graph(10, 20, ndata = x) for i in 1:5] tsg = TemporalSnapshotsGNNGraph(snapshots) @test tsg.tgdata == DataStore() @@ -95,18 +114,31 @@ end @testset "show" begin snapshots = [rand_graph(10, 20) for i in 1:5] tsg = TemporalSnapshotsGNNGraph(snapshots) - @test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with no data" - @test sprint(show, MIME("text/plain"), tsg; context=:compact => true) == "TemporalSnapshotsGNNGraph(5) with no data" + @test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5)" + @test sprint(show, MIME("text/plain"), tsg; context=:compact => true) == "TemporalSnapshotsGNNGraph(5)" @test sprint(show, MIME("text/plain"), tsg; context=:compact => false) == "TemporalSnapshotsGNNGraph:\n num_nodes: [10, 10, 10, 10, 10]\n num_edges: [20, 20, 20, 20, 20]\n num_snapshots: 5" - tsg.tgdata.x=rand(4) - @test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with x: 4-element data" + tsg.tgdata.x = rand(Float32, 4) + @test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5)" +end + +@testset "broadcastable" begin + snapshots = [rand_graph(10, 20) for i in 1:5] + tsg = TemporalSnapshotsGNNGraph(snapshots) + f(g) = g isa GNNGraph + @test f.(tsg) == trues(5) +end + +@testset "iterate" begin + snapshots = [rand_graph(10, 20) for i in 1:5] + tsg = TemporalSnapshotsGNNGraph(snapshots) + @test [g for g in tsg] isa Vector{<:GNNGraph} end if TEST_GPU @testset "gpu" begin - snapshots = [rand_graph(10, 20; ndata = rand(5,10)) for i in 1:5] + snapshots = [rand_graph(10, 20; ndata = rand(Float32, 5,10)) for i in 1:5] tsg = TemporalSnapshotsGNNGraph(snapshots) - tsg.tgdata.x = rand(5) + tsg.tgdata.x = rand(Float32, 5) dev = CUDADevice() #TODO replace with `gpu_device()` tsg = tsg |> dev @test tsg.snapshots[1].ndata.x isa CuArray diff --git a/GNNLux/docs/src_tutorials/gnn_intro.jl b/GNNLux/docs/src_tutorials/gnn_intro.jl index 1fa18e41a..d09cae3e7 100644 --- a/GNNLux/docs/src_tutorials/gnn_intro.jl +++ b/GNNLux/docs/src_tutorials/gnn_intro.jl @@ -220,7 +220,7 @@ visualize_embeddings(emb_init, colors = labels) # If you are not new to Lux, this scheme should appear familiar to you. # Note that our semi-supervised learning scenario is achieved by the following line: -# ``` +# ```julia # logitcrossentropy(ŷ[:,train_mask], y[:,train_mask]) # ``` # While we compute node embeddings for all of our nodes, we **only make use of the training nodes for computing the loss**. diff --git a/GraphNeuralNetworks/docs/Project.toml b/GraphNeuralNetworks/docs/Project.toml index f317ed97f..a6f5fbee7 100644 --- a/GraphNeuralNetworks/docs/Project.toml +++ b/GraphNeuralNetworks/docs/Project.toml @@ -10,6 +10,7 @@ GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PlutoStaticHTML = "359b1769-a58e-495b-9770-312e911026ad" @@ -19,3 +20,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" TSne = "24678dba-d5e9-5843-a4c6-250288b04835" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[compat] +Literate = "2.20" diff --git a/GraphNeuralNetworks/docs/make_tutorials_literate.jl b/GraphNeuralNetworks/docs/make_tutorials_literate.jl new file mode 100644 index 000000000..92035e40b --- /dev/null +++ b/GraphNeuralNetworks/docs/make_tutorials_literate.jl @@ -0,0 +1,6 @@ +using Literate + +ENV["DATADEPS_ALWAYS_ACCEPT"] = true + +Literate.markdown("src_tutorials/introductory_tutorials/temporal_graph_classification.jl", + "src/tutorials/"; execute = true) diff --git a/GraphNeuralNetworks/docs/make_tutorials.jl b/GraphNeuralNetworks/docs/make_tutorials_pluto.jl similarity index 79% rename from GraphNeuralNetworks/docs/make_tutorials.jl rename to GraphNeuralNetworks/docs/make_tutorials_pluto.jl index 5091a696f..78545017b 100644 --- a/GraphNeuralNetworks/docs/make_tutorials.jl +++ b/GraphNeuralNetworks/docs/make_tutorials_pluto.jl @@ -22,14 +22,14 @@ build_notebooks(bopt, move_tutorials("src_tutorials/introductory_tutorials/", "src/tutorials/") # Build temporal tutorials -bopt_temp = BuildOptions("src_tutorials/temporalconv_tutorials/"; +bopt_temp = BuildOptions("src_tutorials/"; output_format = documenter_output, use_distributed = false) build_notebooks( BuildOptions(bopt_temp; output_format = documenter_output), - ["temporal_graph_classification_pluto.jl", "traffic_prediction.jl"], + ["traffic_prediction.jl"], OutputOptions() ) -move_tutorials("src_tutorials/temporalconv_tutorials/", "src/tutorials/") \ No newline at end of file +move_tutorials("src_tutorials/", "src/tutorials/") \ No newline at end of file diff --git a/GraphNeuralNetworks/docs/src_tutorials/introductory_tutorials/temporal_graph_classification.jl b/GraphNeuralNetworks/docs/src_tutorials/introductory_tutorials/temporal_graph_classification.jl new file mode 100644 index 000000000..ff0cf439d --- /dev/null +++ b/GraphNeuralNetworks/docs/src_tutorials/introductory_tutorials/temporal_graph_classification.jl @@ -0,0 +1,150 @@ + +# # Temporal Graph classification with GraphNeuralNetworks.jl +# +# In this tutorial, we will learn how to extend the graph classification task to the case of temporal graphs, i.e., graphs whose topology and features are time-varying. +# +# We will design and train a simple temporal graph neural network architecture to classify subjects' gender (female or male) using the temporal graphs extracted from their brain fMRI scan signals. Given the large amount of data, we will implement the training so that it can also run on the GPU. + +# ## Import +# +# We start by importing the necessary libraries. We use `GraphNeuralNetworks.jl`, `Flux.jl` and `MLDatasets.jl`, among others. + +using Flux +using GraphNeuralNetworks +using Statistics, Random +using LinearAlgebra +using MLDatasets: TemporalBrains +using CUDA # comment out if you don't have a CUDA GPU + +# ## Dataset: TemporalBrains +# The TemporalBrains dataset contains a collection of functional brain connectivity networks from 1000 subjects obtained from resting-state functional MRI data from the [Human Connectome Project (HCP)](https://www.humanconnectome.org/study/hcp-young-adult/document/extensively-processed-fmri-data-documentation). +# Functional connectivity is defined as the temporal dependence of neuronal activation patterns of anatomically separated brain regions. +# +# The graph nodes represent brain regions and their number is fixed at 102 for each of the 27 snapshots, while the edges, representing functional connectivity, change over time. +# For each snapshot, the feature of a node represents the average activation of the node during that snapshot. +# Each temporal graph has a label representing gender ('M' for male and 'F' for female) and age group (22-25, 26-30, 31-35, and 36+). +# The network's edge weights are binarized, and the threshold is set to 0.6 by default. + +brain_dataset = TemporalBrains() + +# After loading the dataset from the MLDatasets.jl package, we see that there are 1000 graphs and we need to convert them to the `TemporalSnapshotsGNNGraph` format. +# So we create a function called `data_loader` that implements the latter and splits the dataset into the training set that will be used to train the model and the test set that will be used to test the performance of the model. + + +function data_loader(brain_dataset) + graphs = brain_dataset.graphs + dataset = Vector{TemporalSnapshotsGNNGraph}(undef, length(graphs)) + for i in 1:length(graphs) + graph = graphs[i] + dataset[i] = TemporalSnapshotsGNNGraph(GraphNeuralNetworks.mlgraph2gnngraph.(graph.snapshots)) + # Add graph and node features + for t in 1:27 + s = dataset[i].snapshots[t] + s.ndata.x = [I(102); s.ndata.x'] + end + dataset[i].tgdata.g = Float32.(Flux.onehot(graph.graph_data.g, ["F", "M"])) + end + # Split the dataset into a 80% training set and a 20% test set + train_loader = dataset[1:200] + test_loader = dataset[201:250] + return train_loader, test_loader +end + +# The first part of the `data_loader` function calls the `mlgraph2gnngraph` function for each snapshot, which takes the graph and converts it to a `GNNGraph`. The vector of `GNNGraph`s is then rewritten to a `TemporalSnapshotsGNNGraph`. +# +# The second part adds the graph and node features to the temporal graphs, in particular it adds the one-hot encoding of the label of the graph (in this case we directly use the identity matrix) and appends the mean activation of the node of the snapshot (which is contained in the vector `dataset[i].snapshots[t].ndata.x`, where `i` is the index indicating the subject and `t` is the snapshot). For the graph feature, it adds the one-hot encoding of gender. +# +# The last part splits the dataset. + +# ## Model +# +# We now implement a simple model that takes a `TemporalSnapshotsGNNGraph` as input. +# It consists of a `GINConv` applied independently to each snapshot, a `GlobalPool` to get an embedding for each snapshot, a pooling on the time dimension to get an embedding for the whole temporal graph, and finally a `Dense` layer. +# +# First, we start by adapting the `GlobalPool` to the `TemporalSnapshotsGNNGraphs`. + +function (l::GlobalPool)(g::TemporalSnapshotsGNNGraph, x::AbstractVector) + h = [reduce_nodes(l.aggr, g[i], x[i]) for i in 1:(g.num_snapshots)] + sze = size(h[1]) + reshape(reduce(hcat, h), sze[1], length(h)) +end + +# Then we implement the constructor of the model, which we call `GenderPredictionModel`, and the foward pass. + +struct GenderPredictionModel + gin::GINConv + mlp::Chain + globalpool::GlobalPool + dense::Dense +end + +Flux.@layer GenderPredictionModel + +function GenderPredictionModel(; nfeatures = 103, nhidden = 128, σ = relu) + mlp = Chain(Dense(nfeatures => nhidden, σ), Dense(nhidden => nhidden, σ)) + gin = GINConv(mlp, 0.5) + globalpool = GlobalPool(mean) + dense = Dense(nhidden => 2) + return GenderPredictionModel(gin, mlp, globalpool, dense) +end + +function (m::GenderPredictionModel)(g::TemporalSnapshotsGNNGraph) + h = m.gin(g, g.ndata.x) + h = m.globalpool(g, h) + h = mean(h, dims=2) + return m.dense(h) +end + +# ## Training +# +# We train the model for 100 epochs, using the Adam optimizer with a learning rate of 0.001. We use the `logitbinarycrossentropy` as the loss function, which is typically used as the loss in two-class classification, where the labels are given in a one-hot format. +# The accuracy expresses the number of correct classifications. + +lossfunction(ŷ, y) = Flux.logitbinarycrossentropy(ŷ, y); + +function eval_loss_accuracy(model, data_loader) + error = mean([lossfunction(model(g), g.tgdata.g) for g in data_loader]) + acc = mean([round(100 * mean(Flux.onecold(model(g)) .== Flux.onecold(g.tgdata.g)); digits = 2) for g in data_loader]) + return (loss = error, acc = acc) +end + +function train(dataset) + device = gpu_device() + + function report(epoch) + train_loss, train_acc = eval_loss_accuracy(model, train_loader) + test_loss, test_acc = eval_loss_accuracy(model, test_loader) + println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))") + return (train_loss, train_acc, test_loss, test_acc) + end + + model = GenderPredictionModel() |> device + + opt = Flux.setup(Adam(1.0f-3), model) + + train_loader, test_loader = data_loader(dataset) + train_loader = train_loader |> device + test_loader = test_loader |> device + + report(0) + for epoch in 1:100 + for g in train_loader + grads = Flux.gradient(model) do model + ŷ = model(g) + lossfunction(vec(ŷ), g.tgdata.g) + end + Flux.update!(opt, model, grads[1]) + end + if epoch % 10 == 0 + report(epoch) + end + end + return model +end + + +train(brain_dataset) + +## Conclusions +# +# In this tutorial, we implemented a very simple architecture to classify temporal graphs in the context of gender classification using brain data. We then trained the model on the GPU for 100 epochs on the TemporalBrains dataset. The accuracy of the model is approximately 75-80%, but can be improved by fine-tuning the parameters and training on more data. diff --git a/GraphNeuralNetworks/docs/src_tutorials/temporalconv_tutorials/traffic_prediction.jl b/GraphNeuralNetworks/docs/src_tutorials/introductory_tutorials/traffic_prediction_pluto.jl similarity index 100% rename from GraphNeuralNetworks/docs/src_tutorials/temporalconv_tutorials/traffic_prediction.jl rename to GraphNeuralNetworks/docs/src_tutorials/introductory_tutorials/traffic_prediction_pluto.jl diff --git a/GraphNeuralNetworks/docs/src_tutorials/temporalconv_tutorials/temporal_graph_classification_pluto.jl b/GraphNeuralNetworks/docs/src_tutorials/temporalconv_tutorials/temporal_graph_classification_pluto.jl deleted file mode 100644 index 7a664869a..000000000 --- a/GraphNeuralNetworks/docs/src_tutorials/temporalconv_tutorials/temporal_graph_classification_pluto.jl +++ /dev/null @@ -1,1735 +0,0 @@ -### A Pluto.jl notebook ### -# v0.19.45 - -#> [frontmatter] -#> author = "[Aurora Rossi](https://github.com/aurorarossi)" -#> title = "Temporal Graph classification with Graph Neural Networks" -#> date = "2024-03-06" -#> description = "Temporal Graph classification with GraphNeuralNetworks.jl" -#> cover = "assets/brain_gnn.gif" - -using Markdown -using InteractiveUtils - -# ╔═╡ b8df1800-c69d-4e18-8a0a-097381b62a4c -begin - using Flux - using GraphNeuralNetworks - using Statistics, Random - using LinearAlgebra - using MLDatasets: TemporalBrains - using CUDA - using cuDNN -end - -# ╔═╡ 69d00ec8-da47-11ee-1bba-13a14e8a6db2 -md" -# Temporal Graph classification with GraphNeuralNetworks.jl - -In this tutorial, we will learn how to extend the graph classification task to the case of temporal graphs, i.e., graphs whose topology and features are time-varying. - -We will design and train a simple temporal graph neural network architecture to classify subjects' gender (female or male) using the temporal graphs extracted from their brain fMRI scan signals. Given the large amount of data, we will implement the training so that it can also run on the GPU. -" - -# ╔═╡ ef8406e4-117a-4cc6-9fa5-5028695b1a4f -md" -## Import - -We start by importing the necessary libraries. We use `GraphNeuralNetworks.jl`, `Flux.jl` and `MLDatasets.jl`, among others. -" - -# ╔═╡ 2544d468-1430-4986-88a9-be4df2a7cf27 -md" -## Dataset: TemporalBrains -The TemporalBrains dataset contains a collection of functional brain connectivity networks from 1000 subjects obtained from resting-state functional MRI data from the [Human Connectome Project (HCP)](https://www.humanconnectome.org/study/hcp-young-adult/document/extensively-processed-fmri-data-documentation). -Functional connectivity is defined as the temporal dependence of neuronal activation patterns of anatomically separated brain regions. - -The graph nodes represent brain regions and their number is fixed at 102 for each of the 27 snapshots, while the edges, representing functional connectivity, change over time. -For each snapshot, the feature of a node represents the average activation of the node during that snapshot. -Each temporal graph has a label representing gender ('M' for male and 'F' for female) and age group (22-25, 26-30, 31-35, and 36+). -The network's edge weights are binarized, and the threshold is set to 0.6 by default. -" - -# ╔═╡ f2dbc66d-b8b7-46ae-ad5b-cbba1af86467 -brain_dataset = TemporalBrains() - -# ╔═╡ d9e4722d-6f02-4d41-955c-8bb3e411e404 -md"After loading the dataset from the MLDatasets.jl package, we see that there are 1000 graphs and we need to convert them to the `TemporalSnapshotsGNNGraph` format. -So we create a function called `data_loader` that implements the latter and splits the dataset into the training set that will be used to train the model and the test set that will be used to test the performance of the model. -" - -# ╔═╡ bb36237a-5545-47d0-a873-7ddff3efe8ba -function data_loader(brain_dataset) - graphs = brain_dataset.graphs - dataset = Vector{TemporalSnapshotsGNNGraph}(undef, length(graphs)) - for i in 1:length(graphs) - graph = graphs[i] - dataset[i] = TemporalSnapshotsGNNGraph(GraphNeuralNetworks.mlgraph2gnngraph.(graph.snapshots)) - # Add graph and node features - for t in 1:27 - s = dataset[i].snapshots[t] - s.ndata.x = [I(102); s.ndata.x'] - end - dataset[i].tgdata.g = Float32.(Flux.onehot(graph.graph_data.g, ["F", "M"])) - end - # Split the dataset into a 80% training set and a 20% test set - train_loader = dataset[1:200] - test_loader = dataset[201:250] - return train_loader, test_loader -end; - -# ╔═╡ d4732340-9179-4ada-b82e-a04291d745c2 -md" -The first part of the `data_loader` function calls the `mlgraph2gnngraph` function for each snapshot, which takes the graph and converts it to a `GNNGraph`. The vector of `GNNGraph`s is then rewritten to a `TemporalSnapshotsGNNGraph`. - -The second part adds the graph and node features to the temporal graphs, in particular it adds the one-hot encoding of the label of the graph (in this case we directly use the identity matrix) and appends the mean activation of the node of the snapshot (which is contained in the vector `dataset[i].snapshots[t].ndata.x`, where `i` is the index indicating the subject and `t` is the snapshot). For the graph feature, it adds the one-hot encoding of gender. - -The last part splits the dataset. -" - - -# ╔═╡ ec088a59-2fc2-426a-a406-f8f8d6784128 -md" -## Model - -We now implement a simple model that takes a `TemporalSnapshotsGNNGraph` as input. -It consists of a `GINConv` applied independently to each snapshot, a `GlobalPool` to get an embedding for each snapshot, a pooling on the time dimension to get an embedding for the whole temporal graph, and finally a `Dense` layer. - -First, we start by adapting the `GlobalPool` to the `TemporalSnapshotsGNNGraphs`. -" - -# ╔═╡ 5ea98df9-4920-4c94-9472-3ef475af89fd -function (l::GlobalPool)(g::TemporalSnapshotsGNNGraph, x::AbstractVector) - h = [reduce_nodes(l.aggr, g[i], x[i]) for i in 1:(g.num_snapshots)] - sze = size(h[1]) - reshape(reduce(hcat, h), sze[1], length(h)) -end - -# ╔═╡ cfda2cf4-d08b-4f46-bd39-02ae3ed53369 -md" -Then we implement the constructor of the model, which we call `GenderPredictionModel`, and the foward pass. -" - -# ╔═╡ 2eedd408-67ee-47b2-be6f-2caec94e95b5 -begin - struct GenderPredictionModel - gin::GINConv - mlp::Chain - globalpool::GlobalPool - f::Function - dense::Dense - end - - Flux.@layer GenderPredictionModel - - function GenderPredictionModel(; nfeatures = 103, nhidden = 128, activation = relu) - mlp = Chain(Dense(nfeatures, nhidden, activation), Dense(nhidden, nhidden, activation)) - gin = GINConv(mlp, 0.5) - globalpool = GlobalPool(mean) - f = x -> mean(x, dims = 2) - dense = Dense(nhidden, 2) - GenderPredictionModel(gin, mlp, globalpool, f, dense) - end - - function (m::GenderPredictionModel)(g::TemporalSnapshotsGNNGraph) - h = m.gin(g, g.ndata.x) - h = m.globalpool(g, h) - h = m.f(h) - m.dense(h) - end - -end - -# ╔═╡ 76780020-406d-4803-9af0-d928e54fc18c -md" -## Training - -We train the model for 100 epochs, using the Adam optimizer with a learning rate of 0.001. We use the `logitbinarycrossentropy` as the loss function, which is typically used as the loss in two-class classification, where the labels are given in a one-hot format. -The accuracy expresses the number of correct classifications. -" - -# ╔═╡ 0a1e07b0-a4f3-4a4b-bcd1-7fe200967cf8 -lossfunction(ŷ, y) = Flux.logitbinarycrossentropy(ŷ, y); - -# ╔═╡ cc2ebdcf-72de-4a3b-af46-5bddab6689cc -function eval_loss_accuracy(model, data_loader) - error = mean([lossfunction(model(g), g.tgdata.g) for g in data_loader]) - acc = mean([round(100 * mean(Flux.onecold(model(g)) .== Flux.onecold(g.tgdata.g)); digits = 2) for g in data_loader]) - return (loss = error, acc = acc) -end; - -# ╔═╡ d64be72e-8c1f-4551-b4f2-28c8b78466c0 -function train(dataset; usecuda::Bool, kws...) - - if usecuda && CUDA.functional() #check if GPU is available - my_device = gpu - @info "Training on GPU" - else - my_device = cpu - @info "Training on CPU" - end - - function report(epoch) - train_loss, train_acc = eval_loss_accuracy(model, train_loader) - test_loss, test_acc = eval_loss_accuracy(model, test_loader) - println("Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc))") - return (train_loss, train_acc, test_loss, test_acc) - end - - model = GenderPredictionModel() |> my_device - - opt = Flux.setup(Adam(1.0f-3), model) - - train_loader, test_loader = data_loader(dataset) - train_loader = train_loader |> my_device - test_loader = test_loader |> my_device - - report(0) - for epoch in 1:100 - for g in train_loader - grads = Flux.gradient(model) do model - ŷ = model(g) - lossfunction(vec(ŷ), g.tgdata.g) - end - Flux.update!(opt, model, grads[1]) - end - if epoch % 10 == 0 - report(epoch) - end - end - return model -end; - - -# ╔═╡ 483f17ba-871c-4769-88bd-8ec781d1909d -train(brain_dataset; usecuda = true) - -# ╔═╡ b4a3059a-db7d-47f1-9ae5-b8c3d896c5e5 -md" -We set up the training on the GPU because training takes a lot of time, especially when working on the CPU. -" - -# ╔═╡ cb4eed19-2658-411d-886c-e0c9c2b44219 -md" -## Conclusions - -In this tutorial, we implemented a very simple architecture to classify temporal graphs in the context of gender classification using brain data. We then trained the model on the GPU for 100 epochs on the TemporalBrains dataset. The accuracy of the model is approximately 75-80%, but can be improved by fine-tuning the parameters and training on more data. -" - -# ╔═╡ 00000000-0000-0000-0000-000000000001 -PLUTO_PROJECT_TOML_CONTENTS = """ -[deps] -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[compat] -CUDA = "~5.4.3" -Flux = "~0.14.16" -GraphNeuralNetworks = "~0.6.19" -MLDatasets = "~0.7.16" -cuDNN = "~1.3.2" -""" - -# ╔═╡ 00000000-0000-0000-0000-000000000002 -PLUTO_MANIFEST_TOML_CONTENTS = """ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.10.4" -manifest_format = "2.0" -project_hash = "25724970092e282d6cd2d6ea9e021d61f3714205" - -[[deps.AbstractFFTs]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.5.0" -weakdeps = ["ChainRulesCore", "Test"] - - [deps.AbstractFFTs.extensions] - AbstractFFTsChainRulesCoreExt = "ChainRulesCore" - AbstractFFTsTestExt = "Test" - -[[deps.Accessors]] -deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] -git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" -uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.37" - - [deps.Accessors.extensions] - AccessorsAxisKeysExt = "AxisKeys" - AccessorsIntervalSetsExt = "IntervalSets" - AccessorsStaticArraysExt = "StaticArrays" - AccessorsStructArraysExt = "StructArrays" - AccessorsUnitfulExt = "Unitful" - - [deps.Accessors.weakdeps] - AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - Requires = "ae029012-a4dd-5104-9daa-d747884805df" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" - -[[deps.Adapt]] -deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.0.4" -weakdeps = ["StaticArrays"] - - [deps.Adapt.extensions] - AdaptStaticArraysExt = "StaticArrays" - -[[deps.ArgCheck]] -git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.3.0" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.ArnoldiMethod]] -deps = ["LinearAlgebra", "Random", "StaticArrays"] -git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6" -uuid = "ec485272-7323-5ecc-a04f-4719b315124d" -version = "0.4.0" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Atomix]] -deps = ["UnsafeAtomics"] -git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" -uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" -version = "0.1.0" - -[[deps.AtomsBase]] -deps = ["LinearAlgebra", "PeriodicTable", "Printf", "Requires", "StaticArrays", "Unitful", "UnitfulAtomic"] -git-tree-sha1 = "995c2b6b17840cd87b722ce9c6cdd72f47bab545" -uuid = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" -version = "0.3.5" - -[[deps.BFloat16s]] -deps = ["LinearAlgebra", "Printf", "Random", "Test"] -git-tree-sha1 = "2c7cc21e8678eff479978a0a2ef5ce2f51b63dff" -uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -version = "0.5.0" - -[[deps.BangBang]] -deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] -git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" -uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.4.3" - - [deps.BangBang.extensions] - BangBangChainRulesCoreExt = "ChainRulesCore" - BangBangDataFramesExt = "DataFrames" - BangBangStaticArraysExt = "StaticArrays" - BangBangStructArraysExt = "StructArrays" - BangBangTablesExt = "Tables" - BangBangTypedTablesExt = "TypedTables" - - [deps.BangBang.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" - TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.Baselet]] -git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" -uuid = "9718e550-a3fa-408a-8086-8db961cd8217" -version = "0.1.1" - -[[deps.BitFlags]] -git-tree-sha1 = "0691e34b3bb8be9307330f88d1a3c3f25466c24d" -uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" -version = "0.1.9" - -[[deps.BufferedStreams]] -git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" -uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" -version = "1.2.1" - -[[deps.CEnum]] -git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" -uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.5.0" - -[[deps.CSV]] -deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] -git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" -uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -version = "0.10.14" - -[[deps.CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics"] -git-tree-sha1 = "fdd9dfb67dfefd548f51000cc400bb51003de247" -uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "5.4.3" - - [deps.CUDA.extensions] - ChainRulesCoreExt = "ChainRulesCore" - EnzymeCoreExt = "EnzymeCore" - SpecialFunctionsExt = "SpecialFunctions" - - [deps.CUDA.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" - -[[deps.CUDA_Driver_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "97df9d4d6be8ac6270cb8fd3b8fc413690820cbd" -uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" -version = "0.9.1+1" - -[[deps.CUDA_Runtime_Discovery]] -deps = ["Libdl"] -git-tree-sha1 = "f3b237289a5a77c759b2dd5d4c2ff641d67c4030" -uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" -version = "0.3.4" - -[[deps.CUDA_Runtime_jll]] -deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "afea94249b821dc754a8ca6695d3daed851e1f5a" -uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" -version = "0.14.1+0" - -[[deps.CUDNN_jll]] -deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "cbf7d75f8c58b147bdf6acea2e5bc96cececa6d4" -uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" -version = "9.0.0+1" - -[[deps.ChainRules]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.69.0" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.24.0" -weakdeps = ["SparseArrays"] - - [deps.ChainRulesCore.extensions] - ChainRulesCoreSparseArraysExt = "SparseArrays" - -[[deps.Chemfiles]] -deps = ["AtomsBase", "Chemfiles_jll", "DocStringExtensions", "PeriodicTable", "Unitful", "UnitfulAtomic"] -git-tree-sha1 = "82fe5e341c793cb51149d993307da9543824b206" -uuid = "46823bd8-5fb3-5f92-9aa0-96921f3dd015" -version = "0.10.41" - -[[deps.Chemfiles_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f3743181e30d87c23d9c8ebd493b77f43d8f1890" -uuid = "78a364fa-1a3c-552a-b4bb-8fa0f9c1fcca" -version = "0.10.4+0" - -[[deps.CodecZlib]] -deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.5" - -[[deps.ColorSchemes]] -deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" -uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.26.0" - -[[deps.ColorTypes]] -deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" -uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.5" - -[[deps.ColorVectorSpace]] -deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] -git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" -uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" -version = "0.10.0" -weakdeps = ["SpecialFunctions"] - - [deps.ColorVectorSpace.extensions] - SpecialFunctionsExt = "SpecialFunctions" - -[[deps.Colors]] -deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" -uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.11" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[deps.Compat]] -deps = ["TOML", "UUIDs"] -git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.15.0" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.1+0" - -[[deps.CompositionsBase]] -git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" -uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" -version = "0.1.2" -weakdeps = ["InverseFunctions"] - - [deps.CompositionsBase.extensions] - CompositionsBaseInverseFunctionsExt = "InverseFunctions" - -[[deps.ConcurrentUtilities]] -deps = ["Serialization", "Sockets"] -git-tree-sha1 = "ea32b83ca4fefa1768dc84e504cc0a94fb1ab8d1" -uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" -version = "2.4.2" - -[[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98" -uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.6" - - [deps.ConstructionBase.extensions] - ConstructionBaseIntervalSetsExt = "IntervalSets" - ConstructionBaseStaticArraysExt = "StaticArrays" - - [deps.ConstructionBase.weakdeps] - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[[deps.ContextVariablesX]] -deps = ["Compat", "Logging", "UUIDs"] -git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" -uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" -version = "0.1.3" - -[[deps.Crayons]] -git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.1.1" - -[[deps.DataAPI]] -git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.16.0" - -[[deps.DataDeps]] -deps = ["HTTP", "Libdl", "Reexport", "SHA", "Scratch", "p7zip_jll"] -git-tree-sha1 = "8ae085b71c462c2cb1cfedcb10c3c877ec6cf03f" -uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" -version = "0.7.13" - -[[deps.DataFrames]] -deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" -uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.6.1" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.20" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DefineSingletons]] -git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" -uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" -version = "0.1.2" - -[[deps.DelimitedFiles]] -deps = ["Mmap"] -git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" -version = "1.9.1" - -[[deps.DiffResults]] -deps = ["StaticArraysCore"] -git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.15.1" - -[[deps.Distances]] -deps = ["LinearAlgebra", "Statistics", "StatsAPI"] -git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" -uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.11" -weakdeps = ["ChainRulesCore", "SparseArrays"] - - [deps.Distances.extensions] - DistancesChainRulesCoreExt = "ChainRulesCore" - DistancesSparseArraysExt = "SparseArrays" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.ExceptionUnwrapping]] -deps = ["Test"] -git-tree-sha1 = "dcb08a0d93ec0b1cdc4af184b26b591e9695423a" -uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" -version = "0.1.10" - -[[deps.ExprTools]] -git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" -uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.10" - -[[deps.FLoops]] -deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] -git-tree-sha1 = "0a2e5873e9a5f54abb06418d57a8df689336a660" -uuid = "cc61a311-1640-44b5-9fba-1b764f453329" -version = "0.2.2" - -[[deps.FLoopsBase]] -deps = ["ContextVariablesX"] -git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" -uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" -version = "0.1.1" - -[[deps.FileIO]] -deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" -uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.16.3" - -[[deps.FilePathsBase]] -deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] -git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" -uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.21" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.FillArrays]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.11.0" - - [deps.FillArrays.extensions] - FillArraysPDMatsExt = "PDMats" - FillArraysSparseArraysExt = "SparseArrays" - FillArraysStatisticsExt = "Statistics" - - [deps.FillArrays.weakdeps] - PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" - SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[deps.FixedPointNumbers]] -deps = ["Statistics"] -git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" -uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.5" - -[[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a" -uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.16" - - [deps.Flux.extensions] - FluxAMDGPUExt = "AMDGPU" - FluxCUDAExt = "CUDA" - FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] - FluxMetalExt = "Metal" - - [deps.Flux.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - Metal = "dde4c033-4e86-420c-a63e-0dd931031962" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.36" -weakdeps = ["StaticArrays"] - - [deps.ForwardDiff.extensions] - ForwardDiffStaticArraysExt = "StaticArrays" - -[[deps.Functors]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05" -uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.11" - -[[deps.Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[deps.GPUArrays]] -deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "a74c3f1cf56a3dfcdef0605f8cdb7015926aae30" -uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.3.0" - -[[deps.GPUArraysCore]] -deps = ["Adapt"] -git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" -uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.6" - -[[deps.GPUCompiler]] -deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Preferences", "Scratch", "Serialization", "TOML", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "ab29216184312f99ff957b32cd63c2fe9c928b91" -uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.26.7" - -[[deps.GZip]] -deps = ["Libdl", "Zlib_jll"] -git-tree-sha1 = "0085ccd5ec327c077ec5b91a5f937b759810ba62" -uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" -version = "0.6.2" - -[[deps.Glob]] -git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" -uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" -version = "1.3.1" - -[[deps.GraphNeuralNetworks]] -deps = ["Adapt", "ChainRulesCore", "DataStructures", "Flux", "Functors", "Graphs", "KrylovKit", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NearestNeighbors", "Random", "Reexport", "SparseArrays", "Statistics", "StatsBase"] -git-tree-sha1 = "6716650d17bf36a41921c679c4d046ac375d5907" -uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694" -version = "0.6.19" - - [deps.GraphNeuralNetworks.extensions] - GraphNeuralNetworksCUDAExt = "CUDA" - GraphNeuralNetworksSimpleWeightedGraphsExt = "SimpleWeightedGraphs" - - [deps.GraphNeuralNetworks.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" - -[[deps.Graphs]] -deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "ebd18c326fa6cee1efb7da9a3b45cf69da2ed4d9" -uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.11.2" - -[[deps.HDF5]] -deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] -git-tree-sha1 = "e856eef26cf5bf2b0f95f8f4fc37553c72c8641c" -uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" -version = "0.17.2" - - [deps.HDF5.extensions] - MPIExt = "MPI" - - [deps.HDF5.weakdeps] - MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" - -[[deps.HDF5_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] -git-tree-sha1 = "82a471768b513dc39e471540fdadc84ff80ff997" -uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" -version = "1.14.3+3" - -[[deps.HTTP]] -deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" -uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.10.8" - -[[deps.Hwloc_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "5e19e1e4fa3e71b774ce746274364aef0234634e" -uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" -version = "2.11.1+0" - -[[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" -uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.14" - -[[deps.ImageBase]] -deps = ["ImageCore", "Reexport"] -git-tree-sha1 = "eb49b82c172811fd2c86759fa0553a2221feb909" -uuid = "c817782e-172a-44cc-b673-b171935fbb9e" -version = "0.1.7" - -[[deps.ImageCore]] -deps = ["ColorVectorSpace", "Colors", "FixedPointNumbers", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "PrecompileTools", "Reexport"] -git-tree-sha1 = "b2a7eaa169c13f5bcae8131a83bc30eff8f71be0" -uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" -version = "0.10.2" - -[[deps.ImageShow]] -deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] -git-tree-sha1 = "3b5344bcdbdc11ad58f3b1956709b5b9345355de" -uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31" -version = "0.3.8" - -[[deps.Inflate]] -git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d" -uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" -version = "0.1.5" - -[[deps.InitialValues]] -git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" -uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" -version = "0.3.1" - -[[deps.InlineStrings]] -git-tree-sha1 = "45521d31238e87ee9f9732561bfee12d4eebd52d" -uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" -version = "1.4.2" - - [deps.InlineStrings.extensions] - ArrowTypesExt = "ArrowTypes" - ParsersExt = "Parsers" - - [deps.InlineStrings.weakdeps] - ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" - Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.InternedStrings]] -deps = ["Random", "Test"] -git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" -uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" -version = "0.7.0" - -[[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "18c59411ece4838b18cd7f537e56cf5e41ce5bfd" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.15" -weakdeps = ["Dates"] - - [deps.InverseFunctions.extensions] - DatesExt = "Dates" - -[[deps.InvertedIndices]] -git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" -uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" -version = "1.3.0" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.2" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.JLD2]] -deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] -git-tree-sha1 = "5fe858cb863e211c6dedc8cce2dc0752d4ab6e2b" -uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.50" - -[[deps.JLLWrappers]] -deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" - -[[deps.JSON3]] -deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] -git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" -uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -version = "1.14.0" - - [deps.JSON3.extensions] - JSON3ArrowExt = ["ArrowTypes"] - - [deps.JSON3.weakdeps] - ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" - -[[deps.JuliaNVTXCallbacks_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "af433a10f3942e882d3c671aacb203e006a5808f" -uuid = "9c1d0b0a-7046-5b2e-a33f-ea22f176ac7e" -version = "0.2.1+0" - -[[deps.JuliaVariables]] -deps = ["MLStyle", "NameResolution"] -git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" -uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" -version = "0.2.4" - -[[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec" -uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.22" - - [deps.KernelAbstractions.extensions] - EnzymeExt = "EnzymeCore" - - [deps.KernelAbstractions.weakdeps] - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - -[[deps.KrylovKit]] -deps = ["ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "Printf", "VectorInterface"] -git-tree-sha1 = "3f3a92bbe8f568b689a7f7bc193f7c717d793751" -uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" -version = "0.7.1" - -[[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c" -uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "8.0.0" -weakdeps = ["BFloat16s"] - - [deps.LLVM.extensions] - BFloat16sExt = "BFloat16s" - -[[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e" -uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.30+0" - -[[deps.LLVMLoopInfo]] -git-tree-sha1 = "2e5c102cfc41f48ae4740c7eca7743cc7e7b75ea" -uuid = "8b046642-f1f6-4319-8d3c-209ddc03c586" -version = "1.0.0" - -[[deps.LaTeXStrings]] -git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" -uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.1" - -[[deps.LazyArtifacts]] -deps = ["Artifacts", "Pkg"] -uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" - -[[deps.LazyModules]] -git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" -uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" -version = "0.3.1" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.4" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibGit2_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] -uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.11.0+1" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.Libiconv_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" -uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" -version = "1.17.0+0" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LogExpFunctions]] -deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.28" - - [deps.LogExpFunctions.extensions] - LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" - LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" - LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - - [deps.LogExpFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.LoggingExtras]] -deps = ["Dates", "Logging"] -git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" -uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "1.0.3" - -[[deps.MAT]] -deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] -git-tree-sha1 = "1d2dd9b186742b0f317f2530ddcbf00eebb18e96" -uuid = "23992714-dd62-5051-b70f-ba57cb901cac" -version = "0.10.7" - -[[deps.MLDatasets]] -deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] -git-tree-sha1 = "55ed5f79697232389d894d05e93633a03e774818" -uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" -version = "0.7.16" - -[[deps.MLStyle]] -git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" -uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.17" - -[[deps.MLUtils]] -deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] -git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" -uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.4.4" - -[[deps.MPICH_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "19d4bd098928a3263693991500d05d74dbdc2004" -uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" -version = "4.2.2+0" - -[[deps.MPIPreferences]] -deps = ["Libdl", "Preferences"] -git-tree-sha1 = "c105fe467859e7f6e9a852cb15cb4301126fac07" -uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" -version = "0.1.11" - -[[deps.MPItrampoline_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "8c35d5420193841b2f367e658540e8d9e0601ed0" -uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" -version = "5.4.0+0" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.13" - -[[deps.MappedArrays]] -git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" -uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" -version = "0.4.2" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MbedTLS]] -deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] -git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" -uuid = "739be429-bea8-5141-9913-cc70e7f3736d" -version = "1.1.9" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" - -[[deps.MicroCollections]] -deps = ["Accessors", "BangBang", "InitialValues"] -git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" -uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.2.0" - -[[deps.MicrosoftMPI_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "f12a29c4400ba812841c6ace3f4efbb6dbb3ba01" -uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" -version = "10.1.4+2" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.2.0" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.MosaicViews]] -deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] -git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" -uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" -version = "0.3.4" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" - -[[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.21" - - [deps.NNlib.extensions] - NNlibAMDGPUExt = "AMDGPU" - NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] - NNlibCUDAExt = "CUDA" - NNlibEnzymeCoreExt = "EnzymeCore" - NNlibFFTWExt = "FFTW" - - [deps.NNlib.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.NPZ]] -deps = ["FileIO", "ZipFile"] -git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" -uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" -version = "0.4.3" - -[[deps.NVTX]] -deps = ["Colors", "JuliaNVTXCallbacks_jll", "Libdl", "NVTX_jll"] -git-tree-sha1 = "53046f0483375e3ed78e49190f1154fa0a4083a1" -uuid = "5da4648a-3479-48b8-97b9-01cb529c0a1f" -version = "0.3.4" - -[[deps.NVTX_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "ce3269ed42816bf18d500c9f63418d4b0d9f5a3b" -uuid = "e98f9f5b-d649-5603-91fd-7774390e6439" -version = "3.1.0+2" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.2" - -[[deps.NameResolution]] -deps = ["PrettyPrint"] -git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" -uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" -version = "0.1.5" - -[[deps.NearestNeighbors]] -deps = ["Distances", "StaticArrays"] -git-tree-sha1 = "91a67b4d73842da90b526011fa85c5c4c9343fe0" -uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.18" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.OffsetArrays]] -git-tree-sha1 = "1a27764e945a152f7ca7efa04de513d473e9542e" -uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -version = "1.14.1" -weakdeps = ["Adapt"] - - [deps.OffsetArrays.extensions] - OffsetArraysAdaptExt = "Adapt" - -[[deps.OneHotArrays]] -deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] -git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" -uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.5" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+4" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+2" - -[[deps.OpenMPI_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762" -uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" -version = "4.1.6+0" - -[[deps.OpenSSL]] -deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] -git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" -uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" -version = "1.4.3" - -[[deps.OpenSSL_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" -uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.14+0" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.Optimisers]] -deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be" -uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.3.3" - -[[deps.OrderedCollections]] -git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.3" - -[[deps.PackageExtensionCompat]] -git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" -uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" -version = "1.0.2" -weakdeps = ["Requires", "TOML"] - -[[deps.PaddedViews]] -deps = ["OffsetArrays"] -git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" -uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" -version = "0.5.12" - -[[deps.Parsers]] -deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.8.1" - -[[deps.PeriodicTable]] -deps = ["Base64", "Unitful"] -git-tree-sha1 = "238aa6298007565529f911b734e18addd56985e1" -uuid = "7b2266bf-644c-5ea3-82d8-af4bbd25a884" -version = "1.2.1" - -[[deps.Pickle]] -deps = ["BFloat16s", "DataStructures", "InternedStrings", "Mmap", "Serialization", "SparseArrays", "StridedViews", "StringEncodings", "ZipFile"] -git-tree-sha1 = "e99da19b86b7e1547b423fc1721b260cfbe83acb" -uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" -version = "0.3.5" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" - -[[deps.PooledArrays]] -deps = ["DataAPI", "Future"] -git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" -uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" -version = "1.4.3" - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.1" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.3" - -[[deps.PrettyPrint]] -git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" -uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" -version = "0.2.0" - -[[deps.PrettyTables]] -deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" -uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.3.2" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.ProgressLogging]] -deps = ["Logging", "SHA", "UUIDs"] -git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" -uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" -version = "0.1.4" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.Random123]] -deps = ["Random", "RandomNumbers"] -git-tree-sha1 = "4743b43e5a9c4a2ede372de7061eed81795b12e7" -uuid = "74087812-796a-5b5d-8853-05524746bad3" -version = "1.7.0" - -[[deps.RandomNumbers]] -deps = ["Random", "Requires"] -git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" -uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" -version = "1.5.3" - -[[deps.RealDot]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" -uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" -version = "0.1.0" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.Scratch]] -deps = ["Dates"] -git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" -uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.2.1" - -[[deps.SentinelArrays]] -deps = ["Dates", "Random"] -git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a" -uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.4.5" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Setfield]] -deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] -git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" -uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" -version = "1.1.1" - -[[deps.SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - -[[deps.ShowCases]] -git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" -uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" -version = "0.1.0" - -[[deps.SimpleBufferStream]] -git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" -uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" -version = "1.1.0" - -[[deps.SimpleTraits]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" -uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" -version = "0.9.4" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.2.1" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" - -[[deps.SparseInverseSubset]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" -uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" -version = "0.1.2" - -[[deps.SpecialFunctions]] -deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.4.0" -weakdeps = ["ChainRulesCore"] - - [deps.SpecialFunctions.extensions] - SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - -[[deps.SplittablesBase]] -deps = ["Setfield", "Test"] -git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" -uuid = "171d559e-b47b-412a-8079-5efa626c420e" -version = "0.1.15" - -[[deps.StackViews]] -deps = ["OffsetArrays"] -git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" -uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" -version = "0.1.1" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.7" -weakdeps = ["ChainRulesCore", "Statistics"] - - [deps.StaticArrays.extensions] - StaticArraysChainRulesCoreExt = "ChainRulesCore" - StaticArraysStatisticsExt = "Statistics" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.3" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" - -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.7.0" - -[[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.3" - -[[deps.StridedViews]] -deps = ["LinearAlgebra", "PackageExtensionCompat"] -git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" -uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" -version = "0.2.2" -weakdeps = ["CUDA"] - - [deps.StridedViews.extensions] - StridedViewsCUDAExt = "CUDA" - -[[deps.StringEncodings]] -deps = ["Libiconv_jll"] -git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb" -uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" -version = "0.3.7" - -[[deps.StringManipulation]] -deps = ["PrecompileTools"] -git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" -uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" -version = "0.3.4" - -[[deps.StructArrays]] -deps = ["ConstructionBase", "DataAPI", "Tables"] -git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" -uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.18" -weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] - - [deps.StructArrays.extensions] - StructArraysAdaptExt = "Adapt" - StructArraysGPUArraysCoreExt = "GPUArraysCore" - StructArraysSparseArraysExt = "SparseArrays" - StructArraysStaticArraysExt = "StaticArrays" - -[[deps.StructTypes]] -deps = ["Dates", "UUIDs"] -git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" -uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" -version = "1.10.0" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.12.0" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.TensorCore]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" -uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" -version = "0.1.1" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.TimerOutputs]] -deps = ["ExprTools", "Printf"] -git-tree-sha1 = "5a13ae8a41237cff5ecf34f73eb1b8f42fff6531" -uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.24" - -[[deps.TranscodingStreams]] -git-tree-sha1 = "96612ac5365777520c3c5396314c8cf7408f436a" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.11.1" -weakdeps = ["Random", "Test"] - - [deps.TranscodingStreams.extensions] - TestExt = ["Test", "Random"] - -[[deps.Transducers]] -deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] -git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" -uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.82" - - [deps.Transducers.extensions] - TransducersBlockArraysExt = "BlockArrays" - TransducersDataFramesExt = "DataFrames" - TransducersLazyArraysExt = "LazyArrays" - TransducersOnlineStatsBaseExt = "OnlineStatsBase" - TransducersReferenceablesExt = "Referenceables" - - [deps.Transducers.weakdeps] - BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" - OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" - Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" - -[[deps.URIs]] -git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" -uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.5.1" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.Unitful]] -deps = ["Dates", "LinearAlgebra", "Random"] -git-tree-sha1 = "d95fe458f26209c66a187b1114df96fd70839efd" -uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" -version = "1.21.0" -weakdeps = ["ConstructionBase", "InverseFunctions"] - - [deps.Unitful.extensions] - ConstructionBaseUnitfulExt = "ConstructionBase" - InverseFunctionsUnitfulExt = "InverseFunctions" - -[[deps.UnitfulAtomic]] -deps = ["Unitful"] -git-tree-sha1 = "903be579194534af1c4b4778d1ace676ca042238" -uuid = "a7773ee8-282e-5fa2-be4e-bd808c38a91a" -version = "1.0.0" - -[[deps.UnsafeAtomics]] -git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" -uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" -version = "0.2.1" - -[[deps.UnsafeAtomicsLLVM]] -deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5" -uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.5" - -[[deps.VectorInterface]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "7aff7d62bffad9bba9928eb6ab55226b32a351eb" -uuid = "409d34a3-91d5-4945-b6ec-7529ddf182d8" -version = "0.4.6" - -[[deps.WeakRefStrings]] -deps = ["DataAPI", "InlineStrings", "Parsers"] -git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" -uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" -version = "1.4.2" - -[[deps.WorkerUtilities]] -git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" -uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" -version = "1.6.1" - -[[deps.ZipFile]] -deps = ["Libdl", "Printf", "Zlib_jll"] -git-tree-sha1 = "f492b7fe1698e623024e873244f10d89c95c340a" -uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" -version = "0.10.1" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+1" - -[[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" -uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.70" - - [deps.Zygote.extensions] - ZygoteColorsExt = "Colors" - ZygoteDistancesExt = "Distances" - ZygoteTrackerExt = "Tracker" - - [deps.Zygote.weakdeps] - Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" - Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - -[[deps.ZygoteRules]] -deps = ["ChainRulesCore", "MacroTools"] -git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.5" - -[[deps.cuDNN]] -deps = ["CEnum", "CUDA", "CUDA_Runtime_Discovery", "CUDNN_jll"] -git-tree-sha1 = "4909e87d6d62c29a897d54d9001c63932e41cb0e" -uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" -version = "1.3.2" - -[[deps.libaec_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "46bf7be2917b59b761247be3f317ddf75e50e997" -uuid = "477f73a3-ac25-53e9-8cc3-50b2fa2566f0" -version = "1.1.2+0" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+2" -""" - -# ╔═╡ Cell order: -# ╟─69d00ec8-da47-11ee-1bba-13a14e8a6db2 -# ╟─ef8406e4-117a-4cc6-9fa5-5028695b1a4f -# ╠═b8df1800-c69d-4e18-8a0a-097381b62a4c -# ╟─2544d468-1430-4986-88a9-be4df2a7cf27 -# ╠═f2dbc66d-b8b7-46ae-ad5b-cbba1af86467 -# ╟─d9e4722d-6f02-4d41-955c-8bb3e411e404 -# ╠═bb36237a-5545-47d0-a873-7ddff3efe8ba -# ╟─d4732340-9179-4ada-b82e-a04291d745c2 -# ╟─ec088a59-2fc2-426a-a406-f8f8d6784128 -# ╠═5ea98df9-4920-4c94-9472-3ef475af89fd -# ╟─cfda2cf4-d08b-4f46-bd39-02ae3ed53369 -# ╠═2eedd408-67ee-47b2-be6f-2caec94e95b5 -# ╟─76780020-406d-4803-9af0-d928e54fc18c -# ╠═0a1e07b0-a4f3-4a4b-bcd1-7fe200967cf8 -# ╠═cc2ebdcf-72de-4a3b-af46-5bddab6689cc -# ╠═d64be72e-8c1f-4551-b4f2-28c8b78466c0 -# ╠═483f17ba-871c-4769-88bd-8ec781d1909d -# ╟─b4a3059a-db7d-47f1-9ae5-b8c3d896c5e5 -# ╟─cb4eed19-2658-411d-886c-e0c9c2b44219 -# ╟─00000000-0000-0000-0000-000000000001 -# ╟─00000000-0000-0000-0000-000000000002