Skip to content

Commit

Permalink
[WebNN EP] Support numThreads option for WebNN CPU device (microsoft#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry authored Nov 13, 2023
1 parent cbf0cf0 commit 73ed34a
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 6 deletions.
1 change: 1 addition & 0 deletions js/common/lib/inference-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ export declare namespace InferenceSession {
export interface WebNNExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'webnn';
deviceType?: 'cpu'|'gpu';
numThreads?: number;
powerPreference?: 'default'|'low-power'|'high-performance';
}
export interface CoreMLExecutionProviderOption extends ExecutionProviderOption {
Expand Down
13 changes: 13 additions & 0 deletions js/web/lib/wasm/session-options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@ const setExecutionProviders =
checkLastError(`Can't set a session config entry: 'deviceType' - ${webnnOptions.deviceType}.`);
}
}
if (webnnOptions?.numThreads) {
let numThreads = webnnOptions.numThreads;
// Just ignore invalid webnnOptions.numThreads.
if (typeof numThreads != 'number' || !Number.isInteger(numThreads) || numThreads < 0) {
numThreads = 0;
}
const keyDataOffset = allocWasmString('numThreads', allocs);
const valueDataOffset = allocWasmString(numThreads.toString(), allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
checkLastError(`Can't set a session config entry: 'numThreads' - ${webnnOptions.numThreads}.`);
}
}
if (webnnOptions?.powerPreference) {
const keyDataOffset = allocWasmString('powerPreference', allocs);
const valueDataOffset = allocWasmString(webnnOptions.powerPreference, allocs);
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/providers/webnn/webnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

namespace onnxruntime {

WebNNExecutionProvider::WebNNExecutionProvider(
const std::string& webnn_device_flags, const std::string& webnn_power_flags)
WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags,
const std::string& webnn_threads_number, const std::string& webnn_power_flags)
: IExecutionProvider{onnxruntime::kWebNNExecutionProvider, true} {
// Create WebNN context and graph builder.
const emscripten::val ml = emscripten::val::global("navigator")["ml"];
Expand All @@ -31,6 +31,10 @@ WebNNExecutionProvider::WebNNExecutionProvider(
if (webnn_device_flags.compare("cpu") == 0) {
preferred_layout_ = DataLayout::NHWC;
wnn_device_type_ = webnn::WebnnDeviceType::CPU;
// Set "numThreads" if it's not default 0.
if (webnn_threads_number.compare("0") != 0) {
context_options.set("numThreads", stoi(webnn_threads_number));
}
} else {
preferred_layout_ = DataLayout::NCHW;
wnn_device_type_ = webnn::WebnnDeviceType::GPU;
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/webnn/webnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class Model;

class WebNNExecutionProvider : public IExecutionProvider {
public:
WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_power_flags);
WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_threads_number,
const std::string& webnn_power_flags);
virtual ~WebNNExecutionProvider();

std::vector<std::unique_ptr<ComputeCapability>>
Expand Down
9 changes: 6 additions & 3 deletions onnxruntime/core/providers/webnn/webnn_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,26 @@ using namespace onnxruntime;

namespace onnxruntime {
struct WebNNProviderFactory : IExecutionProviderFactory {
WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_power_flags)
: webnn_device_flags_(webnn_device_flags), webnn_power_flags_(webnn_power_flags) {}
WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_threads_number,
const std::string& webnn_power_flags)
: webnn_device_flags_(webnn_device_flags), webnn_threads_number_(webnn_threads_number), webnn_power_flags_(webnn_power_flags) {}
~WebNNProviderFactory() override {}

std::unique_ptr<IExecutionProvider> CreateProvider() override;

std::string webnn_device_flags_;
std::string webnn_threads_number_;
std::string webnn_power_flags_;
};

std::unique_ptr<IExecutionProvider> WebNNProviderFactory::CreateProvider() {
return std::make_unique<WebNNExecutionProvider>(webnn_device_flags_, webnn_power_flags_);
return std::make_unique<WebNNExecutionProvider>(webnn_device_flags_, webnn_threads_number_, webnn_power_flags_);
}

std::shared_ptr<IExecutionProviderFactory> WebNNProviderFactoryCreator::Create(
const ProviderOptions& provider_options) {
return std::make_shared<onnxruntime::WebNNProviderFactory>(provider_options.at("deviceType"),
provider_options.at("numThreads"),
provider_options.at("powerPreference"));
}

Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/provider_registration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
} else if (strcmp(provider_name, "WEBNN") == 0) {
#if defined(USE_WEBNN)
std::string deviceType = options->value.config_options.GetConfigOrDefault("deviceType", "cpu");
std::string numThreads = options->value.config_options.GetConfigOrDefault("numThreads", "0");
std::string powerPreference = options->value.config_options.GetConfigOrDefault("powerPreference", "default");
provider_options["deviceType"] = deviceType;
provider_options["numThreads"] = numThreads;
provider_options["powerPreference"] = powerPreference;
options->provider_factories.push_back(WebNNProviderFactoryCreator::Create(provider_options));
#else
Expand Down

0 comments on commit 73ed34a

Please sign in to comment.