From aac90a05bff8e1f216cf7006fccb687a88a06027 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Tue, 10 Dec 2024 16:39:52 -0800 Subject: [PATCH 01/15] Add tests that localize the prefill issue to the kv cache --- extension/llm/modules/test/test_kv_cache.py | 176 ++++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 extension/llm/modules/test/test_kv_cache.py diff --git a/extension/llm/modules/test/test_kv_cache.py b/extension/llm/modules/test/test_kv_cache.py new file mode 100644 index 0000000000..721da9de05 --- /dev/null +++ b/extension/llm/modules/test/test_kv_cache.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Callable, Tuple + +import torch + +from executorch.exir import EdgeCompileConfig, to_edge +from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache +from executorch.runtime import Runtime +from torch.testing import assert_close +from torchtune.modules.kv_cache import KVCache + + +def generate_cache_inputs( + seq_len: int, + batch_size: int = 1, + num_kv_heads: int = 64, + head_dim: int = 8, +) -> Tuple[torch.Tensor, ...]: + """Helper to generate k_val and v_val for both et and tt caches.""" + k_val = torch.ones(batch_size, seq_len, num_kv_heads, head_dim) + v_val = torch.ones(batch_size, seq_len, num_kv_heads, head_dim) + + # For torchtune, the kv cache takes in transposed k and v. + k_val_trans = k_val.transpose(1, 2) + v_val_trans = v_val.transpose(1, 2) + + return (k_val, v_val, k_val_trans, v_val_trans) + + +class KVCacheTest(unittest.TestCase): + def setUp(self): + self.batch_size = 1 + self.max_seq_len = 10 + self.num_kv_heads = 1 # For testing purposes, usually this is 64. + self.head_dim = 8 + self.dtype = torch.float + + self.tt_kv_cache = KVCache( + batch_size=self.batch_size, + max_seq_len=self.max_seq_len, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + dtype=self.dtype, + ) + self.et_kv_cache = InferenceKVCache( + batch_size=self.batch_size, + max_seq_len=self.max_seq_len, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + dtype=self.dtype, + transpose_cache=False, + ) + + def _test_kv_cache(self, et_cache_module: Callable): + """ + Given an executorch kv cache anywhere along the export chain, compare it's results + against torchtune and run basic tests. + """ + prefill_seq_len = 3 + k_val, v_val, k_val_trans, v_val_trans = generate_cache_inputs( + prefill_seq_len, self.batch_size, self.num_kv_heads, self.head_dim + ) + + et_res = et_cache_module(k_val, v_val) + tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans) + tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2)) + + # Check torchtune matches executorch. + assert_close(et_res, tt_res_transposed) + + # Check the values are correct, all rows in the seq_len dim should be + # filled with 1s up to and including the 3rd. + et_k_cache = et_res[0] + for i in range(prefill_seq_len): + self.assertTrue(et_k_cache[0][i][0][0] == 1) + self.assertTrue(et_k_cache[0][prefill_seq_len][0][0] == 0) + + """Case 2: Token-by-token (seq_len = 0)""" + seq_len = 1 + k_val, v_val, k_val_trans, v_val_trans = generate_cache_inputs( + seq_len, self.batch_size, self.num_kv_heads, self.head_dim + ) + + et_res = et_cache_module(k_val, v_val) + tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans) + + # Check torchtune matches executorch. + tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2)) + assert_close(tt_res_transposed, et_res) + + # All rows should be filled with 1s up to 3 + 1th row. + et_k_cache = et_res[0] + for i in range(prefill_seq_len + 1): + self.assertTrue(et_k_cache[0][i][0][0] == 1) + self.assertTrue(et_k_cache[0][prefill_seq_len + 1][0][0] == 0) + + def export_kv_cache( + self, + kv_cache: torch.nn.Module, + ) -> torch.export.ExportedProgram: + # Wrapper since torch.export only exports forward(). + class EtCacheWrapper(torch.nn.Module): + def __init__(self, kv_cache: torch.nn.Module): + super().__init__() + self.kv_cache = kv_cache + + def forward(self, k_val: torch.Tensor, v_val: torch.Tensor): + return self.kv_cache.update(k_val, v_val) + + dim = torch.export.Dim("seq_len_dim", min=1, max=self.max_seq_len) + exported_kv_cache = torch.export.export( + EtCacheWrapper(self.et_kv_cache), + ( + torch.Tensor(self.batch_size, 3, self.num_kv_heads, self.head_dim), + torch.Tensor(self.batch_size, 3, self.num_kv_heads, self.head_dim), + ), # 3 as example prefill seq_len. + dynamic_shapes={ + "k_val": { + 0: torch.export.Dim.STATIC, + 1: dim, + 2: torch.export.Dim.STATIC, + 3: torch.export.Dim.STATIC, + }, + "v_val": { + 0: torch.export.Dim.STATIC, + 1: dim, + 2: torch.export.Dim.STATIC, + 3: torch.export.Dim.STATIC, + }, + }, + ) + return exported_kv_cache + + def test_kv_cache_eager(self): + self._test_kv_cache(self.et_kv_cache.update) + + def test_kv_cache_export(self): + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) + self._test_kv_cache(exported_kv_cache.module()) + + def test_kv_cache_edge(self): + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) + edge_program = to_edge( + exported_kv_cache, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, + ), + ) + self._test_kv_cache(edge_program._edge_programs["forward"].module()) + + def test_kv_cache_executorch(self): + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) + edge_program = to_edge( + exported_kv_cache, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, + ), + ) + et_program = edge_program.to_executorch() + runtime = Runtime.get() + program = runtime.load_program(et_program.buffer) + method = program.load_method("forward") + + # Since method.execute expects a tuple of args. + def wrapped_callable(k_val: torch.Tensor, v_val: torch.Tensor) -> torch.Tensor: + return method.execute((k_val, v_val)) + + self._test_kv_cache(wrapped_callable) From 917fb0deac4758ad7cf71c22d20bca6e78ddb423 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Fri, 13 Dec 2024 15:10:17 -0800 Subject: [PATCH 02/15] Fixes test but not model --- .../models/llama3_2_vision/runner/native.py | 2 +- exir/emit/_emitter.py | 4 +- exir/passes/init_mutable_buffer_pass.py | 47 +++++++++++++++++++ exir/passes/spec_prop_pass.py | 4 +- exir/program/_program.py | 4 ++ extension/llm/export/builder.py | 1 + extension/llm/modules/kv_cache.py | 2 +- extension/llm/modules/test/test_kv_cache.py | 20 +++++++- runtime/executor/method.cpp | 6 +++ 9 files changed, 84 insertions(+), 6 deletions(-) create mode 100644 exir/passes/init_mutable_buffer_pass.py diff --git a/examples/models/llama3_2_vision/runner/native.py b/examples/models/llama3_2_vision/runner/native.py index 9a28c94f9c..ae36fa9cf8 100644 --- a/examples/models/llama3_2_vision/runner/native.py +++ b/examples/models/llama3_2_vision/runner/native.py @@ -24,7 +24,7 @@ from executorch.extension.pybindings import portable_lib # noqa # usort: skip # Note: import this after portable_lib -from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip +from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip from executorch.kernels import quantized # noqa diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 88247d2a27..f55a57b3ff 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1566,6 +1566,7 @@ def _find_fqn_for_placeholder( fqn = self.exported_program.graph_signature.inputs_to_parameters[target] elif target in self.exported_program.graph_signature.inputs_to_buffers: + breakpoint() fqn = self.exported_program.graph_signature.inputs_to_buffers[target] # if the buffer is mutated then record that @@ -1606,6 +1607,7 @@ def placeholder( if isinstance(target, str) and isinstance(spec, TensorSpec): fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec) + print(f"fqn: {fqn}, is_mutable_buffer: {is_mutable_buffer}") # If the placeholder has a constant_tag, it is external to the PTE file # and requires a fqn and location=TensorDataLocation.EXTERNAL @@ -1655,7 +1657,7 @@ def placeholder( spec.storage = real_tensor.untyped_storage() # User inputs and mutable buffers are not constants, other buffers or parameters are. - spec.const = not (is_user_input or is_mutable_buffer) + spec.const = not is_user_input evalue = ( self._tensor_spec_to_evalue(spec, constant_tag) diff --git a/exir/passes/init_mutable_buffer_pass.py b/exir/passes/init_mutable_buffer_pass.py new file mode 100644 index 0000000000..a3d7bcb520 --- /dev/null +++ b/exir/passes/init_mutable_buffer_pass.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from executorch.exir.passes.spec_prop_pass import make_spec + +class InitMutableBufferPass(ExportPass): + def __init__(self) -> None: + super().__init__() + + def update_placeholder_tensor_specs( + self, + exported_program: torch.export.ExportedProgram, + graph_module: torch.fx.GraphModule, + ) -> None: + """ + Update the tensor specs for all placeholder nodes such that + placeholders that are parameters are marked as constant. + """ + for node in graph_module.graph.nodes: + if node.op != "placeholder": + continue + if "spec" not in node.meta: + raise RuntimeError(f"Placeholder node {node} missing meta['spec']") + # print(node) + spec = node.meta["spec"] + if (isinstance(node.target, str) and + node.target in exported_program.graph_signature.inputs_to_buffers and exported_program.graph_signature.inputs_to_buffers[node.target] in exported_program.state_dict): + # print(f"Setting {node.target}.const = True") + # breakpoint() + # print(exported_program.state_dict[exported_program.graph_signature.inputs_to_buffers[node.target]]) + spec.const = True + + # pyre-ignore + def placeholder(self, name: str, arg, meta): + # print(name) + meta["spec"] = make_spec(arg, const=meta.data['spec'].const) + # if name == "b_kv_cache_cache_pos": + # print("breakpoint") + # breakpoint() + + return super().placeholder(name, arg, meta) diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index 25eb5beaa7..b979a266c6 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -18,9 +18,9 @@ # pyre-ignore -def make_spec(x): +def make_spec(x, const=False): if isinstance(x, torch.Tensor): - return TensorSpec.from_tensor(x) + return TensorSpec.from_tensor(x, const) elif isinstance(x, (int, bool, float)): return x else: diff --git a/exir/program/_program.py b/exir/program/_program.py index fd1d0aca3d..fafe99dcfc 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -46,6 +46,7 @@ from executorch.exir.passes.replace_view_copy_with_view_pass import ( ReplaceViewCopyWithViewPass, ) +from executorch.exir.passes.init_mutable_buffer_pass import InitMutableBufferPass from executorch.exir.passes.spec_prop_pass import SpecPropPass from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass from executorch.exir.print_program import pretty_print, print_program @@ -706,6 +707,7 @@ def edge_to_executorch_passes( passes: List[PassType] = [ *config.passes, SpecPropPass(), + InitMutableBufferPass(), # ExecuTorch backend ops are unable to handle unbacked symints. So after # this pass, passes cannot be Interpreter-based, because it will fail if # there exists an unbacked symint operation. @@ -1352,6 +1354,8 @@ def to_executorch( gm, new_signature = insert_write_back_for_buffers_pass(program) new_gm = program.graph_module for p in edge_to_executorch_passes(config, name): + if isinstance(p, InitMutableBufferPass): + p.update_placeholder_tensor_specs(program, new_gm) new_gm_res = p(new_gm) assert new_gm_res is not None new_gm = new_gm_res.graph_module diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index ebc7f02ee1..e4f950d682 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -414,6 +414,7 @@ def to_executorch(self) -> "LLMEdgeManager": sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) ) + print(self.export_program.to_executorch_program(verbose=True)) logging.info( "Required memory for activation in bytes: {}".format( self.export_program._emitter_output.program.execution_plan[ diff --git a/extension/llm/modules/kv_cache.py b/extension/llm/modules/kv_cache.py index db940bca3f..477a3b6f77 100644 --- a/extension/llm/modules/kv_cache.py +++ b/extension/llm/modules/kv_cache.py @@ -56,7 +56,7 @@ def __init__( "v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False ) self.register_buffer( - "cache_pos", torch.arange(0, self.max_seq_len), persistent=False + "cache_pos", torch.arange(0, self.max_seq_len), persistent=True ) self.batch_size = batch_size diff --git a/extension/llm/modules/test/test_kv_cache.py b/extension/llm/modules/test/test_kv_cache.py index 721da9de05..6379d8abd5 100644 --- a/extension/llm/modules/test/test_kv_cache.py +++ b/extension/llm/modules/test/test_kv_cache.py @@ -67,10 +67,21 @@ def _test_kv_cache(self, et_cache_module: Callable): prefill_seq_len, self.batch_size, self.num_kv_heads, self.head_dim ) + print() + print("Prefilling...") + print() + et_res = et_cache_module(k_val, v_val) tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans) tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2)) + print() + print("Final tt kv_cache.cache_pos") + print(self.tt_kv_cache.cache_pos) + print("Final tt kv_cache.k_cache") + print(self.tt_kv_cache.k_cache) + print() + # Check torchtune matches executorch. assert_close(et_res, tt_res_transposed) @@ -89,17 +100,19 @@ def _test_kv_cache(self, et_cache_module: Callable): et_res = et_cache_module(k_val, v_val) tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans) + tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2)) # Check torchtune matches executorch. - tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2)) assert_close(tt_res_transposed, et_res) # All rows should be filled with 1s up to 3 + 1th row. et_k_cache = et_res[0] for i in range(prefill_seq_len + 1): self.assertTrue(et_k_cache[0][i][0][0] == 1) + self.assertTrue(et_k_cache[0][prefill_seq_len + 1][0][0] == 0) + def export_kv_cache( self, kv_cache: torch.nn.Module, @@ -165,6 +178,10 @@ def test_kv_cache_executorch(self): ), ) et_program = edge_program.to_executorch() + + """DEBUG the executorch program""" + et_program.dump_executorch_program(verbose=True) + runtime = Runtime.get() program = runtime.load_program(et_program.buffer) method = program.load_method("forward") @@ -174,3 +191,4 @@ def wrapped_callable(k_val: torch.Tensor, v_val: torch.Tensor) -> torch.Tensor: return method.execute((k_val, v_val)) self._test_kv_cache(wrapped_callable) + diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index eb0a8c8bb6..72110879b3 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -11,7 +11,9 @@ #include // @donotremove #include #include +#include +#include #include #include #include @@ -1179,6 +1181,10 @@ Error Method::execute_instruction() { if (err == Error::Ok) { step_state_.instr_idx = next_instr_idx; } + + // TODO: Print an EValue. + std::cout << "(" << values_[1] << " ) Printing kv_cache k_cache: " << executorch::extension::evalue_edge_items(9216) << values_[2] << std::endl; + return err; } From 46ea733269ed21a928cce348d9311b1d7d374b3d Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Mon, 16 Dec 2024 13:06:34 -0800 Subject: [PATCH 03/15] Updated pass --- .../models/llama3_2_vision/runner/native.py | 6 +++- exir/emit/_emitter.py | 8 +++-- exir/passes/init_mutable_buffer_pass.py | 34 +++---------------- exir/passes/spec_prop_pass.py | 4 +-- exir/program/_program.py | 2 -- extension/llm/export/builder.py | 2 +- extension/llm/modules/kv_cache.py | 2 +- 7 files changed, 19 insertions(+), 39 deletions(-) diff --git a/examples/models/llama3_2_vision/runner/native.py b/examples/models/llama3_2_vision/runner/native.py index ae36fa9cf8..a510db5c29 100644 --- a/examples/models/llama3_2_vision/runner/native.py +++ b/examples/models/llama3_2_vision/runner/native.py @@ -19,6 +19,7 @@ ) from executorch.extension.pybindings.portable_lib import _load_for_executorch +from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer # Load custom ops and quantized ops. from executorch.extension.pybindings import portable_lib # noqa # usort: skip @@ -43,7 +44,10 @@ def __init__(self, args): use_kv_cache=args.kv_cache, vocab_size=params["vocab_size"], ) - self.model = _load_for_executorch(args.pte) + with open(args.pte, "rb") as f: + model_bytes = f.read() + self.model = _load_for_executorch_from_buffer(model_bytes) + # self.model = _load_for_executorch(args.pte) self.use_kv_cache = args.kv_cache def forward( diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index f55a57b3ff..2ee6bb60b6 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1566,7 +1566,6 @@ def _find_fqn_for_placeholder( fqn = self.exported_program.graph_signature.inputs_to_parameters[target] elif target in self.exported_program.graph_signature.inputs_to_buffers: - breakpoint() fqn = self.exported_program.graph_signature.inputs_to_buffers[target] # if the buffer is mutated then record that @@ -1603,6 +1602,7 @@ def placeholder( """ spec = self.node.meta["spec"] constant_tag = self.node.meta.get("constant_tag", None) + initialize_buffer = self.node.meta.get("et_init_buffer", None) is_user_input = True if isinstance(target, str) and isinstance(spec, TensorSpec): @@ -1657,7 +1657,11 @@ def placeholder( spec.storage = real_tensor.untyped_storage() # User inputs and mutable buffers are not constants, other buffers or parameters are. - spec.const = not is_user_input + if initialize_buffer: + assert is_mutable_buffer + spec.const = True + else: + spec.const = not (is_user_input or is_mutable_buffer) evalue = ( self._tensor_spec_to_evalue(spec, constant_tag) diff --git a/exir/passes/init_mutable_buffer_pass.py b/exir/passes/init_mutable_buffer_pass.py index a3d7bcb520..5ffb223165 100644 --- a/exir/passes/init_mutable_buffer_pass.py +++ b/exir/passes/init_mutable_buffer_pass.py @@ -13,35 +13,9 @@ class InitMutableBufferPass(ExportPass): def __init__(self) -> None: super().__init__() - def update_placeholder_tensor_specs( - self, - exported_program: torch.export.ExportedProgram, - graph_module: torch.fx.GraphModule, - ) -> None: - """ - Update the tensor specs for all placeholder nodes such that - placeholders that are parameters are marked as constant. - """ - for node in graph_module.graph.nodes: - if node.op != "placeholder": - continue - if "spec" not in node.meta: - raise RuntimeError(f"Placeholder node {node} missing meta['spec']") - # print(node) - spec = node.meta["spec"] - if (isinstance(node.target, str) and - node.target in exported_program.graph_signature.inputs_to_buffers and exported_program.graph_signature.inputs_to_buffers[node.target] in exported_program.state_dict): - # print(f"Setting {node.target}.const = True") - # breakpoint() - # print(exported_program.state_dict[exported_program.graph_signature.inputs_to_buffers[node.target]]) - spec.const = True - - # pyre-ignore def placeholder(self, name: str, arg, meta): - # print(name) - meta["spec"] = make_spec(arg, const=meta.data['spec'].const) - # if name == "b_kv_cache_cache_pos": - # print("breakpoint") - # breakpoint() - + if "cache_pos" in name: + meta["et_init_buffer"] = True + return super().placeholder(name, arg, meta) + diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index b979a266c6..25eb5beaa7 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -18,9 +18,9 @@ # pyre-ignore -def make_spec(x, const=False): +def make_spec(x): if isinstance(x, torch.Tensor): - return TensorSpec.from_tensor(x, const) + return TensorSpec.from_tensor(x) elif isinstance(x, (int, bool, float)): return x else: diff --git a/exir/program/_program.py b/exir/program/_program.py index fafe99dcfc..7b6686b826 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1354,8 +1354,6 @@ def to_executorch( gm, new_signature = insert_write_back_for_buffers_pass(program) new_gm = program.graph_module for p in edge_to_executorch_passes(config, name): - if isinstance(p, InitMutableBufferPass): - p.update_placeholder_tensor_specs(program, new_gm) new_gm_res = p(new_gm) assert new_gm_res is not None new_gm = new_gm_res.graph_module diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index e4f950d682..8bb98ebeae 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -414,7 +414,7 @@ def to_executorch(self) -> "LLMEdgeManager": sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) ) - print(self.export_program.to_executorch_program(verbose=True)) + print(self.export_program.dump_executorch_program(verbose=True)) logging.info( "Required memory for activation in bytes: {}".format( self.export_program._emitter_output.program.execution_plan[ diff --git a/extension/llm/modules/kv_cache.py b/extension/llm/modules/kv_cache.py index 477a3b6f77..db940bca3f 100644 --- a/extension/llm/modules/kv_cache.py +++ b/extension/llm/modules/kv_cache.py @@ -56,7 +56,7 @@ def __init__( "v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False ) self.register_buffer( - "cache_pos", torch.arange(0, self.max_seq_len), persistent=True + "cache_pos", torch.arange(0, self.max_seq_len), persistent=False ) self.batch_size = batch_size From 5db136c564e552a9df29061ca97156cb4c94fd05 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Tue, 17 Dec 2024 18:38:41 -0800 Subject: [PATCH 04/15] Fix segmentation fault --- examples/models/llama3_2_vision/runner/native.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/models/llama3_2_vision/runner/native.py b/examples/models/llama3_2_vision/runner/native.py index a510db5c29..ee8c2ec624 100644 --- a/examples/models/llama3_2_vision/runner/native.py +++ b/examples/models/llama3_2_vision/runner/native.py @@ -44,9 +44,10 @@ def __init__(self, args): use_kv_cache=args.kv_cache, vocab_size=params["vocab_size"], ) + self.model_bytes = None with open(args.pte, "rb") as f: - model_bytes = f.read() - self.model = _load_for_executorch_from_buffer(model_bytes) + self.model_bytes = f.read() + self.model = _load_for_executorch_from_buffer(self.model_bytes) # self.model = _load_for_executorch(args.pte) self.use_kv_cache = args.kv_cache From 9cdfb4381c1ac7544a7008806a0c6713893263fc Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Tue, 17 Dec 2024 19:57:03 -0800 Subject: [PATCH 05/15] Lint --- .../models/llama3_2_vision/runner/native.py | 6 ++- exir/passes/init_mutable_buffer_pass.py | 2 +- exir/program/_program.py | 2 +- extension/llm/modules/test/test_kv_cache.py | 47 ++++++++++++------- 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/examples/models/llama3_2_vision/runner/native.py b/examples/models/llama3_2_vision/runner/native.py index ee8c2ec624..8180f1abbf 100644 --- a/examples/models/llama3_2_vision/runner/native.py +++ b/examples/models/llama3_2_vision/runner/native.py @@ -18,8 +18,10 @@ TorchTuneLlamaRunner, ) -from executorch.extension.pybindings.portable_lib import _load_for_executorch -from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch, + _load_for_executorch_from_buffer, +) # Load custom ops and quantized ops. from executorch.extension.pybindings import portable_lib # noqa # usort: skip diff --git a/exir/passes/init_mutable_buffer_pass.py b/exir/passes/init_mutable_buffer_pass.py index 5ffb223165..688410cc2f 100644 --- a/exir/passes/init_mutable_buffer_pass.py +++ b/exir/passes/init_mutable_buffer_pass.py @@ -9,6 +9,7 @@ from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue from executorch.exir.passes.spec_prop_pass import make_spec + class InitMutableBufferPass(ExportPass): def __init__(self) -> None: super().__init__() @@ -18,4 +19,3 @@ def placeholder(self, name: str, arg, meta): meta["et_init_buffer"] = True return super().placeholder(name, arg, meta) - diff --git a/exir/program/_program.py b/exir/program/_program.py index 7b6686b826..e6247231f0 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -34,6 +34,7 @@ OpReplacePass, ) from executorch.exir.passes.external_constants_pass import external_constants_pass +from executorch.exir.passes.init_mutable_buffer_pass import InitMutableBufferPass from executorch.exir.passes.insert_write_back_for_buffers_pass import ( insert_write_back_for_buffers_pass, ) @@ -46,7 +47,6 @@ from executorch.exir.passes.replace_view_copy_with_view_pass import ( ReplaceViewCopyWithViewPass, ) -from executorch.exir.passes.init_mutable_buffer_pass import InitMutableBufferPass from executorch.exir.passes.spec_prop_pass import SpecPropPass from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass from executorch.exir.print_program import pretty_print, print_program diff --git a/extension/llm/modules/test/test_kv_cache.py b/extension/llm/modules/test/test_kv_cache.py index 6379d8abd5..b90f9257c1 100644 --- a/extension/llm/modules/test/test_kv_cache.py +++ b/extension/llm/modules/test/test_kv_cache.py @@ -4,13 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import tempfile import unittest from typing import Callable, Tuple import torch - from executorch.exir import EdgeCompileConfig, to_edge + +from executorch.extension.export_util.utils import save_pte_program from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache + +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch_from_buffer, +) from executorch.runtime import Runtime from torch.testing import assert_close from torchtune.modules.kv_cache import KVCache @@ -67,21 +73,10 @@ def _test_kv_cache(self, et_cache_module: Callable): prefill_seq_len, self.batch_size, self.num_kv_heads, self.head_dim ) - print() - print("Prefilling...") - print() - et_res = et_cache_module(k_val, v_val) tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans) tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2)) - print() - print("Final tt kv_cache.cache_pos") - print(self.tt_kv_cache.cache_pos) - print("Final tt kv_cache.k_cache") - print(self.tt_kv_cache.k_cache) - print() - # Check torchtune matches executorch. assert_close(et_res, tt_res_transposed) @@ -112,7 +107,6 @@ def _test_kv_cache(self, et_cache_module: Callable): self.assertTrue(et_k_cache[0][prefill_seq_len + 1][0][0] == 0) - def export_kv_cache( self, kv_cache: torch.nn.Module, @@ -179,9 +173,6 @@ def test_kv_cache_executorch(self): ) et_program = edge_program.to_executorch() - """DEBUG the executorch program""" - et_program.dump_executorch_program(verbose=True) - runtime = Runtime.get() program = runtime.load_program(et_program.buffer) method = program.load_method("forward") @@ -192,3 +183,27 @@ def wrapped_callable(k_val: torch.Tensor, v_val: torch.Tensor) -> torch.Tensor: self._test_kv_cache(wrapped_callable) + def test_kv_cache_executorch_from_file(self): + exported_kv_cache = self.export_kv_cache(self.et_kv_cache) + edge_program = to_edge( + exported_kv_cache, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, + ), + ) + et_program = edge_program.to_executorch() + + with tempfile.TemporaryDirectory() as tempdir: + pte_path = save_pte_program(et_program, "test_et_kv_cache", tempdir) + with open(pte_path, "rb") as f: + model_bytes = f.read() + loaded_et_program = _load_for_executorch_from_buffer(model_bytes) + + # Since method.execute expects a tuple of args. + def wrapped_callable( + k_val: torch.Tensor, v_val: torch.Tensor + ) -> torch.Tensor: + return loaded_et_program.forward((k_val, v_val)) + + self._test_kv_cache(wrapped_callable) From 9e68531ac4846f6740bcd6d78f0c2fb6c0bb77d7 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Tue, 17 Dec 2024 21:37:55 -0800 Subject: [PATCH 06/15] Only add pass when vision model --- examples/models/llama/export_llama_lib.py | 12 ++++++++-- .../models/llama3_2_vision/runner/native.py | 2 -- exir/emit/_emitter.py | 1 - exir/passes/init_mutable_buffer_pass.py | 21 ------------------ exir/program/_program.py | 2 -- extension/llm/export/builder.py | 22 +++++++++++-------- runtime/executor/method.cpp | 5 ----- 7 files changed, 23 insertions(+), 42 deletions(-) delete mode 100644 exir/passes/init_mutable_buffer_pass.py diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index ea4296cc52..65bc8991a8 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -23,6 +23,9 @@ import torch from executorch.devtools.etrecord import generate_etrecord +from executorch.exir.passes.cache_pos_init_mutable_pass import ( + CachePosToInitializedMutableBufferPass, +) from executorch.extension.llm.export.builder import DType, LLMEdgeManager @@ -760,6 +763,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") + additional_passes = [] + if args.model in TORCHTUNE_DEFINED_MODELS: + additional_passes = [CachePosToInitializedMutableBufferPass()] if args.generate_etrecord: if not builder_exported_to_edge.edge_manager: raise ValueError("Unable to generate etrecord due to missing edge manager.") @@ -774,7 +780,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) - builder = builder.to_executorch() + builder = builder.to_executorch( + passes=additional_passes, + ) # Generate ETRecord if edge_manager_copy: @@ -792,7 +800,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) - builder = builder.to_executorch() + builder = builder.to_executorch(passes=additional_passes) if args.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") diff --git a/examples/models/llama3_2_vision/runner/native.py b/examples/models/llama3_2_vision/runner/native.py index 8180f1abbf..105ddf2054 100644 --- a/examples/models/llama3_2_vision/runner/native.py +++ b/examples/models/llama3_2_vision/runner/native.py @@ -19,7 +19,6 @@ ) from executorch.extension.pybindings.portable_lib import ( - _load_for_executorch, _load_for_executorch_from_buffer, ) @@ -50,7 +49,6 @@ def __init__(self, args): with open(args.pte, "rb") as f: self.model_bytes = f.read() self.model = _load_for_executorch_from_buffer(self.model_bytes) - # self.model = _load_for_executorch(args.pte) self.use_kv_cache = args.kv_cache def forward( diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 2ee6bb60b6..119fee3cc6 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1607,7 +1607,6 @@ def placeholder( if isinstance(target, str) and isinstance(spec, TensorSpec): fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec) - print(f"fqn: {fqn}, is_mutable_buffer: {is_mutable_buffer}") # If the placeholder has a constant_tag, it is external to the PTE file # and requires a fqn and location=TensorDataLocation.EXTERNAL diff --git a/exir/passes/init_mutable_buffer_pass.py b/exir/passes/init_mutable_buffer_pass.py deleted file mode 100644 index 688410cc2f..0000000000 --- a/exir/passes/init_mutable_buffer_pass.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch - -from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue -from executorch.exir.passes.spec_prop_pass import make_spec - - -class InitMutableBufferPass(ExportPass): - def __init__(self) -> None: - super().__init__() - - def placeholder(self, name: str, arg, meta): - if "cache_pos" in name: - meta["et_init_buffer"] = True - - return super().placeholder(name, arg, meta) diff --git a/exir/program/_program.py b/exir/program/_program.py index e6247231f0..fd1d0aca3d 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -34,7 +34,6 @@ OpReplacePass, ) from executorch.exir.passes.external_constants_pass import external_constants_pass -from executorch.exir.passes.init_mutable_buffer_pass import InitMutableBufferPass from executorch.exir.passes.insert_write_back_for_buffers_pass import ( insert_write_back_for_buffers_pass, ) @@ -707,7 +706,6 @@ def edge_to_executorch_passes( passes: List[PassType] = [ *config.passes, SpecPropPass(), - InitMutableBufferPass(), # ExecuTorch backend ops are unable to handle unbacked symints. So after # this pass, passes cannot be Interpreter-based, because it will fail if # there exists an unbacked symint operation. diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 8bb98ebeae..619d9782a7 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -25,6 +25,7 @@ from executorch.exir.backend.utils import format_delegated_graph from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig +from executorch.exir.pass_manager import PassType from executorch.exir.passes import MemoryPlanningPass from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass @@ -395,26 +396,29 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag return self - def to_executorch(self) -> "LLMEdgeManager": + def to_executorch(self, passes: Optional[List[PassType]]) -> "LLMEdgeManager": """ Lower the model to executorch and get an ExecutorchProgram. """ assert self.edge_manager, "Need to run export_to_edge() first" + to_executorch_passes = [ + # If there are Linear operations left in the graph, let's execute + # them with the optimized op_linear rather than materializing a + # transpose followed by a regular op_mm. + ConvertToLinearPass(), + QuantFusionPass(), + ] + if passes: + to_executorch_passes.extend(passes) + self.export_program = self.edge_manager.to_executorch( ExecutorchBackendConfig( extract_delegate_segments=True, - passes=[ - # If there are Linear operations left in the graph, let's execute - # them with the optimized op_linear rather than materializing a - # transpose followed by a regular op_mm. - ConvertToLinearPass(), - QuantFusionPass(), - ], + passes=to_executorch_passes, memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) ) - print(self.export_program.dump_executorch_program(verbose=True)) logging.info( "Required memory for activation in bytes: {}".format( self.export_program._emitter_output.program.execution_plan[ diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index 72110879b3..b1094ed122 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -11,7 +11,6 @@ #include // @donotremove #include #include -#include #include #include @@ -1181,10 +1180,6 @@ Error Method::execute_instruction() { if (err == Error::Ok) { step_state_.instr_idx = next_instr_idx; } - - // TODO: Print an EValue. - std::cout << "(" << values_[1] << " ) Printing kv_cache k_cache: " << executorch::extension::evalue_edge_items(9216) << values_[2] << std::endl; - return err; } From 925409d6aa0e9b2b74e26246280c08623659b733 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Tue, 17 Dec 2024 22:11:07 -0800 Subject: [PATCH 07/15] Add comments --- examples/models/llama3_2_vision/runner/native.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/models/llama3_2_vision/runner/native.py b/examples/models/llama3_2_vision/runner/native.py index 105ddf2054..2b4d709f9b 100644 --- a/examples/models/llama3_2_vision/runner/native.py +++ b/examples/models/llama3_2_vision/runner/native.py @@ -45,9 +45,16 @@ def __init__(self, args): use_kv_cache=args.kv_cache, vocab_size=params["vocab_size"], ) + # Save the loaded model bytes to prevent data from going out of + # scope after the `with` and getting cleaned up by Python's + # garbage collector. self.model_bytes = None with open(args.pte, "rb") as f: self.model_bytes = f.read() + # Need to use _load_for_executorch_from_buffer instead of + # _load_for_executorch because the latter uses MmapDataLoader, + # which doesn't have load_into() implemented, which is needed + # for loading initialized mutable buffers. self.model = _load_for_executorch_from_buffer(self.model_bytes) self.use_kv_cache = args.kv_cache From 2a3fe8b49c974059e455351a9ddd29dd4425023e Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Tue, 17 Dec 2024 22:12:10 -0800 Subject: [PATCH 08/15] Remove import --- runtime/executor/method.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index b1094ed122..eb0a8c8bb6 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -12,7 +12,6 @@ #include #include -#include #include #include #include From 61101c29c054e3466b08ef3bc866a36a3faa23d2 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Wed, 18 Dec 2024 12:06:25 -0800 Subject: [PATCH 09/15] Add pass --- exir/passes/cache_pos_init_mutable_pass.py | 27 ++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 exir/passes/cache_pos_init_mutable_pass.py diff --git a/exir/passes/cache_pos_init_mutable_pass.py b/exir/passes/cache_pos_init_mutable_pass.py new file mode 100644 index 0000000000..ccfc325d3e --- /dev/null +++ b/exir/passes/cache_pos_init_mutable_pass.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from executorch.exir.pass_base import ExportPass + + +class CachePosToInitializedMutableBufferPass(ExportPass): + """ + If the buffer has the name "cache_pos", such as in an kv_cache + module with `self.register_buffer("cache_pos", torch.arange(10))`, + mark it with a custom tag which later is used by the emitter to + flag spec.const to True, which provides the mutable buffer with + an initialized state. + """ + + def __init__(self) -> None: + super().__init__() + + def placeholder(self, name: str, arg, meta): + if "cache_pos" in name: + meta["et_init_buffer"] = True + + return super().placeholder(name, arg, meta) From 4ee95d3fc67f623044fd3a9bb5183683f1330dfa Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Fri, 20 Dec 2024 16:39:19 -0800 Subject: [PATCH 10/15] PR review --- examples/models/llama/export_llama_lib.py | 6 ++---- ...pos_init_mutable_pass.py => init_mutable_pass.py} | 12 ++++++++---- extension/llm/export/builder.py | 4 +++- 3 files changed, 13 insertions(+), 9 deletions(-) rename exir/passes/{cache_pos_init_mutable_pass.py => init_mutable_pass.py} (72%) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 65bc8991a8..e267f5d71d 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -23,9 +23,7 @@ import torch from executorch.devtools.etrecord import generate_etrecord -from executorch.exir.passes.cache_pos_init_mutable_pass import ( - CachePosToInitializedMutableBufferPass, -) +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass from executorch.extension.llm.export.builder import DType, LLMEdgeManager @@ -765,7 +763,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 additional_passes = [] if args.model in TORCHTUNE_DEFINED_MODELS: - additional_passes = [CachePosToInitializedMutableBufferPass()] + additional_passes = [InitializedMutableBufferPass(["cache_pos"])] if args.generate_etrecord: if not builder_exported_to_edge.edge_manager: raise ValueError("Unable to generate etrecord due to missing edge manager.") diff --git a/exir/passes/cache_pos_init_mutable_pass.py b/exir/passes/init_mutable_pass.py similarity index 72% rename from exir/passes/cache_pos_init_mutable_pass.py rename to exir/passes/init_mutable_pass.py index ccfc325d3e..72a67b765a 100644 --- a/exir/passes/cache_pos_init_mutable_pass.py +++ b/exir/passes/init_mutable_pass.py @@ -5,10 +5,12 @@ # LICENSE file in the root directory of this source tree. +from typing import List + from executorch.exir.pass_base import ExportPass -class CachePosToInitializedMutableBufferPass(ExportPass): +class InitializedMutableBufferPass(ExportPass): """ If the buffer has the name "cache_pos", such as in an kv_cache module with `self.register_buffer("cache_pos", torch.arange(10))`, @@ -17,11 +19,13 @@ class CachePosToInitializedMutableBufferPass(ExportPass): an initialized state. """ - def __init__(self) -> None: + def __init__(self, patterns: List[str]) -> None: super().__init__() + self.patterns = patterns def placeholder(self, name: str, arg, meta): - if "cache_pos" in name: - meta["et_init_buffer"] = True + for pattern in self.patterns: + if pattern in name: + meta["et_init_buffer"] = True return super().placeholder(name, arg, meta) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 619d9782a7..390e2e47c0 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -396,7 +396,9 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag return self - def to_executorch(self, passes: Optional[List[PassType]]) -> "LLMEdgeManager": + def to_executorch( + self, passes: Optional[List[PassType]] = None + ) -> "LLMEdgeManager": """ Lower the model to executorch and get an ExecutorchProgram. """ From e297c9b1d19d51a79a667b6592102fd6dd4e2276 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Fri, 20 Dec 2024 17:53:09 -0800 Subject: [PATCH 11/15] Fix test --- extension/llm/modules/test/test_kv_cache.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/extension/llm/modules/test/test_kv_cache.py b/extension/llm/modules/test/test_kv_cache.py index b90f9257c1..6029a03882 100644 --- a/extension/llm/modules/test/test_kv_cache.py +++ b/extension/llm/modules/test/test_kv_cache.py @@ -10,6 +10,8 @@ import torch from executorch.exir import EdgeCompileConfig, to_edge +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass from executorch.extension.export_util.utils import save_pte_program from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache @@ -171,7 +173,10 @@ def test_kv_cache_executorch(self): _check_ir_validity=False, ), ) - et_program = edge_program.to_executorch() + et_config = ExecutorchBackendConfig( + passes=[InitializedMutableBufferPass(["cache_pos"])], + ) + et_program = edge_program.to_executorch(config=et_config) runtime = Runtime.get() program = runtime.load_program(et_program.buffer) @@ -192,7 +197,10 @@ def test_kv_cache_executorch_from_file(self): _check_ir_validity=False, ), ) - et_program = edge_program.to_executorch() + et_config = ExecutorchBackendConfig( + passes=[InitializedMutableBufferPass(["cache_pos"])], + ) + et_program = edge_program.to_executorch(config=et_config) with tempfile.TemporaryDirectory() as tempdir: pte_path = save_pte_program(et_program, "test_et_kv_cache", tempdir) From 8145cdaf0e21df8642c91e9ccf6f211bdbf56033 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Mon, 23 Dec 2024 11:06:21 -0800 Subject: [PATCH 12/15] Last changes --- exir/emit/_emitter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 119fee3cc6..1cd286b02f 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1575,7 +1575,8 @@ def _find_fqn_for_placeholder( warnings.warn( "Mutation on a buffer in the model is detected. ExecuTorch assumes " "buffers that are mutated in the graph have a meaningless initial state, " - "only the shape and dtype will be serialized.", + "only the shape and dtype will be serialized, unless a pass which marks " + "spec.const=True such as InitializedMutableBufferPass is run.", UserWarning, stacklevel=1, ) From a2b7ee3db0938642c291175f5cfff50b65bc0e26 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Wed, 8 Jan 2025 11:08:19 -0800 Subject: [PATCH 13/15] Update attention test --- extension/llm/modules/test/test_attention.py | 34 +++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/extension/llm/modules/test/test_attention.py b/extension/llm/modules/test/test_attention.py index 82ee1febf4..6cd05b4bf6 100644 --- a/extension/llm/modules/test/test_attention.py +++ b/extension/llm/modules/test/test_attention.py @@ -11,6 +11,8 @@ import torch from executorch.exir import EdgeCompileConfig, to_edge +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass from executorch.extension.llm.modules.attention import ( MultiHeadAttention as ETMultiHeadAttention, ) @@ -114,7 +116,7 @@ def test_attention_eager(self): et_res = self.et_mha(self.x, self.x) # Self attention. tt_res = self.tt_mha(self.x, self.x) # Self attention. - self.assertTrue(torch.allclose(et_res, tt_res)) + assert_close(et_res, tt_res) self.et_mha.reset_cache() self.tt_mha.reset_cache() @@ -125,7 +127,7 @@ def test_attention_eager(self): self.x, self.x, input_pos=self.input_pos ) # Self attention with input pos. - self.assertTrue(torch.allclose(et_res, tt_res)) + assert_close(et_res, tt_res) # test kv cache read. Input pos can be [10, 11, ..., 19] next_input_pos = torch.arange(10, 20).unsqueeze(0) @@ -187,9 +189,8 @@ def test_attention_aoti(self): def test_attention_executorch(self): # Self attention. - # TODO: Fix kv cache - # self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) - # self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) with torch.no_grad(): et_mha_ep = torch.export.export( @@ -202,9 +203,15 @@ def test_attention_executorch(self): et_program = to_edge( et_mha_ep, compile_config=EdgeCompileConfig( - _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg] + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg], + _check_ir_validity=False, ), - ).to_executorch() + ).to_executorch( + config=ExecutorchBackendConfig( + passes=[InitializedMutableBufferPass(["cache_pos"])], + ) + ) + runtime = Runtime.get() program = runtime.load_program(et_program.buffer) method = program.load_method("forward") @@ -219,9 +226,8 @@ def test_attention_torch_cond_eager(self): self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) - # mask mask = self.causal_mask[self.input_pos, :] - # First run + # First run. et_res = self.et_mha( self.x, self.x, mask=mask, input_pos=self.input_pos ) # Self attention with input pos. @@ -229,18 +235,14 @@ def test_attention_torch_cond_eager(self): self.x, self.x, mask=mask, input_pos=self.input_pos ) # Self attention with input pos. - self.assertTrue(torch.allclose(et_res, tt_res)) + assert_close(et_res, tt_res) # Second run test kv cache read. Input pos is [10, 11, ..., 19] next_input_pos = torch.arange(10, 20).unsqueeze(0) empty_y = torch.full_like(self.x, torch.nan) mask = self.causal_mask[next_input_pos, :] - et_res = self.et_mha( - self.x, empty_y, mask=mask, input_pos=next_input_pos - ) # Self attention with input pos. - tt_res = self.tt_mha( - self.x, None, mask=mask, input_pos=next_input_pos - ) # Self attention with input pos. + et_res = self.et_mha(self.x, empty_y, mask=mask, input_pos=next_input_pos) + tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos) assert_close(et_res, tt_res) From 93f99ad0041d64548b953c93d0ebf90b2addecdf Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Thu, 9 Jan 2025 13:11:08 -0800 Subject: [PATCH 14/15] Tests --- exir/emit/_emitter.py | 3 +-- exir/emit/test/test_emit.py | 52 +++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 1cd286b02f..3bbcb2128a 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1657,8 +1657,7 @@ def placeholder( spec.storage = real_tensor.untyped_storage() # User inputs and mutable buffers are not constants, other buffers or parameters are. - if initialize_buffer: - assert is_mutable_buffer + if initialize_buffer and is_mutable_buffer: spec.const = True else: spec.const = not (is_user_input or is_mutable_buffer) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index a645fa5377..e3850bc96d 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -9,6 +9,7 @@ import typing import unittest from contextlib import contextmanager +from copy import deepcopy from typing import List, Optional, Tuple import executorch.exir as exir @@ -31,6 +32,7 @@ from executorch.exir.error import InternalError from executorch.exir.passes import MemoryPlanningPass from executorch.exir.passes.constant_prop_pass import constant_prop_pass +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.print_program import pretty_print, print_program # noqa from executorch.exir.schema import ( @@ -56,6 +58,7 @@ from executorch.extension.pybindings.portable_lib import ( _load_for_executorch_from_buffer, ) +from executorch.runtime import Runtime from functorch.experimental import control_flow from torch import nn @@ -243,6 +246,55 @@ def forward(self, x): ) self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null) + def test_initialized_mutable_buffer(self): + """Test that mutable buffers can hold meaningful initialized state.""" + + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + # Mutable buffer with non-empty initial state. + self.register_buffer("cache_pos", torch.arange(0, 10)) + + def forward(self, x): + self.cache_pos.add_(1) + return self.cache_pos + + m = TestModule() + example_inputs = (torch.ones(10),) + ep = torch.export.export(m, example_inputs) + edge = to_edge( + ep, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + ), + ) + + # Save a copy of the edge program since to_executorch is + # stateful to sombe degree. + edge_copy = deepcopy(edge) + et_config = ExecutorchBackendConfig( + passes=[InitializedMutableBufferPass(["cache_pos"])], + ) + et_program_init_pass = edge.to_executorch(config=et_config) + et_program_regular = edge_copy.to_executorch() + + runtime = Runtime.get() + program_init_pass = runtime.load_program(et_program_init_pass.buffer) + method_init_pass = program_init_pass.load_method("forward") + + program_regular = runtime.load_program(et_program_regular.buffer) + method_regular = program_regular.load_method("forward") + + # Test that the mutable buffer is initialized. + torch.allclose( + method_init_pass.execute((example_inputs))[0], torch.arange(1, 11) + ) + # Test that the mutable buffer is uninitialized and starts with default zeros. + torch.allclose( + method_regular.execute((example_inputs))[0], + torch.ones(10, dtype=torch.int64), + ) + def test_int_list_input(self): class M(torch.nn.Module): def forward(self, x, y, z): From 69e36fb0f676b9128f47d397bc1d4fbc59625065 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Thu, 9 Jan 2025 14:48:39 -0800 Subject: [PATCH 15/15] Dave pr comment --- exir/emit/test/test_emit.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index e3850bc96d..0da4085914 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -270,7 +270,7 @@ def forward(self, x): ) # Save a copy of the edge program since to_executorch is - # stateful to sombe degree. + # stateful to some degree. edge_copy = deepcopy(edge) et_config = ExecutorchBackendConfig( passes=[InitializedMutableBufferPass(["cache_pos"])], @@ -289,7 +289,8 @@ def forward(self, x): torch.allclose( method_init_pass.execute((example_inputs))[0], torch.arange(1, 11) ) - # Test that the mutable buffer is uninitialized and starts with default zeros. + # Test that the mutable buffer is uninitialized and starts with default zeros, + # we test equality with torch.ones because of the mutation += 1 in the model forward. torch.allclose( method_regular.execute((example_inputs))[0], torch.ones(10, dtype=torch.int64),