From 35697d242111b20ac2160197ff9fe90ee0ca63bc Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 31 May 2024 03:25:14 -0700 Subject: [PATCH] [js/webnn] update API of session options for WebNN (#20816) ### Description This PR is an API-only change to address the requirements being discussed in #20729. There are multiple ways that users may create an ORT session by specifying the session options differently. All the code snippet below will use the variable `webnnOptions` as this: ```js const myWebnnSession = await ort.InferenceSession.create('./model.onnx', { executionProviders: [ webnnOptions ] }); ``` ### The old way (backward-compatibility) ```js // all-default, name only const webnnOptions_0 = 'webnn'; // all-default, properties omitted const webnnOptions_1 = { name: 'webnn' }; // partial const webnnOptions_2 = { name: 'webnn', deviceType: 'cpu' }; // full const webnnOptions_3 = { name: 'webnn', deviceType: 'gpu', numThreads: 1, powerPreference: 'high-performance' }; ``` ### The new way (specify with MLContext) ```js // options to create MLcontext const options = { deviceType: 'gpu', powerPreference: 'high-performance' }; const myMlContext = await navigator.ml.createContext(options); // options for session options const webnnOptions = { name: 'webnn', context: myMlContext, ...options }; ``` This should throw (because no deviceType is specified): ```js const myMlContext = await navigator.ml.createContext({ ... }); const webnnOptions = { name: 'webnn', context: myMlContext }; ``` ### Interop with WebGPU ```js // get WebGPU device const adaptor = await navigator.gpu.requestAdapter({ ... }); const device = await adaptor.requestDevice({ ... }); // set WebGPU adaptor and device ort.env.webgpu.adaptor = adaptor; ort.env.webgpu.device = device; const myMlContext = await navigator.ml.createContext(device); const webnnOptions = { name: 'webnn', context: myMlContext, gpuDevice: device }; ``` This should throw (because cannot specify both gpu device and MLContext option at the same time): ```js const webnnOptions = { name: 'webnn', context: myMlContext, gpuDevice: device, deviceType: 'gpu' }; ``` --- js/common/lib/inference-session.ts | 52 +++++++++++++++++++++++++++++- js/web/lib/wasm/session-options.ts | 30 +++++++++-------- 2 files changed, 67 insertions(+), 15 deletions(-) diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 353d93bbc34a..069fd9b49e48 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -242,12 +242,62 @@ export declare namespace InferenceSession { readonly name: 'webgpu'; preferredLayout?: 'NCHW'|'NHWC'; } - export interface WebNNExecutionProviderOption extends ExecutionProviderOption { + + // #region WebNN options + + interface WebNNExecutionProviderName extends ExecutionProviderOption { readonly name: 'webnn'; + } + + /** + * Represents a set of options for creating a WebNN MLContext. + * + * @see https://www.w3.org/TR/webnn/#dictdef-mlcontextoptions + */ + export interface WebNNContextOptions { deviceType?: 'cpu'|'gpu'|'npu'; numThreads?: number; powerPreference?: 'default'|'low-power'|'high-performance'; } + + /** + * Represents a set of options for WebNN execution provider without MLContext. + */ + export interface WebNNOptionsWithoutMLContext extends WebNNExecutionProviderName, WebNNContextOptions { + context?: never; + } + + /** + * Represents a set of options for WebNN execution provider with MLContext. + * + * When MLContext is provided, the deviceType is also required so that the WebNN EP can determine the preferred + * channel layout. + * + * @see https://www.w3.org/TR/webnn/#dom-ml-createcontext + */ + export interface WebNNOptionsWithMLContext extends WebNNExecutionProviderName, + Omit, + Required> { + context: unknown /* MLContext */; + } + + /** + * Represents a set of options for WebNN execution provider with MLContext which is created from GPUDevice. + * + * @see https://www.w3.org/TR/webnn/#dom-ml-createcontext-gpudevice + */ + export interface WebNNOptionsWebGpu extends WebNNExecutionProviderName { + context: unknown /* MLContext */; + gpuDevice: unknown /* GPUDevice */; + } + + /** + * Options for WebNN execution provider. + */ + export type WebNNExecutionProviderOption = WebNNOptionsWithoutMLContext|WebNNOptionsWithMLContext|WebNNOptionsWebGpu; + + // #endregion + export interface QnnExecutionProviderOption extends ExecutionProviderOption { readonly name: 'qnn'; // TODO add flags diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 48eac5749472..4d2b80e31a47 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -64,34 +64,36 @@ const setExecutionProviders = epName = 'WEBNN'; if (typeof ep !== 'string') { const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption; - if (webnnOptions?.deviceType) { + // const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; + const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType; + const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads; + const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference; + if (deviceType) { const keyDataOffset = allocWasmString('deviceType', allocs); - const valueDataOffset = allocWasmString(webnnOptions.deviceType, allocs); + const valueDataOffset = allocWasmString(deviceType, allocs); if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - checkLastError(`Can't set a session config entry: 'deviceType' - ${webnnOptions.deviceType}.`); + checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`); } } - if (webnnOptions?.numThreads) { - let numThreads = webnnOptions.numThreads; + if (numThreads !== undefined) { // Just ignore invalid webnnOptions.numThreads. - if (typeof numThreads != 'number' || !Number.isInteger(numThreads) || numThreads < 0) { - numThreads = 0; - } + const validatedNumThreads = + (typeof numThreads !== 'number' || !Number.isInteger(numThreads) || numThreads < 0) ? 0 : + numThreads; const keyDataOffset = allocWasmString('numThreads', allocs); - const valueDataOffset = allocWasmString(numThreads.toString(), allocs); + const valueDataOffset = allocWasmString(validatedNumThreads.toString(), allocs); if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - checkLastError(`Can't set a session config entry: 'numThreads' - ${webnnOptions.numThreads}.`); + checkLastError(`Can't set a session config entry: 'numThreads' - ${numThreads}.`); } } - if (webnnOptions?.powerPreference) { + if (powerPreference) { const keyDataOffset = allocWasmString('powerPreference', allocs); - const valueDataOffset = allocWasmString(webnnOptions.powerPreference, allocs); + const valueDataOffset = allocWasmString(powerPreference, allocs); if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - checkLastError( - `Can't set a session config entry: 'powerPreference' - ${webnnOptions.powerPreference}.`); + checkLastError(`Can't set a session config entry: 'powerPreference' - ${powerPreference}.`); } } }