Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ExecuTorch][BE] Split kv cache and SDPA for better code sharing #7413

Open
wants to merge 3 commits into
base: gh/kimishpatel/149/base
Choose a base branch
from

Conversation

kimishpatel
Copy link
Contributor

@kimishpatel kimishpatel commented Dec 20, 2024

Stack from ghstack (oldest at bottom):

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now

  1. Decouple SDPA nn.Module from KV cache.
  2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
    both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
    tensors.
  3. 2 will introduce multiple tranposes when KVCache and SDPA are
    replaced by custom modules, but we will write graph pass to undo
    those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Dec 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/7413

Note: Links to docs will display an error until the docs builds have been completed.

❌ 7 New Failures

As of commit 275144b with merge base 49cc399 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kimishpatel added a commit that referenced this pull request Dec 20, 2024
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 6356acba83a82cb7d19747187a254a735fa77d28
Pull Request resolved: #7413
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 20, 2024
Copy link

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
kimishpatel added a commit that referenced this pull request Dec 20, 2024
Summary:

+ Make all the backend specific kvcache and sdpa implementation abide by
  the new API

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 369434c4d64e6d4500ecfea03b0fd99945b30461
Pull Request resolved: #7413
…sharing"

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

[ghstack-poisoned]
@kimishpatel kimishpatel changed the title Changes to split kv cache and sdpa [ExecuTorch][BE] Split kv cache and SDPA for better code sharing Dec 21, 2024
kimishpatel added a commit that referenced this pull request Dec 21, 2024
Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

ghstack-source-id: 6289ce22a2c190da7e38e098ba8a5d0254d6bf9d
Pull Request resolved: #7413
@kimishpatel kimishpatel requested a review from cccclai December 21, 2024 00:21
@@ -212,6 +215,13 @@ def export(self) -> "LLMEdgeManager":

return self

def run_canonical_optimizations(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on this function

@@ -47,20 +37,21 @@ def forward(
seqlen,
mask,
):
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just thought about it again, and adding this transpose here and also before in the llama_transformer.py so that we can share code for kv_cache.py (this is the reason right?) doesn't really make sense since we are using a custom export-friendly KV cache already anyways: https://github.com/pytorch/executorch/blob/main/extension/llm/modules/kv_cache.py#L13

@@ -212,6 +215,13 @@ def export(self) -> "LLMEdgeManager":

return self

def run_canonical_optimizations(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add checks to make sure self.pre_autograd_graph_module is not None, basically this needs to be run after export().

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants