From 52a28d14925a49c3e84f75ce3b934123f74ef936 Mon Sep 17 00:00:00 2001 From: RainerHeintzmann Date: Thu, 21 Dec 2023 17:24:29 +0100 Subject: [PATCH] bug fixes --- ext/CUDASupportExt.jl | 19 +++++++++---------- test/runtests.jl | 2 +- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/ext/CUDASupportExt.jl b/ext/CUDASupportExt.jl index 2829db2..0c31823 100644 --- a/ext/CUDASupportExt.jl +++ b/ext/CUDASupportExt.jl @@ -5,21 +5,20 @@ using ShiftedArrays using Base # to allow displaying such arrays without causing the single indexing CUDA error Adapt.adapt_structure(to, x::CircShiftedArray{T, D}) where {T, D} = CircShiftedArray(adapt(to, parent(x)), shifts(x)); -parent_type(::Type{CircShiftedArray{T, N, S}) where {T, N, S} = S -Base.Broadcast.BroadcastStyle(::Type{T}) where (T<:CircShiftedArray} = Base.Broadcast.BroadcastStyle(parent_type(T)) +parent_type(::Type{CircShiftedArray{T, N, S}}) where {T, N, S} = S +Base.Broadcast.BroadcastStyle(::Type{T}) where {T<:CircShiftedArray} = Base.Broadcast.BroadcastStyle(parent_type(T)) -Adapt.adapt_structure(to, x::ShiftedArray{T, M, N}) where {T, M, N} = # lets do this for the ShiftedArray type -ShiftedArray(adapt(to, parent(x)), shifts(x); default=ShiftedArrays.default(x)) +Adapt.adapt_structure(to, x::ShiftedArray{T, M, N}) where {T, M, N} = ShiftedArray(adapt(to, parent(x)), shifts(x); default=ShiftedArrays.default(x)); function Base.Broadcast.BroadcastStyle(::Type{T}) where (T<: ShiftedArray{<:Any,<:Any,<:Any,<:CuArray}) CUDA.CuArrayStyle{ndims(T)}() end -function Base.show(io::IO, mm::MIME"text/plain", cs::CircShiftedArray) - CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs) -end +# function Base.show(io::IO, mm::MIME"text/plain", cs::CircShiftedArray) +# CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs) +# end -function Base.show(io::IO, mm::MIME"text/plain", cs::ShiftedArray) - CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs) -end +# function Base.show(io::IO, mm::MIME"text/plain", cs::ShiftedArray) +# CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs) +# end end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index ff0a3d5..bca6452 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using ShiftedArrays, Test using AbstractFFTs -use_cuda = false; # set this to true to test ShiftedArrays for the CuArray datatype +use_cuda = true; # set this to true to test ShiftedArrays for the CuArray datatype if (use_cuda) using CUDA CUDA.allowscalar(true); # needed for some of the comparisons