-
Notifications
You must be signed in to change notification settings - Fork 413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix executorch kv cache incompatibility with to_executorch lowering #7279
Open
dvorjackz
wants to merge
17
commits into
main
Choose a base branch
from
jz/fix-prefill
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+364
−31
Open
Changes from 11 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
aac90a0
Add tests that localize the prefill issue to the kv cache
dvorjackz 917fb0d
Fixes test but not model
dvorjackz 46ea733
Updated pass
dvorjackz 5db136c
Fix segmentation fault
dvorjackz 9cdfb43
Lint
dvorjackz 9e68531
Only add pass when vision model
dvorjackz 925409d
Add comments
dvorjackz 2a3fe8b
Remove import
dvorjackz 61101c2
Add pass
dvorjackz 4ee95d3
PR review
dvorjackz e297c9b
Fix test
dvorjackz 8145cda
Last changes
dvorjackz 73591f1
Merge branch 'main' into jz/fix-prefill
dvorjackz a2b7ee3
Update attention test
dvorjackz 93f99ad
Tests
dvorjackz 69e36fb
Dave pr comment
dvorjackz 5c53856
Merge branch 'main' into jz/fix-prefill
dvorjackz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
dbort marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
from typing import List | ||
|
||
from executorch.exir.pass_base import ExportPass | ||
|
||
|
||
class InitializedMutableBufferPass(ExportPass): | ||
""" | ||
If the buffer has the name "cache_pos", such as in an kv_cache | ||
module with `self.register_buffer("cache_pos", torch.arange(10))`, | ||
mark it with a custom tag which later is used by the emitter to | ||
flag spec.const to True, which provides the mutable buffer with | ||
an initialized state. | ||
""" | ||
|
||
def __init__(self, patterns: List[str]) -> None: | ||
super().__init__() | ||
self.patterns = patterns | ||
|
||
def placeholder(self, name: str, arg, meta): | ||
for pattern in self.patterns: | ||
if pattern in name: | ||
meta["et_init_buffer"] = True | ||
|
||
dvorjackz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return super().placeholder(name, arg, meta) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import tempfile | ||
import unittest | ||
from typing import Callable, Tuple | ||
|
||
import torch | ||
from executorch.exir import EdgeCompileConfig, to_edge | ||
from executorch.exir.capture._config import ExecutorchBackendConfig | ||
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass | ||
|
||
from executorch.extension.export_util.utils import save_pte_program | ||
from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache | ||
|
||
from executorch.extension.pybindings.portable_lib import ( | ||
_load_for_executorch_from_buffer, | ||
) | ||
from executorch.runtime import Runtime | ||
from torch.testing import assert_close | ||
from torchtune.modules.kv_cache import KVCache | ||
|
||
|
||
def generate_cache_inputs( | ||
seq_len: int, | ||
batch_size: int = 1, | ||
num_kv_heads: int = 64, | ||
head_dim: int = 8, | ||
) -> Tuple[torch.Tensor, ...]: | ||
"""Helper to generate k_val and v_val for both et and tt caches.""" | ||
k_val = torch.ones(batch_size, seq_len, num_kv_heads, head_dim) | ||
v_val = torch.ones(batch_size, seq_len, num_kv_heads, head_dim) | ||
|
||
# For torchtune, the kv cache takes in transposed k and v. | ||
k_val_trans = k_val.transpose(1, 2) | ||
v_val_trans = v_val.transpose(1, 2) | ||
|
||
return (k_val, v_val, k_val_trans, v_val_trans) | ||
|
||
|
||
class KVCacheTest(unittest.TestCase): | ||
def setUp(self): | ||
self.batch_size = 1 | ||
self.max_seq_len = 10 | ||
self.num_kv_heads = 1 # For testing purposes, usually this is 64. | ||
self.head_dim = 8 | ||
self.dtype = torch.float | ||
|
||
self.tt_kv_cache = KVCache( | ||
batch_size=self.batch_size, | ||
max_seq_len=self.max_seq_len, | ||
num_kv_heads=self.num_kv_heads, | ||
head_dim=self.head_dim, | ||
dtype=self.dtype, | ||
) | ||
self.et_kv_cache = InferenceKVCache( | ||
batch_size=self.batch_size, | ||
max_seq_len=self.max_seq_len, | ||
num_kv_heads=self.num_kv_heads, | ||
head_dim=self.head_dim, | ||
dtype=self.dtype, | ||
transpose_cache=False, | ||
) | ||
|
||
def _test_kv_cache(self, et_cache_module: Callable): | ||
""" | ||
Given an executorch kv cache anywhere along the export chain, compare it's results | ||
against torchtune and run basic tests. | ||
""" | ||
prefill_seq_len = 3 | ||
k_val, v_val, k_val_trans, v_val_trans = generate_cache_inputs( | ||
prefill_seq_len, self.batch_size, self.num_kv_heads, self.head_dim | ||
) | ||
|
||
et_res = et_cache_module(k_val, v_val) | ||
tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans) | ||
tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2)) | ||
|
||
# Check torchtune matches executorch. | ||
assert_close(et_res, tt_res_transposed) | ||
|
||
# Check the values are correct, all rows in the seq_len dim should be | ||
# filled with 1s up to and including the 3rd. | ||
et_k_cache = et_res[0] | ||
for i in range(prefill_seq_len): | ||
self.assertTrue(et_k_cache[0][i][0][0] == 1) | ||
self.assertTrue(et_k_cache[0][prefill_seq_len][0][0] == 0) | ||
|
||
"""Case 2: Token-by-token (seq_len = 0)""" | ||
dvorjackz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
seq_len = 1 | ||
k_val, v_val, k_val_trans, v_val_trans = generate_cache_inputs( | ||
seq_len, self.batch_size, self.num_kv_heads, self.head_dim | ||
) | ||
|
||
et_res = et_cache_module(k_val, v_val) | ||
tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans) | ||
tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2)) | ||
|
||
# Check torchtune matches executorch. | ||
assert_close(tt_res_transposed, et_res) | ||
|
||
# All rows should be filled with 1s up to 3 + 1th row. | ||
et_k_cache = et_res[0] | ||
for i in range(prefill_seq_len + 1): | ||
self.assertTrue(et_k_cache[0][i][0][0] == 1) | ||
|
||
self.assertTrue(et_k_cache[0][prefill_seq_len + 1][0][0] == 0) | ||
|
||
def export_kv_cache( | ||
self, | ||
kv_cache: torch.nn.Module, | ||
) -> torch.export.ExportedProgram: | ||
# Wrapper since torch.export only exports forward(). | ||
class EtCacheWrapper(torch.nn.Module): | ||
def __init__(self, kv_cache: torch.nn.Module): | ||
super().__init__() | ||
self.kv_cache = kv_cache | ||
|
||
def forward(self, k_val: torch.Tensor, v_val: torch.Tensor): | ||
return self.kv_cache.update(k_val, v_val) | ||
|
||
dim = torch.export.Dim("seq_len_dim", min=1, max=self.max_seq_len) | ||
exported_kv_cache = torch.export.export( | ||
EtCacheWrapper(self.et_kv_cache), | ||
( | ||
torch.Tensor(self.batch_size, 3, self.num_kv_heads, self.head_dim), | ||
torch.Tensor(self.batch_size, 3, self.num_kv_heads, self.head_dim), | ||
), # 3 as example prefill seq_len. | ||
dynamic_shapes={ | ||
"k_val": { | ||
0: torch.export.Dim.STATIC, | ||
1: dim, | ||
2: torch.export.Dim.STATIC, | ||
3: torch.export.Dim.STATIC, | ||
}, | ||
"v_val": { | ||
0: torch.export.Dim.STATIC, | ||
1: dim, | ||
2: torch.export.Dim.STATIC, | ||
3: torch.export.Dim.STATIC, | ||
}, | ||
}, | ||
) | ||
return exported_kv_cache | ||
|
||
def test_kv_cache_eager(self): | ||
self._test_kv_cache(self.et_kv_cache.update) | ||
|
||
def test_kv_cache_export(self): | ||
exported_kv_cache = self.export_kv_cache(self.et_kv_cache) | ||
self._test_kv_cache(exported_kv_cache.module()) | ||
|
||
def test_kv_cache_edge(self): | ||
exported_kv_cache = self.export_kv_cache(self.et_kv_cache) | ||
edge_program = to_edge( | ||
exported_kv_cache, | ||
compile_config=EdgeCompileConfig( | ||
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], | ||
_check_ir_validity=False, | ||
), | ||
) | ||
self._test_kv_cache(edge_program._edge_programs["forward"].module()) | ||
|
||
def test_kv_cache_executorch(self): | ||
exported_kv_cache = self.export_kv_cache(self.et_kv_cache) | ||
edge_program = to_edge( | ||
exported_kv_cache, | ||
compile_config=EdgeCompileConfig( | ||
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], | ||
_check_ir_validity=False, | ||
), | ||
) | ||
et_config = ExecutorchBackendConfig( | ||
passes=[InitializedMutableBufferPass(["cache_pos"])], | ||
) | ||
et_program = edge_program.to_executorch(config=et_config) | ||
|
||
runtime = Runtime.get() | ||
program = runtime.load_program(et_program.buffer) | ||
method = program.load_method("forward") | ||
|
||
# Since method.execute expects a tuple of args. | ||
def wrapped_callable(k_val: torch.Tensor, v_val: torch.Tensor) -> torch.Tensor: | ||
return method.execute((k_val, v_val)) | ||
|
||
self._test_kv_cache(wrapped_callable) | ||
|
||
def test_kv_cache_executorch_from_file(self): | ||
exported_kv_cache = self.export_kv_cache(self.et_kv_cache) | ||
edge_program = to_edge( | ||
exported_kv_cache, | ||
compile_config=EdgeCompileConfig( | ||
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], | ||
_check_ir_validity=False, | ||
), | ||
) | ||
et_config = ExecutorchBackendConfig( | ||
passes=[InitializedMutableBufferPass(["cache_pos"])], | ||
) | ||
et_program = edge_program.to_executorch(config=et_config) | ||
|
||
with tempfile.TemporaryDirectory() as tempdir: | ||
pte_path = save_pte_program(et_program, "test_et_kv_cache", tempdir) | ||
with open(pte_path, "rb") as f: | ||
model_bytes = f.read() | ||
loaded_et_program = _load_for_executorch_from_buffer(model_bytes) | ||
|
||
# Since method.execute expects a tuple of args. | ||
def wrapped_callable( | ||
k_val: torch.Tensor, v_val: torch.Tensor | ||
) -> torch.Tensor: | ||
return loaded_et_program.forward((k_val, v_val)) | ||
|
||
self._test_kv_cache(wrapped_callable) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add unit tests for this logic; tests that would have broken before this fix, and would have caught this kv cache incompatibility