From 7489bfee5324aa38e182c0fef6e388ac07ff0432 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erick=20Mu=C3=B1oz?= Date: Mon, 9 Sep 2024 22:19:31 -0600 Subject: [PATCH] Enable AVX NE CONVERT for FP16 to FP32 cast (#21183) ### Description Implementation of a new cast assembly kernel that uses AVX_NE_CONVERT instructions to accelerate casting from FP16 to FP32. Added CPUID checks to determine support of the ISA. ### Motivation and Context Currently FP16 models executed on systems that lack complete FP16 operator support use single precision on every node to run the model, this means the original FP16 weights have to be casted to FP32 in order to run the model properly, this change aims to accelerate the casting by using upconvert instructions and therefore improve performance. --- cmake/onnxruntime_mlas.cmake | 19 +++ onnxruntime/core/mlas/inc/mlas.h | 3 +- .../core/mlas/lib/amd64/cvtfp16Avx.asm | 151 ++++++++++++++++++ onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm | 6 +- onnxruntime/core/mlas/lib/cast.cpp | 59 +++++++ onnxruntime/core/mlas/lib/mlasi.h | 14 ++ onnxruntime/core/mlas/lib/platform.cpp | 14 ++ onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S | 143 +++++++++++++++++ onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S | 129 +++++++++++++++ .../core/providers/cpu/tensor/cast_op.cc | 10 +- 10 files changed, 537 insertions(+), 11 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm create mode 100644 onnxruntime/core/mlas/lib/cast.cpp create mode 100644 onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S create mode 100644 onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index c02ac2096db2..cf23416943c1 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -40,6 +40,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/sqnbitgemm.cpp ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp + ${MLAS_SRC_DIR}/cast.cpp ) target_sources(onnxruntime_mlas PRIVATE @@ -212,6 +213,12 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm ${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm ) + if(MSVC_VERSION GREATER_EQUAL 1933) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm + ) + endif() + if (NOT onnxruntime_ORT_MINIMAL_BUILD) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/q4gemm_avx512.cpp @@ -522,6 +529,12 @@ else() ${MLAS_SRC_DIR}/x86_64/SconvKernelSse2.S ${MLAS_SRC_DIR}/x86_64/SpoolKernelSse2.S ) + if(NOT APPLE) + set(mlas_platform_srcs_sse2 + ${mlas_platform_srcs_sse2} + ${MLAS_SRC_DIR}/x86_64/cvtfp16a.S + ) + endif() set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") set(mlas_platform_srcs_avx @@ -555,6 +568,12 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ) + if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) + set(mlas_platform_srcs_avx2 + ${mlas_platform_srcs_avx2} + ${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S + ) + endif() message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index bea4b91ebaa7..8b3156d77e57 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1029,14 +1029,13 @@ MlasComputeTanh( // Half-precision floating-point routines. // -extern "C" void MLASCALL MlasConvertHalfToFloatBuffer( const unsigned short* Source, float* Destination, size_t Count - ); +); // // Transpose routines. diff --git a/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm new file mode 100644 index 000000000000..c7f6342c527b --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm @@ -0,0 +1,151 @@ +;++ +; +; Copyright (c) Intel Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; cvtfp16Avx2.asm +; +; Abstract: +; +; This module implements routines to convert between FP16 and FP32 formats using the AVX_NE_CONVERT ISA. +; +;-- + + .xlist +INCLUDE mlasi.inc + .list + + .const + +SINGLE_SIZE equ 4 +HALF_SIZE equ 2 +LOW_SELECTOR equ 00100000b +HIGH_SELECTOR equ 00110001b + + SUBTTL "Convert buffer of half-precision floats to single-precision floats" +;++ +; +; Routine Description: +; +; This routine converts the source buffer of half-precision floats to the +; destination buffer of single-precision floats. +; +; This implementation uses AVX2 instructions. +; +; Arguments: +; +; Source (rcx) - Supplies the address of the source buffer of half-precision +; floats. +; +; Destination (rdx) - Supplies the address of the destination buffer of +; single-precision floats. +; +; Count (r8) - Supplies the number of elements to convert. +; +; Return Value: +; +; None. +; +;-- + + +LEAF_ENTRY MlasCastF16ToF32KernelAvx, _TEXT + + test r8, r8 ; Check if we have any elements to convert + jz ExitRoutine + cmp r8, 8 + jb ConvertMaskedVectors + cmp r8, 16 + jb Convert128Vectors + + + +Convert256Vectors: + vcvtneeph2ps ymm0, ymmword PTR [rcx] ; Load even indexes + vcvtneoph2ps ymm1, ymmword PTR [rcx] ; Load odd indexes + vunpcklps ymm2, ymm0, ymm1 ; Interleave low part + vunpckhps ymm1, ymm0, ymm1 ; Interleave high part + vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR ; Fix the order + vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR ; Fix the order + vmovups ymmword PTR [rdx], ymm0 ; Store the low part + vmovups ymmword PTR [rdx + 8*SINGLE_SIZE], ymm1 ; Store the high part + + add rcx, 16*HALF_SIZE ; Advance src ptr by 16 elements + add rdx, 16*SINGLE_SIZE ; Advance dest ptr by 16 elements + sub r8, 16 ; Reduce the counter by 16 elements + + jz ExitRoutine ; If we are done, exit + cmp r8, 16 ; If the vector is big enough, we go again + jae Convert256Vectors + + + +Convert128Vectors: + vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes + vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order + vmovups xmmword PTR [rdx], xmm0 ; Store the low part + vmovups xmmword PTR [rdx + 4*SINGLE_SIZE], xmm1 ; Store the high part + + add rcx, 8*HALF_SIZE ; Advance src ptr by 8 elements + add rdx, 8*SINGLE_SIZE ; Advance dest ptr by 8 elements + sub r8, 8 ; Reduce the counter by 8 elements + + jz ExitRoutine ; If we are done, exit + + + +ConvertMaskedVectors: + vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes + vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order + + cmp r8, 4 ; Check if we can store the complete lower vector + jae ConvertLowerVector + + vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones + cmp r8, 2 ; Check how many converts we need + jb ConvertLower1 + ja ConvertLower3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values + jmp ConvertLowerMaskedVector +ConvertLower1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value + jmp ConvertLowerMaskedVector +ConvertLower3: + vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values +ConvertLowerMaskedVector: + vmaskmovps xmmword PTR [rdx], xmm2, xmm0 ; Store the masked data, the shift is done in 8bit multiples + jmp ExitRoutine ; If we ran into any of the cases above, means we are done after storing +ConvertLowerVector: + vmovups xmmword PTR [rdx], xmm0 ; Store the low part + sub r8, 4 ; Check if we still need to convert + jz ExitRoutine + + + add rdx, 4*SINGLE_SIZE ; Advance dest ptr by 4 elements + vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones + cmp r8, 2 ; Check how many converts we need + jb ConvertUpper1 + ja ConvertUpper3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values + jmp ConvertMaskedUpperVector +ConvertUpper1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value + jmp ConvertMaskedUpperVector +ConvertUpper3: + vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values +ConvertMaskedUpperVector: + vmaskmovps xmmword PTR [rdx], xmm2, xmm1 ; Store the masked data, the shift is done in 8bit multiples + +ExitRoutine: + ret + + LEAF_END MlasCastF16ToF32KernelAvx, _TEXT + + END diff --git a/onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm b/onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm index 50315146ca79..0ad98d311520 100644 --- a/onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm +++ b/onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm @@ -42,7 +42,7 @@ MlasFp16MagicDenormal DD 4 DUP (38800000h) ; Source (rcx) - Supplies the address of the source buffer of half-precision ; floats. ; -; Destination (edx) - Supplies the address of the destination buffer of +; Destination (rdx) - Supplies the address of the destination buffer of ; single-precision floats. ; ; Count (r8) - Supplies the number of elements to convert. @@ -53,7 +53,7 @@ MlasFp16MagicDenormal DD 4 DUP (38800000h) ; ;-- - LEAF_ENTRY MlasConvertHalfToFloatBuffer, _TEXT + LEAF_ENTRY MlasCastF16ToF32KernelSse, _TEXT test r8,r8 jz ExitRoutine @@ -119,6 +119,6 @@ StoreLastElement: ExitRoutine: ret - LEAF_END MlasConvertHalfToFloatBuffer, _TEXT + LEAF_END MlasCastF16ToF32KernelSse, _TEXT END diff --git a/onnxruntime/core/mlas/lib/cast.cpp b/onnxruntime/core/mlas/lib/cast.cpp new file mode 100644 index 000000000000..24af4064bbd9 --- /dev/null +++ b/onnxruntime/core/mlas/lib/cast.cpp @@ -0,0 +1,59 @@ +/*++ + +Copyright (c) Intel Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + cast.cpp + +Abstract: + + This module implements Half (F16) to Single (F32) precision casting. + +--*/ +#include "mlasi.h" + +union fp32_bits { + uint32_t u; + float f; +}; + +void +MLASCALL +MlasConvertHalfToFloatBuffer( + const unsigned short* Source, + float* Destination, + size_t Count +) +{ + + if (GetMlasPlatform().CastF16ToF32Kernel == nullptr) { + // If there is no kernel use the reference implementation, adapted from mlas_float16.h. + constexpr fp32_bits magic = {113 << 23}; + constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift + + for (size_t i = 0; i < Count; ++i) { + fp32_bits o; + o.u = (Source[i] & 0x7fff) << 13; // exponent/mantissa bits + uint32_t exp = shifted_exp & o.u; // just the exponent + o.u += (127 - 15) << 23; // exponent adjust + + // handle exponent special cases + if (exp == shifted_exp) { // Inf/NaN? + o.u += (128 - 16) << 23; // extra exp adjust + } else if (exp == 0) { // Zero/Denormal? + o.u += 1 << 23; // extra exp adjust + o.f -= magic.f; // renormalize + } + + o.u |= (Source[i] & 0x8000) << 16; // sign bit + Destination[i] = o.f; + } + + } else { + // If the kernel is available, use it to perform the conversion. + GetMlasPlatform().CastF16ToF32Kernel(Source, Destination, Count); + } +} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 4239e2ecaeb6..6f5db766b7de 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -610,6 +610,13 @@ void size_t N ); +typedef +void(MLASCALL MLAS_CAST_F16_TO_F32_KERNEL)( + const unsigned short* Source, + float* Destination, + size_t Count +); + typedef void (MLASCALL MLAS_QLINEAR_BINARY_OP_S8_KERNEL)( @@ -870,6 +877,11 @@ extern "C" { MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32KernelAvx; #endif +#if defined(MLAS_TARGET_AMD64) + MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelSse; + MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx; +#endif + } // @@ -1151,6 +1163,8 @@ struct MLAS_PLATFORM { const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; + + MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index ed437f20f7c2..4cd7faaa9e6f 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -244,6 +244,7 @@ Return Value: this->ConvDepthwiseU8U8Kernel = MlasConvDepthwiseKernel; this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernel; this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernel; + this->CastF16ToF32Kernel = nullptr; #if defined(MLAS_TARGET_AMD64_IX86) @@ -283,6 +284,9 @@ Return Value: this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; +#ifndef __APPLE__ + this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse; +#endif // __APPLE__ this->NchwcBlockSize = 8; this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT; @@ -469,6 +473,16 @@ Return Value: } #ifndef __APPLE__ +#if (defined(_MSC_VER) && (_MSC_VER >= 1933)) || (defined(__GNUC__) && (__GNUC__ >= 13)) + // + // Check if the processor supports AVX NE CONVERT. + // + if ((Cpuid7_1[3] & (0b1 << 5)) != 0) { + this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx; + } +#endif // (defined(_MSC_VER) && (_MSC_VER >= 1933)) || (defined(__GNUC__) && (__GNUC__ >= 13)) + + // // Check if the processor supports AMX-TILE and AMX-INT8 // features. diff --git a/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S b/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S new file mode 100644 index 000000000000..1a70061460e5 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S @@ -0,0 +1,143 @@ +/*++ + +Copyright (c) Intel Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + cvtfp16Avx2.asm + +Abstract: + + This module implements routines to convert between FP16 and FP32 formats using the AVX_NE_CONVERT ISA. + +--*/ + +#include "asmmacro.h" + +.data +.equ SINGLE_SIZE, 4 +.equ HALF_SIZE, 2 +.equ LOW_SELECTOR, 0b00100000 +.equ HIGH_SELECTOR, 0b00110001 + +.text +.intel_syntax noprefix + +/*++ Routine Description: + + This routine converts the source buffer of half-precision floats to the + destination buffer of single-precision floats. + + This implementation uses AVX2 instructions. + + Arguments: + + Source (rdi) - Supplies the address of the source buffer of half-precision + floats. + + Destination (rsi) - Supplies the address of the destination buffer of + single-precision floats. + + Count (rdx) - Supplies the number of elements to convert. + + Return Value: + + None. + +--*/ +FUNCTION_ENTRY MlasCastF16ToF32KernelAvx + + test rdx, rdx // Check if we have any elements to convert + jz ExitRoutine + +AVX_NE_CONVERT: + cmp rdx, 8 + jb ConvertMaskedVectors + cmp rdx, 16 + jb Convert128Vectors + +Convert256Vectors: + vcvtneeph2ps ymm0, ymmword PTR [rdi] // Load even indexes + vcvtneoph2ps ymm1, ymmword PTR [rdi] // Load odd indexes + vunpcklps ymm2, ymm0, ymm1 // Interleave low part + vunpckhps ymm1, ymm0, ymm1 // Interleave high part + vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR // Fix the order + vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR // Fix the order + vmovups ymmword PTR [rsi], ymm0 // Store the low part + vmovups ymmword PTR [rsi + 8*SINGLE_SIZE], ymm1 // Store the high part + + add rdi, 16*HALF_SIZE // Advance src ptr by 16 elements + add rsi, 16*SINGLE_SIZE // Advance dest ptr by 16 elements + sub rdx, 16 // Reduce the counter by 16 elements + + jz ExitRoutine // If we are done, exit + cmp rdx, 16 // If the vector is big enough, we go again + jae Convert256Vectors + + + +Convert128Vectors: + vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes + vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order + vmovups xmmword PTR [rsi], xmm0 // Store the low part + vmovups xmmword PTR [rsi + 4*SINGLE_SIZE], xmm1 // Store the high part + + add rdi, 8*HALF_SIZE // Advance src ptr by 8 elements + add rsi, 8*SINGLE_SIZE // Advance dest ptr by 8 elements + sub rdx, 8 // Reduce the counter by 8 elements + + jz ExitRoutine // If we are done, exit + + + +ConvertMaskedVectors: + vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes + vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order + + cmp rdx, 4 // Check if we can store the complete lower vector + jae ConvertLowerVector + + vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones + cmp rdx, 2 // Check how many converts we need + jb ConvertLower1 + ja ConvertLower3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values + jmp ConvertLowerMaskedVector +ConvertLower1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value + jmp ConvertLowerMaskedVector +ConvertLower3: + vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values +ConvertLowerMaskedVector: + vmaskmovps xmmword PTR [rsi], xmm2, xmm0 // Store the masked data, the shift is done in 8bit multiples + jmp ExitRoutine // If we ran into any of the cases above, means we are done after storing +ConvertLowerVector: + vmovups xmmword PTR [rsi], xmm0 // Store the low part + sub rdx, 4 // Check if we still need to convert + jz ExitRoutine + + + add rsi, 4*SINGLE_SIZE // Advance dest ptr by 4 elements + vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones + cmp rdx, 2 // Check how many converts we need + jb ConvertUpper1 + ja ConvertUpper3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values + jmp ConvertMaskedUpperVector +ConvertUpper1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value + jmp ConvertMaskedUpperVector +ConvertUpper3: + vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values +ConvertMaskedUpperVector: + vmaskmovps xmmword PTR [rsi], xmm2, xmm1 // Store the masked data, the shift is done in 8bit multiples + + jmp ExitRoutine +ExitRoutine: + ret diff --git a/onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S b/onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S new file mode 100644 index 000000000000..f27114c183f4 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S @@ -0,0 +1,129 @@ +/*++ + +Copyright (c) Intel Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + cvtfp16a.S + +Abstract: + + This module implements routines to convert between FP16 and FP32 formats using SSE2 isntructions. + +--*/ + +#include "asmmacro.h" + +// We use RIP relative addressing to avoid reallication related errors +.section .rodata +MlasFp16MaskSign: .long 0x00007FFF +MlasFp16CompareInfinity: .long 0x00007C00 +MlasFp16CompareSmallest: .long 0x00000400 +MlasFp16AdjustExponent: .long 0x38000000 +MlasFp16MagicDenormal: .long 0x38800000 + +.text +.intel_syntax noprefix + +/*++ Routine Description: + + This routine converts the source buffer of half-precision floats to the + destination buffer of single-precision floats. + + This implementation uses SSE2 instructions. + + Arguments: + + Source (rdi) - Supplies the address of the source buffer of half-precision + floats. + + Destination (rsi) - Supplies the address of the destination buffer of + single-precision floats. + + Count (rdx) - Supplies the number of elements to convert. + + Return Value: + + None. + +--*/ + +FUNCTION_ENTRY MlasCastF16ToF32KernelSse + + test rdx,rdx + jz ExitRoutine + + // Load xmm constants + movd xmm5, DWORD PTR [rip + MlasFp16MaskSign] + pshufd xmm5, xmm5, 0x00 + movd xmm6, DWORD PTR [rip + MlasFp16AdjustExponent] + pshufd xmm6, xmm6, 0x00 + movd xmm7, DWORD PTR [rip + MlasFp16MagicDenormal] + pshufd xmm7, xmm7, 0x00 + + + cmp rdx,4 + jb LoadPartialVector + +LoadFullVector: + movq xmm0,QWORD PTR [rdi] + add rdi,4*2 // advance S by 4 elements + +ConvertHalfToFloat: + punpcklwd xmm0,xmm0 // duplicate 4 WORDs to 4 DWORDs + movaps xmm1,xmm0 // isolate exponent/mantissa + pand xmm1,xmm5 + pxor xmm0,xmm1 // isolate sign bit + movd xmm2, DWORD PTR [rip + MlasFp16CompareInfinity] + pshufd xmm2, xmm2, 0x00 + pcmpgtd xmm2,xmm1 // test for infinity/NaNs + movd xmm3, DWORD PTR [rip + MlasFp16CompareSmallest] + pshufd xmm3, xmm3, 0x00 + pcmpgtd xmm3,xmm1 // test for denormals + pandn xmm2,xmm6 + pslld xmm1,13 // shift exponent/mask into place + movaps xmm4,xmm1 + paddd xmm1,xmm6 + paddd xmm1,xmm2 // adjust exponent again for infinity/NaNs + paddd xmm4,xmm7 + pslld xmm0,16 // shift sign into place + subps xmm4,xmm7 + pand xmm4,xmm3 // select elements that are denormals + pandn xmm3,xmm1 // select elements that are not denormals + por xmm3,xmm4 // blend the selected values together + por xmm0,xmm3 // merge sign into exponent/mantissa + + cmp rdx,4 // storing full vector? + jb StorePartialVector + movups XMMWORD PTR [rsi],xmm0 + add rsi,4*4 // advance D by 4 elements + sub rdx,4 + jz ExitRoutine + cmp rdx,4 + jae LoadFullVector + +LoadPartialVector: + pxor xmm0,xmm0 + pinsrw xmm0,WORD PTR [rdi],0 + cmp rdx,2 + jb ConvertHalfToFloat + pinsrw xmm0,WORD PTR [rdi+2],1 + je ConvertHalfToFloat + pinsrw xmm0,WORD PTR [rdi+4],2 + jmp ConvertHalfToFloat + +StorePartialVector: + cmp rdx,2 + jb StoreLastElement + movsd QWORD PTR [rsi],xmm0 + je ExitRoutine + movhlps xmm0,xmm0 // shift third element down + add rsi,4*2 // advance D by 2 elements + +StoreLastElement: + movss DWORD PTR [rsi],xmm0 + +ExitRoutine: + ret diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 6742bab4fa4a..f2aaa75cadd8 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -22,9 +22,8 @@ #include "Eigen/src/Core/arch/Default/BFloat16.h" #include "Eigen/src/Core/arch/Default/Half.h" -#if defined(_M_AMD64) && !defined(_M_ARM64EC) #include "core/mlas/inc/mlas.h" -#endif +#include "core/common/cpuid_info.h" namespace onnxruntime { @@ -252,10 +251,6 @@ struct TensorCasterNoSat { #endif -#if defined(_M_AMD64) && !defined(_M_ARM64EC) -// specializations to use optimized and Windows x64-specific -// MlasConvertHalfToFloatBuffer() routine for MLFloat16 -> float conversion - // tensor MLFloat16 -> float template <> struct TensorCaster { @@ -267,6 +262,9 @@ struct TensorCaster { } }; +#if defined(_M_AMD64) && !defined(_M_ARM64EC) +// specializations to use optimized and Windows x64-specific + Tensor GetIntermediateMLFloat16ToFloatTensor( const OpKernelContext& context, const TensorShape& shape, const Tensor& in) { AllocatorPtr allocator;