From 1b6d2c7591f4985e1d8eb98672583b528eb572c9 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Fri, 19 Jul 2024 11:12:17 -0700 Subject: [PATCH] resolve comments --- .../selectors_actions/qdq_actions.cc | 48 ++++++++++--------- .../selectors_actions/qdq_selectors.cc | 6 ++- .../optimizer/selectors_actions/actions.h | 3 ++ 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 3c0902889b1d..829d1858adbb 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -289,6 +289,7 @@ DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction(int64_t accuracy_level, intra_op_thread_pool_{intra_op_thread_pool} { ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); +// Webassembly only has a global thread pool. It's not possible to create a new thread pool. #if !defined(__wasm__) if (!intra_op_thread_pool) { OrtThreadPoolParams to; @@ -321,6 +322,7 @@ DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, const NodesToOptimize& selected_nodes, Node& replacement_node) const { +// Webassembly only has a global thread pool. To call into this method, the global thread pool must pass in. #if defined(__wasm__) ORT_RETURN_IF_NOT(intra_op_thread_pool_, "Thread pool is required for DQMatMulToMatMulNBitsAction"); #endif @@ -351,24 +353,24 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, // to what we need. But it does not handle external data. Initializer weight_src(*weight_tensor_proto, graph.ModelPath()); Initializer scale_src(*scale_tensor_proto, graph.ModelPath()); - std::optional> zp_src_ptr; + std::optional zp_src; Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8, graph.GenerateNodeArgName(weight_arg->Name() + "_T"), std::vector{N, quant_num, blob_bytes}); Initializer scale_dst(static_cast(scale_src.data_type()), graph.GenerateNodeArgName(scale_arg->Name() + "_T"), std::vector{N * quant_num}); - std::optional> zp_dst_ptr; + std::optional zp_dst; if (zp_tensor_proto) { - zp_src_ptr.emplace(std::make_unique(*zp_tensor_proto, graph.ModelPath())); - zp_dst_ptr.emplace(std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName(zp_arg->Name() + "_T"), - std::vector{N * ((quant_num + 1) / 2)})); + zp_src.emplace(*zp_tensor_proto, graph.ModelPath()); + zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName(zp_arg->Name() + "_T"), + std::vector{N * ((quant_num + 1) / 2)}); } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { - zp_dst_ptr.emplace(std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, - graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), - std::vector{N * ((quant_num + 1) / 2)})); + zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), + std::vector{N * ((quant_num + 1) / 2)}); } auto* thread_pool = intra_op_thread_pool_ @@ -380,10 +382,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), - zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, + zp_dst ? zp_dst->data() : nullptr, true, static_cast(K), static_cast(N), @@ -393,10 +395,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), - zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, + zp_dst ? zp_dst->data() : nullptr, true, static_cast(K), static_cast(N), @@ -408,10 +410,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), - zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, + zp_dst ? zp_dst->data() : nullptr, true, static_cast(K), static_cast(N), @@ -422,10 +424,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, weight_dst.data(), scale_dst.data(), - zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, + zp_dst ? zp_dst->data() : nullptr, true, static_cast(K), static_cast(N), @@ -436,15 +438,15 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, ONNX_NAMESPACE::TensorProto weight_T_tp; ONNX_NAMESPACE::TensorProto scale_T_tp; - std::optional> zp_T_tp_ptr; + std::optional zp_T_tp; // TODO(fajin): external_data to memory location to avoid arena allocation // https://github.com/microsoft/onnxruntime/pull/12465 weight_dst.ToProto(weight_T_tp); scale_dst.ToProto(scale_T_tp); - if (zp_dst_ptr) { - zp_T_tp_ptr = std::make_unique(); - zp_dst_ptr.value()->ToProto(*zp_T_tp_ptr.value()); + if (zp_dst) { + zp_T_tp.emplace(); + zp_dst->ToProto(zp_T_tp.value()); } auto& input_defs = replacement_node.MutableInputDefs(); @@ -453,8 +455,8 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp)); replacement_node.MutableInputArgsCount().push_back(1); - if (zp_T_tp_ptr) { - input_defs.push_back(&graph_utils::AddInitializer(graph, *zp_T_tp_ptr.value())); + if (zp_T_tp) { + input_defs.push_back(&graph_utils::AddInitializer(graph, zp_T_tp.value())); replacement_node.MutableInputArgsCount().push_back(1); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 11185491da9b..6e93445c7c5c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -418,7 +418,11 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - ORT_UNUSED_PARAMETER(q_nodes); + // Should not have any Q nodes + if (!q_nodes.empty()) { + return false; + } + const auto& graph = graph_viewer.GetGraph(); // MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output diff --git a/onnxruntime/core/optimizer/selectors_actions/actions.h b/onnxruntime/core/optimizer/selectors_actions/actions.h index 4d5b520cc47c..8ff5ced17c49 100644 --- a/onnxruntime/core/optimizer/selectors_actions/actions.h +++ b/onnxruntime/core/optimizer/selectors_actions/actions.h @@ -158,6 +158,9 @@ struct ReplaceWithNew : public Action { // specifies how the inputs and outputs for the replaced nodes are moved to the new node virtual std::vector ValueMoves(const RuntimeState&) const = 0; + // For the changes that cannot be done by simply moving node args around, use this method to make + // additional changes to the new node and the graph. e.g., DQMatMulToMatMulNBitsAction transposes + // the second weight of MatMul ops, delete old node args, and create new node args. virtual Status ProcessNewNode(Graph&, const NodesToOptimize&, Node&) const { return Status::OK(); } RemoveNodes node_remover_;