Skip to content

Commit

Permalink
Remove attribute f16 support
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Aug 12, 2024
1 parent 512574a commit 2ba3131
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 24 deletions.
9 changes: 1 addition & 8 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,7 @@ export class WebGpuBackend {
} else if (v.type === DataType.uint32) {
new Uint32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === DataType.float16) {
if (typeof Float16Array !== 'undefined') {
new Float16Array(arrayBuffer, offset, data.length).set(data);
} else {
// Fallback to Uint16Array when Float16Array polyfill is not available, unit test only.
// eslint-disable-next-line no-console
console.warn('Unit test only, please make sure the float16 data has been encoded as float 16 bits.');
new Uint16Array(arrayBuffer, offset, data.length).set(data);
}
new Uint16Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === DataType.float) {
new Float32Array(arrayBuffer, offset, data.length).set(data);
} else {
Expand Down
12 changes: 6 additions & 6 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import {Env} from 'onnxruntime-common';

import {DataType, getTensorElementSize} from '../wasm-common';

import type {OrtWasmModule} from '../wasm-types';
import {DataType, Float16ArrayType, getTensorElementSize} from '../wasm-common';

import {WebGpuBackend} from './backend-webgpu';
import {LOG_DEBUG} from './log';
Expand All @@ -19,16 +20,15 @@ class TensorViewImpl implements TensorView {
private module: OrtWasmModule, public readonly dataType: number, public readonly data: number,
public readonly dims: readonly number[]) {}

getFloat16Array(): Float16ArrayType {
if (this.dataType !== DataType.float16) {
getUint16Array(): Uint16Array {
if (this.dataType !== DataType.float16 && this.dataType !== DataType.uint16) {
throw new Error('Invalid data type');
}
const elementCount = ShapeUtil.size(this.dims);
const float16ViewConstructor = typeof Float16Array !== 'undefined' ? Float16Array : Uint16Array;
return elementCount === 0 ? new float16ViewConstructor() :
new float16ViewConstructor(this.module.HEAP8.buffer, this.data, elementCount);
return elementCount === 0 ? new Uint16Array() : new Uint16Array(this.module.HEAP8.buffer, this.data, elementCount);
}


getFloat32Array(): Float32Array {
if (this.dataType !== DataType.float) {
throw new Error('Invalid data type');
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/tensor-view.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import {Tensor} from 'onnxruntime-common';

import {Float16ArrayType, tensorTypeToTypedArrayConstructor} from '../wasm-common';
import {tensorTypeToTypedArrayConstructor} from '../wasm-common';

export const createView = (dataBuffer: ArrayBuffer, type: Tensor.Type): Int32Array|Uint32Array|BigInt64Array|
BigUint64Array|Uint8Array|Float32Array|Float64Array|Int8Array|Int16Array|Uint16Array =>
Expand All @@ -20,7 +20,7 @@ export interface TensorView {
/**
* get a Float16Array data view of the tensor data. tensor data must be on CPU.
*/
getFloat16Array(): Float16ArrayType;
getUint16Array(): Uint16Array;

/**
* get a Float32Array data view of the tensor data. tensor data must be on CPU.
Expand Down
17 changes: 13 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/ops/pad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,14 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr
const outputSize = ShapeUtil.size(outputShape);
const programUniforms: ProgramUniform[] =
[{type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: attributes.pads}];

const isValueFromInput = (inputs.length >= 3 && inputs[2].data);
if (attributes.mode === 0) {
programUniforms.push({type: inputs[0].dataType, data: attributes.value});
programUniforms.push({
type: inputs[0].dataType == DataType.float16 ? (isValueFromInput ? DataType.float16 : DataType.float) :
inputs[0].dataType,
data: attributes.value
});
}

programUniforms.push(...createTensorShapeVariables(inputs[0].dims, outputShape));
Expand All @@ -169,7 +175,10 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr
const uniforms: UniformsArrayType =
[{name: 'output_size', type: 'u32'}, {name: 'pads', type: 'i32', length: attributes.pads.length}];
if (attributes.mode === 0) {
uniforms.push({name: 'constant_value', type: dataType as UniformDataElementType});
uniforms.push({
name: 'constant_value',
type: dataType == 'f16' ? (isValueFromInput ? 'f16' : 'f32') : dataType as UniformDataElementType
});
}

return `
Expand All @@ -187,7 +196,7 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr

return {
name: 'Pad',
shaderCache: {hint: `${attributes.mode}`, inputDependencies},
shaderCache: {hint: `${attributes.mode}${isValueFromInput}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)},
Expand All @@ -201,7 +210,7 @@ const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes
if (inputs.length > 1) {
const bigInt64Pads = inputs[1].getBigInt64Array();
const value = (inputs.length >= 3 && inputs[2].data) ?
(inputs[2].dataType === DataType.float16 ? inputs[2].getFloat16Array()[0] : inputs[2].getFloat32Array()[0]) :
(inputs[2].dataType === DataType.float16 ? inputs[2].getUint16Array()[0] : inputs[2].getFloat32Array()[0]) :
0.0;

const inputRank = inputs[0].dims.length;
Expand Down
2 changes: 0 additions & 2 deletions js/web/lib/wasm/wasm-common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ declare global {
var Float16Array: any;
}

export type Float16ArrayType = InstanceType<typeof Float16Array>;

// This file includes common definitions. They do NOT have dependency on the WebAssembly instance.

/**
Expand Down
4 changes: 2 additions & 2 deletions js/web/test/data/ops/pad_f16.jsonc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[
{
"name": "constant 2D float16",
"name": "constant 2D float16 v10",
"operator": "Pad",
"opset": { "domain": "", "version": 10 },
"attributes": [
Expand Down Expand Up @@ -33,7 +33,7 @@
]
},
{
"name": "constant 2D float16",
"name": "constant 2D float16 v19",
"operator": "Pad",
"opset": { "domain": "", "version": 19 },
"attributes": [{ "name": "mode", "data": "constant", "type": "string" }],
Expand Down

0 comments on commit 2ba3131

Please sign in to comment.