Skip to content

Commit

Permalink
movielens
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 25, 2024
1 parent e95f70d commit 71dfa03
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Scratch = "6c6a2e73-6563-6170-7368-637461726353"

[compat]
DLPack = "0.3.0"
GNNGraphs = "1.3.1"
GNNGraphs = "1.4.1"
PythonCall = "0.9.23"
Scratch = "1.2.1"
julia = "1.10"
Expand Down
2 changes: 0 additions & 2 deletions src/graph_conversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,8 @@ function to_gnnheterograph(data)
jt = to_edge_t(t)
for k in data[t].keys()
jk = Symbol(k)
@show jt jk
jk == :edge_index && continue
py_x = data[t][k]
@show pytype(py_x)
x = try_from_dlpack(py_x)
last_dim = size(x, ndims(x))
if last_dim != num_edges[jt] || jk == :edge_label_index
Expand Down
16 changes: 12 additions & 4 deletions test/datasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,15 @@ end
@test length(src) == length(dst) == 4278
end

# @testitem "AMiner" setup=[TestModule] begin
# using .TestModule
# dataset = load_dataset("AMiner")
# end
@testitem "MovieLens100K" setup=[TestModule] begin
using .TestModule
dataset = load_dataset("MovieLens100K")
@test length(dataset) == 1
g = dataset[1]
@test g.num_nodes == Dict(:user => 943, :movie => 1682)
@test g.num_edges == Dict((:movie, :rated_by, :user) => 80000, (:user, :rates, :movie) => 80000)
@test g.gdata.edge_label[(:user, :rates, :movie)] isa Vector{Float32}
@test length(g.gdata.edge_label[(:user, :rates, :movie)]) == 20000
@test g.gdata.edge_label_index[(:user, :rates, :movie)] isa Matrix{Int}
@test size(g.gdata.edge_label_index[(:user, :rates, :movie)]) == (20000, 2)
end

0 comments on commit 71dfa03

Please sign in to comment.