Skip to content

Commit

Permalink
remove redundant calls to mtlconvert
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich committed Aug 23, 2024
1 parent 28576b3 commit c83e37b
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ macro metal(ex...)
$kernel_tt = Tuple{map(Core.Typeof, $kernel_args)...}
$kernel = $mtlfunction($kernel_f, $kernel_tt; $(compiler_kwargs...))
if $launch
$kernel($(var_exprs...); $(call_kwargs...))
$kernel($kernel_args...; $(call_kwargs...))
end
$kernel
end
Expand Down Expand Up @@ -227,7 +227,7 @@ const _kernel_instances = Dict{UInt, Any}()
else
# everything else is passed by reference, in an argument buffer
append!(ex.args, (quote
buf = encode_argument!(kernel, mtlconvert($(argex), cce))
buf = encode_argument!(kernel, $argex)
set_buffer!(cce, buf, 0, $idx)
push!(bufs, buf)
end).args)
Expand Down Expand Up @@ -259,8 +259,8 @@ end
return argument_buffer
end

@autoreleasepool function (kernel::HostKernel)(args...; groups=1, threads=1,
queue=global_queue(device()))
@autoreleasepool function (kernel::HostKernel{F,TT})(args...; groups=1, threads=1,
queue=global_queue(device())) where {F,TT}
groups = MTLSize(groups)
threads = MTLSize(threads)
(groups.width>0 && groups.height>0 && groups.depth>0) ||
Expand All @@ -274,9 +274,24 @@ end
cmdbuf = MTLCommandBuffer(queue)
cmdbuf.label = "MTLCommandBuffer($(nameof(kernel.f)))"
cce = MTLComputeCommandEncoder(cmdbuf)

argexprs = [:(kernel.f)]
for i in 1:length(args)
if args[i] != TT.parameters[i]
# if there's a type mismatch between the argument and the compiled kernel,
# assume we still have to call `cudaconvert`. this is a bit of a hack, but
# it's generally true that `cudaconvert` will change the type of the argument,
# and it makes it possible to pass both original and converted arguments
# to compiled kernel objects, avoiding the need for multiple conversions.
push!(argexprs, :(mtlconvert(args[$i])))
else
push!(argexprs, :(args[$i]))
end
end

argument_buffers = try
MTL.set_function!(cce, kernel.pipeline)
bufs = encode_arguments!(cce, kernel, kernel.f, args...)
bufs = encode_arguments!(cce, kernel, kernel.f, argexprs...)
MTL.append_current_function!(cce, groups, threads)
bufs
finally
Expand Down

0 comments on commit c83e37b

Please sign in to comment.