Fix vision model graph capture not creating static buffers for embedding #942
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This change essentially reverses the assignment of the embeddings memory. Instead of creating the embeddings tensor in the embedding model and pointing the embeddings of the text model to it, we now create the embeddings tensor inside the text model and point the embeddings of the embedding model to it.
The reason to do that is that the text model can possibly be in "graph capture mode", which means that it allocates static buffers that it uses between iterations, and even between generators. If we allocate the memory in the embedding model and point the text model to it, the memory will become invalid when the generator is destroyed and the captured graph will exhibit undefined behavior (mostly spitting out garbage output). But by pointing the embeddings output of the embedding model towards the static buffer created by the text model, we can be certain that the memory will stay alive for the duration of the model.
This PR doesn't change the behavior of the non-graph capture mode since it really doesn't matter in that scenario whether the tensor is created by the embedding model or the text model, but it fixes graph capture usage for vision models.