Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
RainerHeintzmann committed Dec 21, 2023
1 parent 6fc77cd commit 52a28d1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
19 changes: 9 additions & 10 deletions ext/CUDASupportExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,20 @@ using ShiftedArrays
using Base # to allow displaying such arrays without causing the single indexing CUDA error

Adapt.adapt_structure(to, x::CircShiftedArray{T, D}) where {T, D} = CircShiftedArray(adapt(to, parent(x)), shifts(x));
parent_type(::Type{CircShiftedArray{T, N, S}) where {T, N, S} = S
Base.Broadcast.BroadcastStyle(::Type{T}) where (T<:CircShiftedArray} = Base.Broadcast.BroadcastStyle(parent_type(T))
parent_type(::Type{CircShiftedArray{T, N, S}}) where {T, N, S} = S
Base.Broadcast.BroadcastStyle(::Type{T}) where {T<:CircShiftedArray} = Base.Broadcast.BroadcastStyle(parent_type(T))

Adapt.adapt_structure(to, x::ShiftedArray{T, M, N}) where {T, M, N} =
# lets do this for the ShiftedArray type
ShiftedArray(adapt(to, parent(x)), shifts(x); default=ShiftedArrays.default(x))
Adapt.adapt_structure(to, x::ShiftedArray{T, M, N}) where {T, M, N} = ShiftedArray(adapt(to, parent(x)), shifts(x); default=ShiftedArrays.default(x));
function Base.Broadcast.BroadcastStyle(::Type{T}) where (T<: ShiftedArray{<:Any,<:Any,<:Any,<:CuArray})
CUDA.CuArrayStyle{ndims(T)}()
end

function Base.show(io::IO, mm::MIME"text/plain", cs::CircShiftedArray)
CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs)
end
# function Base.show(io::IO, mm::MIME"text/plain", cs::CircShiftedArray)
# CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs)
# end

function Base.show(io::IO, mm::MIME"text/plain", cs::ShiftedArray)
CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs)
end
# function Base.show(io::IO, mm::MIME"text/plain", cs::ShiftedArray)
# CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs)
# end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using ShiftedArrays, Test
using AbstractFFTs
use_cuda = false; # set this to true to test ShiftedArrays for the CuArray datatype
use_cuda = true; # set this to true to test ShiftedArrays for the CuArray datatype
if (use_cuda)
using CUDA
CUDA.allowscalar(true); # needed for some of the comparisons
Expand Down

0 comments on commit 52a28d1

Please sign in to comment.