Skip to content

Commit

Permalink
Merged PR 6606929: RI 10/26 from github into fork #2
Browse files Browse the repository at this point in the history
Related work items: #36831318
  • Loading branch information
Sheil Kumar committed Oct 27, 2021
1 parent 08dd8a0 commit 8dd4d20
Show file tree
Hide file tree
Showing 15 changed files with 101 additions and 83 deletions.
57 changes: 31 additions & 26 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ cmake_policy(SET CMP0104 OLD)
# Project
project(onnxruntime C CXX)
# Needed for Java
set (CMAKE_C_STANDARD 99)
set(CMAKE_C_STANDARD 99)

include(CheckCXXCompilerFlag)
include(CheckLanguage)
Expand Down Expand Up @@ -109,10 +109,11 @@ option(onnxruntime_USE_ROCM "Build with AMD GPU support" OFF)
option(onnxruntime_DISABLE_CONTRIB_OPS "Disable contrib ops" OFF)
option(onnxruntime_DISABLE_ML_OPS "Disable traditional ML ops" OFF)
option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OFF)
option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF)
cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "onnxruntime_MINIMAL_BUILD;NOT onnxruntime_ENABLE_PYTHON" OFF)
# For now onnxruntime_DISABLE_EXCEPTIONS will only work with onnxruntime_MINIMAL_BUILD, more changes (ONNX, non-CPU EP, ...) are required to run this standalone
option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handling. Requires onnxruntime_MINIMAL_BUILD currently." OFF)
option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF)
cmake_dependent_option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handling. Requires onnxruntime_MINIMAL_BUILD currently." ON "onnxruntime_MINIMAL_BUILD;NOT onnxruntime_ENABLE_PYTHON" OFF)

option(onnxruntime_EXTENDED_MINIMAL_BUILD "onnxruntime_MINIMAL_BUILD with support for execution providers that compile kernels." OFF)
option(onnxruntime_MINIMAL_BUILD_CUSTOM_OPS "Add custom operator kernels support to a minimal build." OFF)
option(onnxruntime_REDUCED_OPS_BUILD "Reduced set of kernels are registered in build via modification of the kernel registration source files." OFF)
Expand Down Expand Up @@ -180,7 +181,7 @@ elseif (NOT WIN32 AND NOT APPLE)
endif()

# Single output director for all binaries
set (RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin CACHE PATH "Single output directory for all binaries.")
set(RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin CACHE PATH "Single output directory for all binaries.")

function(set_msvc_c_cpp_compiler_warning_level warning_level)
if (NOT "${warning_level}" MATCHES "^[0-4]$")
Expand Down Expand Up @@ -283,8 +284,8 @@ endif()
if (onnxruntime_USE_OPENMP)
find_package(OpenMP)
if (OPENMP_FOUND)
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
include_directories(${OpenMP_CXX_INCLUDE_DIR})
else()
message(WARNING "Flag --use_openmp is specified, but OpenMP is not found in current build environment. Setting it to OFF.")
Expand Down Expand Up @@ -354,7 +355,7 @@ if (onnxruntime_MINIMAL_BUILD)
if (MSVC)
# turn on LTO (which adds some compiler flags and turns on LTCG) unless it's a Debug build to minimize binary size
if (NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
set (onnxruntime_ENABLE_LTO ON)
set(onnxruntime_ENABLE_LTO ON)
endif()

# undocumented internal flag to allow analysis of a minimal build binary size
Expand Down Expand Up @@ -393,7 +394,7 @@ if (onnxruntime_ENABLE_LTO)
message(WARNING "IPO is not supported by this compiler")
set(onnxruntime_ENABLE_LTO OFF)
else()
set (CMAKE_INTERPROCEDURAL_OPTIMIZATION ON)
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION ON)
endif()
endif()

Expand All @@ -406,7 +407,7 @@ if (onnxruntime_DISABLE_EXTERNAL_INITIALIZERS)
endif()

if (onnxruntime_DISABLE_RTTI)
add_compile_definitions(ORT_NO_RTTI GOOGLE_PROTOBUF_NO_RTTI)
add_compile_definitions(ORT_NO_RTTI)
if (MSVC)
# Disable RTTI and turn usage of dynamic_cast and typeid into errors
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/GR->" "$<$<COMPILE_LANGUAGE:CXX>:/we4541>")
Expand All @@ -428,6 +429,10 @@ if (onnxruntime_DISABLE_EXCEPTIONS)
message(FATAL_ERROR "onnxruntime_MINIMAL_BUILD required for onnxruntime_DISABLE_EXCEPTIONS")
endif()

if (onnxruntime_ENABLE_PYTHON)
# pybind11 highly depends on C++ exceptions.
message(FATAL_ERROR "onnxruntime_ENABLE_PYTHON must be disabled for onnxruntime_DISABLE_EXCEPTIONS")
endif()
add_compile_definitions("ORT_NO_EXCEPTIONS")
add_compile_definitions("MLAS_NO_EXCEPTION")
add_compile_definitions("ONNX_NO_EXCEPTIONS")
Expand Down Expand Up @@ -562,30 +567,30 @@ if (MSVC)
endif()

if (onnxruntime_ENABLE_LTO AND NOT onnxruntime_USE_CUDA)
SET (CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /Gw /GL")
SET (CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /Gw /GL")
SET (CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /Gw /GL")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /Gw /GL")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /Gw /GL")
set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /Gw /GL")
endif()

# The WinML build tool chain builds ARM/ARM64, and the internal tool chain does not have folders for spectre mitigation libs.
# WinML performs spectre mitigation differently.
if (NOT DEFINED onnxruntime_DISABLE_QSPECTRE_CHECK)
check_cxx_compiler_flag(-Qspectre HAS_QSPECTRE)
if (HAS_QSPECTRE)
SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Qspectre")
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Qspectre")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Qspectre")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Qspectre")
endif()
endif()
SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DYNAMICBASE")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DYNAMICBASE")
check_cxx_compiler_flag(-guard:cf HAS_GUARD_CF)
if (HAS_GUARD_CF)
SET(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /guard:cf")
SET(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /guard:cf")
SET(CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELWITHDEBINFO} /guard:cf")
SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /guard:cf")
SET(CMAKE_C_FLAGS_MINSIZEREL "${CMAKE_C_FLAGS_MINSIZEREL} /guard:cf")
SET(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /guard:cf")
SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /guard:cf")
set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /guard:cf")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /guard:cf")
set(CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELWITHDEBINFO} /guard:cf")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /guard:cf")
set(CMAKE_C_FLAGS_MINSIZEREL "${CMAKE_C_FLAGS_MINSIZEREL} /guard:cf")
set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /guard:cf")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /guard:cf")
endif()
else()
if (NOT APPLE)
Expand Down Expand Up @@ -640,16 +645,16 @@ endif()
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
#For Mac compliance
message("Adding flags for Mac builds")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fstack-protector-strong")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fstack-protector-strong")
endif()

if (${CMAKE_SYSTEM_NAME} MATCHES "iOSCross")
#For ios compliance
message("Adding flags for ios builds")
if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -target arm64-apple-darwin-macho")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -target arm64-apple-darwin-macho")
elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "arm")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -target armv7a-apple-darwin-macho")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -target armv7a-apple-darwin-macho")
endif()
endif()

Expand Down Expand Up @@ -760,7 +765,7 @@ else()
set(protobuf_WITH_ZLIB OFF CACHE BOOL "Build with zlib support" FORCE)
endif()
if (onnxruntime_DISABLE_RTTI)
set(protobuf_DISABLE_RTTI OFF CACHE BOOL "Remove runtime type information in the binaries" FORCE)
set(protobuf_DISABLE_RTTI ON CACHE BOOL "Remove runtime type information in the binaries" FORCE)
endif()
add_subdirectory(${PROJECT_SOURCE_DIR}/external/protobuf/cmake EXCLUDE_FROM_ALL)

Expand Down
5 changes: 1 addition & 4 deletions cmake/onnxruntime_common.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ if(WIN32)
"${ONNXRUNTIME_ROOT}/core/platform/windows/logging/*.h"
"${ONNXRUNTIME_ROOT}/core/platform/windows/logging/*.cc"
)
# Windows platform adapter code uses advapi32, which isn't linked in by default in desktop ARM
if (NOT WINDOWS_STORE)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES advapi32)
endif()

else()
list(APPEND onnxruntime_common_src_patterns
"${ONNXRUNTIME_ROOT}/core/platform/posix/*.h"
Expand Down
13 changes: 12 additions & 1 deletion cmake/onnxruntime_nodejs.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ if(had_error)
message(FATAL_ERROR "Failed to find NPM: " ${had_error})
endif()

# setup ARCH
if (APPLE AND CMAKE_OSX_ARCHITECTURES_LEN GREATER 1)
message(FATAL_ERROR "CMake.js does not support multi-architecture for macOS")
endif()
if (APPLE AND CMAKE_OSX_ARCHITECTURES STREQUAL "arm64")
set(NODEJS_BINDING_ARCH arm64)
# elseif()
else()
set(NODEJS_BINDING_ARCH x64)
endif()

if(NOT onnxruntime_ENABLE_STATIC_ANALYSIS)
# add custom target
add_custom_target(js_npm_ci ALL
Expand All @@ -42,7 +53,7 @@ add_custom_target(js_common_npm_ci ALL

add_custom_target(nodejs_binding_wrapper ALL
COMMAND ${NPM_CLI} ci
COMMAND ${NPM_CLI} run build -- --onnxruntime-build-dir=${CMAKE_CURRENT_BINARY_DIR} --config=${CMAKE_BUILD_TYPE} --arch=x64
COMMAND ${NPM_CLI} run build -- --onnxruntime-build-dir=${CMAKE_CURRENT_BINARY_DIR} --config=${CMAKE_BUILD_TYPE} --arch=${NODEJS_BINDING_ARCH}
WORKING_DIRECTORY ${JS_NODE_ROOT}
COMMENT "Using cmake-js to build OnnxRuntime Node.js binding")
add_dependencies(js_common_npm_ci js_npm_ci)
Expand Down
2 changes: 1 addition & 1 deletion java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins {
id 'maven-publish'
id 'signing'
id 'jacoco'
id 'com.diffplug.spotless' version '5.9.0'
id 'com.diffplug.spotless' version '5.17.0'
}

allprojects {
Expand Down
8 changes: 4 additions & 4 deletions js/node/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@ add_compile_definitions(NAPI_VERSION=${napi_build_version})
add_compile_definitions(ORT_API_MANUAL_INIT)

# dist variables
execute_process(COMMAND node -e "console.log(process.arch)"
OUTPUT_VARIABLE node_arch OUTPUT_STRIP_TRAILING_WHITESPACE)
execute_process(COMMAND node -e "console.log(process.platform)"
OUTPUT_VARIABLE node_platform OUTPUT_STRIP_TRAILING_WHITESPACE)
file(READ ${CMAKE_SOURCE_DIR}/../../VERSION_NUMBER ort_version)
string(STRIP "${ort_version}" ort_version)
set(dist_folder "${CMAKE_SOURCE_DIR}/bin/napi-v3/${node_platform}/${node_arch}/")
set(dist_folder "${CMAKE_SOURCE_DIR}/bin/napi-v3/${node_platform}/${NODE_ARCH}/")

# onnxruntime.dll dir
if(NOT ONNXRUNTIME_BUILD_DIR)
if (WIN32)
set(ONNXRUNTIME_BUILD_DIR ${CMAKE_SOURCE_DIR}/../../build/Windows/${CMAKE_BUILD_TYPE})
elseif(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(ONNXRUNTIME_BUILD_DIR ${CMAKE_SOURCE_DIR}/../../build/MacOS/${CMAKE_BUILD_TYPE})
else()
set(ONNXRUNTIME_BUILD_DIR ${CMAKE_SOURCE_DIR}/../../build/Linux)
set(ONNXRUNTIME_BUILD_DIR ${CMAKE_SOURCE_DIR}/../../build/Linux/${CMAKE_BUILD_TYPE})
endif()
endif()

Expand Down
11 changes: 9 additions & 2 deletions js/node/script/build.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
import {execSync, spawnSync} from 'child_process';
import * as fs from 'fs-extra';
import minimist from 'minimist';
import * as os from 'os';
import * as path from 'path';

// command line flags
const buildArgs = minimist(process.argv.slice(2));

// --config=Debug|Release|RelWithDebInfo
const CONFIG: 'Debug'|'Release'|'RelWithDebInfo' = buildArgs.config || 'RelWithDebInfo';
const CONFIG: 'Debug'|'Release'|'RelWithDebInfo' =
buildArgs.config || (os.platform() === 'win32' ? 'RelWithDebInfo' : 'Release');
if (CONFIG !== 'Debug' && CONFIG !== 'Release' && CONFIG !== 'RelWithDebInfo') {
throw new Error(`unrecognized config: ${CONFIG}`);
}
// --arch=x64|ia32|arm64|arm
const ARCH: 'x64'|'ia32'|'arm64'|'arm' = buildArgs.arch || 'x64';
const ARCH: 'x64'|'ia32'|'arm64'|'arm' = buildArgs.arch || os.arch();
if (ARCH !== 'x64' && ARCH !== 'ia32' && ARCH !== 'arm64' && ARCH !== 'arm') {
throw new Error(`unrecognized architecture: ${ARCH}`);
}
Expand Down Expand Up @@ -49,6 +51,11 @@ if (ONNXRUNTIME_BUILD_DIR && typeof ONNXRUNTIME_BUILD_DIR === 'string') {
args.push(`--CDONNXRUNTIME_BUILD_DIR=${ONNXRUNTIME_BUILD_DIR}`);
}

// set cross-compile for arm64 on macOS
if (os.platform() === 'darwin' && ARCH === 'arm64') {
args.push('--CDCMAKE_OSX_ARCHITECTURES=arm64');
}

// launch cmake-js configure
const procCmakejs = spawnSync(command, args, {shell: true, stdio: 'inherit', cwd: ROOT_FOLDER});
if (procCmakejs.status !== 0) {
Expand Down
47 changes: 13 additions & 34 deletions onnxruntime/contrib_ops/cpu/qlinear_global_average_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ Status ComputeQLinearGlobalAvgPool(
int64_t image_size,
bool channels_last,
concurrency::ThreadPool* tp) {
static constexpr int64_t kMiniChannelGroup = 64;

if (!channels_last || C == 1) {
auto worker = [=](std::ptrdiff_t first, std::ptrdiff_t last) {
const uint8_t* input = (const uint8_t*)(x + (first * image_size));
Expand All @@ -38,38 +36,19 @@ Status ComputeQLinearGlobalAvgPool(
concurrency::ThreadPool::TryParallelFor(
tp, static_cast<std::ptrdiff_t>(N * C), {1.0 * image_size, 1.0, 8.0 * image_size}, worker);
} else {
if (N == 1) {
int64_t channel_padded = (C + kMiniChannelGroup - 1) & (~(kMiniChannelGroup - 1));
int64_t channel_groups = channel_padded / kMiniChannelGroup;
auto worker = [=](std::ptrdiff_t first, std::ptrdiff_t last) {
std::vector<int32_t> acc_buffer(MlasQLinearSafePaddingElementCount(sizeof(int32_t), C));
std::vector<uint8_t> zero_buffer(MlasQLinearSafePaddingElementCount(sizeof(uint8_t), C), 0);
const uint8_t* input = x + first * kMiniChannelGroup;
uint8_t* output = y + first * kMiniChannelGroup;
int64_t channel_count = (last == channel_groups) ? (C - first * kMiniChannelGroup) : ((last - first) * kMiniChannelGroup);
MlasQLinearGlobalAveragePoolNhwc(
input, x_scale, x_zero_point, output, y_scale, y_zero_point,
N, image_size, C, channel_count, acc_buffer.data(), zero_buffer.data());
};
concurrency::ThreadPool::TryParallelFor(
tp, static_cast<std::ptrdiff_t>(channel_groups),
{1.0 * N * image_size * kMiniChannelGroup, 1.0 * N * kMiniChannelGroup, 8.0 * N * image_size * kMiniChannelGroup},
worker);
} else {
auto worker = [=](std::ptrdiff_t first, std::ptrdiff_t last) {
const uint8_t* input = x + first * C * image_size;
uint8_t* output = y + first * C;
std::vector<int32_t> acc_buffer(MlasQLinearSafePaddingElementCount(sizeof(int32_t), C));
std::vector<uint8_t> zero_buffer(MlasQLinearSafePaddingElementCount(sizeof(uint8_t), C), 0);
MlasQLinearGlobalAveragePoolNhwc(
input, x_scale, x_zero_point, output, y_scale, y_zero_point,
last - first, image_size, C, C, acc_buffer.data(), zero_buffer.data());
};
concurrency::ThreadPool::TryParallelFor(
tp, static_cast<std::ptrdiff_t>(N),
{1.0 * image_size * C, 1.0 * C, 8.0 *image_size * C},
worker);
}
auto worker = [=](std::ptrdiff_t first, std::ptrdiff_t last) {
const uint8_t* input = x + first * C * image_size;
uint8_t* output = y + first * C;
std::vector<int32_t> acc_buffer(MlasQLinearSafePaddingElementCount(sizeof(int32_t), C));
std::vector<uint8_t> zero_buffer(MlasQLinearSafePaddingElementCount(sizeof(uint8_t), C), 0);
MlasQLinearGlobalAveragePoolNhwc(
input, x_scale, x_zero_point, output, y_scale, y_zero_point,
last - first, image_size, C, C, acc_buffer.data(), zero_buffer.data());
};
concurrency::ThreadPool::TryParallelFor(
tp, static_cast<std::ptrdiff_t>(N),
{1.0 * image_size * C, 1.0 * C, 8.0 *image_size * C},
worker);
}
return Status::OK();
}
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/graph/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,10 @@ static void InitNestedModelLocalFunction(onnxruntime::Graph& graph,
}
ORT_CATCH(const std::exception& e) {
LOGS(logger, WARNING) << "Function body initialization failed for Function '"
<< onnx_function_proto.name() << "'. Error message " << e.what()
<< onnx_function_proto.name()
#ifndef ORT_NO_EXCEPTIONS
<< "'. Error message " << e.what()
#endif //ORT_NO_EXCEPTIONS
<< ". Execution will fail if ORT does not have a specialized kernel for this op";
// Return without using this function op's expansion. No need to fail just yet.
// If ORT has a specialized kernel for this op then execution will proceed
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2537,7 +2537,9 @@ void Graph::InitFunctionBodyForNode(Node& node) {
ORT_CATCH(const std::exception& e) {
LOGS(logger_, WARNING) << "Function body initialization failed for node '"
<< node.Name() << "' optype " << node.OpType()
#ifndef ORT_NO_EXCEPTIONS
<< ". Error message " << e.what()
#endif //ORT_NO_EXCEPTIONS
<< ". Execution will fail if ORT does not have a specialized kernel for this op";
// Return without using this function op's expansion. No need to fail just yet.
// If ORT has a specialized kernel for this op then execution will proceed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,15 +445,15 @@ namespace Dml

HRESULT STDMETHODCALLTYPE ExecutionProviderImpl::FillTensorWithPattern(
IMLOperatorTensor* dst,
gsl::span<const std::byte> value // Data type agnostic value, treated as raw bits
gsl::span<const std::byte> rawValue // Data type agnostic rawValue, treated as raw bits
) const noexcept try
{
auto mlTensor = MLOperatorTensor(dst).GetDataInterface();
if (mlTensor != nullptr)
{
const AllocationInfo* dstAllocInfo = m_allocator->DecodeDataHandle(mlTensor.Get());
ID3D12Resource* dstData = dstAllocInfo->GetResource();
m_context->FillBufferWithPattern(dstData, value);
m_context->FillBufferWithPattern(dstData, rawValue);
}

return S_OK;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ namespace Dml

STDMETHOD(FillTensorWithPattern)(
IMLOperatorTensor* dst,
gsl::span<const std::byte> value
gsl::span<const std::byte> rawValue
) const noexcept final;

STDMETHOD(UploadToResource)(ID3D12Resource* dstData, const void* srcData, uint64_t srcDataSize) const noexcept final;
Expand Down
Loading

0 comments on commit 8dd4d20

Please sign in to comment.