From c83e37bf5efd31a153adea0ee065a93e5e98e991 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Fri, 23 Aug 2024 16:52:46 +0200 Subject: [PATCH] remove redundant calls to mtlconvert --- src/compiler/execution.jl | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index 1fd4f3b9d..3f381f9de 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -84,7 +84,7 @@ macro metal(ex...) $kernel_tt = Tuple{map(Core.Typeof, $kernel_args)...} $kernel = $mtlfunction($kernel_f, $kernel_tt; $(compiler_kwargs...)) if $launch - $kernel($(var_exprs...); $(call_kwargs...)) + $kernel($kernel_args...; $(call_kwargs...)) end $kernel end @@ -227,7 +227,7 @@ const _kernel_instances = Dict{UInt, Any}() else # everything else is passed by reference, in an argument buffer append!(ex.args, (quote - buf = encode_argument!(kernel, mtlconvert($(argex), cce)) + buf = encode_argument!(kernel, $argex) set_buffer!(cce, buf, 0, $idx) push!(bufs, buf) end).args) @@ -259,8 +259,8 @@ end return argument_buffer end -@autoreleasepool function (kernel::HostKernel)(args...; groups=1, threads=1, - queue=global_queue(device())) +@autoreleasepool function (kernel::HostKernel{F,TT})(args...; groups=1, threads=1, + queue=global_queue(device())) where {F,TT} groups = MTLSize(groups) threads = MTLSize(threads) (groups.width>0 && groups.height>0 && groups.depth>0) || @@ -274,9 +274,24 @@ end cmdbuf = MTLCommandBuffer(queue) cmdbuf.label = "MTLCommandBuffer($(nameof(kernel.f)))" cce = MTLComputeCommandEncoder(cmdbuf) + + argexprs = [:(kernel.f)] + for i in 1:length(args) + if args[i] != TT.parameters[i] + # if there's a type mismatch between the argument and the compiled kernel, + # assume we still have to call `cudaconvert`. this is a bit of a hack, but + # it's generally true that `cudaconvert` will change the type of the argument, + # and it makes it possible to pass both original and converted arguments + # to compiled kernel objects, avoiding the need for multiple conversions. + push!(argexprs, :(mtlconvert(args[$i]))) + else + push!(argexprs, :(args[$i])) + end + end + argument_buffers = try MTL.set_function!(cce, kernel.pipeline) - bufs = encode_arguments!(cce, kernel, kernel.f, args...) + bufs = encode_arguments!(cce, kernel, kernel.f, argexprs...) MTL.append_current_function!(cce, groups, threads) bufs finally