-
Notifications
You must be signed in to change notification settings - Fork 1
/
AddMM.swift
62 lines (49 loc) · 2.28 KB
/
AddMM.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import Foundation
import CoreML
// it'll be loaded by CoreML engine, don't change the objc class name
@objc(dneprDroid_addmm)
final class AddMM: NSObject, MLCustomLayer {
let pipelineState: MTLComputePipelineState
init(parameters: [String : Any]) throws {
guard let device = MTLCreateSystemDefaultDevice() else {
throw ErrorCommon.metalNotSupported
}
let library = try device.moduleLibrary()
guard
let function = library.makeFunction(name: "dneprDroid::addmm")
else { throw ErrorCommon.shaderNotFound }
pipelineState = try device.makeComputePipelineState(function: function)
super.init()
}
func setWeightData(_ weights: [Data]) throws { }
func outputShapes(forInputShapes inputShapes: [[NSNumber]]) throws -> [[NSNumber]] {
var outShape = inputShapes[1].map { $0.intValue }
outShape[outShape.count - 2] = 1
return [outShape.map { NSNumber(value: $0) }]
}
func encode(commandBuffer: any MTLCommandBuffer, inputs: [any MTLTexture], outputs: [any MTLTexture]) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw ErrorCommon.encoderInvalid
}
let p0 = inputs[0]
let p1 = inputs[1]
let output = outputs[0]
encoder.setTexture(p0, index: 0)
encoder.setTexture(p1, index: 1)
encoder.setTexture(output, index: 2)
let w = pipelineState.threadExecutionWidth
let h = pipelineState.maxTotalThreadsPerThreadgroup / w
let threadGroupSize = MTLSize(width: w, height: h, depth: 1)
let threadGroups = MTLSize(
width: (output.width + threadGroupSize.width - 1) / threadGroupSize.width,
height: (output.height + threadGroupSize.height - 1) / threadGroupSize.height,
depth: (output.arrayLength + threadGroupSize.depth - 1) / threadGroupSize.depth
)
encoder.setComputePipelineState(pipelineState)
encoder.dispatchThreadgroups(threadGroups, threadsPerThreadgroup: threadGroupSize)
encoder.endEncoding()
}
func evaluate(inputs: [MLMultiArray], outputs: [MLMultiArray]) throws {
throw ErrorCommon.cpuNotImplemented
}
}