You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This RFC proposes to use BFloat16 for GEMM/CONV/RNN internal computations on CPU device with user controlled frontend API. Currently, we have torch.set_float32_matmul_precision which allow float32 matrix multiplications in lower precision.
highest -> Do not use lower precision
high -> Use TF32 as the internal computation data type.
medium -> Designed to use BF16 as the internal computation data type.
To allow CONV/RNN to also have an internal computation data type for float32 and integrate mkldnn BF16 as an internal computation data type with GEMM/CONV/RNN on CPU device, we proposed below high-level code changes.
These frontend API should work under the same behavior with torch.set_float32_matmul_precision and torch.get_float32_matmul_precision. Users can set the precision to highest, high, and medium. When the precision is high, CUDA/CUDNN backend will be allowed to use TF32 as the internal computation data type. When the precision is medium, the MKLDNN backend will be allowed to use BF16 as the internal computation data type.
Backend changes:
For matmul. Currently, we only dispatch at::matmul to mkldnn_matmul when input tensors are BFloat16. We propose to further dispatch at::matmul to mkldnn_matmul when:
(1)float32_matmul_precision is medium and
(2) Input tensors are float32
Then We will use BF16 as the internal computation data type, PR is already created.
For Conv. We will check float32_conv_precision in mkldnn_conv and will use BF16 as the internal computation data type.
For RNN. We will check float32_rnn_precision in mkldnn_rnn_layer and will use BF16 as the internal computation data type.
Inductor changes:
We will packaddmm/mm to mkldnn_linear when float32_matmul_precision is medium
Motivation
A new instruction set of BF16 TMUL on Intel XEON server product can improve user application performance. With these frontend API, users can control internal computation data types for GEMM/CONV/RNN even when the model's data type is FLoat32. This will
Have higher precision compared with Autocast features since only GEMM/CONV/RNN can have BF16 internal computation data types while for Autocast, more ops might be computed at the BF16 level.
Users can enable BF16 without finding a place to enable autocast in model scripts.
Pitch
Provide float32_conv_precision and float32_rnn_precision and enable bfloat16 datatype for internal computations with MKLDNN backend when precision is set to medium
Additional context
Design option
Front end API:
option 1: provide backend irrelevant API get/set_float32_conv/rnn_precision like float32_matmul_precision.
Pros:
The user-facing API is unified. Users can use lower-precision computation data types without knowing the backend details.
Cons:
Less of a fine-grained controller for different backend.
option 2: provide allow_bf32 in the mkldnn backend like allow_tf32 in cudnn backend.
Pros:
Find-grained controller: The user will be able to run BF16 as internal computation datatypes on CPU and run FP32 datatypes on the GPU if the model is distributed on multiple kinds of devices.
Cons:
The Users need to learn about different backend details and more code changes in their app.
Design option
Inductor linear packable rules:
option 1: Only pack it to mkldnn_linear when presion is medium.
Pros:
No performance changes for pure FP32 case. No regression risks.
Cons:
Less of fusion opportunities.
option 2: Always pack it to mkldnn_linear.
Pros:
mkldnn_linear will introduce more fusion opportunities.
Cons:
May have regression risks for pure FP32 case.
The text was updated successfully, but these errors were encountered:
zhuhaozhe
changed the title
1
Ese the bfloat16 datatype (8 mantissa bits) for internal computations with GEMM/Conv/RNN
Jan 17, 2024
zhuhaozhe
changed the title
Ese the bfloat16 datatype (8 mantissa bits) for internal computations with GEMM/Conv/RNN
USE the bfloat16 datatype (8 mantissa bits) for internal computations with GEMM/Conv/RNN
Jan 17, 2024
zhuhaozhe
changed the title
USE the bfloat16 datatype (8 mantissa bits) for internal computations with GEMM/Conv/RNN
USE the bfloat16 datatype (8 mantissa bits) for internal computations with GEMM/CONV/RNN
Jan 17, 2024
🚀 The Feature
This RFC proposes to use
BFloat16
forGEMM/CONV/RNN
internal computations onCPU
device with user controlled frontend API. Currently, we havetorch.set_float32_matmul_precision
which allow float32 matrix multiplications in lower precision.To allow
CONV/RNN
to also have an internal computation data type for float32 and integratemkldnn BF16
as an internal computation data type withGEMM/CONV/RNN
on CPU device, we proposed below high-level code changes.Frontend changes:
torch.set_float32_conv_precision
,torch.get_float32_conv_precision
torch.set_float32_rnn_precision
,torch.get_float32_rnn_precision
These frontend API should work under the same behavior with
torch.set_float32_matmul_precision
andtorch.get_float32_matmul_precision
. Users can set the precision tohighest
,high
, andmedium
. When the precision ishigh
, CUDA/CUDNN backend will be allowed to useTF32
as the internal computation data type. When the precision ismedium
, the MKLDNN backend will be allowed to useBF16
as the internal computation data type.Backend changes:
matmul
. Currently, we only dispatchat::matmul
tomkldnn_matmul
when input tensors are BFloat16. We propose to further dispatchat::matmul
tomkldnn_matmul
when:float32_matmul_precision
ismedium
andThen We will use
BF16
as the internal computation data type, PR is already created.For
Conv
. We will checkfloat32_conv_precision
inmkldnn_conv
and will useBF16
as the internal computation data type.For
RNN
. We will checkfloat32_rnn_precision
inmkldnn_rnn_layer
and will useBF16
as the internal computation data type.Inductor changes:
addmm/mm
tomkldnn_linear
whenfloat32_matmul_precision
ismedium
Motivation
A new instruction set of BF16 TMUL on Intel XEON server product can improve user application performance. With these frontend API, users can control internal computation data types for
GEMM/CONV/RNN
even when the model's data type isFLoat32
. This willAutocast
features since onlyGEMM/CONV/RNN
can haveBF16
internal computation data types while forAutocast
, more ops might be computed at theBF16
level.Pitch
Provide
float32_conv_precision
andfloat32_rnn_precision
and enable bfloat16 datatype for internal computations withMKLDNN
backend when precision is set tomedium
Additional context
Design option
Front end API:
get/set_float32_conv/rnn_precision
likefloat32_matmul_precision
.allow_bf32
in themkldnn
backend likeallow_tf32
incudnn
backend.Design option
Inductor linear packable rules:
mkldnn_linear
whenpresion
ismedium
.mkldnn_linear
.mkldnn_linear
will introduce more fusion opportunities.The text was updated successfully, but these errors were encountered: