Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Jul 19, 2024
1 parent a27dd3c commit 1b6d2c7
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction(int64_t accuracy_level,
intra_op_thread_pool_{intra_op_thread_pool} {
ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4");

// Webassembly only has a global thread pool. It's not possible to create a new thread pool.
#if !defined(__wasm__)
if (!intra_op_thread_pool) {
OrtThreadPoolParams to;
Expand Down Expand Up @@ -321,6 +322,7 @@ DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state)
Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph,
const NodesToOptimize& selected_nodes,
Node& replacement_node) const {
// Webassembly only has a global thread pool. To call into this method, the global thread pool must pass in.
#if defined(__wasm__)
ORT_RETURN_IF_NOT(intra_op_thread_pool_, "Thread pool is required for DQMatMulToMatMulNBitsAction");
#endif
Expand Down Expand Up @@ -351,24 +353,24 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph,
// to what we need. But it does not handle external data.
Initializer weight_src(*weight_tensor_proto, graph.ModelPath());
Initializer scale_src(*scale_tensor_proto, graph.ModelPath());
std::optional<std::unique_ptr<Initializer>> zp_src_ptr;
std::optional<Initializer> zp_src;
Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
graph.GenerateNodeArgName(weight_arg->Name() + "_T"),
std::vector<int64_t>{N, quant_num, blob_bytes});
Initializer scale_dst(static_cast<ONNX_NAMESPACE::TensorProto_DataType>(scale_src.data_type()),
graph.GenerateNodeArgName(scale_arg->Name() + "_T"),
std::vector<int64_t>{N * quant_num});
std::optional<std::unique_ptr<Initializer>> zp_dst_ptr;
std::optional<Initializer> zp_dst;

if (zp_tensor_proto) {
zp_src_ptr.emplace(std::make_unique<Initializer>(*zp_tensor_proto, graph.ModelPath()));
zp_dst_ptr.emplace(std::make_unique<Initializer>(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
graph.GenerateNodeArgName(zp_arg->Name() + "_T"),
std::vector<int64_t>{N * ((quant_num + 1) / 2)}));
zp_src.emplace(*zp_tensor_proto, graph.ModelPath());
zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
graph.GenerateNodeArgName(zp_arg->Name() + "_T"),
std::vector<int64_t>{N * ((quant_num + 1) / 2)});
} else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) {
zp_dst_ptr.emplace(std::make_unique<Initializer>(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"),
std::vector<int64_t>{N * ((quant_num + 1) / 2)}));
zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"),
std::vector<int64_t>{N * ((quant_num + 1) / 2)});
}

auto* thread_pool = intra_op_thread_pool_
Expand All @@ -380,10 +382,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph,
MlasQDQTransposeBlockwiseQuantized<float, 4, true>(
weight_src.DataAsByteSpan().data(),
scale_src.data<float>(),
zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr,
zp_src ? zp_src->DataAsByteSpan().data() : nullptr,
weight_dst.data<uint8_t>(),
scale_dst.data<float>(),
zp_dst_ptr ? zp_dst_ptr.value()->data<uint8_t>() : nullptr,
zp_dst ? zp_dst->data<uint8_t>() : nullptr,
true,
static_cast<int>(K),
static_cast<int>(N),
Expand All @@ -393,10 +395,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph,
MlasQDQTransposeBlockwiseQuantized<float, 4, false>(
weight_src.DataAsByteSpan().data(),
scale_src.data<float>(),
zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr,
zp_src ? zp_src->DataAsByteSpan().data() : nullptr,
weight_dst.data<uint8_t>(),
scale_dst.data<float>(),
zp_dst_ptr ? zp_dst_ptr.value()->data<uint8_t>() : nullptr,
zp_dst ? zp_dst->data<uint8_t>() : nullptr,
true,
static_cast<int>(K),
static_cast<int>(N),
Expand All @@ -408,10 +410,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph,
MlasQDQTransposeBlockwiseQuantized<MLFloat16, 4, true>(
weight_src.DataAsByteSpan().data(),
scale_src.data<MLFloat16>(),
zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr,
zp_src ? zp_src->DataAsByteSpan().data() : nullptr,
weight_dst.data<uint8_t>(),
scale_dst.data<MLFloat16>(),
zp_dst_ptr ? zp_dst_ptr.value()->data<uint8_t>() : nullptr,
zp_dst ? zp_dst->data<uint8_t>() : nullptr,
true,
static_cast<int>(K),
static_cast<int>(N),
Expand All @@ -422,10 +424,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph,
MlasQDQTransposeBlockwiseQuantized<MLFloat16, 4, false>(
weight_src.DataAsByteSpan().data(),
scale_src.data<MLFloat16>(),
zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr,
zp_src ? zp_src->DataAsByteSpan().data() : nullptr,
weight_dst.data<uint8_t>(),
scale_dst.data<MLFloat16>(),
zp_dst_ptr ? zp_dst_ptr.value()->data<uint8_t>() : nullptr,
zp_dst ? zp_dst->data<uint8_t>() : nullptr,
true,
static_cast<int>(K),
static_cast<int>(N),
Expand All @@ -436,15 +438,15 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph,

ONNX_NAMESPACE::TensorProto weight_T_tp;
ONNX_NAMESPACE::TensorProto scale_T_tp;
std::optional<std::unique_ptr<ONNX_NAMESPACE::TensorProto>> zp_T_tp_ptr;
std::optional<ONNX_NAMESPACE::TensorProto> zp_T_tp;

// TODO(fajin): external_data to memory location to avoid arena allocation
// https://github.com/microsoft/onnxruntime/pull/12465
weight_dst.ToProto(weight_T_tp);
scale_dst.ToProto(scale_T_tp);
if (zp_dst_ptr) {
zp_T_tp_ptr = std::make_unique<ONNX_NAMESPACE::TensorProto>();
zp_dst_ptr.value()->ToProto(*zp_T_tp_ptr.value());
if (zp_dst) {
zp_T_tp.emplace();
zp_dst->ToProto(zp_T_tp.value());
}

auto& input_defs = replacement_node.MutableInputDefs();
Expand All @@ -453,8 +455,8 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph,
input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp));
replacement_node.MutableInputArgsCount().push_back(1);

if (zp_T_tp_ptr) {
input_defs.push_back(&graph_utils::AddInitializer(graph, *zp_T_tp_ptr.value()));
if (zp_T_tp) {
input_defs.push_back(&graph_utils::AddInitializer(graph, zp_T_tp.value()));
replacement_node.MutableInputArgsCount().push_back(1);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,11 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
ORT_UNUSED_PARAMETER(q_nodes);
// Should not have any Q nodes
if (!q_nodes.empty()) {
return false;
}

const auto& graph = graph_viewer.GetGraph();

// MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/optimizer/selectors_actions/actions.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ struct ReplaceWithNew : public Action {
// specifies how the inputs and outputs for the replaced nodes are moved to the new node
virtual std::vector<NodeAndMoveInfo> ValueMoves(const RuntimeState&) const = 0;

// For the changes that cannot be done by simply moving node args around, use this method to make
// additional changes to the new node and the graph. e.g., DQMatMulToMatMulNBitsAction transposes
// the second weight of MatMul ops, delete old node args, and create new node args.
virtual Status ProcessNewNode(Graph&, const NodesToOptimize&, Node&) const { return Status::OK(); }

RemoveNodes node_remover_;
Expand Down

0 comments on commit 1b6d2c7

Please sign in to comment.