diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake
index 5adfc7ba0392..ac4d0c4afe6c 100644
--- a/cmake/onnxruntime_providers.cmake
+++ b/cmake/onnxruntime_providers.cmake
@@ -1799,7 +1799,7 @@ if (onnxruntime_USE_XNNPACK)
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_xnnpack_cc_srcs})
onnxruntime_add_static_library(onnxruntime_providers_xnnpack ${onnxruntime_providers_xnnpack_cc_srcs})
onnxruntime_add_include_to_target(onnxruntime_providers_xnnpack
- onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} XNNPACK pthreadpool Boost::mp11 safeint_interface
+ onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} XNNPACK pthreadpool flatbuffers::flatbuffers Boost::mp11 safeint_interface
)
add_dependencies(onnxruntime_providers_xnnpack onnx ${onnxruntime_EXTERNAL_DEPENDENCIES})
diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake
index 4243031045b7..d7712a7b70c9 100644
--- a/cmake/onnxruntime_webassembly.cmake
+++ b/cmake/onnxruntime_webassembly.cmake
@@ -277,19 +277,29 @@ else()
"SHELL:-s EXPORT_NAME=ortWasmThreaded"
"SHELL:-s DEFAULT_PTHREAD_STACK_SIZE=131072"
)
- if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD)
- set_target_properties(onnxruntime_webassembly PROPERTIES OUTPUT_NAME "ort-wasm-simd-threaded")
- else()
- set_target_properties(onnxruntime_webassembly PROPERTIES OUTPUT_NAME "ort-wasm-threaded")
- endif()
else()
target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:-s EXPORT_NAME=ortWasm"
)
- if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD)
- set_target_properties(onnxruntime_webassembly PROPERTIES OUTPUT_NAME "ort-wasm-simd")
- else()
- set_target_properties(onnxruntime_webassembly PROPERTIES OUTPUT_NAME "ort-wasm")
- endif()
endif()
+
+ set(target_name ort)
+
+ if (onnxruntime_ENABLE_TRAINING_APIS)
+ list(APPEND target_name "training")
+ endif()
+
+ list(APPEND target_name "wasm")
+
+ if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD)
+ list(APPEND target_name "simd")
+ endif()
+
+ if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS)
+ list(APPEND target_name "threaded")
+ endif()
+
+ list(JOIN target_name "-" target_name)
+
+ set_target_properties(onnxruntime_webassembly PROPERTIES OUTPUT_NAME ${target_name})
endif()
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 408610711dc1..b27d215c0856 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -759,6 +759,7 @@ Do not modify directly.*
|Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
+|Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float), tensor(float16)|
|Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)|
|Size|*in* data:**T**
*out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h
index 81015b25bc9f..19caa69d94cc 100644
--- a/include/onnxruntime/core/graph/graph.h
+++ b/include/onnxruntime/core/graph/graph.h
@@ -20,6 +20,8 @@
#pragma warning(pop)
#endif
+#include "flatbuffers/flatbuffers.h"
+
#include "core/common/gsl.h"
#include "core/common/common.h"
@@ -43,12 +45,6 @@
#include "core/graph/node_arg.h"
#include "core/graph/ort_format_load_options.h"
-namespace flatbuffers {
-class FlatBufferBuilder;
-template
-struct Offset;
-} // namespace flatbuffers
-
namespace onnxruntime {
class Graph;
struct IndexedSubGraph;
diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts
index f1f8a8aad56a..525272294c58 100644
--- a/js/common/lib/env.ts
+++ b/js/common/lib/env.ts
@@ -61,6 +61,10 @@ export declare namespace Env {
* @defaultValue `'webgl2'`
*/
contextId?: 'webgl'|'webgl2';
+ /**
+ * Get the WebGL rendering context.
+ */
+ readonly context: WebGLRenderingContext;
/**
* Set or get the maximum batch size for matmul. 0 means to disable batching.
*
@@ -88,7 +92,19 @@ export declare namespace Env {
}
export interface WebGpuFlags {
+ /**
+ * Set or get the profiling mode.
+ */
profilingMode?: 'off'|'default';
+ /**
+ * Get the device for WebGPU.
+ *
+ * When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types".
+ * Use `const device = env.webgpu.device as GPUDevice;` in TypeScript to access this property with correct type.
+ *
+ * see comments on {@link GpuBufferType} for more details about why not use types defined in "@webgpu/types".
+ */
+ readonly device: unknown;
}
}
@@ -110,27 +126,27 @@ export interface Env {
* Get version of the current package.
*/
readonly versions: {
- common: string;
- web?: string;
- node?: string;
+ readonly common: string;
+ readonly web?: string;
+ readonly node?: string;
// eslint-disable-next-line @typescript-eslint/naming-convention
- 'react-native'?: string;
+ readonly 'react-native'?: string;
};
/**
* Represent a set of flags for WebAssembly
*/
- wasm: Env.WebAssemblyFlags;
+ readonly wasm: Env.WebAssemblyFlags;
/**
* Represent a set of flags for WebGL
*/
- webgl: Env.WebGLFlags;
+ readonly webgl: Env.WebGLFlags;
/**
* Represent a set of flags for WebGPU
*/
- webgpu: Env.WebGpuFlags;
+ readonly webgpu: Env.WebGpuFlags;
[name: string]: unknown;
}
diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts
index 834b1f670f16..ec030084c967 100644
--- a/js/common/lib/inference-session.ts
+++ b/js/common/lib/inference-session.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
import {InferenceSession as InferenceSessionImpl} from './inference-session-impl.js';
-import {OnnxValue} from './onnx-value.js';
+import {OnnxValue, OnnxValueDataLocation} from './onnx-value.js';
/* eslint-disable @typescript-eslint/no-redeclare */
@@ -138,6 +138,14 @@ export declare namespace InferenceSession {
*/
logVerbosityLevel?: number;
+ /**
+ * Specify string as a preferred data location for all outputs, or an object that use output names as keys and a
+ * preferred data location as corresponding values.
+ *
+ * This setting is available only in ONNXRuntime Web for WebGL and WebGPU EP.
+ */
+ preferredOutputLocation?: OnnxValueDataLocation|{readonly [outputName: string]: OnnxValueDataLocation};
+
/**
* Store configurations for a session. See
* https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/
diff --git a/js/common/lib/onnx-value.ts b/js/common/lib/onnx-value.ts
index 29b9d64d9be2..a16a30d25d83 100644
--- a/js/common/lib/onnx-value.ts
+++ b/js/common/lib/onnx-value.ts
@@ -11,3 +11,8 @@ type NonTensorType = never;
* NOTE: currently not support non-tensor
*/
export type OnnxValue = Tensor|NonTensorType;
+
+/**
+ * Type OnnxValueDataLocation represents the location of the data of an OnnxValue.
+ */
+export type OnnxValueDataLocation = Tensor.DataLocation;
diff --git a/js/common/lib/tensor-factory-impl.ts b/js/common/lib/tensor-factory-impl.ts
index c02ff1bb24a9..926312e62c85 100644
--- a/js/common/lib/tensor-factory-impl.ts
+++ b/js/common/lib/tensor-factory-impl.ts
@@ -1,8 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {OptionsDimensions, OptionsFormat, OptionsNormalizationParameters, OptionsTensorFormat, OptionsTensorLayout, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromUrlOptions} from './tensor-factory.js';
-import {Tensor, TypedTensor} from './tensor.js';
+import {GpuBufferDataTypes, OptionsDimensions, OptionsFormat, OptionsNormalizationParameters, OptionsTensorFormat, OptionsTensorLayout, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureDataTypes} from './tensor-factory.js';
+import {Tensor} from './tensor-impl.js';
+import {Tensor as TensorInterface} from './tensor.js';
interface BufferToTensorOptions extends OptionsDimensions, OptionsTensorLayout, OptionsNormalizationParameters,
OptionsFormat, OptionsTensorFormat {}
@@ -14,87 +15,84 @@ interface BufferToTensorOptions extends OptionsDimensions, OptionsTensorLayout,
* @param imageFormat - input image configuration - required configurations height, width, format
* @param tensorFormat - output tensor configuration - Default is RGB format
*/
-export const bufferToTensor =
- (buffer: Uint8ClampedArray|undefined, options: BufferToTensorOptions): TypedTensor<'float32'>|
- TypedTensor<'uint8'> => {
- if (buffer === undefined) {
- throw new Error('Image buffer must be defined');
- }
- if (options.height === undefined || options.width === undefined) {
- throw new Error('Image height and width must be defined');
- }
- if (options.tensorLayout === 'NHWC') {
- throw new Error('NHWC Tensor layout is not supported yet');
- }
+export const bufferToTensor = (buffer: Uint8ClampedArray|undefined, options: BufferToTensorOptions): Tensor => {
+ if (buffer === undefined) {
+ throw new Error('Image buffer must be defined');
+ }
+ if (options.height === undefined || options.width === undefined) {
+ throw new Error('Image height and width must be defined');
+ }
+ if (options.tensorLayout === 'NHWC') {
+ throw new Error('NHWC Tensor layout is not supported yet');
+ }
- const {height, width} = options;
+ const {height, width} = options;
- const norm = options.norm ?? {mean: 255, bias: 0};
- let normMean: [number, number, number, number];
- let normBias: [number, number, number, number];
+ const norm = options.norm ?? {mean: 255, bias: 0};
+ let normMean: [number, number, number, number];
+ let normBias: [number, number, number, number];
- if (typeof (norm.mean) === 'number') {
- normMean = [norm.mean, norm.mean, norm.mean, norm.mean];
- } else {
- normMean = [norm.mean![0], norm.mean![1], norm.mean![2], norm.mean![3] ?? 255];
- }
+ if (typeof (norm.mean) === 'number') {
+ normMean = [norm.mean, norm.mean, norm.mean, norm.mean];
+ } else {
+ normMean = [norm.mean![0], norm.mean![1], norm.mean![2], norm.mean![3] ?? 255];
+ }
- if (typeof (norm.bias) === 'number') {
- normBias = [norm.bias, norm.bias, norm.bias, norm.bias];
- } else {
- normBias = [norm.bias![0], norm.bias![1], norm.bias![2], norm.bias![3] ?? 0];
- }
+ if (typeof (norm.bias) === 'number') {
+ normBias = [norm.bias, norm.bias, norm.bias, norm.bias];
+ } else {
+ normBias = [norm.bias![0], norm.bias![1], norm.bias![2], norm.bias![3] ?? 0];
+ }
- const inputformat = options.format !== undefined ? options.format : 'RGBA';
- // default value is RGBA since imagedata and HTMLImageElement uses it
-
- const outputformat = options.tensorFormat !== undefined ?
- (options.tensorFormat !== undefined ? options.tensorFormat : 'RGB') :
- 'RGB';
- const stride = height * width;
- const float32Data = outputformat === 'RGBA' ? new Float32Array(stride * 4) : new Float32Array(stride * 3);
-
- // Default pointer assignments
- let step = 4, rImagePointer = 0, gImagePointer = 1, bImagePointer = 2, aImagePointer = 3;
- let rTensorPointer = 0, gTensorPointer = stride, bTensorPointer = stride * 2, aTensorPointer = -1;
-
- // Updating the pointer assignments based on the input image format
- if (inputformat === 'RGB') {
- step = 3;
- rImagePointer = 0;
- gImagePointer = 1;
- bImagePointer = 2;
- aImagePointer = -1;
- }
+ const inputformat = options.format !== undefined ? options.format : 'RGBA';
+ // default value is RGBA since imagedata and HTMLImageElement uses it
- // Updating the pointer assignments based on the output tensor format
- if (outputformat === 'RGBA') {
- aTensorPointer = stride * 3;
- } else if (outputformat === 'RBG') {
- rTensorPointer = 0;
- bTensorPointer = stride;
- gTensorPointer = stride * 2;
- } else if (outputformat === 'BGR') {
- bTensorPointer = 0;
- gTensorPointer = stride;
- rTensorPointer = stride * 2;
- }
+ const outputformat =
+ options.tensorFormat !== undefined ? (options.tensorFormat !== undefined ? options.tensorFormat : 'RGB') : 'RGB';
+ const stride = height * width;
+ const float32Data = outputformat === 'RGBA' ? new Float32Array(stride * 4) : new Float32Array(stride * 3);
- for (let i = 0; i < stride;
- i++, rImagePointer += step, bImagePointer += step, gImagePointer += step, aImagePointer += step) {
- float32Data[rTensorPointer++] = (buffer[rImagePointer] + normBias[0]) / normMean[0];
- float32Data[gTensorPointer++] = (buffer[gImagePointer] + normBias[1]) / normMean[1];
- float32Data[bTensorPointer++] = (buffer[bImagePointer] + normBias[2]) / normMean[2];
- if (aTensorPointer !== -1 && aImagePointer !== -1) {
- float32Data[aTensorPointer++] = (buffer[aImagePointer] + normBias[3]) / normMean[3];
- }
- }
+ // Default pointer assignments
+ let step = 4, rImagePointer = 0, gImagePointer = 1, bImagePointer = 2, aImagePointer = 3;
+ let rTensorPointer = 0, gTensorPointer = stride, bTensorPointer = stride * 2, aTensorPointer = -1;
- // Float32Array -> ort.Tensor
- const outputTensor = outputformat === 'RGBA' ? new Tensor('float32', float32Data, [1, 4, height, width]) :
- new Tensor('float32', float32Data, [1, 3, height, width]);
- return outputTensor;
- };
+ // Updating the pointer assignments based on the input image format
+ if (inputformat === 'RGB') {
+ step = 3;
+ rImagePointer = 0;
+ gImagePointer = 1;
+ bImagePointer = 2;
+ aImagePointer = -1;
+ }
+
+ // Updating the pointer assignments based on the output tensor format
+ if (outputformat === 'RGBA') {
+ aTensorPointer = stride * 3;
+ } else if (outputformat === 'RBG') {
+ rTensorPointer = 0;
+ bTensorPointer = stride;
+ gTensorPointer = stride * 2;
+ } else if (outputformat === 'BGR') {
+ bTensorPointer = 0;
+ gTensorPointer = stride;
+ rTensorPointer = stride * 2;
+ }
+
+ for (let i = 0; i < stride;
+ i++, rImagePointer += step, bImagePointer += step, gImagePointer += step, aImagePointer += step) {
+ float32Data[rTensorPointer++] = (buffer[rImagePointer] + normBias[0]) / normMean[0];
+ float32Data[gTensorPointer++] = (buffer[gImagePointer] + normBias[1]) / normMean[1];
+ float32Data[bTensorPointer++] = (buffer[bImagePointer] + normBias[2]) / normMean[2];
+ if (aTensorPointer !== -1 && aImagePointer !== -1) {
+ float32Data[aTensorPointer++] = (buffer[aImagePointer] + normBias[3]) / normMean[3];
+ }
+ }
+
+ // Float32Array -> ort.Tensor
+ const outputTensor = outputformat === 'RGBA' ? new Tensor('float32', float32Data, [1, 4, height, width]) :
+ new Tensor('float32', float32Data, [1, 3, height, width]);
+ return outputTensor;
+};
/**
* implementation of Tensor.fromImage().
@@ -102,7 +100,7 @@ export const bufferToTensor =
export const tensorFromImage = async(
image: ImageData|HTMLImageElement|ImageBitmap|string,
options?: TensorFromImageDataOptions|TensorFromImageElementOptions|TensorFromImageBitmapOptions|
- TensorFromUrlOptions): Promise|TypedTensor<'uint8'>> => {
+ TensorFromUrlOptions): Promise => {
// checking the type of image object
const isHTMLImageEle = typeof (HTMLImageElement) !== 'undefined' && image instanceof HTMLImageElement;
const isImageDataEle = typeof (ImageData) !== 'undefined' && image instanceof ImageData;
@@ -237,3 +235,30 @@ export const tensorFromImage = async(
throw new Error('Input data provided is not supported - aborted tensor creation');
}
};
+
+/**
+ * implementation of Tensor.fromTexture().
+ */
+export const tensorFromTexture = (
+ texture: TensorInterface.TextureType, options: TensorFromTextureOptions): Tensor => {
+ const {width, height, download, dispose} = options;
+ // Always assume RGBAF32. TODO: support different texture format
+ const dims = [1, height, width, 4];
+ return new Tensor({location: 'texture', type: 'float32', texture, dims, download, dispose});
+};
+
+/**
+ * implementation of Tensor.fromGpuBuffer().
+ */
+export const tensorFromGpuBuffer = (
+ gpuBuffer: TensorInterface.GpuBufferType, options: TensorFromGpuBufferOptions): Tensor => {
+ const {dataType, dims, download, dispose} = options;
+ return new Tensor({location: 'gpu-buffer', type: dataType ?? 'float32', gpuBuffer, dims, download, dispose});
+};
+
+/**
+ * implementation of Tensor.fromPinnedBuffer().
+ */
+export const tensorFromPinnedBuffer = >(
+ type: T, buffer: TensorInterface.DataTypeMap[T], dims?: readonly number[]): Tensor =>
+ new Tensor({location: 'cpu-pinned', type, data: buffer, dims: dims ?? [buffer.length]});
diff --git a/js/common/lib/tensor-factory.ts b/js/common/lib/tensor-factory.ts
index 3eac33c0e849..38d3106d56bc 100644
--- a/js/common/lib/tensor-factory.ts
+++ b/js/common/lib/tensor-factory.ts
@@ -1,12 +1,107 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {TypedTensor} from './tensor.js';
+import {Tensor, TypedTensor} from './tensor.js';
export type ImageFormat = 'RGB'|'RGBA'|'BGR'|'RBG';
export type ImageTensorLayout = 'NHWC'|'NCHW';
-// the following session contains type definitions of each individual options.
+// the following region contains type definitions for constructing tensor from a specific location.
+
+// #region types for constructing a tensor from a specific location
+
+/**
+ * represent common properties of the parameter for constructing a tensor from a specific location.
+ */
+interface CommonConstructorParameters extends Pick {
+ /**
+ * Specify the data type of the tensor.
+ */
+ readonly type: T;
+}
+
+/**
+ * represent the parameter for constructing a tensor from a GPU resource.
+ */
+interface GpuResourceConstructorParameters {
+ /**
+ * an optional callback function to download data from GPU to CPU.
+ *
+ * If not provided, the tensor treat the GPU data as external resource.
+ */
+ download?(): Promise;
+
+ /**
+ * an optional callback function that will be called when the tensor is disposed.
+ *
+ * If not provided, the tensor treat the GPU data as external resource.
+ */
+ dispose?(): void;
+}
+
+/**
+ * supported data types for constructing a tensor from a pinned CPU buffer
+ */
+export type CpuPinnedDataTypes = Exclude;
+
+/**
+ * represent the parameter for constructing a tensor from a pinned CPU buffer
+ */
+export interface CpuPinnedConstructorParameters extends
+ CommonConstructorParameters {
+ /**
+ * Specify the location of the data to be 'cpu-pinned'.
+ */
+ readonly location: 'cpu-pinned';
+ /**
+ * Specify the CPU pinned buffer that holds the tensor data.
+ */
+ readonly data: Tensor.DataTypeMap[T];
+}
+
+/**
+ * supported data types for constructing a tensor from a WebGL texture
+ */
+export type TextureDataTypes = 'float32';
+
+/**
+ * represent the parameter for constructing a tensor from a WebGL texture
+ */
+export interface TextureConstructorParameters extends
+ CommonConstructorParameters, GpuResourceConstructorParameters {
+ /**
+ * Specify the location of the data to be 'texture'.
+ */
+ readonly location: 'texture';
+ /**
+ * Specify the WebGL texture that holds the tensor data.
+ */
+ readonly texture: Tensor.TextureType;
+}
+
+/**
+ * supported data types for constructing a tensor from a WebGPU buffer
+ */
+export type GpuBufferDataTypes = 'float32'|'int32';
+
+/**
+ * represent the parameter for constructing a tensor from a WebGPU buffer
+ */
+export interface GpuBufferConstructorParameters extends
+ CommonConstructorParameters, GpuResourceConstructorParameters {
+ /**
+ * Specify the location of the data to be 'gpu-buffer'.
+ */
+ readonly location: 'gpu-buffer';
+ /**
+ * Specify the WebGPU buffer that holds the tensor data.
+ */
+ readonly gpuBuffer: Tensor.GpuBufferType;
+}
+
+// #endregion
+
+// the following region contains type definitions of each individual options.
// the tensor factory functions use a composition of those options as the parameter type.
// #region Options fields
@@ -92,6 +187,8 @@ export interface OptionsNormalizationParameters {
// #endregion
+// #region Options composition
+
export interface TensorFromImageDataOptions extends OptionResizedDimensions, OptionsTensorFormat, OptionsTensorLayout,
OptionsTensorDataType, OptionsNormalizationParameters {}
@@ -106,6 +203,23 @@ export interface TensorFromUrlOptions extends OptionsDimensions, OptionResizedDi
export interface TensorFromImageBitmapOptions extends OptionResizedDimensions, OptionsTensorFormat, OptionsTensorLayout,
OptionsTensorDataType, OptionsNormalizationParameters {}
+export interface TensorFromTextureOptions extends
+ Required, OptionsFormat, GpuResourceConstructorParameters/* TODO: add more */ {}
+
+export interface TensorFromGpuBufferOptions extends Pick,
+ GpuResourceConstructorParameters {
+ /**
+ * Describes the data type of the tensor.
+ */
+ dataType?: T;
+}
+
+// #endregion
+
+/**
+ * type TensorFactory defines the factory functions of 'Tensor' to create tensor instances from existing data or
+ * resources.
+ */
export interface TensorFactory {
/**
* create a tensor from an ImageData object
@@ -165,4 +279,57 @@ export interface TensorFactory {
*/
fromImage(bitmap: ImageBitmap, options: TensorFromImageBitmapOptions):
Promise|TypedTensor<'uint8'>>;
+
+ /**
+ * create a tensor from a WebGL texture
+ *
+ * @param texture - the WebGLTexture object to create tensor from
+ * @param options - An optional object representing options for creating tensor from WebGL texture.
+ *
+ * The options include following properties:
+ * - `width`: the width of the texture. Required.
+ * - `height`: the height of the texture. Required.
+ * - `format`: the format of the texture. If omitted, assume 'RGBA'.
+ * - `download`: an optional function to download the tensor data from GPU to CPU. If omitted, the GPU data
+ * will not be able to download. Usually, this is provided by a GPU backend for the inference outputs. Users don't
+ * need to provide this function.
+ * - `dispose`: an optional function to dispose the tensor data on GPU. If omitted, the GPU data will not be disposed.
+ * Usually, this is provided by a GPU backend for the inference outputs. Users don't need to provide this function.
+ *
+ * @returns a tensor object
+ */
+ fromTexture(
+ texture: Tensor.TextureType, options: TensorFromTextureOptions): TypedTensor<'float32'>;
+
+ /**
+ * create a tensor from a WebGPU buffer
+ *
+ * @param buffer - the GPUBuffer object to create tensor from
+ * @param options - An optional object representing options for creating tensor from WebGPU buffer.
+ *
+ * The options include following properties:
+ * - `dataType`: the data type of the tensor. If omitted, assume 'float32'.
+ * - `dims`: the dimension of the tensor. Required.
+ * - `download`: an optional function to download the tensor data from GPU to CPU. If omitted, the GPU data
+ * will not be able to download. Usually, this is provided by a GPU backend for the inference outputs. Users don't
+ * need to provide this function.
+ * - `dispose`: an optional function to dispose the tensor data on GPU. If omitted, the GPU data will not be disposed.
+ * Usually, this is provided by a GPU backend for the inference outputs. Users don't need to provide this function.
+ *
+ * @returns a tensor object
+ */
+ fromGpuBuffer(
+ buffer: Tensor.GpuBufferType, options: TensorFromGpuBufferOptions): TypedTensor;
+
+ /**
+ * create a tensor from a pre-allocated buffer. The buffer will be used as a pinned buffer.
+ *
+ * @param type - the tensor element type.
+ * @param buffer - a TypedArray corresponding to the type.
+ * @param dims - specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
+ *
+ * @returns a tensor object
+ */
+ fromPinnedBuffer>(
+ type: T, buffer: Tensor.DataTypeMap[T], dims?: readonly number[]): TypedTensor;
}
diff --git a/js/common/lib/tensor-impl-type-mapping.ts b/js/common/lib/tensor-impl-type-mapping.ts
new file mode 100644
index 000000000000..c4a43ea27fea
--- /dev/null
+++ b/js/common/lib/tensor-impl-type-mapping.ts
@@ -0,0 +1,57 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+import {Tensor} from './tensor.js';
+
+export type SupportedTypedArrayConstructors = Float32ArrayConstructor|Uint8ArrayConstructor|Int8ArrayConstructor|
+ Uint16ArrayConstructor|Int16ArrayConstructor|Int32ArrayConstructor|BigInt64ArrayConstructor|Uint8ArrayConstructor|
+ Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor;
+export type SupportedTypedArray = InstanceType;
+
+// a runtime map that maps type string to TypedArray constructor. Should match Tensor.DataTypeMap.
+export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map([
+ ['float32', Float32Array],
+ ['uint8', Uint8Array],
+ ['int8', Int8Array],
+ ['uint16', Uint16Array],
+ ['float16', Uint16Array],
+ ['int16', Int16Array],
+ ['int32', Int32Array],
+ ['bool', Uint8Array],
+ ['float64', Float64Array],
+ ['uint32', Uint32Array],
+]);
+
+// a runtime map that maps type string to TypedArray constructor. Should match Tensor.DataTypeMap.
+export const NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP = new Map([
+ [Float32Array, 'float32'],
+ [Uint8Array, 'uint8'],
+ [Int8Array, 'int8'],
+ [Uint16Array, 'uint16'],
+ [Int16Array, 'int16'],
+ [Int32Array, 'int32'],
+ [Float64Array, 'float64'],
+ [Uint32Array, 'uint32'],
+]);
+
+// the following code allows delaying execution of BigInt checking. This allows lazy initialization for
+// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt polyfill
+// if available.
+let isBigIntChecked = false;
+export const checkBigInt = () => {
+ if (!isBigIntChecked) {
+ isBigIntChecked = true;
+ const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && typeof BigInt64Array.from === 'function';
+ const isBigUint64ArrayAvailable =
+ typeof BigUint64Array !== 'undefined' && typeof BigUint64Array.from === 'function';
+
+ if (isBigInt64ArrayAvailable) {
+ NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array);
+ NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigInt64Array, 'int64');
+ }
+ if (isBigUint64ArrayAvailable) {
+ NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array);
+ NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64');
+ }
+ }
+};
diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts
index 2ac13d42b995..dbd8685de43f 100644
--- a/js/common/lib/tensor-impl.ts
+++ b/js/common/lib/tensor-impl.ts
@@ -3,201 +3,257 @@
import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js';
import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js';
-import {tensorFromImage} from './tensor-factory-impl.js';
-import {TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromUrlOptions} from './tensor-factory.js';
+import {tensorFromGpuBuffer, tensorFromImage, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js';
+import {CpuPinnedConstructorParameters, CpuPinnedDataTypes, GpuBufferConstructorParameters, GpuBufferDataTypes, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js';
+import {checkBigInt, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js';
import {calculateSize, tensorReshape} from './tensor-utils-impl.js';
import {Tensor as TensorInterface} from './tensor.js';
+// type aliases for those exported from Tensor interface
+
type TensorType = TensorInterface.Type;
type TensorDataType = TensorInterface.DataType;
+type TensorDataLocation = TensorInterface.DataLocation;
+type TensorTextureType = TensorInterface.TextureType;
+type TensorGpuBufferType = TensorInterface.GpuBufferType;
-type SupportedTypedArrayConstructors = Float32ArrayConstructor|Uint8ArrayConstructor|Int8ArrayConstructor|
- Uint16ArrayConstructor|Int16ArrayConstructor|Int32ArrayConstructor|BigInt64ArrayConstructor|Uint8ArrayConstructor|
- Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor;
-type SupportedTypedArray = InstanceType;
-
-// a runtime map that maps type string to TypedArray constructor. Should match Tensor.DataTypeMap.
-const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map([
- ['float32', Float32Array],
- ['uint8', Uint8Array],
- ['int8', Int8Array],
- ['uint16', Uint16Array],
- ['float16', Uint16Array],
- ['int16', Int16Array],
- ['int32', Int32Array],
- ['bool', Uint8Array],
- ['float64', Float64Array],
- ['uint32', Uint32Array],
-]);
-
-// a runtime map that maps type string to TypedArray constructor. Should match Tensor.DataTypeMap.
-const NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP = new Map([
- [Float32Array, 'float32'],
- [Uint8Array, 'uint8'],
- [Int8Array, 'int8'],
- [Uint16Array, 'uint16'],
- [Int16Array, 'int16'],
- [Int32Array, 'int32'],
- [Float64Array, 'float64'],
- [Uint32Array, 'uint32'],
-]);
-
-// the following code allows delaying execution of BigInt checking. This allows lazy initialization for
-// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt polyfill
-// if available.
-let isBigIntChecked = false;
-const checkBigInt = () => {
- if (!isBigIntChecked) {
- isBigIntChecked = true;
- const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && typeof BigInt64Array.from === 'function';
- const isBigUint64ArrayAvailable =
- typeof BigUint64Array !== 'undefined' && typeof BigUint64Array.from === 'function';
-
- if (isBigInt64ArrayAvailable) {
- NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array);
- NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigInt64Array, 'int64');
- }
- if (isBigUint64ArrayAvailable) {
- NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array);
- NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64');
- }
- }
-};
-
-
+/**
+ * the implementation of Tensor interface.
+ *
+ * @internal
+ */
export class Tensor implements TensorInterface {
// #region constructors
- constructor(type: TensorType, data: TensorDataType|readonly number[]|readonly boolean[], dims?: readonly number[]);
- constructor(data: TensorDataType|readonly boolean[], dims?: readonly number[]);
+
+ /**
+ * Construct a new CPU tensor object from the given type, data and dims.
+ */
+ constructor(
+ type: TensorType, data: TensorDataType|readonly string[]|readonly number[]|readonly boolean[],
+ dims?: readonly number[]);
+ /**
+ * Construct a new CPU tensor object from the given data and dims. Type is inferred from data.
+ */
+ constructor(data: TensorDataType|readonly string[]|readonly boolean[], dims?: readonly number[]);
+ /**
+ * Construct a new tensor object from the pinned CPU data with the given type and dims.
+ *
+ * Tensor's location will be set to 'cpu-pinned'.
+ *
+ * @param params - Specify the parameters to construct the tensor.
+ */
+ constructor(params: CpuPinnedConstructorParameters);
+ /**
+ * Construct a new tensor object from the WebGL texture with the given type and dims.
+ *
+ * Tensor's location will be set to 'texture'.
+ *
+ * @param params - Specify the parameters to construct the tensor.
+ */
+ constructor(params: TextureConstructorParameters);
+ /**
+ * Construct a new tensor object from the WebGPU buffer with the given type and dims.
+ *
+ * Tensor's location will be set to 'gpu-buffer'.
+ *
+ * @param params - Specify the parameters to construct the tensor.
+ */
+ constructor(params: GpuBufferConstructorParameters);
+
+ /**
+ * implementation.
+ */
constructor(
- arg0: TensorType|TensorDataType|readonly boolean[], arg1?: TensorDataType|readonly number[]|readonly boolean[],
- arg2?: readonly number[]) {
+ arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters|
+ TextureConstructorParameters|GpuBufferConstructorParameters,
+ arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) {
+ // perform one-time check for BigInt support
checkBigInt();
let type: TensorType;
- let data: TensorDataType;
- let dims: typeof arg1|typeof arg2;
- // check whether arg0 is type or data
- if (typeof arg0 === 'string') {
+ let dims: readonly number[];
+
+ if (typeof arg0 === 'object' && 'location' in arg0) {
//
- // Override: constructor(type, data, ...)
+ // constructing tensor from specific location
//
- type = arg0;
- dims = arg2;
- if (arg0 === 'string') {
- // string tensor
- if (!Array.isArray(arg1)) {
- throw new TypeError('A string tensor\'s data must be a string array.');
+ this.dataLocation = arg0.location;
+ type = arg0.type;
+ dims = arg0.dims;
+ switch (arg0.location) {
+ case 'cpu-pinned': {
+ const expectedTypedArrayConstructor = NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.get(type);
+ if (!expectedTypedArrayConstructor) {
+ throw new TypeError(`unsupported type "${type}" to create tensor from pinned buffer`);
+ }
+ if (!(arg0.data instanceof expectedTypedArrayConstructor)) {
+ throw new TypeError(`buffer should be of type ${expectedTypedArrayConstructor.name}`);
+ }
+ this.cpuData = arg0.data;
+ break;
}
- // we don't check whether every element in the array is string; this is too slow. we assume it's correct and
- // error will be populated at inference
- data = arg1;
- } else {
- // numeric tensor
- const typedArrayConstructor = NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.get(arg0);
- if (typedArrayConstructor === undefined) {
- throw new TypeError(`Unsupported tensor type: ${arg0}.`);
+ case 'texture': {
+ if (type !== 'float32') {
+ throw new TypeError(`unsupported type "${type}" to create tensor from texture`);
+ }
+ this.gpuTextureData = arg0.texture;
+ this.downloader = arg0.download;
+ this.disposer = arg0.dispose;
+ break;
}
- if (Array.isArray(arg1)) {
- if (arg0 === 'float16') {
- // Throw error here because when user try to use number array as data,
- // e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call
- // Uint16Array.from(arg1) which generates wrong data.
- throw new TypeError(
- 'Creating a float16 tensor from number array is not supported. Please use Uint16Array as data.');
- } else if (arg0 === 'uint64' || arg0 === 'int64') {
- // use 'as any' here because:
- // 1. TypeScript's check on type of 'Array.isArray()' does not work with readonly arrays.
- // see https://github.com/microsoft/TypeScript/issues/17002
- // 2. TypeScript's check on union type of '(BigInt64ArrayConstructor|BigUint64ArrayConstructor).from()' does
- // not accept parameter mapFn.
- // 3. parameters of 'SupportedTypedArrayConstructors.from()' does not match the requirement of the union
- // type.
-
- // assume 'arg1' is of type "readonly number[]|readonly bigint[]" here.
-
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- data = (typedArrayConstructor as any).from(arg1, BigInt);
- } else {
- // assume 'arg1' is of type "readonly number[]" here.
-
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- data = (typedArrayConstructor as any).from(arg1);
+ case 'gpu-buffer': {
+ if (type !== 'float32' && type !== 'int32') {
+ throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`);
}
- } else if (arg1 instanceof typedArrayConstructor) {
- data = arg1;
- } else {
- throw new TypeError(`A ${type} tensor's data must be type of ${typedArrayConstructor}`);
+ this.gpuBufferData = arg0.gpuBuffer;
+ this.downloader = arg0.download;
+ this.disposer = arg0.dispose;
+ break;
}
+ default:
+ throw new Error(`Tensor constructor: unsupported location '${this.dataLocation}'`);
}
} else {
//
- // Override: constructor(data, ...)
+ // constructing tensor of location 'cpu'
//
- dims = arg1;
- if (Array.isArray(arg0)) {
- // only boolean[] and string[] is supported
- if (arg0.length === 0) {
- throw new TypeError('Tensor type cannot be inferred from an empty array.');
- }
- const firstElementType = typeof arg0[0];
- if (firstElementType === 'string') {
- type = 'string';
- data = arg0;
- } else if (firstElementType === 'boolean') {
- type = 'bool';
- // 'arg0' is of type 'boolean[]'. Uint8Array.from(boolean[]) actually works, but typescript thinks this is
- // wrong type. We use 'as any' to make it happy.
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- data = Uint8Array.from(arg0 as any[]);
+ let data: TensorDataType;
+ let maybeDims: typeof arg1|typeof arg2;
+ // check whether arg0 is type or data
+ if (typeof arg0 === 'string') {
+ //
+ // Override: constructor(type, data, ...)
+ //
+ type = arg0;
+ maybeDims = arg2;
+ if (arg0 === 'string') {
+ // string tensor
+ if (!Array.isArray(arg1)) {
+ throw new TypeError('A string tensor\'s data must be a string array.');
+ }
+ // we don't check whether every element in the array is string; this is too slow. we assume it's correct and
+ // error will be populated at inference
+ data = arg1;
} else {
- throw new TypeError(`Invalid element type of data array: ${firstElementType}.`);
+ // numeric tensor
+ const typedArrayConstructor = NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.get(arg0);
+ if (typedArrayConstructor === undefined) {
+ throw new TypeError(`Unsupported tensor type: ${arg0}.`);
+ }
+ if (Array.isArray(arg1)) {
+ if (arg0 === 'float16') {
+ // Throw error here because when user try to use number array as data,
+ // e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call
+ // Uint16Array.from(arg1) which generates wrong data.
+ throw new TypeError(
+ 'Creating a float16 tensor from number array is not supported. Please use Uint16Array as data.');
+ } else if (arg0 === 'uint64' || arg0 === 'int64') {
+ // use 'as any' here because:
+ // 1. TypeScript's check on type of 'Array.isArray()' does not work with readonly arrays.
+ // see https://github.com/microsoft/TypeScript/issues/17002
+ // 2. TypeScript's check on union type of '(BigInt64ArrayConstructor|BigUint64ArrayConstructor).from()'
+ // does not accept parameter mapFn.
+ // 3. parameters of 'SupportedTypedArrayConstructors.from()' does not match the requirement of the union
+ // type.
+
+ // assume 'arg1' is of type "readonly number[]|readonly bigint[]" here.
+
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ data = (typedArrayConstructor as any).from(arg1, BigInt);
+ } else {
+ // assume 'arg1' is of type "readonly number[]" here.
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ data = (typedArrayConstructor as any).from(arg1);
+ }
+ } else if (arg1 instanceof typedArrayConstructor) {
+ data = arg1;
+ } else {
+ throw new TypeError(`A ${type} tensor's data must be type of ${typedArrayConstructor}`);
+ }
}
} else {
- // get tensor type from TypedArray
- const mappedType =
- NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.get(arg0.constructor as SupportedTypedArrayConstructors);
- if (mappedType === undefined) {
- throw new TypeError(`Unsupported type for tensor data: ${arg0.constructor}.`);
+ //
+ // Override: constructor(data, ...)
+ //
+ maybeDims = arg1;
+ if (Array.isArray(arg0)) {
+ // only boolean[] and string[] is supported
+ if (arg0.length === 0) {
+ throw new TypeError('Tensor type cannot be inferred from an empty array.');
+ }
+ const firstElementType = typeof arg0[0];
+ if (firstElementType === 'string') {
+ type = 'string';
+ data = arg0;
+ } else if (firstElementType === 'boolean') {
+ type = 'bool';
+ // 'arg0' is of type 'boolean[]'. Uint8Array.from(boolean[]) actually works, but typescript thinks this is
+ // wrong type. We use 'as any' to make it happy.
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ data = Uint8Array.from(arg0 as any[]);
+ } else {
+ throw new TypeError(`Invalid element type of data array: ${firstElementType}.`);
+ }
+ } else {
+ // get tensor type from TypedArray
+ const mappedType =
+ NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.get(arg0.constructor as SupportedTypedArrayConstructors);
+ if (mappedType === undefined) {
+ throw new TypeError(`Unsupported type for tensor data: ${arg0.constructor}.`);
+ }
+ type = mappedType;
+ data = arg0 as SupportedTypedArray;
}
- type = mappedType;
- data = arg0 as SupportedTypedArray;
}
- }
- // type and data is processed, now processing dims
- if (dims === undefined) {
- // assume 1-D tensor if dims omitted
- dims = [data.length];
- } else if (!Array.isArray(dims)) {
- throw new TypeError('A tensor\'s dims must be a number array');
+ // type and data is processed, now processing dims
+ if (maybeDims === undefined) {
+ // assume 1-D tensor if dims omitted
+ maybeDims = [data.length];
+ } else if (!Array.isArray(maybeDims)) {
+ throw new TypeError('A tensor\'s dims must be a number array');
+ }
+ dims = maybeDims as readonly number[];
+
+ this.cpuData = data;
+ this.dataLocation = 'cpu';
}
- // perform check
+ // perform check on dims
const size = calculateSize(dims);
- if (size !== data.length) {
- throw new Error(`Tensor's size(${size}) does not match data length(${data.length}).`);
+ // if data is on CPU, check whether data length matches tensor size
+ if (this.cpuData && size !== this.cpuData.length) {
+ throw new Error(`Tensor's size(${size}) does not match data length(${this.cpuData.length}).`);
}
- this.dims = dims as readonly number[];
this.type = type;
- this.data = data;
+ this.dims = dims;
this.size = size;
}
// #endregion
// #region factory
- static async fromImage(imageData: ImageData, options?: TensorFromImageDataOptions): Promise;
- static async fromImage(imageElement: HTMLImageElement, options?: TensorFromImageElementOptions): Promise;
- static async fromImage(bitmap: ImageBitmap, options: TensorFromImageBitmapOptions): Promise;
- static async fromImage(urlSource: string, options?: TensorFromUrlOptions): Promise;
-
static async fromImage(
image: ImageData|HTMLImageElement|ImageBitmap|string,
options?: TensorFromImageDataOptions|TensorFromImageElementOptions|TensorFromImageBitmapOptions|
- TensorFromUrlOptions): Promise {
+ TensorFromUrlOptions): Promise {
return tensorFromImage(image, options);
}
+
+ static fromTexture(texture: TensorTextureType, options: TensorFromTextureOptions<'float32'>): TensorInterface {
+ return tensorFromTexture(texture, options);
+ }
+
+ static fromGpuBuffer(
+ gpuBuffer: TensorGpuBufferType, options: TensorFromGpuBufferOptions): TensorInterface {
+ return tensorFromGpuBuffer(gpuBuffer, options);
+ }
+
+ static fromPinnedBuffer(
+ type: T, buffer: TensorInterface.DataTypeMap[T], dims?: readonly number[]): Tensor {
+ return tensorFromPinnedBuffer(type, buffer, dims);
+ }
+
// #endregion
// #region conversions
@@ -210,15 +266,153 @@ export class Tensor implements TensorInterface {
}
// #endregion
- // #region fields
+ // #region public fields
readonly dims: readonly number[];
readonly type: TensorType;
- readonly data: TensorDataType;
readonly size: number;
// #endregion
+ // #region private fields
+
+ /**
+ * stores the location of the data.
+ */
+ private dataLocation: TensorDataLocation;
+
+ /**
+ * stores the data on CPU, if location is 'cpu' or 'cpu-pinned'. otherwise empty.
+ */
+ private cpuData?: TensorDataType;
+
+ /**
+ * stores the underlying texture when location is 'texture'. otherwise empty.
+ */
+ private gpuTextureData?: TensorTextureType;
+
+ /**
+ * stores the underlying GPU buffer when location is 'gpu-buffer'. otherwise empty.
+ */
+ private gpuBufferData?: TensorGpuBufferType;
+
+ /**
+ * stores an optional downloader function to download data from GPU to CPU.
+ */
+ private downloader?(): Promise;
+
+ /**
+ * a flag indicating whether the data is being downloaded from GPU to CPU.
+ */
+ private isDownloading?: boolean;
+
+ /**
+ * stores an optional disposer function to dispose the underlying data.
+ */
+ private disposer?(): void;
+ // #endregion
+
+ // #region properties
+ get data(): TensorDataType {
+ this.ensureValid();
+ if (!this.cpuData) {
+ throw new Error(
+ 'The data is not on CPU. Use `getData()` to download GPU data to CPU, ' +
+ 'or use `texture` property to access the GPU data directly.');
+ }
+ return this.cpuData;
+ }
+
+ get location(): TensorDataLocation {
+ return this.dataLocation;
+ }
+
+ get texture(): TensorTextureType {
+ this.ensureValid();
+ if (!this.gpuTextureData) {
+ throw new Error('The data is not stored as a WebGL texture.');
+ }
+ return this.gpuTextureData;
+ }
+
+ get gpuBuffer(): TensorGpuBufferType {
+ this.ensureValid();
+ if (!this.gpuBufferData) {
+ throw new Error('The data is not stored as a WebGPU buffer.');
+ }
+ return this.gpuBufferData;
+ }
+ // #endregion
+
+ // #region methods
+
+ async getData(releaseData?: boolean): Promise {
+ this.ensureValid();
+ switch (this.dataLocation) {
+ case 'cpu':
+ case 'cpu-pinned':
+ return this.data;
+ case 'texture':
+ case 'gpu-buffer': {
+ if (!this.downloader) {
+ throw new Error('The current tensor is not created with a specified data downloader.');
+ }
+ if (this.isDownloading) {
+ throw new Error('The current tensor is being downloaded.');
+ }
+ try {
+ this.isDownloading = true;
+ const data = await this.downloader();
+ this.downloader = undefined;
+ this.dataLocation = 'cpu';
+ this.cpuData = data;
+
+ if (releaseData && this.disposer) {
+ this.disposer();
+ this.disposer = undefined;
+ }
+
+ return data;
+
+ } finally {
+ this.isDownloading = false;
+ }
+ }
+ default:
+ throw new Error(`cannot get data from location: ${this.dataLocation}`);
+ }
+ }
+
+ dispose(): void {
+ if (this.isDownloading) {
+ throw new Error('The current tensor is being downloaded.');
+ }
+
+ if (this.disposer) {
+ this.disposer();
+ this.disposer = undefined;
+ }
+ this.cpuData = undefined;
+ this.gpuTextureData = undefined;
+ this.gpuBufferData = undefined;
+ this.downloader = undefined;
+ this.isDownloading = undefined;
+
+ this.dataLocation = 'none';
+ }
+
+ // #endregion
+
// #region tensor utilities
- reshape(dims: readonly number[]): Tensor {
+ private ensureValid(): void {
+ if (this.dataLocation === 'none') {
+ throw new Error('The tensor is disposed.');
+ }
+ }
+
+ reshape(dims: readonly number[]): TensorInterface {
+ this.ensureValid();
+ if (this.downloader || this.disposer) {
+ throw new Error('Cannot reshape a tensor that owns GPU resource.');
+ }
return tensorReshape(this, dims);
}
// #endregion
diff --git a/js/common/lib/tensor-utils-impl.ts b/js/common/lib/tensor-utils-impl.ts
index 8a259b236157..bd3080b72465 100644
--- a/js/common/lib/tensor-utils-impl.ts
+++ b/js/common/lib/tensor-utils-impl.ts
@@ -1,7 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {Tensor} from './tensor.js';
+import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, TextureConstructorParameters} from './tensor-factory.js';
+import {Tensor} from './tensor-impl.js';
/**
* calculate size from dims.
@@ -26,5 +27,32 @@ export const calculateSize = (dims: readonly unknown[]): number => {
/**
* implementation of Tensor.reshape()
*/
-export const tensorReshape = (tensor: Tensor, dims: readonly number[]): Tensor =>
- new Tensor(tensor.type, tensor.data, dims);
+export const tensorReshape = (tensor: Tensor, dims: readonly number[]): Tensor => {
+ switch (tensor.location) {
+ case 'cpu':
+ return new Tensor(tensor.type, tensor.data, dims);
+ case 'cpu-pinned':
+ return new Tensor({
+ location: 'cpu-pinned',
+ data: tensor.data as CpuPinnedConstructorParameters['data'],
+ type: tensor.type as CpuPinnedConstructorParameters['type'],
+ dims,
+ });
+ case 'texture':
+ return new Tensor({
+ location: 'texture',
+ texture: tensor.texture,
+ type: tensor.type as TextureConstructorParameters['type'],
+ dims,
+ });
+ case 'gpu-buffer':
+ return new Tensor({
+ location: 'gpu-buffer',
+ gpuBuffer: tensor.gpuBuffer,
+ type: tensor.type as GpuBufferConstructorParameters['type'],
+ dims,
+ });
+ default:
+ throw new Error(`tensorReshape: tensor location ${tensor.location} is not supported`);
+ }
+};
diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts
index 90e3be9acbd2..10071eda3940 100644
--- a/js/common/lib/tensor.ts
+++ b/js/common/lib/tensor.ts
@@ -21,8 +21,46 @@ interface TypedTensorBase {
readonly type: T;
/**
* Get the buffer data of the tensor.
+ *
+ * If the data is not on CPU (eg. it's in the form of WebGL texture or WebGPU buffer), throw error.
*/
readonly data: Tensor.DataTypeMap[T];
+ /**
+ * Get the location of the data.
+ */
+ readonly location: Tensor.DataLocation;
+ /**
+ * Get the WebGL texture that holds the tensor data.
+ *
+ * If the data is not on GPU as WebGL texture, throw error.
+ */
+ readonly texture: Tensor.TextureType;
+ /**
+ * Get the WebGPU buffer that holds the tensor data.
+ *
+ * If the data is not on GPU as WebGPU buffer, throw error.
+ */
+ readonly gpuBuffer: Tensor.GpuBufferType;
+
+ /**
+ * Get the buffer data of the tensor.
+ *
+ * If the data is on CPU, returns the data immediately.
+ * If the data is on GPU, downloads the data and returns the promise.
+ *
+ * @param releaseData - whether release the data on GPU. Ignore if data is already on CPU.
+ */
+ getData(releaseData?: boolean): Promise;
+
+ /**
+ * Dispose the tensor data.
+ *
+ * If the data is on CPU, remove its internal reference to the underlying data.
+ * If the data is on GPU, release the data on GPU.
+ *
+ * After calling this function, the tensor is considered no longer valid. Its location will be set to 'none'.
+ */
+ dispose(): void;
}
export declare namespace Tensor {
@@ -67,6 +105,28 @@ export declare namespace Tensor {
type DataType = DataTypeMap[Type];
type ElementType = ElementTypeMap[Type];
+ /**
+ * type alias for WebGL texture
+ */
+ export type TextureType = WebGLTexture;
+
+ /**
+ * type alias for WebGPU buffer
+ *
+ * The reason why we don't use type "GPUBuffer" defined in webgpu.d.ts from @webgpu/types is because "@webgpu/types"
+ * requires "@types/dom-webcodecs" as peer dependency when using TypeScript < v5.1 and its version need to be chosen
+ * carefully according to the TypeScript version being used. This means so far there is not a way to keep every
+ * TypeScript version happy. It turns out that we will easily broke users on some TypeScript version.
+ *
+ * for more info see https://github.com/gpuweb/types/issues/127
+ */
+ export type GpuBufferType = {size: number; mapState: 'unmapped' | 'pending' | 'mapped'};
+
+ /**
+ * represent where the tensor data is stored
+ */
+ export type DataLocation = 'none'|'cpu'|'cpu-pinned'|'texture'|'gpu-buffer';
+
/**
* represent the data type of a tensor
*/
@@ -82,13 +142,16 @@ export interface TypedTensor extends TypedTensorBase,
*/
export interface Tensor extends TypedTensorBase, TypedTensorUtils {}
+/**
+ * type TensorConstructor defines the constructors of 'Tensor' to create CPU tensor instances.
+ */
export interface TensorConstructor {
- // #region specify element type
+ // #region CPU tensor - specify element type
/**
* Construct a new string tensor object from the given type, data and dims.
*
* @param type - Specify the element type.
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(type: 'string', data: Tensor.DataTypeMap['string']|readonly string[],
@@ -98,7 +161,7 @@ export interface TensorConstructor {
* Construct a new bool tensor object from the given type, data and dims.
*
* @param type - Specify the element type.
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(type: 'bool', data: Tensor.DataTypeMap['bool']|readonly boolean[], dims?: readonly number[]): TypedTensor<'bool'>;
@@ -107,7 +170,7 @@ export interface TensorConstructor {
* Construct a new 64-bit integer typed tensor object from the given type, data and dims.
*
* @param type - Specify the element type.
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(
@@ -118,19 +181,19 @@ export interface TensorConstructor {
* Construct a new numeric tensor object from the given type, data and dims.
*
* @param type - Specify the element type.
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new>(
type: T, data: Tensor.DataTypeMap[T]|readonly number[], dims?: readonly number[]): TypedTensor;
// #endregion
- // #region infer element types
+ // #region CPU tensor - infer element types
/**
* Construct a new float32 tensor object from the given data and dims.
*
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(data: Float32Array, dims?: readonly number[]): TypedTensor<'float32'>;
@@ -138,7 +201,7 @@ export interface TensorConstructor {
/**
* Construct a new int8 tensor object from the given data and dims.
*
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(data: Int8Array, dims?: readonly number[]): TypedTensor<'int8'>;
@@ -146,7 +209,7 @@ export interface TensorConstructor {
/**
* Construct a new uint8 tensor object from the given data and dims.
*
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(data: Uint8Array, dims?: readonly number[]): TypedTensor<'uint8'>;
@@ -154,7 +217,7 @@ export interface TensorConstructor {
/**
* Construct a new uint16 tensor object from the given data and dims.
*
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(data: Uint16Array, dims?: readonly number[]): TypedTensor<'uint16'>;
@@ -162,7 +225,7 @@ export interface TensorConstructor {
/**
* Construct a new int16 tensor object from the given data and dims.
*
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(data: Int16Array, dims?: readonly number[]): TypedTensor<'int16'>;
@@ -170,7 +233,7 @@ export interface TensorConstructor {
/**
* Construct a new int32 tensor object from the given data and dims.
*
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(data: Int32Array, dims?: readonly number[]): TypedTensor<'int32'>;
@@ -178,7 +241,7 @@ export interface TensorConstructor {
/**
* Construct a new int64 tensor object from the given data and dims.
*
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(data: BigInt64Array, dims?: readonly number[]): TypedTensor<'int64'>;
@@ -186,7 +249,7 @@ export interface TensorConstructor {
/**
* Construct a new string tensor object from the given data and dims.
*
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(data: readonly string[], dims?: readonly number[]): TypedTensor<'string'>;
@@ -194,7 +257,7 @@ export interface TensorConstructor {
/**
* Construct a new bool tensor object from the given data and dims.
*
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(data: readonly boolean[], dims?: readonly number[]): TypedTensor<'bool'>;
@@ -202,7 +265,7 @@ export interface TensorConstructor {
/**
* Construct a new float64 tensor object from the given data and dims.
*
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(data: Float64Array, dims?: readonly number[]): TypedTensor<'float64'>;
@@ -210,7 +273,7 @@ export interface TensorConstructor {
/**
* Construct a new uint32 tensor object from the given data and dims.
*
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(data: Uint32Array, dims?: readonly number[]): TypedTensor<'uint32'>;
@@ -218,20 +281,20 @@ export interface TensorConstructor {
/**
* Construct a new uint64 tensor object from the given data and dims.
*
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(data: BigUint64Array, dims?: readonly number[]): TypedTensor<'uint64'>;
// #endregion
- // #region fall back to non-generic tensor type declaration
+ // #region CPU tensor - fall back to non-generic tensor type declaration
/**
* Construct a new tensor object from the given type, data and dims.
*
* @param type - Specify the element type.
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(type: Tensor.Type, data: Tensor.DataType|readonly number[]|readonly string[]|readonly bigint[]|readonly boolean[],
@@ -240,7 +303,7 @@ export interface TensorConstructor {
/**
* Construct a new tensor object from the given data and dims.
*
- * @param data - Specify the tensor data.
+ * @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new(data: Tensor.DataType, dims?: readonly number[]): Tensor;
diff --git a/js/node/lib/index.ts b/js/node/lib/index.ts
index 9dba44bce43b..69b1ef1d96af 100644
--- a/js/node/lib/index.ts
+++ b/js/node/lib/index.ts
@@ -12,4 +12,4 @@ for (const backend of backends) {
registerBackend(backend.name, onnxruntimeBackend, 100);
}
-env.versions.node = version;
+Object.defineProperty(env.versions, 'node', {value: version, enumerable: true});
diff --git a/js/react_native/lib/index.ts b/js/react_native/lib/index.ts
index b6b559ceb3cd..3bf9da3719e9 100644
--- a/js/react_native/lib/index.ts
+++ b/js/react_native/lib/index.ts
@@ -15,4 +15,4 @@ if (Platform.OS === 'android') {
registerBackend('coreml', onnxruntimeBackend, 1);
}
-env.versions['react-native'] = version;
+Object.defineProperty(env.versions, 'react-native', {value: version, enumerable: true});
diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md
index c56bf4c6ff02..a969e1b86bf9 100644
--- a/js/web/docs/webgpu-operators.md
+++ b/js/web/docs/webgpu-operators.md
@@ -38,6 +38,7 @@ Do not modify directly.*
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
| Floor | ai.onnx(6-12,13+) | |
| Gather | ai.onnx(1-10,11-12,13+) | |
+| GatherElements | ai.onnx(11-12,13+) | |
| Gelu | com.microsoft(1+) | |
| Gemm | ai.onnx(7-8,9-10,11-12,13+) | |
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts
index e3f2cf7300c8..d5ed536034f3 100644
--- a/js/web/lib/index.ts
+++ b/js/web/lib/index.ts
@@ -26,4 +26,4 @@ if (!BUILD_DEFS.DISABLE_WASM) {
registerBackend('webnn', wasmBackend, 9);
}
-env.versions.web = version;
+Object.defineProperty(env.versions, 'web', {value: version, enumerable: true});
diff --git a/js/web/lib/onnxjs/backends/backend-webgl.ts b/js/web/lib/onnxjs/backends/backend-webgl.ts
index cc00b8be809e..74716ca0edcb 100644
--- a/js/web/lib/onnxjs/backends/backend-webgl.ts
+++ b/js/web/lib/onnxjs/backends/backend-webgl.ts
@@ -72,6 +72,8 @@ export class WebGLBackend implements Backend {
Logger.setWithEnv(env);
+ Object.defineProperty(env.webgl, 'context', {value: this.glContext.gl});
+
Logger.verbose(
'WebGLBackend',
`Created WebGLContext: ${typeof this.glContext} with matmulMaxBatchSize: ${
diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts
index 06fcbf634408..7f0430b7b28b 100644
--- a/js/web/lib/wasm/binding/ort-wasm.d.ts
+++ b/js/web/lib/wasm/binding/ort-wasm.d.ts
@@ -64,6 +64,38 @@ export interface OrtWasmModule extends EmscriptenModule {
_OrtEndProfiling(sessionHandle: number): number;
// #endregion
+ // #region ORT Training APIs
+ _OrtTrainingLoadCheckpoint?(dataOffset: number, dataLength: number): number;
+
+ _OrtTrainingReleaseCheckpoint?(checkpointHandle: number): void;
+
+ _OrtTrainingCreateSession?
+ (sessionOptionsHandle: number, checkpointHandle: number, trainOffset: number, trainLength: number,
+ evalOffset: number, evalLength: number, optimizerOffset: number, optimizerLength: number): number;
+
+ _OrtTrainingLazyResetGrad?(trainingHandle: number): number;
+
+ _OrtTrainingRunTrainStep?
+ (trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number,
+ runOptionsHandle: number): number;
+
+ _OrtTrainingOptimizerStep?(trainingHandle: number, runOptionsHandle: number): number;
+
+ _OrtTrainingEvalStep?
+ (trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number,
+ runOptionsHandle: number): number;
+
+ _OrtTrainingGetParametersSize?(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number;
+
+ _OrtTrainingCopyParametersToBuffer?
+ (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
+
+ _OrtTrainingCopyParametersFromBuffer?
+ (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
+
+ _OrtTrainingReleaseSession?(trainingHandle: number): void;
+ // #endregion
+
// #region config
mainScriptUrlOrBlob?: string|Blob;
// #endregion
diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts
index 861562d2e0e5..9b97a45d7580 100644
--- a/js/web/lib/wasm/jsep/backend-webgpu.ts
+++ b/js/web/lib/wasm/jsep/backend-webgpu.ts
@@ -155,6 +155,8 @@ export class WebGpuBackend {
count: 2,
});
}
+
+ Object.defineProperty(this.env.webgpu, 'device', {value: this.device});
}
dispose(): void {
diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
index ae4b754f7628..23aabb6531f0 100644
--- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
+++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
@@ -8,6 +8,7 @@ import {conv, parseConvAttributes} from './ops/conv';
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
import {expand} from './ops/expand';
import {gather, parseGatherAttributes} from './ops/gather';
+import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
import {gemm, parseGemmAttributes} from './ops/gemm';
import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm';
import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm';
@@ -58,6 +59,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new
['Expand', [expand]],
['Floor', [unaryOps.floor]],
['Gather', [gather, parseGatherAttributes]],
+ ['GatherElements', [gatherElements, parseGatherElementsAttributes]],
['Gelu', [unaryOps.gelu]],
['Gemm', [gemm, parseGemmAttributes]],
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
index b77e9bea7b87..02507ad802b3 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
@@ -174,7 +174,7 @@ export const createConv2DMatMulProgramInfo =
const dispatch = [
Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]),
Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]),
- Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[1])
+ Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[2])
];
LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`);
@@ -242,9 +242,10 @@ export const createConv2DMatMulProgramInfo =
isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0],
elementsSize[1], elementsSize[2])}
${
- isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, !isChannelsLast, tileInner) :
- makeMatMulPackedSource(
- elementsPerThread, workGroupSize, !isChannelsLast, tileInner, false, undefined,
- sequentialAccessByThreads)}`
+ isVec4 ?
+ makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) :
+ makeMatMulPackedSource(
+ elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined,
+ sequentialAccessByThreads)}`
};
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
index d30821e50808..fee872f4120e 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
@@ -19,19 +19,27 @@
//
// modified to fit the needs of the project
-const writeDataToSubAVec4Snippet = (transpose: boolean) => {
+import {TensorView} from '../../../tensor';
+import {ShapeUtil} from '../../../util';
+import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types';
+import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from '../common';
+import {getActicationSnippet, InternalActivationAttributes} from '../fuse-utils';
+
+import {typeSnippet} from './activation_util';
+
+const writeDataToSubAVec4Snippet = (transpose: boolean, batchDims?: IndicesHelper) => {
if (transpose) {
return `
mm_Asub[inputRow][inputCol] = mm_readA(batch,
kStart + inputRow,
- globalRowStart / innerElementSize + inputCol);
+ globalRowStart / innerElementSize + inputCol${batchDims ? ', batchIndices' : ''});
`;
} else {
return `
mm_Asub[inputRow][inputCol] = mm_readA(batch,
globalRow + innerRow,
- kStart / innerElementSize + inputCol);
+ kStart / innerElementSize + inputCol${batchDims ? ', batchIndices' : ''});
`;
}
};
@@ -62,8 +70,8 @@ const calculateResultSnippet = (transposeA: boolean, innerElementSize: number) =
};
export const makeMatMulPackedVec4Source =
- (workPerThread: number[], workgroupSize: [number, number, number], transposeA = false, tileInner = 32,
- splitK = false, splitedDimInner = 32, isVectorA = false): string => {
+ (workPerThread: number[], workgroupSize: [number, number, number], batchDims?: IndicesHelper, transposeA = false,
+ tileInner = 32, splitK = false, splitedDimInner = 32): string => {
const tileAOuter = workgroupSize[1] * workPerThread[1];
const tileBOuter = workgroupSize[0] * workPerThread[0];
const tileAWidth = transposeA ? tileAOuter : tileInner;
@@ -95,12 +103,13 @@ fn main(@builtin(local_invocation_id) localId : vec3,
@builtin(global_invocation_id) globalId : vec3,
@builtin(workgroup_id) workgroupId : vec3) {
let localRow = i32(localId.y);
- let tileRow = ${isVectorA ? '0' : 'localRow * rowPerThread'};
+ let tileRow = localRow * rowPerThread;
let tileCol = i32(localId.x);
- let globalRow = ${isVectorA ? '0' : 'i32(globalId.y) * rowPerThread'};
+ let globalRow =i32(globalId.y) * rowPerThread;
let globalCol = i32(globalId.x);
let batch = ${splitK ? '0' : 'i32(globalId.z)'};
+ ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''}
let globalRowStart = i32(workgroupId.y) * ${tileAOuter};
let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'};
@@ -115,14 +124,15 @@ fn main(@builtin(local_invocation_id) localId : vec3,
for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {
let inputRow = tileRow + innerRow;
let inputCol = tileCol;
- ${writeDataToSubAVec4Snippet(transposeA)}
+ ${writeDataToSubAVec4Snippet(transposeA, batchDims)}
}
// Load one tile of B into local memory.
for (var innerRow = 0; innerRow < ${rowPerThreadB}; innerRow = innerRow + 1) {
let inputRow = tileRowB + innerRow;
let inputCol = tileCol;
- mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol);
+ mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol${
+ batchDims ? ', batchIndices' : ''});
}
kStart = kStart + tileInner;
workgroupBarrier();
@@ -146,19 +156,19 @@ fn main(@builtin(local_invocation_id) localId : vec3,
}`;
};
-const writeDataToSubASnippet = (transpose: boolean) => {
+const writeDataToSubASnippet = (transpose: boolean, batchDims?: IndicesHelper) => {
if (transpose) {
return `
mm_Asub[inputRow][inputCol] = mm_readA(batch,
kStart + inputRow,
- globalRowStart + inputCol);
+ globalRowStart + inputCol${batchDims ? ', batchIndices' : ''});
`;
} else {
return `
mm_Asub[inputRow][inputCol] = mm_readA(batch,
globalRowStart + inputRow,
- kStart + inputCol);
+ kStart + inputCol${batchDims ? ', batchIndices' : ''});
`;
}
};
@@ -169,8 +179,8 @@ const readDataFromSubASnippet = (transposeA: boolean) =>
// sequentialAccessByThreads means sequential data in memory is accessed by
// threads, instead of a single thread (default behavior).
export const makeMatMulPackedSource =
- (workPerThread: number[], workgroupSize: [number, number, number], transposeA = false, tileInner = 32,
- splitK = false, splitedDimInner = 32, sequentialAccessByThreads = false): string => {
+ (workPerThread: number[], workgroupSize: [number, number, number], batchDims?: IndicesHelper, transposeA = false,
+ tileInner = 32, splitK = false, splitedDimInner = 32, sequentialAccessByThreads = false): string => {
const tileAOuter = workPerThread[1] * workgroupSize[1];
const tileBOuter = workPerThread[0] * workgroupSize[0];
const tileAWidth = transposeA ? tileAOuter : tileInner;
@@ -197,7 +207,7 @@ export const makeMatMulPackedSource =
// Load one tile of A into local memory.
for (var inputRow = localRow; inputRow < ${tileAHight}; inputRow = inputRow + ${workgroupSize[1]}) {
for (var inputCol = localCol; inputCol < ${tileAWidth}; inputCol = inputCol + ${workgroupSize[0]}) {
- ${writeDataToSubASnippet(transposeA)}
+ ${writeDataToSubASnippet(transposeA, batchDims)}
}
}
// Load one tile of B into local memory.
@@ -205,7 +215,7 @@ export const makeMatMulPackedSource =
for (var inputCol = localCol; inputCol < ${tileBOuter}; inputCol = inputCol + ${workgroupSize[0]}) {
mm_Bsub[inputRow][inputCol] = mm_readB(batch,
kStart + inputRow,
- globalColStart + inputCol);
+ globalColStart + inputCol${batchDims ? ', batchIndices' : ''});
}
}
kStart = kStart + tileInner;
@@ -255,7 +265,7 @@ for (var t = 0; t < numTiles; t = t + 1) {
for (var innerCol = 0; innerCol < ${colPerThreadA}; innerCol = innerCol + 1) {
let inputRow = tileRowA + innerRow;
let inputCol = tileColA + innerCol;
- ${writeDataToSubASnippet(transposeA)}
+ ${writeDataToSubASnippet(transposeA, batchDims)}
}
}
@@ -266,7 +276,7 @@ for (var t = 0; t < numTiles; t = t + 1) {
let inputCol = tileCol + innerCol;
mm_Bsub[inputRow][inputCol] = mm_readB(batch,
kStart + inputRow,
- globalCol + innerCol);
+ globalCol + innerCol${batchDims ? ', batchIndices' : ''});
}
}
kStart = kStart + tileInner;
@@ -310,6 +320,7 @@ fn main(@builtin(local_invocation_id) localId : vec3,
@builtin(global_invocation_id) globalId : vec3,
@builtin(workgroup_id) workgroupId : vec3) {
let batch = ${splitK ? '0' : 'i32(globalId.z)'};
+ ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''}
let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'};
var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'};
@@ -325,3 +336,144 @@ fn main(@builtin(local_invocation_id) localId : vec3,
}
`;
};
+
+const matMulReadWriteFnSource =
+ (component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[]): string => {
+ const batchAVariable = variables[0];
+ const batchBVariable = variables[1];
+ const batchVariable = variables[2];
+ const aVariable = variables[3];
+ const bVariable = variables[4];
+ const outputVariable = variables[5];
+ const broadCastADims = getBroadcastDims(batchAVariable.shape, batchVariable.shape);
+ const broadCastBDims = getBroadcastDims(batchBVariable.shape, batchVariable.shape);
+ const getAIndices = () => {
+ const aRank = aVariable.shape.length;
+ const batchRank = batchVariable.shape.length;
+ let resStr = `var aIndices: ${aVariable.type.indices};`;
+ for (let i = aRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) {
+ resStr += `\naIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`;
+ }
+ broadCastADims.forEach(i => {
+ resStr += `\naIndices[${i}] = 0;`;
+ });
+ resStr += `\naIndices[${aRank - 2}] = u32(row);
+ aIndices[${aRank - 1}] = u32(colIn);`;
+ return resStr;
+ };
+ const getBIndices = () => {
+ const bRank = bVariable.shape.length;
+ const batchRank = batchVariable.shape.length;
+ let resStr = `var bIndices: ${bVariable.type.indices};`;
+ for (let i = bRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) {
+ resStr += `\nbIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`;
+ }
+ broadCastBDims.forEach(i => {
+ resStr += `\nbIndices[${i}] = 0;`;
+ });
+ resStr += `\nbIndices[${bRank - 2}] = u32(row);
+ bIndices[${bRank - 1}] = u32(colIn);`;
+ return resStr;
+ };
+ const source = `
+ fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${
+ typeSnippet(component)} {
+ var value = ${typeSnippet(component)}(0.0);
+ let col = colIn * ${component};
+ if(row < dimAOuter && col < dimInner)
+ {
+ ${getAIndices()}
+ value = ${aVariable.getByIndices('aIndices')};
+ }
+ return value;
+ }
+
+ fn mm_readB(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${
+ typeSnippet(component)} {
+ var value = ${typeSnippet(component)}(0.0);
+ let col = colIn * ${component};
+ if(row < dimInner && col < dimBOuter)
+ {
+ ${getBIndices()}
+ value = ${bVariable.getByIndices('bIndices')};
+ }
+ return value;
+ }
+
+ fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component)}) {
+ let col = colIn * ${component};
+ if (row < dimAOuter && col < dimBOuter) {
+ var value = valueIn;
+ let coords = vec3(batch, row, colIn);
+ ${hasBias ? 'value = value + bias[colIn];' : ''}
+ ${applyActivation}
+ ${outputVariable.setByIndices('vec3(coords)', 'value')}
+ }
+ }
+ `;
+ return source;
+ };
+
+export const createMatmulProgramInfo =
+ (metadata: ProgramMetadata, inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes,
+ outputShape: readonly number[]): ProgramInfo => {
+ const aShape = inputs[0].dims;
+ const bShape = inputs[1].dims;
+
+ const outerDimsA = aShape.slice(0, -2);
+ const outerDimsB = bShape.slice(0, -2);
+ const outerDims = outputShape.slice(0, -2);
+ const batchDims = inputVariable('batchDims', inputs[0].dataType, outerDims);
+ const batchADims = inputVariable('batchADims', inputs[0].dataType, outerDimsA);
+ const batchBDims = inputVariable('batchBDims', inputs[0].dataType, outerDimsB);
+ const variables = [batchADims, batchBDims, batchDims];
+ const batchSize = ShapeUtil.size(outerDims);
+
+ const dimAOuter = outputShape[outputShape.length - 2];
+ const dimInner = aShape[aShape.length - 1];
+ const dimBOuter = outputShape[outputShape.length - 1];
+ const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0;
+ const component = isVec4 ? 4 : 1;
+ const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes);
+
+ // TODO: fine tune size
+ const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1];
+ const workgroupSize: [number, number, number] = [8, 8, 1];
+ const dispatch = [
+ Math.ceil(dimBOuter / workgroupSize[0] / elementsPerThread[0]),
+ Math.ceil(dimAOuter / workgroupSize[1] / elementsPerThread[1]),
+ Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2])
+ ];
+
+ const components = isVec4 ? 4 : 1;
+ const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components);
+ const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components);
+ const output =
+ outputVariable('result', inputs[0].dataType, [batchSize, dimAOuter, dimBOuter / components], components);
+ variables.push(A);
+ variables.push(B);
+ variables.push(output);
+ const inputVariables = [A, B];
+ const hasBias = inputs.length > 2;
+ const declareFunctions = matMulReadWriteFnSource(component, hasBias, applyActivation, variables);
+ if (hasBias) {
+ inputVariables.push(inputVariable('bias', inputs[2].dataType, [dimBOuter / components], components));
+ }
+ const getShaderSource = (shaderHelper: ShaderHelper) => `
+ const dimAOuter: i32 = ${dimAOuter};
+ const dimBOuter: i32 = ${dimBOuter};
+ const dimInner: i32 = ${dimInner};
+ ${shaderHelper.declareVariables(...inputVariables, output)}
+ ${declareFunctions}
+ ${activationFunction}
+ ${
+ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, batchDims) :
+ makeMatMulPackedSource(elementsPerThread, workgroupSize, batchDims)}
+ ${batchDims.impl()}`;
+ return {
+ ...metadata,
+ outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
+ getShaderSource,
+ dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]})
+ };
+ };
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
index 75c37b3ed09e..c96f4858db2a 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
@@ -625,3 +625,27 @@ class ShaderHelperImpl implements ShaderHelper {
export const createShaderHelper = (dispatchGroup: [number, number, number]): ShaderHelper =>
new ShaderHelperImpl(dispatchGroup);
+
+/**
+ * This function comes from https://github.com/tensorflow/tfjs/blob/master/tfjs-core/src/ops/broadcast_util.ts#L18-L40
+ * Returns the dimensions in the input shape that are broadcasted to
+ * produce the provided output shape.
+ *
+ * The returned dimensions are 0-indexed and sorted. An example:
+ * inShape = [4, 1, 3]
+ * outShape = [5, 4, 3, 3]
+ * result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3.
+ */
+export const getBroadcastDims = (inShape: readonly number[], outShape: readonly number[]): number[] => {
+ const inRank = inShape.length;
+ const dims: number[] = [];
+ for (let i = 0; i < inRank; i++) {
+ const dim = inRank - 1 - i;
+ const a = inShape[dim] || 1;
+ const b = outShape[outShape.length - 1 - i] || 1;
+ if (b > 1 && a === 1) {
+ dims.unshift(dim);
+ }
+ }
+ return dims;
+};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
index f01e6e0d97ee..afac503290c4 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
@@ -10,6 +10,7 @@ import {ComputeContext} from '../types';
import {createGroupedConvProgramInfoLoader} from './conv-grouped';
import {createConv2DMatMulProgramInfoLoader} from './conv2d-mm';
import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils';
+import {createMatmulProgramInfoLoader} from './matmul';
import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose';
export const calculateOutputShape =
@@ -160,16 +161,39 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
const outHeight = outputShape[isChannelsLast ? 1 : 2];
const outWidth = outputShape[isChannelsLast ? 2 : 3];
const outChannels = outputShape[isChannelsLast ? 3 : 1];
+ const batch = outputShape[0];
const sameSize =
isChannelsLast && weightHeight === inputHeight && weightWidth === inputWidth && attributes.autoPad === 'VALID';
if (sameSize ||
(weightHeight === 1 && weightWidth === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1 &&
- attributes.strides[0] === 1 && attributes.strides[1] === 1 &&
- (attributes.autoPad === 'SAME_UPPER' || attributes.autoPad === 'SAME_LOWER' ||
- attributes.autoPad === 'VALID'))) {
- // TODO: implement conv2dByMatMul()
- context.compute(createGroupedConvProgramInfoLoader(inputs, adjustedAttributes));
+ attributes.strides[0] === 1 && attributes.strides[1] === 1 && attributes.pads[0] === 0 &&
+ attributes.pads[1] === 0)) {
+ if (isChannelsLast && attributes.group === 1) {
+ // conv2dByMatMul
+ const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
+ context.compute(
+ {
+ ...transposeProgramMetadata,
+ cacheHint: weightTransposeAttribute.cacheKey,
+ get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm)
+ },
+ {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
+ if (attributes.wIsConst && !context.kernelCustomData.wT) {
+ context.kernelCustomData.wT = transposedWeight;
+ }
+
+ const matmulInputs = [];
+ matmulInputs.push(inputs[0].reshape([batch, inputHeight * inputWidth, inputChannels]));
+ matmulInputs.push(transposedWeight.reshape([1, inputChannels, outChannels]));
+ if (hasBias) {
+ matmulInputs.push(inputs[2]);
+ }
+ context.compute(
+ createMatmulProgramInfoLoader(matmulInputs, adjustedAttributes, outputShape), {inputs: matmulInputs});
+ } else {
+ context.compute(createGroupedConvProgramInfoLoader(inputs, adjustedAttributes));
+ }
return;
}
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts
new file mode 100644
index 000000000000..57c5fccfd8c2
--- /dev/null
+++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts
@@ -0,0 +1,110 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+import {TensorView} from '../../tensor';
+import {ShapeUtil} from '../../util';
+import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
+import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
+
+import {inputVariable, outputVariable, ShaderHelper} from './common';
+
+export interface GatherElementsAttributes extends AttributeWithCacheKey {
+ axis: number;
+}
+
+const validateInputs = (inputs: readonly TensorView[]): void => {
+ if (!inputs || inputs.length !== 2) {
+ throw new Error('GatherElements requires 2 inputs.');
+ }
+
+ if (inputs[0].dims.length < 1) {
+ throw new Error('GatherElements requires that the data input be rank >= 1.');
+ }
+
+ if (inputs[0].dims.length !== inputs[1].dims.length) {
+ throw new Error(`GatherElements requires that the data input and
+ indices input tensors be of same rank.`);
+ }
+};
+
+const createGatherElementsProgramInfo =
+ (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: GatherElementsAttributes): ProgramInfo => {
+ const inputShape = inputs[0].dims;
+ const inputOutputDataType = inputs[0].dataType;
+ const inputRank = inputShape.length;
+ const inputStrides = ShapeUtil.computeStrides(inputShape);
+ const inputSize = ShapeUtil.size(inputShape);
+
+ const indicesShape = inputs[1].dims;
+ const indicesDataType = inputs[1].dataType;
+ const indicesSize = ShapeUtil.size(indicesShape);
+
+ const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank);
+ const axisDimLimit = inputShape[axis];
+
+ const outputShape = indicesShape.slice(0);
+ const outputSize = ShapeUtil.size(outputShape);
+
+ const input = inputVariable('input', inputOutputDataType, inputShape);
+ const indices = inputVariable('indices', indicesDataType, [indicesSize]);
+ const output = outputVariable('output', inputOutputDataType, outputShape);
+
+
+ // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits
+ // That assumption is safe as it's not possible to allocate >2gb buffer for input tensor
+ // Input data will be treated as u32 or two u32 for 8-byte tensors
+ const getShaderSource = (shaderHelper: ShaderHelper) => `
+ const inputStrides = array(${inputStrides.map(i => `${i}u`).join(',')});
+ ${shaderHelper.declareVariables(input, indices, output)}
+ ${shaderHelper.mainStart()}
+ ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
+
+ let outputIndices = ${output.offsetToIndices('global_idx')};
+
+ var idx = ${indices.getByOffset('global_idx')};
+ if (idx < 0) {
+ idx = idx + ${axisDimLimit};
+ }
+
+ var srcOffset = u32(0);
+
+ for (var i = 0; i < ${inputShape.length}; i++) {
+ if (i == ${axis}) {
+ srcOffset += u32(idx) * inputStrides[i];
+ } else {
+ srcOffset += ${output.indicesGet('outputIndices', 'i')} * inputStrides[i];
+ }
+ }
+
+ // Should never hit this with valid values in indices
+ // This is a guard against malicious data in the indices input
+ if (srcOffset < 0 || srcOffset >= ${inputSize}) {
+ return;
+ }
+
+ output[global_idx] = input[srcOffset];
+ }`;
+
+ return {
+ ...metadata,
+ outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
+ getShaderSource,
+ dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
+ };
+ };
+
+export const parseGatherElementsAttributes = (attributes: Record): GatherElementsAttributes =>
+ createAttributeWithCacheKey({axis: attributes.axis as number});
+
+export const gatherElements = (context: ComputeContext, attributes: GatherElementsAttributes): void => {
+ const inputs = context.inputs;
+ validateInputs(inputs);
+
+ const metadata = {
+ name: 'GatherElements',
+ inputTypes: [GpuDataType.default, GpuDataType.default],
+ cacheHint: attributes.cacheKey,
+ };
+
+ context.compute(createGatherElementsProgramInfo(metadata, context.inputs, attributes));
+};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
index 75191be3cf1e..2d5750c3e2a8 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
@@ -3,11 +3,11 @@
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
-import {BroadcastUtil, ShapeUtil} from '../../util';
-import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
+import {BroadcastUtil} from '../../util';
+import {ComputeContext, GpuDataType, ProgramInfoLoader} from '../types';
-import {ShaderHelper} from './common';
-import {getActicationSnippet, InternalActivationAttributes} from './fuse-utils';
+import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu';
+import {InternalActivationAttributes} from './fuse-utils';
const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({
@@ -17,66 +17,12 @@ const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({
cacheHint
});
-const createMatmulProgramInfo =
- (metadata: ProgramMetadata, inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes):
- ProgramInfo => {
- const aShape = inputs[0].dims;
- const bShape = inputs[1].dims;
- const outputShape = BroadcastUtil.calcShape(aShape, bShape, true);
- if (!outputShape) {
- throw new Error('Can\'t use matmul on the given tensors');
- }
- const outputSize = ShapeUtil.size(outputShape);
- // TODO: support broadcasting
-
- const dataType = 'f32'; // TODO: support other data type
- const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes);
-
- const M = outputShape[outputShape.length - 2];
- const K = aShape[aShape.length - 1];
- const N = outputShape[outputShape.length - 1];
- const getShaderSource = (shaderHelper: ShaderHelper) => `
- const M: u32 = ${M}u;
- const N: u32 = ${N}u;
- const K: u32 = ${K}u;
-
- @group(0) @binding(0) var a : array<${dataType}>;
- @group(0) @binding(1) var b : array<${dataType}>;
- @group(0) @binding(2) var output : array<${dataType}>;
-
- ${activationFunction}
-
- ${shaderHelper.mainStart()}
- ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
-
- let stack = global_idx / (M * N);
- let mn = global_idx % (M * N);
- let n = global_idx % N;
- let m = mn / N;
-
- let offsetA = stack * (M * K);
- let offsetB = stack * (K * N);
-
- var value = ${dataType}(0);
- for (var k: u32 = 0u; k<${K}u; k++) {
- value += a[offsetA + m * K + k] * b[offsetB + k * N + n];
- }
- ${applyActivation}
- output[global_idx] = value;
- }`;
- return {
- ...metadata,
- outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
- getShaderSource,
- dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
- };
- };
-
export const createMatmulProgramInfoLoader =
- (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes): ProgramInfoLoader => {
- const metadata = createMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey);
- return {...metadata, get: () => createMatmulProgramInfo(metadata, inputs, activationAttributes)};
- };
+ (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[]):
+ ProgramInfoLoader => {
+ const metadata = createMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey);
+ return {...metadata, get: () => createMatmulProgramInfo(metadata, inputs, activationAttributes, outputShape)};
+ };
const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 2) {
@@ -94,6 +40,9 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
export const matMul = (context: ComputeContext): void => {
validateInputs(context.inputs);
-
- context.compute(createMatmulProgramInfoLoader(context.inputs, {activation: '', activationCacheKey: ''}));
+ const outputShape = BroadcastUtil.calcShape(context.inputs[0].dims, context.inputs[1].dims, true);
+ if (!outputShape) {
+ throw new Error('Can\'t use matmul on the given tensors');
+ }
+ context.compute(createMatmulProgramInfoLoader(context.inputs, {activation: '', activationCacheKey: ''}, outputShape));
};
diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts
index f2f44b795abe..e34529fa1037 100644
--- a/js/web/script/test-runner-cli-args.ts
+++ b/js/web/script/test-runner-cli-args.ts
@@ -295,7 +295,7 @@ function parseWebglOptions(_args: minimist.ParsedArgs): InferenceSession.WebGLEx
return {name: 'webgl'};
}
-function parseWebglFlags(args: minimist.ParsedArgs): Env.WebGLFlags {
+function parseWebglFlags(args: minimist.ParsedArgs): Partial {
const contextId = args['webgl-context-id'];
if (contextId !== undefined && contextId !== 'webgl' && contextId !== 'webgl2') {
throw new Error('Flag "webgl-context-id" is invalid');
@@ -319,7 +319,7 @@ function parseWebglFlags(args: minimist.ParsedArgs): Env.WebGLFlags {
return {contextId, matmulMaxBatchSize, textureCacheMode, pack};
}
-function parseWebgpuFlags(args: minimist.ParsedArgs): Env.WebGpuFlags {
+function parseWebgpuFlags(args: minimist.ParsedArgs): Partial {
const profilingMode = args['webgpu-profiling-mode'];
if (profilingMode !== undefined && profilingMode !== 'off' && profilingMode !== 'default') {
throw new Error('Flag "webgpu-profiling-mode" is invalid');
diff --git a/js/web/test/data/ops/gather-elements.jsonc b/js/web/test/data/ops/gather-elements.jsonc
new file mode 100644
index 000000000000..caab3c11f64d
--- /dev/null
+++ b/js/web/test/data/ops/gather-elements.jsonc
@@ -0,0 +1,234 @@
+[
+ {
+ "name": "GatherElements float32 data + int32 indices-1",
+ "operator": "GatherElements",
+ "attributes": [{ "name": "axis", "data": 1, "type": "int" }],
+ "cases": [
+ {
+ "name": "float32 data + int32 indices-1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [0, 0, 1, 0],
+ "dims": [2, 2],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 1, 4, 3],
+ "dims": [2, 2],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "GatherElements float32 data + int32 indices-2",
+ "operator": "GatherElements",
+ "attributes": [{ "name": "axis", "data": 1, "type": "int" }],
+ "cases": [
+ {
+ "name": "float32 data + int32 indices-2",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [0, 1, 1, 0],
+ "dims": [2, 2],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 2, 4, 3],
+ "dims": [2, 2],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "GatherElements float32 data + int64 indices - 1",
+ "operator": "GatherElements",
+ "attributes": [{ "name": "axis", "data": 1, "type": "int" }],
+ "cases": [
+ {
+ "name": "float32 data + int64 indices - 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [0, 0, -1, 0],
+ "dims": [2, 2],
+ "type": "int64"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 1, 4, 3],
+ "dims": [2, 2],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "GatherElements float32 data + int64 indices - 2",
+ "operator": "GatherElements",
+ "attributes": [{ "name": "axis", "data": 1, "type": "int" }],
+ "cases": [
+ {
+ "name": "float32 data + int64 indices - 2",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [0, 0, -2, 0],
+ "dims": [2, 2],
+ "type": "int64"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 1, 3, 3],
+ "dims": [2, 2],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "GatherElements int32 data + int32 indices-1",
+ "operator": "GatherElements",
+ "attributes": [{ "name": "axis", "data": 1, "type": "int" }],
+ "cases": [
+ {
+ "name": "int32 data + int32 indices-1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [2, 2],
+ "type": "int32"
+ },
+ {
+ "data": [0, 0, 1, 0],
+ "dims": [2, 2],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 1, 4, 3],
+ "dims": [2, 2],
+ "type": "int32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "GatherElements uint32 data + int32 indices-1",
+ "operator": "GatherElements",
+ "attributes": [{ "name": "axis", "data": 1, "type": "int" }],
+ "cases": [
+ {
+ "name": "uint32 data + int32 indices-1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [2, 2],
+ "type": "uint32"
+ },
+ {
+ "data": [0, 0, 1, 0],
+ "dims": [2, 2],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 1, 4, 3],
+ "dims": [2, 2],
+ "type": "uint32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "GatherElements float32 data + int32 indices-1 + Negative axis + Negative indices",
+ "operator": "GatherElements",
+ "attributes": [{ "name": "axis", "data": -1, "type": "int" }],
+ "cases": [
+ {
+ "name": "GatherElements float32 data + int32 indices-1 + Negative axis + Negative indices",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [0, 0, -1, 0],
+ "dims": [2, 2],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 1, 4, 3],
+ "dims": [2, 2],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "GatherElements float32 data + int32 indices-3",
+ "operator": "GatherElements",
+ "attributes": [{ "name": "axis", "data": 0, "type": "int" }],
+ "cases": [
+ {
+ "name": "GatherElements float32 data + int32 indices-3",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
+ "dims": [3, 3],
+ "type": "float32"
+ },
+ {
+ "data": [1, 2, 0, 2, 0, 0],
+ "dims": [2, 3],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [4, 8, 3, 7, 2, 3],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ }
+]
diff --git a/js/web/test/data/ops/matmul.jsonc b/js/web/test/data/ops/matmul.jsonc
index 6b3d93f019bd..2c2cf509d7e3 100644
--- a/js/web/test/data/ops/matmul.jsonc
+++ b/js/web/test/data/ops/matmul.jsonc
@@ -246,6 +246,73 @@
"type": "float32"
}
]
+ },
+ {
+ "name": "multiplies 2D with 4D tensors vec4",
+ "inputs": [
+ {
+ "data": [1, 2, 1, 3, 2, 3, 1, 2],
+ "dims": [2, 4],
+ "type": "float32"
+ },
+ {
+ "data": [
+ 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74,
+ 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100,
+ 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,
+ 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
+ 30, 31
+ ],
+ "dims": [3, 2, 4, 4],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [
+ 395, 402, 409, 416, 436, 444, 452, 460, 507, 514, 521, 528, 564, 572, 580, 588, 619, 626, 633, 640, 692,
+ 700, 708, 716, 731, 738, 745, 752, 820, 828, 836, 844, 843, 850, 857, 864, 948, 956, 964, 972, 955, 962,
+ 630, 637, 1076, 1084, 866, 874
+ ],
+ "dims": [3, 2, 2, 4],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "multiplies 5D with 3D tensors vec4",
+ "inputs": [
+ {
+ "data": [
+ 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74,
+ 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100,
+ 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,
+ 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
+ 30, 31
+ ],
+ "dims": [3, 1, 2, 4, 4],
+ "type": "float32"
+ },
+ {
+ "data": [1, 2, 1, 3, 2, 3, 1, 2, 1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [1, 4, 4],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [
+ 460, 662, 616, 867, 496, 714, 664, 935, 532, 766, 712, 1003, 568, 818, 760, 1071, 604, 870, 808, 1139,
+ 640, 922, 856, 1207, 676, 974, 904, 1275, 712, 1026, 952, 1343, 748, 1078, 1000, 1411, 784, 1130, 1048,
+ 1479, 820, 1182, 1096, 1547, 856, 1234, 1144, 1615, 892, 1286, 1192, 1683, 928, 1338, 1240, 1751, 964,
+ 1390, 1288, 1819, 1000, 1442, 1336, 1887, 1036, 1494, 1384, 1955, 1072, 1546, 1432, 2023, 1108, 1598,
+ 1480, 2091, 1144, 1650, 1528, 2159, 1180, 1702, 1576, 2227, 1216, 1754, 1624, 2295, 1252, 1806, 1672,
+ 2363, 610, 954, 590, 1075
+ ],
+ "dims": [3, 1, 2, 4, 4],
+ "type": "float32"
+ }
+ ]
}
]
}
diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc
index e0b0207c9fe7..ace53701455f 100644
--- a/js/web/test/suite-test-list.jsonc
+++ b/js/web/test/suite-test-list.jsonc
@@ -539,9 +539,9 @@
"test_gather_1",
"test_gather_2d_indices",
"test_gather_negative_indices",
- // "test_gather_elements_0",
- // "test_gather_elements_1",
- // "test_gather_elements_negative_indices",
+ "test_gather_elements_0",
+ "test_gather_elements_1",
+ "test_gather_elements_negative_indices",
// "test_gather_negative_indices",
// // "test_gathernd_example_float32",
// // "test_gathernd_example_int32_batch_dim1",
@@ -1339,12 +1339,13 @@
"exp.jsonc",
"expand.jsonc",
"floor.jsonc",
+ "gather-elements.jsonc",
"gemm.jsonc",
"global-average-pool.jsonc",
"greater.jsonc",
"less.jsonc",
"log.jsonc",
- //"matmul.jsonc", // <--- some tests fail (when input is 3D/4D/5D)
+ "matmul.jsonc",
"mul.jsonc",
"mul_int32.jsonc",
//"neg.jsonc",
diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts
index db01082b9f9b..1f95d1cd8e68 100644
--- a/js/web/test/test-types.ts
+++ b/js/web/test/test-types.ts
@@ -110,6 +110,12 @@ export declare namespace Test {
[backend: string]: {[group: string]: readonly TestList.Test[]};
}
+ interface EnvOptions extends Partial> {
+ wasm: Partial;
+ webgl: Partial;
+ webgpu: Partial;
+ }
+
/**
* Represent ONNX Runtime Web global options
*/
@@ -122,7 +128,7 @@ export declare namespace Test {
cudaFlags?: Record;
wasmOptions?: InferenceSession.WebAssemblyExecutionProviderOption;
webglOptions?: InferenceSession.WebGLExecutionProviderOption;
- globalEnvFlags?: Partial;
+ globalEnvFlags?: EnvOptions;
}
/**
diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu
index dcbc733f2acb..5ac10f6321e6 100644
--- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu
+++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu
@@ -139,10 +139,7 @@ __launch_bounds__(thread_block_size) __global__ void BeamSearchOnlineTopKStage2K
input_tokens += vector_id * k * parts_per_beam;
TopK thread_topk;
- for (int i = 0; i < max_k; ++i) {
- thread_topk.key[i] = -1;
- thread_topk.value[i] = NumericLimits::Min();
- }
+ thread_topk.Init();
for (int idx = thread_id; idx < k * parts_per_beam; idx += thread_block_size) {
value_shared_buf[idx] = input_values[idx];
diff --git a/onnxruntime/core/common/logging/sinks/ostream_sink.cc b/onnxruntime/core/common/logging/sinks/ostream_sink.cc
index 3b832c9d63c1..0db3d8709d48 100644
--- a/onnxruntime/core/common/logging/sinks/ostream_sink.cc
+++ b/onnxruntime/core/common/logging/sinks/ostream_sink.cc
@@ -46,7 +46,7 @@ void OStreamSink::SendImpl(const Timestamp& timestamp, const std::string& logger
#endif
msg << timestamp << " [" << message.SeverityPrefix() << ":" << message.Category() << ":" << logger_id << ", "
- << message.Location().ToString() << "] " << message.Message() << "\n";
+ << message.Location().ToString() << "] " << message.Message();
#ifndef ORT_MINIMAL_BUILD
if (message.Severity() == Severity::kWARNING ||
@@ -55,6 +55,7 @@ void OStreamSink::SendImpl(const Timestamp& timestamp, const std::string& logger
msg << Color::kEnd;
}
#endif
+ msg << "\n";
(*stream_) << msg.str();
@@ -87,7 +88,7 @@ void WOStreamSink::SendImpl(const Timestamp& timestamp, const std::string& logge
#endif
msg << timestamp << L" [" << message.SeverityPrefix() << L":" << message.Category() << L":" << ToWideString(logger_id) << L", "
- << ToWideString(message.Location().ToString()) << L"] " << ToWideString(message.Message()) << L"\n";
+ << ToWideString(message.Location().ToString()) << L"] " << ToWideString(message.Message());
#ifndef ORT_MINIMAL_BUILD
if (message.Severity() == Severity::kWARNING ||
@@ -96,6 +97,7 @@ void WOStreamSink::SendImpl(const Timestamp& timestamp, const std::string& logge
msg << Color::kLEnd;
}
#endif
+ msg << L"\n";
(*stream_) << msg.str();
diff --git a/onnxruntime/core/flatbuffers/flatbuffers_utils.h b/onnxruntime/core/flatbuffers/flatbuffers_utils.h
index 4e7db4df9ae2..55bde0b2df80 100644
--- a/onnxruntime/core/flatbuffers/flatbuffers_utils.h
+++ b/onnxruntime/core/flatbuffers/flatbuffers_utils.h
@@ -5,6 +5,8 @@
#include
+#include "flatbuffers/flatbuffers.h"
+
#include "core/common/common.h"
#include "core/common/path_string.h"
#include "core/common/status.h"
@@ -13,18 +15,6 @@ namespace ONNX_NAMESPACE {
class ValueInfoProto;
}
-namespace flatbuffers {
-class FlatBufferBuilder;
-
-template
-struct Offset;
-
-struct String;
-
-template
-class Vector;
-} // namespace flatbuffers
-
namespace onnxruntime {
namespace fbs {
diff --git a/onnxruntime/core/framework/kernel_type_str_resolver.h b/onnxruntime/core/framework/kernel_type_str_resolver.h
index 75fc2fa894f8..31a806dd5229 100644
--- a/onnxruntime/core/framework/kernel_type_str_resolver.h
+++ b/onnxruntime/core/framework/kernel_type_str_resolver.h
@@ -7,6 +7,8 @@
#include
#include
+#include "flatbuffers/flatbuffers.h"
+
#if !defined(ORT_MINIMAL_BUILD)
#include "core/graph/onnx_protobuf.h"
#endif // !defined(ORT_MINIMAL_BUILD)
@@ -18,12 +20,6 @@
#include "core/graph/graph.h"
#include "core/platform/ort_mutex.h"
-namespace flatbuffers {
-class FlatBufferBuilder;
-template
-struct Offset;
-} // namespace flatbuffers
-
namespace onnxruntime {
namespace fbs {
diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h
index d546f264a9d5..51bb02918d82 100644
--- a/onnxruntime/core/framework/session_state.h
+++ b/onnxruntime/core/framework/session_state.h
@@ -8,6 +8,8 @@
#include
#include
+#include "flatbuffers/flatbuffers.h"
+
#include "core/common/gsl.h"
#include "core/common/common.h"
@@ -43,12 +45,6 @@
#include "core/framework/program_region.h"
#endif
-namespace flatbuffers {
-class FlatBufferBuilder;
-template
-struct Offset;
-} // namespace flatbuffers
-
namespace onnxruntime {
namespace fbs {
diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc
index 5a42f5d34b93..08ed811d9ac3 100644
--- a/onnxruntime/core/framework/tensorprotoutils.cc
+++ b/onnxruntime/core/framework/tensorprotoutils.cc
@@ -1492,7 +1492,7 @@ Status UnpackInitializerData(const onnx::TensorProto& initializer,
if (initializer.data_location() == TensorProto_DataLocation_EXTERNAL) {
ORT_RETURN_IF_ERROR(ReadExternalDataForTensor(
initializer,
- model_path.IsEmpty() ? nullptr : model_path.ParentPath().ToPathString().c_str(),
+ (model_path.IsEmpty() || model_path.ParentPath().IsEmpty()) ? nullptr : model_path.ParentPath().ToPathString().c_str(),
unpacked_tensor));
return Status::OK();
}
diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.h b/onnxruntime/core/graph/graph_flatbuffers_utils.h
index f4899ffc1281..b625cbf3ca49 100644
--- a/onnxruntime/core/graph/graph_flatbuffers_utils.h
+++ b/onnxruntime/core/graph/graph_flatbuffers_utils.h
@@ -5,6 +5,8 @@
#include
+#include "flatbuffers/flatbuffers.h"
+
#include "core/common/status.h"
#include "core/graph/ort_format_load_options.h"
#include "core/framework/tensor.h"
@@ -18,12 +20,6 @@ class SparseTensorProto;
#endif // !defined(DISABLE_SPARSE_TENSORS)
} // namespace ONNX_NAMESPACE
-namespace flatbuffers {
-class FlatBufferBuilder;
-template
-struct Offset;
-} // namespace flatbuffers
-
namespace onnxruntime {
class Graph;
diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h
index 5337211ae79d..7e3942b02925 100644
--- a/onnxruntime/core/graph/model.h
+++ b/onnxruntime/core/graph/model.h
@@ -7,6 +7,9 @@
#include
#include
#include
+
+#include "flatbuffers/flatbuffers.h"
+
#include "core/common/path.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/ort_format_load_options.h"
@@ -15,12 +18,6 @@
#include "core/graph/function_template.h"
#endif
-namespace flatbuffers {
-class FlatBufferBuilder;
-template
-struct Offset;
-} // namespace flatbuffers
-
namespace onnxruntime {
namespace fbs {
diff --git a/onnxruntime/core/graph/op_identifier_utils.h b/onnxruntime/core/graph/op_identifier_utils.h
index 265364a88d3e..8a9351a2d0dd 100644
--- a/onnxruntime/core/graph/op_identifier_utils.h
+++ b/onnxruntime/core/graph/op_identifier_utils.h
@@ -3,21 +3,14 @@
#pragma once
+#include "flatbuffers/flatbuffers.h"
+
#include "core/graph/op_identifier.h"
#include "core/common/status.h"
#include "core/graph/graph.h"
#include "core/graph/onnx_protobuf.h"
-namespace flatbuffers {
-class FlatBufferBuilder;
-
-template
-struct Offset;
-
-struct String;
-} // namespace flatbuffers
-
namespace onnxruntime {
namespace fbs::utils {
diff --git a/onnxruntime/core/graph/runtime_optimization_record_container.h b/onnxruntime/core/graph/runtime_optimization_record_container.h
index 5db784f1a27a..a28b19e786de 100644
--- a/onnxruntime/core/graph/runtime_optimization_record_container.h
+++ b/onnxruntime/core/graph/runtime_optimization_record_container.h
@@ -9,17 +9,11 @@
#include
#include
+#include "flatbuffers/flatbuffers.h"
+
#include "core/common/common.h"
#include "core/graph/runtime_optimization_record.h"
-namespace flatbuffers {
-class FlatBufferBuilder;
-template
-struct Offset;
-template
-class Vector;
-} // namespace flatbuffers
-
namespace onnxruntime {
namespace fbs {
diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h
index 9a1e327c6185..f517be185b3f 100644
--- a/onnxruntime/core/mlas/lib/mlasi.h
+++ b/onnxruntime/core/mlas/lib/mlasi.h
@@ -51,7 +51,14 @@ Module Name:
#endif
#if defined(__x86_64__) || defined(__i386__)
#include
+#if defined(__GNUC__) && __GNUC__ >= 12
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // GCC 12 warns about uninitialized variables in immintrin.h.
#include
+#pragma GCC diagnostic pop
+#else
+#include
+#endif
#endif
#if defined(__VSX__)
#include
diff --git a/onnxruntime/core/mlas/lib/q4_dq_cli.cpp b/onnxruntime/core/mlas/lib/q4_dq_cli.cpp
index b994f171c67d..9c330b9eaf12 100644
--- a/onnxruntime/core/mlas/lib/q4_dq_cli.cpp
+++ b/onnxruntime/core/mlas/lib/q4_dq_cli.cpp
@@ -218,13 +218,21 @@ quantize(const Cli& cli)
} else {
buf = std::cout.rdbuf();
}
+#if defined(__GNUC__) && __GNUC__ >= 12
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored \
+ "-Wdangling-pointer" // TODO: suppress warning about dangling pointer until we have a fix
std::ostream stream(buf);
+#pragma GCC diagnostic pop
+#else
+ std::ostream stream(buf);
+#endif
+
writeUint8Txt(stream, dstbuf.data(), dstbuf.size());
}
return 0;
}
-
int
dequantize(const Cli& cli)
{
@@ -254,13 +262,14 @@ dequantize(const Cli& cli)
out.write((const char*)dstbuf.data(), std::streamsize(dstbuf.size()) * sizeof(float));
} else {
std::streambuf* buf;
+ std::ofstream file_output_stream;
if (cli.output_file) {
- std::ofstream out(cli.output_file, std::ios::out);
- if (!out) {
+ file_output_stream.open(cli.output_file, std::ios::out);
+ if (file_output_stream.fail()) {
std::cerr << "Cannot open output file " << cli.output_file << std::endl;
return -1;
}
- buf = out.rdbuf();
+ buf = file_output_stream.rdbuf();
} else {
buf = std::cout.rdbuf();
}
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc
index eed7ef506b49..f725bc40e542 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc
@@ -47,6 +47,7 @@ static const OpVersionsAndSelector::OpVersionsMap GetDropDQOpVersionsMap() {
static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() {
return {{"AveragePool", {}},
{"GlobalAveragePool", {}},
+ {"GlobalMaxPool", {}},
{"LeakyRelu", {}},
{"ReduceMean", {}},
{"ReduceMin", {}},
@@ -79,7 +80,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() {
{"Div", {}},
{"Mul", {}},
{"Pow", {}},
- {"Sub", {}}};
+ {"Sub", {}},
+ {"GridSample", {}}};
}
static const OpVersionsAndSelector::OpVersionsMap GetVariadicOpVersionsMap() {
return {{"Concat", {}}};
diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc
index 0de7dccd2a5f..ce834e371fde 100644
--- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc
+++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc
@@ -890,49 +890,49 @@ Status ReduceL1::Compute(OpKernelContext* ctx) const {
// The following variable does not change if the input tensor and the
// axes do not either. It could be either cached in ctx or precomputed
// in the constructor if shape and axes are known at this stage.
- CommonReduce1Loop>(ctx, axes_, keepdims_);
+ CommonReduce1Loop>(ctx, axes_, keepdims_, noop_with_empty_axes_);
return Status::OK();
}
template
Status ReduceL2::Compute(OpKernelContext* ctx) const {
- CommonReduce1Loop>(ctx, axes_, keepdims_);
+ CommonReduce1Loop>(ctx, axes_, keepdims_, noop_with_empty_axes_);
return Status::OK();
}
template
Status ReduceLogSum::Compute(OpKernelContext* ctx) const {
- CommonReduce1Loop>(ctx, axes_, keepdims_);
+ CommonReduce1Loop>(ctx, axes_, keepdims_, noop_with_empty_axes_);
return Status::OK();
}
template
Status ReduceLogSumExp::Compute(OpKernelContext* ctx) const {
- CommonReduce2Loops>(ctx, axes_, keepdims_);
+ CommonReduce2Loops>(ctx, axes_, keepdims_, noop_with_empty_axes_);
return Status::OK();
}
template
Status ReduceMax::Compute(OpKernelContext* ctx) const {
- CommonReduce1Loop>(ctx, axes_, keepdims_);
+ CommonReduce1Loop>(ctx, axes_, keepdims_, noop_with_empty_axes_);
return Status::OK();
}
template
Status ReduceMean::Compute(OpKernelContext* ctx) const {
- CommonReduce1Loop>(ctx, axes_, keepdims_);
+ CommonReduce1Loop>(ctx, axes_, keepdims_, noop_with_empty_axes_);
return Status::OK();
}
template
Status ReduceMin::Compute(OpKernelContext* ctx) const {
- CommonReduce1Loop>(ctx, axes_, keepdims_);
+ CommonReduce1Loop>(ctx, axes_, keepdims_, noop_with_empty_axes_);
return Status::OK();
}
template
Status ReduceProd::Compute(OpKernelContext* ctx) const {
- CommonReduce1Loop>(ctx, axes_, keepdims_);
+ CommonReduce1Loop>(ctx, axes_, keepdims_, noop_with_empty_axes_);
return Status::OK();
}
@@ -1017,7 +1017,7 @@ std::unique_ptr ReduceSum::Impl(const Tensor& input, gsl::span
Status ReduceSumSquare::Compute(OpKernelContext* ctx) const {
- CommonReduce1Loop>(ctx, axes_, keepdims_);
+ CommonReduce1Loop>(ctx, axes_, keepdims_, noop_with_empty_axes_);
return Status::OK();
}
diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc
index f87788e8f477..8844b7e7a26c 100644
--- a/onnxruntime/core/providers/cpu/tensor/scatter.cc
+++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc
@@ -308,7 +308,7 @@ Status ScatterData(
const auto& upd_shape = updates_input->Shape();
const auto num_dims = input_data_shape.NumDimensions();
- assert(num_dims > 0);
+ ORT_RETURN_IF_NOT(num_dims > 0, "ScatterElements op: input tensor must have at least one dimension");
// Allocate and zero out counts. The input/output is of the same rank as
// indices/updates but the actual dimensions of indices/updates must be less or equal
diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh
index a50b53315ec9..0d9928baa86e 100644
--- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh
+++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh
@@ -20,7 +20,7 @@ namespace cuda {
// float16 arithmetic is supported after sm5.3 with intrinsics, and cuda does not provide fallback for lower versions
// CUDA 12.2 does not limit the definition based on sm53 anymore and defines for all arches
-#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12 ) && (__CUDACC_VER_MINOR__ < 2)))
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
__device__ __forceinline__ half operator+(const half& lh, const half& rh) { return half((float)lh + (float)rh); }
__device__ __forceinline__ half operator-(const half& lh, const half& rh) { return half((float)lh - (float)rh); }
__device__ __forceinline__ half operator*(const half& lh, const half& rh) { return half((float)lh * (float)rh); }
@@ -351,6 +351,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }
template
__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }
+template
+__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; }
+
+template
+__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); }
+
+template
+__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed()); }
+
+template <>
+__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); }
+
template
__device__ __inline__ T _Normcdf(T a);
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index aa60db4d0722..ad892eab3b84 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1180,6 +1180,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int16_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint16_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub);
@@ -2118,6 +2129,17 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
index f026444328b2..9ede1f8d90ec 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
@@ -157,6 +157,7 @@ UNARY_OP_HFD(Sqrt, 13)
UNARY_OP_HFD(Log, 13)
UNARY_OP_HFD(Exp, 13)
UNARY_OP_HFD(Erf, 13)
+UNARY_OP_BWUZCSILHFD(Sign, 13)
UNARY_LOGICALOP_NOT_TYPED(1, bool)
UNARY_OP_HFD(Round, 11)
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
index 3ff97a60114d..775b78c43a73 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
@@ -112,5 +112,12 @@ class Cos final : public UnaryElementwise {
Status ComputeInternal(OpKernelContext* context) const override;
};
+template
+class Sign final : public UnaryElementwise {
+ public:
+ Sign(const OpKernelInfo& info) : UnaryElementwise(info) {}
+ Status ComputeInternal(OpKernelContext* context) const override;
+};
+
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
index ac7cc1126acb..1298d5333833 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
@@ -90,6 +90,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Cos)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(Not, bool)
+SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Sign)
// When casting, half needs to be converted via float type from most other types
template
@@ -119,52 +120,52 @@ struct OP_Cast {
}
};
-#define IMPL_CAST_IMPL(InT, OutT) \
+#define IMPL_CAST_IMPL(InT, OutT) \
void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \
- UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \
+ UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \
}
-#define IMPL_CAST_IMPL_THROW(InT, OutT) \
+#define IMPL_CAST_IMPL_THROW(InT, OutT) \
void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \
- ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \
+ ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \
}
#if !defined(DISABLE_FLOAT8_TYPES)
-#define IMPL_CAST_IMPL_FROM(T) \
- IMPL_CAST_IMPL(T, half) \
- IMPL_CAST_IMPL(T, float) \
- IMPL_CAST_IMPL(T, double) \
- IMPL_CAST_IMPL(T, int8_t) \
- IMPL_CAST_IMPL(T, int16_t) \
- IMPL_CAST_IMPL(T, int32_t) \
- IMPL_CAST_IMPL(T, int64_t) \
- IMPL_CAST_IMPL(T, uint8_t) \
- IMPL_CAST_IMPL(T, uint16_t) \
- IMPL_CAST_IMPL(T, uint32_t) \
- IMPL_CAST_IMPL(T, uint64_t) \
- IMPL_CAST_IMPL(T, bool) \
- IMPL_CAST_IMPL(T, BFloat16) \
- IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \
- IMPL_CAST_IMPL_THROW(T, Float8E5M2) \
+#define IMPL_CAST_IMPL_FROM(T) \
+ IMPL_CAST_IMPL(T, half) \
+ IMPL_CAST_IMPL(T, float) \
+ IMPL_CAST_IMPL(T, double) \
+ IMPL_CAST_IMPL(T, int8_t) \
+ IMPL_CAST_IMPL(T, int16_t) \
+ IMPL_CAST_IMPL(T, int32_t) \
+ IMPL_CAST_IMPL(T, int64_t) \
+ IMPL_CAST_IMPL(T, uint8_t) \
+ IMPL_CAST_IMPL(T, uint16_t) \
+ IMPL_CAST_IMPL(T, uint32_t) \
+ IMPL_CAST_IMPL(T, uint64_t) \
+ IMPL_CAST_IMPL(T, bool) \
+ IMPL_CAST_IMPL(T, BFloat16) \
+ IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \
+ IMPL_CAST_IMPL_THROW(T, Float8E5M2) \
IMPL_CAST_IMPL_THROW(T, Float8E4M3FNUZ) \
IMPL_CAST_IMPL_THROW(T, Float8E5M2FNUZ)
#else
-#define IMPL_CAST_IMPL_FROM(T) \
- IMPL_CAST_IMPL(T, half) \
- IMPL_CAST_IMPL(T, float) \
- IMPL_CAST_IMPL(T, double) \
- IMPL_CAST_IMPL(T, int8_t) \
- IMPL_CAST_IMPL(T, int16_t) \
- IMPL_CAST_IMPL(T, int32_t) \
- IMPL_CAST_IMPL(T, int64_t) \
- IMPL_CAST_IMPL(T, uint8_t) \
- IMPL_CAST_IMPL(T, uint16_t) \
- IMPL_CAST_IMPL(T, uint32_t) \
- IMPL_CAST_IMPL(T, uint64_t) \
- IMPL_CAST_IMPL(T, bool) \
+#define IMPL_CAST_IMPL_FROM(T) \
+ IMPL_CAST_IMPL(T, half) \
+ IMPL_CAST_IMPL(T, float) \
+ IMPL_CAST_IMPL(T, double) \
+ IMPL_CAST_IMPL(T, int8_t) \
+ IMPL_CAST_IMPL(T, int16_t) \
+ IMPL_CAST_IMPL(T, int32_t) \
+ IMPL_CAST_IMPL(T, int64_t) \
+ IMPL_CAST_IMPL(T, uint8_t) \
+ IMPL_CAST_IMPL(T, uint16_t) \
+ IMPL_CAST_IMPL(T, uint32_t) \
+ IMPL_CAST_IMPL(T, uint64_t) \
+ IMPL_CAST_IMPL(T, bool) \
IMPL_CAST_IMPL(T, BFloat16)
#endif
@@ -199,58 +200,58 @@ struct OP_CastNoSat {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
-#define OP_CAST(T, NVT) \
- template <> \
- struct OP_CastSat { \
- __device__ __inline__ T operator()(const half& v) const { \
+#define OP_CAST(T, NVT) \
+ template <> \
+ struct OP_CastSat { \
+ __device__ __inline__ T operator()(const half& v) const { \
return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \
- } \
- }; \
- template <> \
- struct OP_CastNoSat { \
- __device__ __inline__ T operator()(const half& v) const { \
- return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
- } \
- }; \
- template <> \
- struct OP_CastSat { \
- __device__ __inline__ T operator()(const float& v) const { \
- return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \
- } \
- }; \
- template <> \
- struct OP_CastNoSat { \
- __device__ __inline__ T operator()(const float& v) const { \
- return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
- } \
+ } \
+ }; \
+ template <> \
+ struct OP_CastNoSat { \
+ __device__ __inline__ T operator()(const half& v) const { \
+ return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
+ } \
+ }; \
+ template <> \
+ struct OP_CastSat { \
+ __device__ __inline__ T operator()(const float& v) const { \
+ return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \
+ } \
+ }; \
+ template <> \
+ struct OP_CastNoSat { \
+ __device__ __inline__ T operator()(const float& v) const { \
+ return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \
+ } \
};
#else
-#define OP_CAST(T, NVT) \
- template <> \
- struct OP_CastSat { \
- __device__ __inline__ T operator()(const half& v) const { \
- return T(__half2float(v), true); \
- } \
- }; \
- template <> \
- struct OP_CastNoSat { \
- __device__ __inline__ T operator()(const half& v) const { \
- return T(__half2float(v), false); \
- } \
- }; \
- template <> \
- struct OP_CastSat { \
+#define OP_CAST(T, NVT) \
+ template <> \
+ struct OP_CastSat { \
+ __device__ __inline__ T operator()(const half& v) const { \
+ return T(__half2float(v), true); \
+ } \
+ }; \
+ template <> \
+ struct OP_CastNoSat { \
+ __device__ __inline__ T operator()(const half& v) const { \
+ return T(__half2float(v), false); \
+ } \
+ }; \
+ template <> \
+ struct OP_CastSat { \
__device__ __inline__ T operator()(const float& v) const { \
- return T(v, true); \
- } \
- }; \
- template <> \
- struct OP_CastNoSat { \
+ return T(v, true); \
+ } \
+ }; \
+ template <> \
+ struct OP_CastNoSat { \
__device__ __inline__ T operator()(const float& v) const { \
- return T(v, false); \
- } \
+ return T(v, false); \
+ } \
};
#endif
@@ -260,14 +261,13 @@ struct OP_CastNoSat {
OP_CAST(Float8E4M3FN, __NV_E4M3)
OP_CAST(Float8E5M2, __NV_E5M2)
-
-#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \
+#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \
void Explicit_Impl_CastSat(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count, bool saturate) { \
- if (saturate) { \
- UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat(), count); \
- } else { \
- UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat(), count); \
- } \
+ if (saturate) { \
+ UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat(), count); \
+ } else { \
+ UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat(), count); \
+ } \
}
EXPLICIT_IMPL_CASTSAT(float, Float8E4M3FN)
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
index 3d4868b54abe..608a81a24cf4 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
@@ -31,7 +31,8 @@ namespace cuda {
UNARY_OP_NAME_EXPR(Not, !a) \
UNARY_OP_NAME_EXPR(Round, _Round(a)) \
UNARY_OP_NAME_EXPR(Sin, _Sin(a)) \
- UNARY_OP_NAME_EXPR(Cos, _Cos(a))
+ UNARY_OP_NAME_EXPR(Cos, _Cos(a)) \
+ UNARY_OP_NAME_EXPR(Sign, _Sign(a))
#define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \
template \
diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc
index 2732eb0c3d7b..829f3e5f4f14 100644
--- a/onnxruntime/core/providers/js/js_execution_provider.cc
+++ b/onnxruntime/core/providers/js/js_execution_provider.cc
@@ -291,6 +291,9 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gather);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements);
+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize);
@@ -532,6 +535,9 @@ std::unique_ptr RegisterKernels() {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/js/operators/gather_elements.cc b/onnxruntime/core/providers/js/operators/gather_elements.cc
new file mode 100644
index 000000000000..b4db122341bc
--- /dev/null
+++ b/onnxruntime/core/providers/js/operators/gather_elements.cc
@@ -0,0 +1,37 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/js/js_kernel.h"
+#include "core/providers/js/js_data_types.h"
+#include "gather_elements.h"
+
+namespace onnxruntime {
+namespace js {
+
+ONNX_OPERATOR_VERSIONED_KERNEL_EX(
+ GatherElements,
+ kOnnxDomain,
+ 11,
+ 12,
+ kJsExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T", {DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType()})
+ .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()),
+ GatherElements);
+
+ONNX_OPERATOR_KERNEL_EX(
+ GatherElements,
+ kOnnxDomain,
+ 13,
+ kJsExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T", {DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType()})
+ .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()),
+ GatherElements);
+
+} // namespace js
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/js/operators/gather_elements.h b/onnxruntime/core/providers/js/operators/gather_elements.h
new file mode 100644
index 000000000000..ce9014513377
--- /dev/null
+++ b/onnxruntime/core/providers/js/operators/gather_elements.h
@@ -0,0 +1,24 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/providers/js/js_kernel.h"
+
+namespace onnxruntime {
+namespace js {
+
+class GatherElements : public JsKernel {
+ public:
+ GatherElements(const OpKernelInfo& info) : JsKernel(info) {
+ int64_t axis = info.GetAttrOrDefault("axis", 0);
+
+ JSEP_INIT_KERNEL_ATTRIBUTE(GatherElements, ({
+ "axis" : Number($1),
+ }),
+ static_cast(axis));
+ }
+};
+
+} // namespace js
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc
index 99f35f9e660e..58ac3ad45a57 100644
--- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc
+++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc
@@ -63,6 +63,8 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
CreateSimpleOpBuilder("DepthToSpace", *this);
CreateSimpleOpBuilder("SpaceToDepth", *this);
+
+ CreateSimpleOpBuilder("GridSample", *this);
}
{
@@ -86,6 +88,7 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
CreatePoolOpBuilder("GlobalAveragePool", *this);
CreatePoolOpBuilder("AveragePool", *this);
CreatePoolOpBuilder("MaxPool", *this);
+ CreatePoolOpBuilder("GlobalMaxPool", *this);
}
{
diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h
index 75f76e7c9b10..14d5e45799b8 100644
--- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h
+++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h
@@ -120,6 +120,7 @@ class BaseOpBuilder : public IOpBuilder {
{"Sub", QNN_OP_ELEMENT_WISE_SUBTRACT},
{"Tanh", QNN_OP_TANH},
{"Transpose", QNN_OP_TRANSPOSE},
+ {"GridSample", QNN_OP_GRID_SAMPLE},
{"DequantizeLinear", QNN_OP_DEQUANTIZE},
{"QuantizeLinear", QNN_OP_QUANTIZE},
@@ -140,6 +141,7 @@ class BaseOpBuilder : public IOpBuilder {
{"GlobalAveragePool", QNN_OP_POOL_AVG_2D},
{"AveragePool", QNN_OP_POOL_AVG_2D},
{"MaxPool", QNN_OP_POOL_MAX_2D},
+ {"GlobalMaxPool", QNN_OP_POOL_MAX_2D},
{"Reshape", QNN_OP_RESHAPE},
{"Resize", QNN_OP_RESIZE},
diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc
index c2909c9e0d79..a44640b37ae3 100644
--- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc
+++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc
@@ -58,7 +58,17 @@ Status PoolOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
std::vector input_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape");
if (input_shape.size() != 4) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN Conv only support 2D!");
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN Pool2D only support 2D!");
+ }
+
+ if (node_unit.Outputs().size() > 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN only support 1 output!");
+ }
+
+ const std::string& op_type = node_unit.OpType();
+ // Onnx GlobalMaxPool doesn't have any attributes
+ if (op_type == "GlobalMaxPool") {
+ return Status::OK();
}
NodeAttrHelper node_helper(node_unit);
@@ -67,11 +77,7 @@ Status PoolOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN does not support Dilation attribute");
}
- if (node_unit.Outputs().size() > 1) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN only support 1 output!");
- }
-
- if (node_unit.OpType() == "MaxPool" || node_unit.OpType() == "AveragePool") {
+ if (op_type == "MaxPool" || op_type == "AveragePool") {
auto auto_pad = node_helper.Get("auto_pad", std::string("NOTSET"));
ORT_RETURN_IF(auto_pad != "NOTSET" && auto_pad != "SAME_LOWER" && auto_pad != "SAME_UPPER",
"QNN Pool operators do not support 'auto_pad' value: ", auto_pad.c_str());
@@ -121,6 +127,21 @@ Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper,
return Status::OK();
} // namespace qnn
+void SetPoolParam(const NodeUnit& node_unit,
+ const std::string& param_name,
+ std::vector&& parm_shape,
+ std::vector&& parm_data,
+ std::vector& param_tensor_names,
+ QnnModelWrapper& qnn_model_wrapper) {
+ QnnParamWrapper qnn_param(node_unit.Index(),
+ node_unit.Name(),
+ param_name,
+ std::move(parm_shape),
+ std::move(parm_data));
+ param_tensor_names.push_back(qnn_param.GetParamTensorName());
+ qnn_model_wrapper.AddParamWrapper(std::move(qnn_param));
+}
+
Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector&& input_names,
@@ -142,7 +163,25 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
std::vector pad_amount{0, 0, 0, 0};
std::vector pad_amount_dim{2, 2};
int32_t ceil_mode = 0;
- if (node_unit.OpType() == "MaxPool" || node_unit.OpType() == "AveragePool") {
+
+ std::vector param_tensor_names;
+ const std::string& op_type = node_unit.OpType();
+ if (op_type == "GlobalMaxPool") {
+ // set default params for Qnn PoolMax2D
+ SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_FILTER_SIZE, std::move(filter_size_dim), std::move(filter_size), param_tensor_names, qnn_model_wrapper);
+ SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_PAD_AMOUNT, std::move(pad_amount_dim), std::move(pad_amount), param_tensor_names, qnn_model_wrapper);
+ SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_STRIDE, std::move(stride_dim), std::move(stride), param_tensor_names, qnn_model_wrapper);
+
+ ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit,
+ std::move(input_names),
+ std::move(param_tensor_names),
+ logger,
+ do_op_validation,
+ GetQnnOpType(op_type)));
+ return Status::OK();
+ }
+
+ if (op_type == "MaxPool" || op_type == "AveragePool") {
const auto& outputs = node_unit.Outputs();
std::vector output_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(outputs[0].node_arg, output_shape), "Cannot get shape");
@@ -151,30 +190,10 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
std::move(input_shape), std::move(output_shape)));
}
- std::vector param_tensor_names;
- QnnParamWrapper filter_size_param(node_unit.Index(),
- node_unit.Name(),
- QNN_OP_POOL_MAX_2D_PARAM_FILTER_SIZE,
- std::move(filter_size_dim),
- std::move(filter_size));
- param_tensor_names.push_back(filter_size_param.GetParamTensorName());
- qnn_model_wrapper.AddParamWrapper(std::move(filter_size_param));
-
- QnnParamWrapper pad_amount_param(node_unit.Index(),
- node_unit.Name(),
- QNN_OP_POOL_MAX_2D_PARAM_PAD_AMOUNT,
- std::move(pad_amount_dim),
- std::move(pad_amount));
- param_tensor_names.push_back(pad_amount_param.GetParamTensorName());
- qnn_model_wrapper.AddParamWrapper(std::move(pad_amount_param));
-
- QnnParamWrapper stride_param(node_unit.Index(),
- node_unit.Name(),
- QNN_OP_POOL_MAX_2D_PARAM_STRIDE,
- std::move(stride_dim),
- std::move(stride));
- param_tensor_names.push_back(stride_param.GetParamTensorName());
- qnn_model_wrapper.AddParamWrapper(std::move(stride_param));
+ SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_FILTER_SIZE, std::move(filter_size_dim), std::move(filter_size), param_tensor_names, qnn_model_wrapper);
+ SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_PAD_AMOUNT, std::move(pad_amount_dim), std::move(pad_amount), param_tensor_names, qnn_model_wrapper);
+ SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_STRIDE, std::move(stride_dim), std::move(stride), param_tensor_names, qnn_model_wrapper);
+
if (0 != ceil_mode) {
Qnn_Scalar_t rounding_mode_param = QNN_SCALAR_INIT;
rounding_mode_param.dataType = QNN_DATATYPE_UINT_32;
@@ -186,7 +205,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
param_tensor_names.push_back(rounding_mode_param_wrapper.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(rounding_mode_param_wrapper));
}
- if (node_unit.OpType() == "GlobalAveragePool") {
+ if (op_type == "GlobalAveragePool") {
Qnn_Scalar_t scalar_param = QNN_SCALAR_INIT;
scalar_param.dataType = QNN_DATATYPE_BOOL_8;
scalar_param.bool8Value = 1;
@@ -196,7 +215,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
scalar_param);
param_tensor_names.push_back(count_pad_for_edges_param.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(count_pad_for_edges_param));
- } else if (node_unit.OpType() == "AveragePool") {
+ } else if (op_type == "AveragePool") {
Qnn_Scalar_t scalar_param = QNN_SCALAR_INIT;
scalar_param.dataType = QNN_DATATYPE_BOOL_8;
scalar_param.bool8Value = static_cast(node_helper.Get("count_include_pad", static_cast(0)) != 0);
@@ -211,7 +230,9 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit,
std::move(input_names),
std::move(param_tensor_names),
- logger, do_op_validation, GetQnnOpType(node_unit.OpType())));
+ logger,
+ do_op_validation,
+ GetQnnOpType(op_type)));
return Status::OK();
}
diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc
index f36854cfea76..511f2a5149f2 100644
--- a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc
+++ b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc
@@ -14,6 +14,7 @@
#include "core/common/safeint.h"
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
+#include "core/providers/qnn/builder/qnn_utils.h"
namespace onnxruntime {
namespace qnn {
@@ -157,19 +158,6 @@ Status ResizeOpBuilder::GetQnnModeFromString(const std::array
-static bool ArrayHasString(const std::array& strings, std::string_view str) {
- for (auto s : strings) {
- if (s == str) {
- return true;
- }
- }
-
- return false;
-}
-
// Resize ops are sensitive with data layout, no special validation so far
// The nodes from 1st call of GetCapability do not get layout transformer applied, it's still NCHW
// The nodes from 2nd call of GetCapability get layout transformer applied, it's NHWC
@@ -252,6 +240,7 @@ Status ResizeOpBuilder::ValidateOp(QnnModelWrapper& qnn_model_wrapper, const Nod
Status ResizeOpBuilder::ValidateQDQOp(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
NodeAttrHelper node_helper(node_unit);
+ using namespace onnxruntime::qnn::utils;
// Check mode
const std::string interp_mode = GetOnnxAttr(node_helper, onnx_mode_attr);
ORT_RETURN_IF_NOT(ArrayHasString(supported_modes, interp_mode), "QNN EP: Resize does not support mode ",
diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc
index 8d9a79ddf888..ca18c051a992 100644
--- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc
+++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc
@@ -30,18 +30,9 @@ class SimpleOpBuilder : public BaseOpBuilder {
private:
Status ExplictOpCheck(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const;
- Status ProcessAlphaAttribute(QnnModelWrapper& qnn_model_wrapper,
- const NodeUnit& node_unit,
- std::vector& param_tensor_names) const;
- Status ProcessAlphaAttributeAsInput(QnnModelWrapper& qnn_model_wrapper,
- const NodeUnit& node_unit,
- const std::string input_name) const;
- Status ProcessBlockSizeAttribute(QnnModelWrapper& qnn_model_wrapper,
- const NodeUnit& node_unit,
- std::vector& param_tensor_names) const;
- Status ProcessModeAttribute(QnnModelWrapper& qnn_model_wrapper,
- const NodeUnit& node_unit,
- std::vector& param_tensor_names) const;
+
+ static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest"};
+ static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"};
};
Status SimpleOpBuilder::ExplictOpCheck(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
@@ -57,12 +48,22 @@ Status SimpleOpBuilder::ExplictOpCheck(const QnnModelWrapper& qnn_model_wrapper,
"QNN Softmax only supports an `axis` attribute equal to input_rank-1 (or -1)");
}
+ if (node_unit.OpType() == "GridSample") {
+ NodeAttrHelper node_helper(node_unit);
+ std::string mode = node_helper.Get("mode", "linear");
+ ORT_RETURN_IF_NOT(utils::ArrayHasString(gridsample_supported_modes, mode), "GridSample does not support mode ",
+ mode.c_str());
+ std::string padding_mode = node_helper.Get("padding_mode", "zeros");
+ ORT_RETURN_IF_NOT(utils::ArrayHasString(gridsample_supported_padding_modes, padding_mode), "GridSample does not support padding_mode ",
+ padding_mode.c_str());
+ }
+
return Status::OK();
}
-Status SimpleOpBuilder::ProcessAlphaAttribute(QnnModelWrapper& qnn_model_wrapper,
- const NodeUnit& node_unit,
- std::vector& param_tensor_names) const {
+Status ProcessAlphaAttribute(QnnModelWrapper& qnn_model_wrapper,
+ const NodeUnit& node_unit,
+ std::vector& param_tensor_names) {
NodeAttrHelper node_helper(node_unit);
float alpha = node_helper.Get("alpha", 1.0f);
Qnn_Scalar_t alpha_qnn_scalar = QNN_SCALAR_INIT;
@@ -76,9 +77,9 @@ Status SimpleOpBuilder::ProcessAlphaAttribute(QnnModelWrapper& qnn_model_wrapper
return Status::OK();
}
-Status SimpleOpBuilder::ProcessBlockSizeAttribute(QnnModelWrapper& qnn_model_wrapper,
- const NodeUnit& node_unit,
- std::vector& param_tensor_names) const {
+Status ProcessBlockSizeAttribute(QnnModelWrapper& qnn_model_wrapper,
+ const NodeUnit& node_unit,
+ std::vector& param_tensor_names) {
NodeAttrHelper node_helper(node_unit);
uint32_t block_size = node_helper.Get("blocksize", static_cast(0));
std::vector block_size_shape{2};
@@ -91,9 +92,9 @@ Status SimpleOpBuilder::ProcessBlockSizeAttribute(QnnModelWrapper& qnn_model_wra
return Status::OK();
}
-Status SimpleOpBuilder::ProcessModeAttribute(QnnModelWrapper& qnn_model_wrapper,
- const NodeUnit& node_unit,
- std::vector& param_tensor_names) const {
+Status ProcessModeAttribute(QnnModelWrapper& qnn_model_wrapper,
+ const NodeUnit& node_unit,
+ std::vector& param_tensor_names) {
NodeAttrHelper node_helper(node_unit);
std::string mode = node_helper.Get("mode", "DCR");
Qnn_Scalar_t mode_qnn_scalar = QNN_SCALAR_INIT;
@@ -114,9 +115,9 @@ Status SimpleOpBuilder::ProcessModeAttribute(QnnModelWrapper& qnn_model_wrapper,
}
// Process alpha attribute as input for Qnn LeakyRelu
-Status SimpleOpBuilder::ProcessAlphaAttributeAsInput(QnnModelWrapper& qnn_model_wrapper,
- const NodeUnit& node_unit,
- const std::string input_name) const {
+Status ProcessAlphaAttributeAsInput(QnnModelWrapper& qnn_model_wrapper,
+ const NodeUnit& node_unit,
+ const std::string input_name) {
NodeAttrHelper node_helper(node_unit);
Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT;
Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32;
@@ -149,6 +150,51 @@ Status SimpleOpBuilder::ProcessAlphaAttributeAsInput(QnnModelWrapper& qnn_model_
return Status::OK();
}
+Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper,
+ const NodeUnit& node_unit,
+ std::vector& param_tensor_names) {
+ NodeAttrHelper node_helper(node_unit);
+ int64_t align_corners = node_helper.Get("align_corners", static_cast(0));
+ Qnn_Scalar_t align_corners_qnn_scalar = QNN_SCALAR_INIT;
+ align_corners_qnn_scalar.dataType = QNN_DATATYPE_BOOL_8;
+ align_corners_qnn_scalar.bool8Value = static_cast(align_corners == 0 ? 0 : 1);
+ QnnParamWrapper align_corners_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_ALIGN_CORNERS, align_corners_qnn_scalar);
+ param_tensor_names.push_back(align_corners_param.GetParamTensorName());
+ qnn_model_wrapper.AddParamWrapper(std::move(align_corners_param));
+
+ std::string mode = node_helper.Get("mode", "linear");
+ Qnn_Scalar_t mode_qnn_scalar = QNN_SCALAR_INIT;
+ mode_qnn_scalar.dataType = QNN_DATATYPE_UINT_32;
+ if ("bilinear" == mode) {
+ mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_MODE_BILINEAR;
+ } else if ("nearest" == mode) {
+ mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_MODE_NEAREST;
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample mode only support bilinear & nearest.");
+ }
+ QnnParamWrapper mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_MODE, mode_qnn_scalar);
+ param_tensor_names.push_back(mode_param.GetParamTensorName());
+ qnn_model_wrapper.AddParamWrapper(std::move(mode_param));
+
+ std::string padding_mode = node_helper.Get("padding_mode", "zeros");
+ Qnn_Scalar_t padding_mode_qnn_scalar = QNN_SCALAR_INIT;
+ padding_mode_qnn_scalar.dataType = QNN_DATATYPE_UINT_32;
+ if ("zeros" == padding_mode) {
+ padding_mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_PADDING_MODE_ZEROS;
+ } else if ("border" == padding_mode) {
+ padding_mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_PADDING_MODE_BORDER;
+ } else if ("reflection" == padding_mode) {
+ padding_mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_PADDING_MODE_REFLECTION;
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample padding_mode only support zeros, border & reflection.");
+ }
+ QnnParamWrapper padding_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_PADDING_MODE, padding_mode_qnn_scalar);
+ param_tensor_names.push_back(padding_mode_param.GetParamTensorName());
+ qnn_model_wrapper.AddParamWrapper(std::move(padding_mode_param));
+
+ return Status::OK();
+}
+
Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector&& input_names,
@@ -163,7 +209,7 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
if (do_op_validation) {
ORT_RETURN_IF_ERROR(ExplictOpCheck(qnn_model_wrapper, node_unit));
// Skip the op validation for DepthToSpace & SpaceToDepth if it's not NHWC data layout
- if (node_unit.Domain() != kMSInternalNHWCDomain && (op_type == "DepthToSpace" || op_type == "SpaceToDepth")) {
+ if (node_unit.Domain() != kMSInternalNHWCDomain && (op_type == "DepthToSpace" || op_type == "SpaceToDepth" || op_type == "GridSample")) {
return Status::OK();
}
}
@@ -211,6 +257,10 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
ORT_RETURN_IF_ERROR(ProcessBlockSizeAttribute(qnn_model_wrapper, node_unit, param_tensor_names));
}
+ if (op_type == "GridSample") {
+ ORT_RETURN_IF_ERROR(ProcessGridSampleAttributes(qnn_model_wrapper, node_unit, param_tensor_names));
+ }
+
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit,
std::move(input_names),
std::move(param_tensor_names),
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h
index 1c4d85a0d147..a54e0c8276e7 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h
+++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h
@@ -35,6 +35,19 @@ inline void InitializeQuantizeParam(Qnn_QuantizeParams_t& quantize_param, bool i
quantize_param.scaleOffsetEncoding.offset = offset;
}
+// Utility function that checks if an array of strings contains a specific string.
+// Used to validate ONNX operator attributes.
+template
+static bool ArrayHasString(const std::array& strings, std::string_view str) {
+ for (auto s : strings) {
+ if (s == str) {
+ return true;
+ }
+ }
+
+ return false;
+}
+
} // namespace utils
} // namespace qnn
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh
index 5c516aac65aa..429ceb1f7c69 100644
--- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh
+++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh
@@ -250,6 +250,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }
template