From 92ee664f64e96a8cc7308302a3e4f67f95254d1f Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Sat, 2 Dec 2023 07:35:35 +0800 Subject: [PATCH] [js/webgpu] Fix shader errors in indicesGet/Set when rank > 4 (#18661) ### Description Currently, for non-uniform variables, we still use `array` type instead of array, N1>`. So we can't always treat all variables with rank > 4 as uniforms to index. This PR fixes below errors: ``` error(s) generated while compiling the shader: :5:44 error: index 4 out of bounds [0..1] return uniforms.input_strides[4] * (outputIndices[4] % uniforms.input_shape[4])+uniforms.input_strides[3] * (outputIndices[3] % uniforms.input_shape[3])+uniforms.input_strides[2] * (outputIndices[2] % uniforms.input_shape[2])+uniforms.input_strides[1] * (outputIndices[1] % uniforms.input_shape[1])+uniforms.input_strides[0] * (outputIndices[0] % uniforms.input_shape[0]); ^ FAILED #OpTest# - expand.jsonc [webgpu]Expand - Expand 5D - float32 Expand 5 - float32 FAILED #OpTest# - expand.jsonc [webgpu]Expand - Expand 5D - float32 Expand 5 - shape < input.size() --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 30 +++++++++++++---------- js/web/lib/wasm/jsep/webgpu/ops/slice.ts | 10 ++++---- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index af7202903d36..5fffa2f26660 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -326,16 +326,20 @@ export const sumVector = (name: string, components: number) => { }; /** - * A helper function that returns uniform element at index. - * @param name - the name of uniform element. - * @param index - the index of uniform element. - * @param length - the length of uniform element. + * A helper function that returns variable element at index. + * @param name - the name of variable. + * @param index - the index of variable element. + * @param length - the length of variable. */ -export const getUniformElementAt = (name: string, index: number|string, length: number): string => { - if (typeof (index) === 'string') { - return length > 4 ? `${name}[(${index}) / 4][(${index}) % 4]` : length > 1 ? `${name}[${index}]` : name; +export const getElementAt = (name: string, index: number|string, length: number): string => { + if (name.startsWith('uniforms.') && length > 4) { + if (typeof (index) === 'string') { + return `${name}[(${index}) / 4][(${index}) % 4]`; + } else { + return `${name}[${Math.floor(index / 4)}][${index % 4}]`; + } } else { - return length > 4 ? `${name}[${Math.floor(index / 4)}][${index % 4}]` : length > 1 ? `${name}[${index}]` : name; + return length > 1 ? `${name}[${index}]` : name; } }; @@ -380,8 +384,8 @@ const createIndicesHelper = let o2iSnippet = ''; for (let i = 0; i < rank - 1; i++) { o2iSnippet += ` - let dim${i} = current / ${getUniformElementAt(strides, i, rank)}; - let rest${i} = current % ${getUniformElementAt(strides, i, rank)}; + let dim${i} = current / ${getElementAt(strides, i, rank)}; + let rest${i} = current % ${getElementAt(strides, i, rank)}; indices[${i}] = dim${i}; current = rest${i}; `; @@ -404,7 +408,7 @@ const createIndicesHelper = const offsets: string[] = []; if (rank >= 2) { for (let i = rank - 1; i >= 0; i--) { - offsets.push(`${getUniformElementAt(strides, i, rank)} * (indices[${i}])`); + offsets.push(`${getElementAt(strides, i, rank)} * (indices[${i}])`); } } @@ -425,7 +429,7 @@ const createIndicesHelper = if (rank < 2) { return `${varIndices}`; } else { - return `${varIndices}[${idx}]`; + return `${getElementAt(varIndices, idx, rank)}`; } }; @@ -433,7 +437,7 @@ const createIndicesHelper = if (rank < 2) { return `${varIndices}=${value};`; } else { - return `${varIndices}[${idx}]=${value};`; + return `${getElementAt(varIndices, idx, rank)}=${value};`; } }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index aa68cd0b2c61..43d4e5356d1d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types'; -import {createTensorShapeVariables, getUniformElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; +import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; export interface SliceAttributes extends AttributeWithCacheKey { readonly starts: number[]; @@ -82,10 +82,10 @@ const calculateInputIndicesImpl = var inputIndices: ${input.type.indices}; var carry = 0u; for (var i = ${inputShape.length}; i >= 0; i--) { - let input_shape_i = ${getUniformElementAt('uniforms.input_shape', 'i', inputShape.length)}; - let steps_i = ${getUniformElementAt('uniforms.steps', 'i', inputShape.length)}; - let signs_i = ${getUniformElementAt('uniforms.signs', 'i', inputShape.length)}; - let starts_i = ${getUniformElementAt('uniforms.starts', 'i', inputShape.length)}; + let input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}; + let steps_i = ${getElementAt('uniforms.steps', 'i', inputShape.length)}; + let signs_i = ${getElementAt('uniforms.signs', 'i', inputShape.length)}; + let starts_i = ${getElementAt('uniforms.starts', 'i', inputShape.length)}; var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; var inputIndex = outputIndex * steps_i + starts_i + carry; carry = inputIndex / input_shape_i;