Skip to content

Commit

Permalink
[GPU] Fix double jit constants
Browse files Browse the repository at this point in the history
  • Loading branch information
p-durandin committed Oct 3, 2024
1 parent 2507d89 commit f70edc0
Showing 1 changed file with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para
size_t tile_k_ofm_packed = tile_k_ofm;
size_t quantize_grp_size = get_dynamic_quantize_group_size(params);

bool add_decompress_scale_post_op = false;
WeightsType weights_dt = params.weights.GetDType();
if (weights_dt == WeightsType::UINT4 || weights_dt == WeightsType::INT4) {
tile_k_ofm_packed /= 2;
Expand All @@ -542,7 +543,7 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para
const size_t scale_group_size = params.weights.IFM().v / params.decompression_scale.Feature().v;
// Do not use SCALE_POST_OP for SLM kernel, since it demonstrates worse performance
if (scale_group_size % simd == 0 && !dispatchData.use_slm)
jit.AddConstant(MakeJitConstant("DECOMPRESSION_SCALE_POST_OP", 1));
add_decompress_scale_post_op = true;
}
if (params.weights.GetLayout() == WeightsLayout::os_is_yx_osv32_isv2) {
jit.AddConstant(MakeJitConstant("W_IDX", "fi * TILE_K + kii"));
Expand Down Expand Up @@ -619,6 +620,8 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para
jit.AddConstant(MakeJitConstant("DQ_TYPE", "char"));
jit.AddConstant(MakeJitConstant("QUANTIZE_GROUP_SIZE", quantize_grp_size));
} else {
if (add_decompress_scale_post_op)
jit.AddConstant(MakeJitConstant("DECOMPRESSION_SCALE_POST_OP", 1));
jit.AddConstant(MakeJitConstant("DYNAMIC_QUANTIZE", 0));
jit.AddConstant(MakeJitConstant("QUANTIZE_GROUP_SIZE", min_quantize_grp_size));
}
Expand Down

0 comments on commit f70edc0

Please sign in to comment.