Skip to content

Commit

Permalink
Update ext/CUDASupportExt.jl
Browse files Browse the repository at this point in the history
Co-authored-by: Pietro Vertechi <pietro.vertechi@protonmail.com>
  • Loading branch information
RainerHeintzmann and piever authored Dec 21, 2023
1 parent 2f751c0 commit 6fc77cd
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions ext/CUDASupportExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ using ShiftedArrays
using Base # to allow displaying such arrays without causing the single indexing CUDA error

Adapt.adapt_structure(to, x::CircShiftedArray{T, D}) where {T, D} = CircShiftedArray(adapt(to, parent(x)), shifts(x));
function Base.Broadcast.BroadcastStyle(::Type{T}) where (T<: CircShiftedArray{<:Any,<:Any,<:CuArray})
CUDA.CuArrayStyle{ndims(T)}()
end
parent_type(::Type{CircShiftedArray{T, N, S}) where {T, N, S} = S
Base.Broadcast.BroadcastStyle(::Type{T}) where (T<:CircShiftedArray} = Base.Broadcast.BroadcastStyle(parent_type(T))

Adapt.adapt_structure(to, x::ShiftedArray{T, M, N}) where {T, M, N} =
# lets do this for the ShiftedArray type
Expand Down

0 comments on commit 6fc77cd

Please sign in to comment.