-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU]whisper readvalue optimize #26130
base: master
Are you sure you want to change the base?
Changes from all commits
2916414
451c76d
137beee
737fe5c
6b05005
58d9f6f
d54dc25
a533d73
f7339e3
e142a06
4a2dba0
bf7e493
d90144e
5a98e7b
577721d
c23062e
592919b
f134307
258c3c8
5c771b0
17b8ce3
3a6b83a
f57851d
ca7dde8
df7bb70
aa33812
62772a9
cebed6f
d0f7986
47f436c
dbb9a3e
2ed69c0
44c8fb2
87664d4
153b4b8
e188276
2833602
0a6f13f
7ae318f
79b8272
9c0989d
3eaea83
d3bf35c
266cf38
e94e67f
e5402f8
74147b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -183,6 +183,10 @@ void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) { | |
MatchSdpaKvCache(graph); | ||
graph.RemoveDroppedNodes(); | ||
|
||
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "ReplaceMemoryOutputWithMemoryOutputStub"); | ||
ReplaceMemoryOutputWithMemoryOutputStub(graph); | ||
graph.RemoveDroppedNodes(); | ||
|
||
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "RemoveDroppedEdges"); | ||
graph.RemoveDroppedEdges(); | ||
} | ||
|
@@ -3064,6 +3068,90 @@ void GraphOptimizer::RemoveConvertMemoryOutput(Graph &graph) { | |
} | ||
} | ||
|
||
void GraphOptimizer::ReplaceMemoryOutputWithMemoryOutputStub(Graph& graph) { | ||
auto& graphNodes = graph.GetNodes(); | ||
|
||
auto isSuitableMemInput = [](const NodePtr& node) -> bool { | ||
if (Type::MemoryInput != node->getType()) { | ||
return false; | ||
} | ||
auto memInput = std::dynamic_pointer_cast<node::MemoryInput>(node); | ||
if (memInput) { | ||
return memInput->haveSubgraph(); | ||
} | ||
Comment on lines
+3078
to
+3081
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To my understanding, in all the cases when Assign is directly attached to the ReadValue node, we should replace it with a stub, since the assign node is practically useless. ReadValue->Assign pair means that the state values aren't really changed by the assign node. So it looks like it can always be safely replaced with a stub. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, Same to my first understanding. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To my understanding, when we have a direct ReadValue->Assign pair, we most definitely should use a single buffer, as nothing new will be written to the state during the assign stage. May be we can check for the MemoryOutput child in the MemoryInput node to select an appropriate state type (i.e. single buffer or double buffer). |
||
|
||
return false; | ||
}; | ||
|
||
for (size_t i = 0; i < graphNodes.size(); i++) { | ||
auto node = graphNodes[i]; | ||
if (!isSuitableMemInput(node)) { | ||
continue; | ||
} | ||
|
||
CPU_GRAPH_OPTIMIZER_SCOPE(ReplaceMemoryOutputWithMemoryOutputStub); | ||
|
||
auto memoryNode = std::dynamic_pointer_cast<node::MemoryNode>(node); | ||
if (nullptr == memoryNode) { | ||
continue; | ||
} | ||
|
||
// Find sibling MemoryOutput | ||
std::shared_ptr<MemoryOutput> memOutput = nullptr; | ||
bool isReplaced = false; | ||
for (auto&& edge : node->getChildEdgesAtPort(0)) { | ||
auto child = edge->getChild(); | ||
if (Type::MemoryOutput == child->getType()) { | ||
memOutput = std::dynamic_pointer_cast<MemoryOutput>(child); | ||
if (memOutput && memOutput->getId() == memoryNode->getId()) { | ||
break; | ||
} | ||
|
||
auto memOutputStub = std::dynamic_pointer_cast<MemoryOutputStub>(child); | ||
if (memOutputStub && memOutputStub->getId() == memoryNode->getId()) { | ||
isReplaced = true; | ||
break; | ||
} | ||
} | ||
} | ||
|
||
if (isReplaced) { | ||
continue; | ||
} | ||
if (memOutput == nullptr) { | ||
OPENVINO_THROW("Can't find ", node->getName(), " corresponding sibling node."); | ||
} | ||
|
||
auto memInputNode = std::dynamic_pointer_cast<node::MemoryInputBase>(node); | ||
OPENVINO_ASSERT(memInputNode, "MemoryInput node ", node->getName(), " has unexpected dynamic type"); | ||
|
||
ov::optional<Shape> input_shape; | ||
ov::optional<ov::element::Type> input_prc; | ||
|
||
if (!node->getParentEdges().empty()) { | ||
input_shape = ov::optional<Shape>(node->getInputShapeAtPort(0)); | ||
input_prc = ov::optional<ov::element::Type>(node->getOriginalInputPrecisionAtPort(0)); | ||
} | ||
|
||
// Capture reference to the original mem output before graph transformations | ||
auto& memOutputBase = memInputNode->getOutputNode(); | ||
|
||
// Create a stub memory output | ||
auto memOutputStub = std::make_shared<MemoryOutputStub>(memOutputBase.getId(), | ||
memOutputBase.getName() + "_MemoryOutputStub", | ||
memOutputBase.getTypeStr(), | ||
memOutputBase.getInputShapeAtPort(0), | ||
memOutputBase.getOriginalInputPrecisionAtPort(0), | ||
graph.getGraphContext()); | ||
|
||
auto memOutputEdge = memOutputBase.getParentEdgeAt(0); | ||
const auto inputNum = memOutputEdge->getInputNum(); | ||
graph.RemoveEdge(memOutputEdge); | ||
graph.CreateEdge(node, memOutputStub, inputNum, 0); | ||
graph.AddNode(memOutputStub); | ||
} | ||
} | ||
|
||
void GraphOptimizer::MatchSdpaKvCache(Graph &graph) { | ||
auto& graphNodes = graph.GetNodes(); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @maxnick ,
For example:
MemoryInput(ReadValueWithSubgraph)->computer nodes
MemoryInput(ReadValueWithSubgraph)->MemoryOutputStub
MemoryInput(ReadValueWithSubgraph)->Stateful
I have to add this branch to call MemoryInput::resolveInPlaceEdges, it will let
Stateful
andMemoryOutputStub
,computer nodes
, share memory. It is reasonable.If not call, these nodes's output will be inPlace, but MemoryOutputStub will
allocate memory, so maybe these output will empty ptr. Just clarify.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is in the
MemoryInput::selectOptimalPrimitiveDescriptor
implementation you developed for the subgraph scenario. There you:resolveInPlaceEdges(Edge::LOOK_UP)
isn't being called for such a node.Thus I would recommend the following:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1: Removed redefine
supportedPrimitiveDescriptors
2: Internal subgraph provide input setting
InputConfig
interface. but there is noOutputConfig
setting interface.@maxnick