From aec3bf285882254ebe2a644f0808e6585ba89435 Mon Sep 17 00:00:00 2001 From: RainerHeintzmann Date: Wed, 13 Dec 2023 18:00:49 +0100 Subject: [PATCH 1/6] Added support for CuArray via Adapt.jl --- Project.toml | 11 ++++++++++- ext/CUDASupportExt.jl | 26 ++++++++++++++++++++++++++ test/runtests.jl | 22 +++++++++++++++++++++- 3 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 ext/CUDASupportExt.jl diff --git a/Project.toml b/Project.toml index 8f5990b..af5e5f1 100644 --- a/Project.toml +++ b/Project.toml @@ -3,8 +3,17 @@ uuid = "1277b4bf-5013-50f5-be3d-901d8477a67a" repo = "https://github.com/JuliaArrays/ShiftedArrays.jl.git" version = "2.0.0" +[weakdeps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" + +[extensions] +CUDASupportExt = ["CUDA", "Adapt"] + [compat] -julia = "1" +CUDA = "5.1.1" +Adapt = "3.7.2" +julia = "1.9" [extras] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/ext/CUDASupportExt.jl b/ext/CUDASupportExt.jl new file mode 100644 index 0000000..61fad24 --- /dev/null +++ b/ext/CUDASupportExt.jl @@ -0,0 +1,26 @@ +module CUDASupportExt +using CUDA +using Adapt +using ShiftedArrays +using Base # to allow displaying such arrays without causing the single indexing CUDA error + +Adapt.adapt_structure(to, x::CircShiftedArray{T, D, CT}) where {T,D,CT<:CuArray} = CircShiftedArray(adapt(to, parent(x)), shifts(x)); +function Base.Broadcast.BroadcastStyle(::Type{T}) where (T<: CircShiftedArray{<:Any,<:Any,<:CuArray}) + CUDA.CuArrayStyle{ndims(T)}() +end + +Adapt.adapt_structure(to, x::ShiftedArray{T, M, N, <:CuArray}) where {T,M,N} = +# lets do this for the ShiftedArray type +ShiftedArray(adapt(to, parent(x)), shifts(x); default=ShiftedArrays.default(x)) +function Base.Broadcast.BroadcastStyle(::Type{T}) where (T<: ShiftedArray{<:Any,<:Any,<:Any,<:CuArray}) + CUDA.CuArrayStyle{ndims(T)}() +end + +function Base.show(io::IO, mm::MIME"text/plain", cs::CircShiftedArray) + CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs) +end + +function Base.show(io::IO, mm::MIME"text/plain", cs::ShiftedArray) + CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs) +end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 4415ed5..ff0a3d5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,22 @@ using ShiftedArrays, Test using AbstractFFTs +use_cuda = false; # set this to true to test ShiftedArrays for the CuArray datatype +if (use_cuda) + using CUDA + CUDA.allowscalar(true); # needed for some of the comparisons +end + +function opt_convert(v) + if (use_cuda) + CuArray(v) + else + v + end +end @testset "ShiftedVector" begin v = [1, 3, 5, 4] + v = opt_convert(v); @test all(v .== ShiftedVector(v)) sv = ShiftedVector(v, -1) @test isequal(sv, ShiftedVector(v, (-1,))) @@ -28,6 +42,7 @@ end @testset "ShiftedArray" begin v = reshape(1:16, 4, 4) + v = opt_convert(v); @test all(v .== ShiftedArray(v)) sv = ShiftedArray(v, (-2, 0)) @test length(sv) == 16 @@ -64,6 +79,7 @@ end @testset "padded_tuple" begin v = rand(2, 2) + v = opt_convert(v); @test (1, 0) == @inferred ShiftedArrays.padded_tuple(v, 1) @test (0, 0) == @inferred ShiftedArrays.padded_tuple(v, ()) @test (3, 0) == @inferred ShiftedArrays.padded_tuple(v, (3,)) @@ -82,11 +98,12 @@ end @testset "CircShiftedVector" begin v = [1, 3, 5, 4] + v = opt_convert(v); @test all(v .== CircShiftedVector(v)) sv = CircShiftedVector(v, -1) @test isequal(sv, CircShiftedVector(v, (-1,))) @test length(sv) == 4 - @test all(sv .== [3, 5, 4, 1]) + @test all(sv .== opt_convert([3, 5, 4, 1])) diff = v .- sv @test diff == [-2, -2, 1, 3] @test shifts(sv) == (3,) @@ -110,6 +127,7 @@ end @testset "CircShiftedArray" begin v = reshape(1:16, 4, 4) + v = opt_convert(v); @test all(v .== CircShiftedArray(v)) sv = CircShiftedArray(v, (-2, 0)) @test length(sv) == 16 @@ -130,6 +148,7 @@ end @testset "circshift" begin v = reshape(1:16, 4, 4) + v = opt_convert(v); @test all(circshift(v, (1, -1)) .== ShiftedArrays.circshift(v, (1, -1))) @test all(circshift(v, (1,)) .== ShiftedArrays.circshift(v, (1,))) @test all(circshift(v, 3) .== ShiftedArrays.circshift(v, 3)) @@ -163,6 +182,7 @@ end @testset "laglead" begin v = [1, 3, 8, 12] + v = opt_convert(v); diff = v .- ShiftedArrays.lag(v) @test isequal(diff, [missing, 2, 5, 4]) From dee91443c1a1f046ee106b35ee29b73c65c7daac Mon Sep 17 00:00:00 2001 From: Rainer Heintzmann Date: Thu, 21 Dec 2023 16:46:16 +0100 Subject: [PATCH 2/6] Update ext/CUDASupportExt.jl Co-authored-by: Pietro Vertechi --- ext/CUDASupportExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/CUDASupportExt.jl b/ext/CUDASupportExt.jl index 61fad24..1acf132 100644 --- a/ext/CUDASupportExt.jl +++ b/ext/CUDASupportExt.jl @@ -4,7 +4,7 @@ using Adapt using ShiftedArrays using Base # to allow displaying such arrays without causing the single indexing CUDA error -Adapt.adapt_structure(to, x::CircShiftedArray{T, D, CT}) where {T,D,CT<:CuArray} = CircShiftedArray(adapt(to, parent(x)), shifts(x)); +Adapt.adapt_structure(to, x::CircShiftedArray{T, D}) where {T, D} = CircShiftedArray(adapt(to, parent(x)), shifts(x)); function Base.Broadcast.BroadcastStyle(::Type{T}) where (T<: CircShiftedArray{<:Any,<:Any,<:CuArray}) CUDA.CuArrayStyle{ndims(T)}() end From 2f751c023079d9ab5085444e427cba9389fd6d54 Mon Sep 17 00:00:00 2001 From: Rainer Heintzmann Date: Thu, 21 Dec 2023 16:46:24 +0100 Subject: [PATCH 3/6] Update ext/CUDASupportExt.jl Co-authored-by: Pietro Vertechi --- ext/CUDASupportExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/CUDASupportExt.jl b/ext/CUDASupportExt.jl index 1acf132..98e2cb0 100644 --- a/ext/CUDASupportExt.jl +++ b/ext/CUDASupportExt.jl @@ -9,7 +9,7 @@ function Base.Broadcast.BroadcastStyle(::Type{T}) where (T<: CircShiftedArray{< CUDA.CuArrayStyle{ndims(T)}() end -Adapt.adapt_structure(to, x::ShiftedArray{T, M, N, <:CuArray}) where {T,M,N} = +Adapt.adapt_structure(to, x::ShiftedArray{T, M, N}) where {T, M, N} = # lets do this for the ShiftedArray type ShiftedArray(adapt(to, parent(x)), shifts(x); default=ShiftedArrays.default(x)) function Base.Broadcast.BroadcastStyle(::Type{T}) where (T<: ShiftedArray{<:Any,<:Any,<:Any,<:CuArray}) From 6fc77cdf407453016a53bb226bafb47fda0b6bf7 Mon Sep 17 00:00:00 2001 From: Rainer Heintzmann Date: Thu, 21 Dec 2023 16:46:35 +0100 Subject: [PATCH 4/6] Update ext/CUDASupportExt.jl Co-authored-by: Pietro Vertechi --- ext/CUDASupportExt.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ext/CUDASupportExt.jl b/ext/CUDASupportExt.jl index 98e2cb0..2829db2 100644 --- a/ext/CUDASupportExt.jl +++ b/ext/CUDASupportExt.jl @@ -5,9 +5,8 @@ using ShiftedArrays using Base # to allow displaying such arrays without causing the single indexing CUDA error Adapt.adapt_structure(to, x::CircShiftedArray{T, D}) where {T, D} = CircShiftedArray(adapt(to, parent(x)), shifts(x)); -function Base.Broadcast.BroadcastStyle(::Type{T}) where (T<: CircShiftedArray{<:Any,<:Any,<:CuArray}) - CUDA.CuArrayStyle{ndims(T)}() -end +parent_type(::Type{CircShiftedArray{T, N, S}) where {T, N, S} = S +Base.Broadcast.BroadcastStyle(::Type{T}) where (T<:CircShiftedArray} = Base.Broadcast.BroadcastStyle(parent_type(T)) Adapt.adapt_structure(to, x::ShiftedArray{T, M, N}) where {T, M, N} = # lets do this for the ShiftedArray type From 52a28d14925a49c3e84f75ce3b934123f74ef936 Mon Sep 17 00:00:00 2001 From: RainerHeintzmann Date: Thu, 21 Dec 2023 17:24:29 +0100 Subject: [PATCH 5/6] bug fixes --- ext/CUDASupportExt.jl | 19 +++++++++---------- test/runtests.jl | 2 +- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/ext/CUDASupportExt.jl b/ext/CUDASupportExt.jl index 2829db2..0c31823 100644 --- a/ext/CUDASupportExt.jl +++ b/ext/CUDASupportExt.jl @@ -5,21 +5,20 @@ using ShiftedArrays using Base # to allow displaying such arrays without causing the single indexing CUDA error Adapt.adapt_structure(to, x::CircShiftedArray{T, D}) where {T, D} = CircShiftedArray(adapt(to, parent(x)), shifts(x)); -parent_type(::Type{CircShiftedArray{T, N, S}) where {T, N, S} = S -Base.Broadcast.BroadcastStyle(::Type{T}) where (T<:CircShiftedArray} = Base.Broadcast.BroadcastStyle(parent_type(T)) +parent_type(::Type{CircShiftedArray{T, N, S}}) where {T, N, S} = S +Base.Broadcast.BroadcastStyle(::Type{T}) where {T<:CircShiftedArray} = Base.Broadcast.BroadcastStyle(parent_type(T)) -Adapt.adapt_structure(to, x::ShiftedArray{T, M, N}) where {T, M, N} = # lets do this for the ShiftedArray type -ShiftedArray(adapt(to, parent(x)), shifts(x); default=ShiftedArrays.default(x)) +Adapt.adapt_structure(to, x::ShiftedArray{T, M, N}) where {T, M, N} = ShiftedArray(adapt(to, parent(x)), shifts(x); default=ShiftedArrays.default(x)); function Base.Broadcast.BroadcastStyle(::Type{T}) where (T<: ShiftedArray{<:Any,<:Any,<:Any,<:CuArray}) CUDA.CuArrayStyle{ndims(T)}() end -function Base.show(io::IO, mm::MIME"text/plain", cs::CircShiftedArray) - CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs) -end +# function Base.show(io::IO, mm::MIME"text/plain", cs::CircShiftedArray) +# CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs) +# end -function Base.show(io::IO, mm::MIME"text/plain", cs::ShiftedArray) - CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs) -end +# function Base.show(io::IO, mm::MIME"text/plain", cs::ShiftedArray) +# CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs) +# end end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index ff0a3d5..bca6452 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using ShiftedArrays, Test using AbstractFFTs -use_cuda = false; # set this to true to test ShiftedArrays for the CuArray datatype +use_cuda = true; # set this to true to test ShiftedArrays for the CuArray datatype if (use_cuda) using CUDA CUDA.allowscalar(true); # needed for some of the comparisons From ec23e5a2a6b277d30bd76d4df6f0867b65ff0460 Mon Sep 17 00:00:00 2001 From: RainerHeintzmann Date: Thu, 21 Dec 2023 17:25:34 +0100 Subject: [PATCH 6/6] put the use_cuda back to false as default --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index bca6452..ff0a3d5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using ShiftedArrays, Test using AbstractFFTs -use_cuda = true; # set this to true to test ShiftedArrays for the CuArray datatype +use_cuda = false; # set this to true to test ShiftedArrays for the CuArray datatype if (use_cuda) using CUDA CUDA.allowscalar(true); # needed for some of the comparisons