Skip to content

Commit

Permalink
Fix bug when Embedding has >2 output (#20678)
Browse files Browse the repository at this point in the history
  • Loading branch information
guyang3532 authored May 17, 2024
1 parent 6b58fcc commit d7f7c3b
Showing 1 changed file with 31 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,30 @@ void IterateSubgraphFromNode(Graph& graph,
}
} // namespace

void RemovePrintDensityFlag(Graph& graph,
const std::vector<NodeIndex>& node_topology_list,
bool& modified,
const logging::Logger& logger) {
for (auto node_index : node_topology_list) {
Node* node = graph.GetNode(node_index);
if (node == nullptr) {
continue;
}
if (graph_utils::IsSupportedOptypeVersionAndDomain(*node, "PythonOp", {1}, kMSDomain) &&
static_cast<std::string>(node->GetAttributes().at("func_name").s()) == kFlagAndPrintDensityFuncName) {
if (graph_utils::CanRemoveNode(graph, *node, logger)) {
if (graph_utils::RemoveNode(graph, *node)) {
modified = true;
} else {
LOG_DEBUG_INFO(logger, "Failed to remove node " + node->Name() + "(" + node->OpType() + ")");
}
} else {
LOG_DEBUG_INFO(logger, "Can not remove node " + node->Name() + "(" + node->OpType() + ")");
}
}
}
}

Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
LOG_DEBUG_INFO(logger, "Enter PaddingElimination");

Expand Down Expand Up @@ -392,10 +416,6 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
node.InputDefs()[1]->Exists() &&
node.InputDefs()[1]->Shape() &&
node.InputDefs()[1]->Shape()->dim_size() >= 2) {
const auto outputNodeCount = std::distance(node.OutputEdgesBegin(), node.OutputEdgesEnd());
if (outputNodeCount != 1) {
continue;
}
Node* embedding_input_node = graph.GetMutableProducerNode(node.MutableInputDefs()[1]->Name());
if (embedding_input_node == nullptr ||
!graph_utils::IsSupportedOptypeVersionAndDomain(*embedding_input_node, "PythonOp", {1}, kMSDomain) ||
Expand All @@ -404,21 +424,6 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
LOG_DEBUG_INFO(logger, "not find PythonOp of flagPaddingElimination after embedding node");
continue;
}
if (!print_density_) {
if (graph_utils::CanRemoveNode(graph, *embedding_input_node, logger)) {
if (graph_utils::RemoveNode(graph, *embedding_input_node)) {
modified = true;
} else {
LOG_DEBUG_INFO(logger, "Failed to remove node " + embedding_input_node->Name() +
"(" + embedding_input_node->OpType() + ")");
continue;
}
} else {
LOG_DEBUG_INFO(logger, "Can not remove node " + embedding_input_node->Name() +
"(" + embedding_input_node->OpType() + ")");
continue;
}
}
const ONNX_NAMESPACE::TensorProto* padding_initializer =
graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name());
if (padding_initializer != nullptr &&
Expand All @@ -430,19 +435,22 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
continue;
}
embedding_node = &node;
input_ids_arg = embedding_node->MutableInputDefs()[1];
for (auto output_defs : embedding_node->MutableOutputDefs()) {
subgraph.insert(output_defs);
}
break;
}
}
}

if (!print_density_) {
RemovePrintDensityFlag(graph, node_topology_list, modified, logger);
}
if (!embedding_node) {
LOG_DEBUG_INFO(logger, "Exit PaddingElimination optimization for not finding any valid embedding node.");
return Status::OK();
}
input_ids_arg = embedding_node->MutableInputDefs()[1];
for (auto output_defs : embedding_node->MutableOutputDefs()) {
subgraph.insert(output_defs);
}

if (!input_ids_arg->Shape()) {
LOG_DEBUG_INFO(logger, "Exit PaddingElimination optimization for not finding shape of input_ids.");
Expand Down

0 comments on commit d7f7c3b

Please sign in to comment.