Skip to content

Commit

Permalink
MatMulNBits + Add fusion (microsoft#20587)
Browse files Browse the repository at this point in the history
- Add MatMulNBits Bias input
- Add graph transformer to fuse MatMulNBits + Add
  • Loading branch information
edgchen1 authored May 16, 2024
1 parent 1e1b3f9 commit e81c867
Show file tree
Hide file tree
Showing 28 changed files with 880 additions and 256 deletions.
2 changes: 2 additions & 0 deletions cmake/onnxruntime_optimizer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ if (onnxruntime_MINIMAL_BUILD)
"${ONNXRUNTIME_ROOT}/core/optimizer/graph_transformer_utils.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/initializer.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/initializer.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/matmul_nbits_fusion.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/matmul_nbits_fusion.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/nhwc_transformer.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/nhwc_transformer.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/qdq_final_cleanup.cc"
Expand Down
10 changes: 5 additions & 5 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -876,11 +876,11 @@ if (MSVC)
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd26451>")
target_compile_options(onnxruntime_test_all PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd4244>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd4244>")
# Avoid the error for Win arm64 build. error C1128: number of sections exceeded object file format limit: compile with /bigobj
string(TOLOWER ${onnxruntime_target_platform} GEN_PLATFORM)
if (${GEN_PLATFORM} STREQUAL "arm64")
target_compile_options(onnxruntime_test_all PRIVATE "/bigobj")
endif()

# Avoid this compile error in graph_transform_test.cc:
# fatal error C1128: number of sections exceeded object file format limit: compile with /bigobj
set_property(SOURCE "${TEST_SRC_DIR}/optimizer/graph_transform_test.cc"
APPEND PROPERTY COMPILE_OPTIONS "/bigobj")
else()
target_compile_options(onnxruntime_test_all PRIVATE "-Wno-parentheses")
endif()
Expand Down
4 changes: 3 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2915,7 +2915,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.</dd>
</dl>

#### Inputs (3 - 5)
#### Inputs (3 - 6)

<dl>
<dt><tt>A</tt> : T1</dt>
Expand All @@ -2928,6 +2928,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>quantization zero points</dd>
<dt><tt>g_idx</tt> (optional) : T4</dt>
<dd>group_idx</dd>
<dt><tt>bias</tt> (optional) : T1</dt>
<dd>Bias to add to result. It should have shape [N].</dd>
</dl>

#### Outputs
Expand Down
6 changes: 3 additions & 3 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ Do not modify directly.*
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(float), tensor(uint8)<br/> **T4** = tensor(int32)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(float), tensor(uint8)<br/> **T4** = tensor(int32)|
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|
Expand Down Expand Up @@ -880,7 +880,7 @@ Do not modify directly.*
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
Expand Down Expand Up @@ -1304,7 +1304,7 @@ Do not modify directly.*
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**M** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
Expand Down
Loading

0 comments on commit e81c867

Please sign in to comment.