Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix executorch kv cache incompatibility with to_executorch lowering #7279

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
10 changes: 8 additions & 2 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch

from executorch.devtools.etrecord import generate_etrecord
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass

from executorch.extension.llm.export.builder import DType, LLMEdgeManager

Expand Down Expand Up @@ -760,6 +761,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")

additional_passes = []
if args.model in TORCHTUNE_DEFINED_MODELS:
additional_passes = [InitializedMutableBufferPass(["cache_pos"])]
if args.generate_etrecord:
if not builder_exported_to_edge.edge_manager:
raise ValueError("Unable to generate etrecord due to missing edge manager.")
Expand All @@ -774,7 +778,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch()
builder = builder.to_executorch(
passes=additional_passes,
)

# Generate ETRecord
if edge_manager_copy:
Expand All @@ -792,7 +798,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch()
builder = builder.to_executorch(passes=additional_passes)

if args.profile_memory:
generate_memory_trace(builder.export_program, "memory_profile.json")
Expand Down
18 changes: 15 additions & 3 deletions examples/models/llama3_2_vision/runner/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
TorchTuneLlamaRunner,
)

from executorch.extension.pybindings.portable_lib import _load_for_executorch
from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)

# Load custom ops and quantized ops.
from executorch.extension.pybindings import portable_lib # noqa # usort: skip

# Note: import this after portable_lib
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip
from executorch.kernels import quantized # noqa


Expand All @@ -43,7 +45,17 @@ def __init__(self, args):
use_kv_cache=args.kv_cache,
vocab_size=params["vocab_size"],
)
self.model = _load_for_executorch(args.pte)
# Save the loaded model bytes to prevent data from going out of
# scope after the `with` and getting cleaned up by Python's
# garbage collector.
self.model_bytes = None
with open(args.pte, "rb") as f:
self.model_bytes = f.read()
# Need to use _load_for_executorch_from_buffer instead of
# _load_for_executorch because the latter uses MmapDataLoader,
# which doesn't have load_into() implemented, which is needed
# for loading initialized mutable buffers.
self.model = _load_for_executorch_from_buffer(self.model_bytes)
self.use_kv_cache = args.kv_cache

def forward(
Expand Down
7 changes: 6 additions & 1 deletion exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,6 +1602,7 @@ def placeholder(
"""
spec = self.node.meta["spec"]
constant_tag = self.node.meta.get("constant_tag", None)
initialize_buffer = self.node.meta.get("et_init_buffer", None)
is_user_input = True

if isinstance(target, str) and isinstance(spec, TensorSpec):
Expand Down Expand Up @@ -1655,7 +1656,11 @@ def placeholder(
spec.storage = real_tensor.untyped_storage()

# User inputs and mutable buffers are not constants, other buffers or parameters are.
spec.const = not (is_user_input or is_mutable_buffer)
if initialize_buffer:
assert is_mutable_buffer
spec.const = True
else:
spec.const = not (is_user_input or is_mutable_buffer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add unit tests for this logic; tests that would have broken before this fix, and would have caught this kv cache incompatibility


evalue = (
self._tensor_spec_to_evalue(spec, constant_tag)
Expand Down
31 changes: 31 additions & 0 deletions exir/passes/init_mutable_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
dbort marked this conversation as resolved.
Show resolved Hide resolved
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import List

from executorch.exir.pass_base import ExportPass


class InitializedMutableBufferPass(ExportPass):
"""
If the buffer has the name "cache_pos", such as in an kv_cache
module with `self.register_buffer("cache_pos", torch.arange(10))`,
mark it with a custom tag which later is used by the emitter to
flag spec.const to True, which provides the mutable buffer with
an initialized state.
"""

def __init__(self, patterns: List[str]) -> None:
super().__init__()
self.patterns = patterns

def placeholder(self, name: str, arg, meta):
for pattern in self.patterns:
if pattern in name:
meta["et_init_buffer"] = True

dvorjackz marked this conversation as resolved.
Show resolved Hide resolved
return super().placeholder(name, arg, meta)
23 changes: 15 additions & 8 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from executorch.exir.backend.utils import format_delegated_graph
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig

from executorch.exir.pass_manager import PassType
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
Expand Down Expand Up @@ -395,21 +396,27 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag

return self

def to_executorch(self) -> "LLMEdgeManager":
def to_executorch(
self, passes: Optional[List[PassType]] = None
) -> "LLMEdgeManager":
"""
Lower the model to executorch and get an ExecutorchProgram.
"""
assert self.edge_manager, "Need to run export_to_edge() first"
to_executorch_passes = [
# If there are Linear operations left in the graph, let's execute
# them with the optimized op_linear rather than materializing a
# transpose followed by a regular op_mm.
ConvertToLinearPass(),
QuantFusionPass(),
]
if passes:
to_executorch_passes.extend(passes)

self.export_program = self.edge_manager.to_executorch(
ExecutorchBackendConfig(
extract_delegate_segments=True,
passes=[
# If there are Linear operations left in the graph, let's execute
# them with the optimized op_linear rather than materializing a
# transpose followed by a regular op_mm.
ConvertToLinearPass(),
QuantFusionPass(),
],
passes=to_executorch_passes,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
)
Expand Down
217 changes: 217 additions & 0 deletions extension/llm/modules/test/test_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest
from typing import Callable, Tuple

import torch
from executorch.exir import EdgeCompileConfig, to_edge
from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass

from executorch.extension.export_util.utils import save_pte_program
from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache

from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)
from executorch.runtime import Runtime
from torch.testing import assert_close
from torchtune.modules.kv_cache import KVCache


def generate_cache_inputs(
seq_len: int,
batch_size: int = 1,
num_kv_heads: int = 64,
head_dim: int = 8,
) -> Tuple[torch.Tensor, ...]:
"""Helper to generate k_val and v_val for both et and tt caches."""
k_val = torch.ones(batch_size, seq_len, num_kv_heads, head_dim)
v_val = torch.ones(batch_size, seq_len, num_kv_heads, head_dim)

# For torchtune, the kv cache takes in transposed k and v.
k_val_trans = k_val.transpose(1, 2)
v_val_trans = v_val.transpose(1, 2)

return (k_val, v_val, k_val_trans, v_val_trans)


class KVCacheTest(unittest.TestCase):
def setUp(self):
self.batch_size = 1
self.max_seq_len = 10
self.num_kv_heads = 1 # For testing purposes, usually this is 64.
self.head_dim = 8
self.dtype = torch.float

self.tt_kv_cache = KVCache(
batch_size=self.batch_size,
max_seq_len=self.max_seq_len,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
dtype=self.dtype,
)
self.et_kv_cache = InferenceKVCache(
batch_size=self.batch_size,
max_seq_len=self.max_seq_len,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
dtype=self.dtype,
transpose_cache=False,
)

def _test_kv_cache(self, et_cache_module: Callable):
"""
Given an executorch kv cache anywhere along the export chain, compare it's results
against torchtune and run basic tests.
"""
prefill_seq_len = 3
k_val, v_val, k_val_trans, v_val_trans = generate_cache_inputs(
prefill_seq_len, self.batch_size, self.num_kv_heads, self.head_dim
)

et_res = et_cache_module(k_val, v_val)
tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans)
tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2))

# Check torchtune matches executorch.
assert_close(et_res, tt_res_transposed)

# Check the values are correct, all rows in the seq_len dim should be
# filled with 1s up to and including the 3rd.
et_k_cache = et_res[0]
for i in range(prefill_seq_len):
self.assertTrue(et_k_cache[0][i][0][0] == 1)
self.assertTrue(et_k_cache[0][prefill_seq_len][0][0] == 0)

"""Case 2: Token-by-token (seq_len = 0)"""
dvorjackz marked this conversation as resolved.
Show resolved Hide resolved
seq_len = 1
k_val, v_val, k_val_trans, v_val_trans = generate_cache_inputs(
seq_len, self.batch_size, self.num_kv_heads, self.head_dim
)

et_res = et_cache_module(k_val, v_val)
tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans)
tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2))

# Check torchtune matches executorch.
assert_close(tt_res_transposed, et_res)

# All rows should be filled with 1s up to 3 + 1th row.
et_k_cache = et_res[0]
for i in range(prefill_seq_len + 1):
self.assertTrue(et_k_cache[0][i][0][0] == 1)

self.assertTrue(et_k_cache[0][prefill_seq_len + 1][0][0] == 0)

def export_kv_cache(
self,
kv_cache: torch.nn.Module,
) -> torch.export.ExportedProgram:
# Wrapper since torch.export only exports forward().
class EtCacheWrapper(torch.nn.Module):
def __init__(self, kv_cache: torch.nn.Module):
super().__init__()
self.kv_cache = kv_cache

def forward(self, k_val: torch.Tensor, v_val: torch.Tensor):
return self.kv_cache.update(k_val, v_val)

dim = torch.export.Dim("seq_len_dim", min=1, max=self.max_seq_len)
exported_kv_cache = torch.export.export(
EtCacheWrapper(self.et_kv_cache),
(
torch.Tensor(self.batch_size, 3, self.num_kv_heads, self.head_dim),
torch.Tensor(self.batch_size, 3, self.num_kv_heads, self.head_dim),
), # 3 as example prefill seq_len.
dynamic_shapes={
"k_val": {
0: torch.export.Dim.STATIC,
1: dim,
2: torch.export.Dim.STATIC,
3: torch.export.Dim.STATIC,
},
"v_val": {
0: torch.export.Dim.STATIC,
1: dim,
2: torch.export.Dim.STATIC,
3: torch.export.Dim.STATIC,
},
},
)
return exported_kv_cache

def test_kv_cache_eager(self):
self._test_kv_cache(self.et_kv_cache.update)

def test_kv_cache_export(self):
exported_kv_cache = self.export_kv_cache(self.et_kv_cache)
self._test_kv_cache(exported_kv_cache.module())

def test_kv_cache_edge(self):
exported_kv_cache = self.export_kv_cache(self.et_kv_cache)
edge_program = to_edge(
exported_kv_cache,
compile_config=EdgeCompileConfig(
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg],
_check_ir_validity=False,
),
)
self._test_kv_cache(edge_program._edge_programs["forward"].module())

def test_kv_cache_executorch(self):
exported_kv_cache = self.export_kv_cache(self.et_kv_cache)
edge_program = to_edge(
exported_kv_cache,
compile_config=EdgeCompileConfig(
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg],
_check_ir_validity=False,
),
)
et_config = ExecutorchBackendConfig(
passes=[InitializedMutableBufferPass(["cache_pos"])],
)
et_program = edge_program.to_executorch(config=et_config)

runtime = Runtime.get()
program = runtime.load_program(et_program.buffer)
method = program.load_method("forward")

# Since method.execute expects a tuple of args.
def wrapped_callable(k_val: torch.Tensor, v_val: torch.Tensor) -> torch.Tensor:
return method.execute((k_val, v_val))

self._test_kv_cache(wrapped_callable)

def test_kv_cache_executorch_from_file(self):
exported_kv_cache = self.export_kv_cache(self.et_kv_cache)
edge_program = to_edge(
exported_kv_cache,
compile_config=EdgeCompileConfig(
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg],
_check_ir_validity=False,
),
)
et_config = ExecutorchBackendConfig(
passes=[InitializedMutableBufferPass(["cache_pos"])],
)
et_program = edge_program.to_executorch(config=et_config)

with tempfile.TemporaryDirectory() as tempdir:
pte_path = save_pte_program(et_program, "test_et_kv_cache", tempdir)
with open(pte_path, "rb") as f:
model_bytes = f.read()
loaded_et_program = _load_for_executorch_from_buffer(model_bytes)

# Since method.execute expects a tuple of args.
def wrapped_callable(
k_val: torch.Tensor, v_val: torch.Tensor
) -> torch.Tensor:
return loaded_et_program.forward((k_val, v_val))

self._test_kv_cache(wrapped_callable)
Loading