Skip to content

Commit

Permalink
[js/webgpu] Support where
Browse files Browse the repository at this point in the history
Case where_broadcast.jsonc is not enabled due to microsoft#17405.
  • Loading branch information
axinging committed Sep 14, 2023
1 parent ad369a1 commit c167ea4
Show file tree
Hide file tree
Showing 9 changed files with 354 additions and 140 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,4 @@ Do not modify directly.*
| Tile | ai.onnx(6-12,13+) | |
| Transpose | ai.onnx(1-12,13+) | need perf optimization |
| Unsqueeze | ai.onnx(1-10,11-12,13+) | |
| Where | ai.onnx(9-15,16+) | |
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {parseSoftmaxAttributes, softmax} from './ops/softmax';
import {parseSplitAttributes, split} from './ops/split';
import {tile} from './ops/tile';
import {parseTransposeAttributes, transpose} from './ops/transpose';
import {where} from './ops/where';
import * as unaryOps from './ops/unary-op';
import {ComputeContext} from './types';

Expand Down Expand Up @@ -108,4 +109,5 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]],
['Tile', [tile]],
['Transpose', [transpose, parseTransposeAttributes]],
['Where', [where]],
]);
175 changes: 36 additions & 139 deletions js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,19 @@

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {BroadcastUtil, ShapeUtil} from '../../util';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';

type BuiltinFunctionName = string;
type BinaryCustomExpression = (expressionA: string, expressionB: string) => string;
type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{
scalar: BinaryCustomExpression;
vector: BinaryCustomExpression;
};

const createBinaryOpProgramShader =
(shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[],
vectorize: boolean, doBroadcast: boolean, funcCall: BinaryFunctionCall, typeA: number, typeB: number,
typeOutput: number, additionalImplementation?: string) => {
import {ShapeUtil} from '../../util';
import {ComputeContext} from '../types';

import {BinaryCustomExpression, BinaryFunctionCall, createOpProgramInfoLoader, fourAssignment, getBroadcastIndexComponent} from './binary-like-util';
import {createBroadcastHelper, inputVariable, outputVariable, ShaderHelper} from './common';

const createOpProgramShader =
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], vectorize: boolean,
doBroadcast: boolean, funcCall: BinaryFunctionCall, typeOutput: number, additionalImplementation?: string) => {
const typeA = inputs[0].dataType;
const typeB = inputs[1].dataType;
const dimsA = inputs[0].dims;
const dimsB = inputs[1].dims;
const outputSize = ShapeUtil.size(dimsOutput);
const vecSize = Math.ceil(outputSize / 4);

Expand All @@ -33,39 +30,18 @@ const createBinaryOpProgramShader =
expressionVector = funcCall.vector;
}

let broadcastImpl = '';
const output = outputVariable('outputData', typeOutput, dimsOutput, 4);
const a = inputVariable('aData', typeA, dimsA, 4);
const b = inputVariable('bData', typeB, dimsB, 4);
if (doBroadcast) {
const calcOffsetImpl = (dims: readonly number[]) => {
const strides = ShapeUtil.computeStrides(dims);
const offsets: string[] = [];
for (let i = dims.length - 1; i >= 0; i--) {
const idx = output.indicesGet('outputIndices', i + dimsOutput.length - dims.length);
offsets.push(`${strides[i]}u * (${idx} % ${dims[i]}u)`);
}
return offsets.length > 0 ? offsets.join('+') : '0u';
};

broadcastImpl = `
fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 {
return ${calcOffsetImpl(dimsA)};
}
fn calcOffsetB(outputIndices: ${output.type.indices}) -> u32 {
return ${calcOffsetImpl(dimsB)};
}
`;
}
const broadcastImpl = doBroadcast ? createBroadcastHelper([a, b], output).broadcastIndicesToOffset() : '';

let assignment: string;
if (vectorize) {
if (doBroadcast) {
assignment = `
let outputIndices = ${output.offsetToIndices('global_idx * 4u')};
let offsetA = calcOffsetA(outputIndices);
let offsetB = calcOffsetB(outputIndices);
let offsetA = broadcastIndicesToOffsetA(outputIndices);
let offsetB = broadcastIndicesToOffsetB(outputIndices);
${
output.setByOffset(
'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))}
Expand All @@ -84,31 +60,12 @@ const createBinaryOpProgramShader =
const expressionB = `bData[indexB${x}][componentB${x}]`;
return `
let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
let offsetA${x} = calcOffsetA(outputIndices${x});
let offsetB${x} = calcOffsetB(outputIndices${x});
let indexA${x} = offsetA${x} / 4u;
let indexB${x} = offsetB${x} / 4u;
let componentA${x} = offsetA${x} % 4u;
let componentB${x} = offsetB${x} % 4u;
${getBroadcastIndexComponent('A', x)}
${getBroadcastIndexComponent('B', x)}
${resStr}[${x}] = ${typeCast}(${expressionScalar(expressionA, expressionB)});
`;
};
if (typeOutput === DataType.bool) {
assignment = `
var data = vec4<u32>(0);
${singleAssignment('data', 0, 'u32')}
${singleAssignment('data', 1, 'u32')}
${singleAssignment('data', 2, 'u32')}
${singleAssignment('data', 3, 'u32')}
outputData[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
} else {
assignment = `
${singleAssignment('outputData[global_idx]', 0)}
${singleAssignment('outputData[global_idx]', 1)}
${singleAssignment('outputData[global_idx]', 2)}
${singleAssignment('outputData[global_idx]', 3)}
`;
}
assignment = fourAssignment(singleAssignment, typeOutput);
}

return `
Expand All @@ -123,91 +80,31 @@ const createBinaryOpProgramShader =
}`;
};

const createBinaryOpProgramInfo =
(metadata: ProgramMetadata, a: TensorView, b: TensorView, funcCall: BinaryFunctionCall,
additionalImplementation?: string, outputDataType: number = a.dataType): ProgramInfo => {
const isBroadcast = !ShapeUtil.areEqual(a.dims, b.dims);
let outputShape = a.dims;
let outputSize = ShapeUtil.size(a.dims);

let vectorize = false;

// TODO: deal with zero-sized tensors (eg. dims=[1,0])

if (isBroadcast) {
const calculatedShape = BroadcastUtil.calcShape(a.dims, b.dims, false);
if (!calculatedShape) {
throw new Error('Can\'t perform binary op on the given tensors');
}
outputShape = calculatedShape;
outputSize = ShapeUtil.size(outputShape);

// check whether vectorize can be enabled
let sharedDimension = 1;
for (let i = 1; i < outputShape.length; i++) {
const dimA = a.dims[a.dims.length - i] ?? 1;
const dimB = b.dims[b.dims.length - i] ?? 1;
if (dimA === dimB) {
sharedDimension *= dimA;
} else {
break;
}
}
if (sharedDimension % 4 === 0) {
vectorize = true;
}
} else {
// element-wise
vectorize = true;
}

return {
...metadata,
getShaderSource: (shaderHelper) => createBinaryOpProgramShader(
shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType,
outputDataType, additionalImplementation),
outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}],
dispatchGroup: () =>
({x: Math.ceil(outputSize / 64 /* workgroup size */ / (vectorize ? 4 : 1) /* vec size */)})
};
};

const createBinaryOpProgramInfoLoader =
(inputs: readonly TensorView[], name: string, funcCall: BinaryFunctionCall, additionalImplementation?: string,
cacheKey?: string, outputDataType?: number): ProgramInfoLoader => {
const metadata:
ProgramMetadata = {name, inputTypes: [GpuDataType.default, GpuDataType.default], cacheHint: cacheKey};
return {
...metadata,
get: () => createBinaryOpProgramInfo(
metadata, inputs[0], inputs[1], funcCall, additionalImplementation, outputDataType)
};
};

export const add = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Add', (a, b) => `${a}+${b}`));
context.compute(createOpProgramInfoLoader(context.inputs, 'Add', (a, b) => `${a}+${b}`, createOpProgramShader));
};

export const div = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Div', (a, b) => `${a}/${b}`));
context.compute(createOpProgramInfoLoader(context.inputs, 'Div', (a, b) => `${a}/${b}`, createOpProgramShader));
};

export const equal = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(
context.compute(createOpProgramInfoLoader(
context.inputs, 'Equal', ({scalar: (a, b) => `u32(${a}==${b})`, vector: (a, b) => `vec4<u32>(${a}==${b})`}),
undefined, undefined, DataType.bool));
createOpProgramShader, undefined, undefined, DataType.bool));
};

export const mul = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Mul', (a, b) => `${a}*${b}`));
context.compute(createOpProgramInfoLoader(context.inputs, 'Mul', (a, b) => `${a}*${b}`, createOpProgramShader));
};

export const pow = (context: ComputeContext): void => {
const type = inputVariable('input', context.inputs[0].dataType, context.inputs[0].dims).type.value;
const roundStr = type === 'i32' ? 'round' : '';
context.compute(createBinaryOpProgramInfoLoader(
context.compute(createOpProgramInfoLoader(
context.inputs, 'Pow',
({scalar: (a, b) => `pow_custom(${a},${b})`, vector: (a, b) => `pow_vector_custom(${a},${b})`}),
createOpProgramShader,
`
fn pow_custom(a : ${type}, b : ${type}) -> ${type} {
if (b == ${type}(0.0)) {
Expand All @@ -226,30 +123,30 @@ export const pow = (context: ComputeContext): void => {
};

export const sub = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Sub', (a, b) => `${a}-${b}`));
context.compute(createOpProgramInfoLoader(context.inputs, 'Sub', (a, b) => `${a}-${b}`, createOpProgramShader));
};

export const greater = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(
context.compute(createOpProgramInfoLoader(
context.inputs, 'Greater', ({scalar: (a, b) => `u32(${a}>${b})`, vector: (a, b) => `vec4<u32>(${a}>${b})`}),
undefined, undefined, DataType.bool));
createOpProgramShader, undefined, undefined, DataType.bool));
};

export const less = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(
context.compute(createOpProgramInfoLoader(
context.inputs, 'Less', ({scalar: (a, b) => `u32(${a}<${b})`, vector: (a, b) => `vec4<u32>(${a}<${b})`}),
undefined, undefined, DataType.bool));
createOpProgramShader, undefined, undefined, DataType.bool));
};

export const greaterOrEqual = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(
context.compute(createOpProgramInfoLoader(
context.inputs, 'GreaterOrEqual',
({scalar: (a, b) => `u32(${a}>=${b})`, vector: (a, b) => `vec4<u32>(${a}>=${b})`}), undefined, undefined,
DataType.bool));
({scalar: (a, b) => `u32(${a}>=${b})`, vector: (a, b) => `vec4<u32>(${a}>=${b})`}), createOpProgramShader,
undefined, undefined, DataType.bool));
};

export const lessOrEqual = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(
context.compute(createOpProgramInfoLoader(
context.inputs, 'LessOrEqual', ({scalar: (a, b) => `u32(${a}<=${b})`, vector: (a, b) => `vec4<u32>(${a}<=${b})`}),
undefined, undefined, DataType.bool));
createOpProgramShader, undefined, undefined, DataType.bool));
};
37 changes: 37 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,43 @@ export const outputVariable =
(name: string, type: number, shape: readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shape, false, components);

/**
* A helper class for generating WGSL code for manipulating broadcast indices for a shader's input.
*/
export interface BroadcastHelper {
/**
* WGSL code for getting offset from broadcast indices.
*
*/
broadcastIndicesToOffset(): string;
}

class BroadcastHelperImpl implements BroadcastHelper {
constructor(private inputs: IndicesHelper[], private output: IndicesHelper) {}

broadcastIndicesToOffset(): string {
let implementation = '';
for (let j = 0; j < this.inputs.length; j++) {
const dims = this.inputs[j].shape;
const name = this.inputs[j].name.substring(0, 1).toUpperCase();
const strides = ShapeUtil.computeStrides(dims);
const offsets: string[] = [];
for (let i = dims.length - 1; i >= 0; i--) {
const idx = this.output.indicesGet('outputIndices', i + this.output.shape.length - dims.length);
offsets.push(`${strides[i]}u * (${idx} % ${dims[i]}u)`);
}
implementation += `fn broadcastIndicesToOffset${name}(outputIndices: ${this.output.type.indices}) -> u32 {
return ${offsets.length > 0 ? offsets.join('+') : '0u'};
}
`;
}
return implementation;
}
}

export const createBroadcastHelper = (inputs: IndicesHelper[], output: IndicesHelper): BroadcastHelper =>
new BroadcastHelperImpl(inputs, output);

/**
* A ShaderHelper is a helper class for generating WGSL code.
*/
Expand Down
Loading

0 comments on commit c167ea4

Please sign in to comment.