diff --git a/src/gpuarrays.jl b/src/gpuarrays.jl index 0796de3ee..be9023132 100644 --- a/src/gpuarrays.jl +++ b/src/gpuarrays.jl @@ -14,11 +14,8 @@ struct mtlKernelContext <: AbstractKernelContext end kernel = @metal launch=false f(mtlKernelContext(), args...) # The pipeline state automatically computes occupancy stats - threads_needed = cld(elements, elements_per_thread) - - # Limit the threadgroup size - threads = min(threads_needed, kernel.pipeline_state.maxTotalThreadsPerThreadgroup) - blocks = cld(threads_needed, threads) + threads = min(elements, kernel.pipeline_state.maxTotalThreadsPerThreadgroup) + blocks = cld(elements, threads) return (; threads, blocks) end