Skip to content

Commit

Permalink
[js/webgpu] Fix shader errors in indicesGet/Set when rank > 4 (micros…
Browse files Browse the repository at this point in the history
…oft#18661)

### Description
Currently, for non-uniform variables, we still use `array<u32, N>` type
instead of array<vec4<u32>, N1>`. So we can't always treat all variables
with rank > 4 as uniforms to index.

This PR fixes below errors:
```
error(s) generated while compiling the shader:
:5:44 error: index 4 out of bounds [0..1]
             return uniforms.input_strides[4] * (outputIndices[4] % uniforms.input_shape[4])+uniforms.input_strides[3] * (outputIndices[3] % uniforms.input_shape[3])+uniforms.input_strides[2] * (outputIndices[2] % uniforms.input_shape[2])+uniforms.input_strides[1] * (outputIndices[1] % uniforms.input_shape[1])+uniforms.input_strides[0] * (outputIndices[0] % uniforms.input_shape[0]);
                                           ^
FAILED #OpTest# - expand.jsonc [webgpu]Expand - Expand 5D - float32 Expand 5 - float32
FAILED #OpTest# - expand.jsonc [webgpu]Expand - Expand 5D - float32 Expand 5 - shape < input.size()
  • Loading branch information
qjia7 authored Dec 1, 2023
1 parent eaaf270 commit 92ee664
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
30 changes: 17 additions & 13 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -326,16 +326,20 @@ export const sumVector = (name: string, components: number) => {
};

/**
* A helper function that returns uniform element at index.
* @param name - the name of uniform element.
* @param index - the index of uniform element.
* @param length - the length of uniform element.
* A helper function that returns variable element at index.
* @param name - the name of variable.
* @param index - the index of variable element.
* @param length - the length of variable.
*/
export const getUniformElementAt = (name: string, index: number|string, length: number): string => {
if (typeof (index) === 'string') {
return length > 4 ? `${name}[(${index}) / 4][(${index}) % 4]` : length > 1 ? `${name}[${index}]` : name;
export const getElementAt = (name: string, index: number|string, length: number): string => {
if (name.startsWith('uniforms.') && length > 4) {
if (typeof (index) === 'string') {
return `${name}[(${index}) / 4][(${index}) % 4]`;
} else {
return `${name}[${Math.floor(index / 4)}][${index % 4}]`;
}
} else {
return length > 4 ? `${name}[${Math.floor(index / 4)}][${index % 4}]` : length > 1 ? `${name}[${index}]` : name;
return length > 1 ? `${name}[${index}]` : name;
}
};

Expand Down Expand Up @@ -380,8 +384,8 @@ const createIndicesHelper =
let o2iSnippet = '';
for (let i = 0; i < rank - 1; i++) {
o2iSnippet += `
let dim${i} = current / ${getUniformElementAt(strides, i, rank)};
let rest${i} = current % ${getUniformElementAt(strides, i, rank)};
let dim${i} = current / ${getElementAt(strides, i, rank)};
let rest${i} = current % ${getElementAt(strides, i, rank)};
indices[${i}] = dim${i};
current = rest${i};
`;
Expand All @@ -404,7 +408,7 @@ const createIndicesHelper =
const offsets: string[] = [];
if (rank >= 2) {
for (let i = rank - 1; i >= 0; i--) {
offsets.push(`${getUniformElementAt(strides, i, rank)} * (indices[${i}])`);
offsets.push(`${getElementAt(strides, i, rank)} * (indices[${i}])`);
}
}

Expand All @@ -425,15 +429,15 @@ const createIndicesHelper =
if (rank < 2) {
return `${varIndices}`;
} else {
return `${varIndices}[${idx}]`;
return `${getElementAt(varIndices, idx, rank)}`;
}
};

const indicesSet = (varIndices: string, idx: number|string, value: string) => {
if (rank < 2) {
return `${varIndices}=${value};`;
} else {
return `${varIndices}[${idx}]=${value};`;
return `${getElementAt(varIndices, idx, rank)}=${value};`;
}
};

Expand Down
10 changes: 5 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/ops/slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types';

import {createTensorShapeVariables, getUniformElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';

export interface SliceAttributes extends AttributeWithCacheKey {
readonly starts: number[];
Expand Down Expand Up @@ -82,10 +82,10 @@ const calculateInputIndicesImpl =
var inputIndices: ${input.type.indices};
var carry = 0u;
for (var i = ${inputShape.length}; i >= 0; i--) {
let input_shape_i = ${getUniformElementAt('uniforms.input_shape', 'i', inputShape.length)};
let steps_i = ${getUniformElementAt('uniforms.steps', 'i', inputShape.length)};
let signs_i = ${getUniformElementAt('uniforms.signs', 'i', inputShape.length)};
let starts_i = ${getUniformElementAt('uniforms.starts', 'i', inputShape.length)};
let input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)};
let steps_i = ${getElementAt('uniforms.steps', 'i', inputShape.length)};
let signs_i = ${getElementAt('uniforms.signs', 'i', inputShape.length)};
let starts_i = ${getElementAt('uniforms.starts', 'i', inputShape.length)};
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
var inputIndex = outputIndex * steps_i + starts_i + carry;
carry = inputIndex / input_shape_i;
Expand Down

0 comments on commit 92ee664

Please sign in to comment.