From ee2eb1570f6f942a0f6d07b21bae7d460a1551f5 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Mon, 16 Dec 2024 13:06:34 -0800 Subject: [PATCH] Updated pass --- .../models/llama3_2_vision/runner/native.py | 6 +++- exir/emit/_emitter.py | 8 +++-- exir/passes/init_mutable_buffer_pass.py | 34 +++---------------- exir/passes/spec_prop_pass.py | 4 +-- exir/program/_program.py | 2 -- extension/llm/export/builder.py | 2 +- extension/llm/modules/kv_cache.py | 2 +- 7 files changed, 19 insertions(+), 39 deletions(-) diff --git a/examples/models/llama3_2_vision/runner/native.py b/examples/models/llama3_2_vision/runner/native.py index ae36fa9cf82..a510db5c294 100644 --- a/examples/models/llama3_2_vision/runner/native.py +++ b/examples/models/llama3_2_vision/runner/native.py @@ -19,6 +19,7 @@ ) from executorch.extension.pybindings.portable_lib import _load_for_executorch +from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer # Load custom ops and quantized ops. from executorch.extension.pybindings import portable_lib # noqa # usort: skip @@ -43,7 +44,10 @@ def __init__(self, args): use_kv_cache=args.kv_cache, vocab_size=params["vocab_size"], ) - self.model = _load_for_executorch(args.pte) + with open(args.pte, "rb") as f: + model_bytes = f.read() + self.model = _load_for_executorch_from_buffer(model_bytes) + # self.model = _load_for_executorch(args.pte) self.use_kv_cache = args.kv_cache def forward( diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index f55a57b3ff1..2ee6bb60b67 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1566,7 +1566,6 @@ def _find_fqn_for_placeholder( fqn = self.exported_program.graph_signature.inputs_to_parameters[target] elif target in self.exported_program.graph_signature.inputs_to_buffers: - breakpoint() fqn = self.exported_program.graph_signature.inputs_to_buffers[target] # if the buffer is mutated then record that @@ -1603,6 +1602,7 @@ def placeholder( """ spec = self.node.meta["spec"] constant_tag = self.node.meta.get("constant_tag", None) + initialize_buffer = self.node.meta.get("et_init_buffer", None) is_user_input = True if isinstance(target, str) and isinstance(spec, TensorSpec): @@ -1657,7 +1657,11 @@ def placeholder( spec.storage = real_tensor.untyped_storage() # User inputs and mutable buffers are not constants, other buffers or parameters are. - spec.const = not is_user_input + if initialize_buffer: + assert is_mutable_buffer + spec.const = True + else: + spec.const = not (is_user_input or is_mutable_buffer) evalue = ( self._tensor_spec_to_evalue(spec, constant_tag) diff --git a/exir/passes/init_mutable_buffer_pass.py b/exir/passes/init_mutable_buffer_pass.py index a3d7bcb5203..5ffb223165e 100644 --- a/exir/passes/init_mutable_buffer_pass.py +++ b/exir/passes/init_mutable_buffer_pass.py @@ -13,35 +13,9 @@ class InitMutableBufferPass(ExportPass): def __init__(self) -> None: super().__init__() - def update_placeholder_tensor_specs( - self, - exported_program: torch.export.ExportedProgram, - graph_module: torch.fx.GraphModule, - ) -> None: - """ - Update the tensor specs for all placeholder nodes such that - placeholders that are parameters are marked as constant. - """ - for node in graph_module.graph.nodes: - if node.op != "placeholder": - continue - if "spec" not in node.meta: - raise RuntimeError(f"Placeholder node {node} missing meta['spec']") - # print(node) - spec = node.meta["spec"] - if (isinstance(node.target, str) and - node.target in exported_program.graph_signature.inputs_to_buffers and exported_program.graph_signature.inputs_to_buffers[node.target] in exported_program.state_dict): - # print(f"Setting {node.target}.const = True") - # breakpoint() - # print(exported_program.state_dict[exported_program.graph_signature.inputs_to_buffers[node.target]]) - spec.const = True - - # pyre-ignore def placeholder(self, name: str, arg, meta): - # print(name) - meta["spec"] = make_spec(arg, const=meta.data['spec'].const) - # if name == "b_kv_cache_cache_pos": - # print("breakpoint") - # breakpoint() - + if "cache_pos" in name: + meta["et_init_buffer"] = True + return super().placeholder(name, arg, meta) + diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index b979a266c6c..5c563105510 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -18,9 +18,9 @@ # pyre-ignore -def make_spec(x, const=False): +def make_spec(x: if isinstance(x, torch.Tensor): - return TensorSpec.from_tensor(x, const) + return TensorSpec.from_tensor(x) elif isinstance(x, (int, bool, float)): return x else: diff --git a/exir/program/_program.py b/exir/program/_program.py index fafe99dcfcf..7b6686b8267 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1354,8 +1354,6 @@ def to_executorch( gm, new_signature = insert_write_back_for_buffers_pass(program) new_gm = program.graph_module for p in edge_to_executorch_passes(config, name): - if isinstance(p, InitMutableBufferPass): - p.update_placeholder_tensor_specs(program, new_gm) new_gm_res = p(new_gm) assert new_gm_res is not None new_gm = new_gm_res.graph_module diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index e4f950d6826..8bb98ebeaeb 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -414,7 +414,7 @@ def to_executorch(self) -> "LLMEdgeManager": sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) ) - print(self.export_program.to_executorch_program(verbose=True)) + print(self.export_program.dump_executorch_program(verbose=True)) logging.info( "Required memory for activation in bytes: {}".format( self.export_program._emitter_output.program.execution_plan[ diff --git a/extension/llm/modules/kv_cache.py b/extension/llm/modules/kv_cache.py index 477a3b6f771..db940bca3f8 100644 --- a/extension/llm/modules/kv_cache.py +++ b/extension/llm/modules/kv_cache.py @@ -56,7 +56,7 @@ def __init__( "v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False ) self.register_buffer( - "cache_pos", torch.arange(0, self.max_seq_len), persistent=True + "cache_pos", torch.arange(0, self.max_seq_len), persistent=False ) self.batch_size = batch_size