Skip to content

Commit

Permalink
Fixes test but not model
Browse files Browse the repository at this point in the history
  • Loading branch information
dvorjackz committed Dec 13, 2024
1 parent aac90a0 commit 917fb0d
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/models/llama3_2_vision/runner/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from executorch.extension.pybindings import portable_lib # noqa # usort: skip

# Note: import this after portable_lib
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip
from executorch.kernels import quantized # noqa


Expand Down
4 changes: 3 additions & 1 deletion exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,7 @@ def _find_fqn_for_placeholder(
fqn = self.exported_program.graph_signature.inputs_to_parameters[target]

elif target in self.exported_program.graph_signature.inputs_to_buffers:
breakpoint()
fqn = self.exported_program.graph_signature.inputs_to_buffers[target]

# if the buffer is mutated then record that
Expand Down Expand Up @@ -1606,6 +1607,7 @@ def placeholder(

if isinstance(target, str) and isinstance(spec, TensorSpec):
fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)
print(f"fqn: {fqn}, is_mutable_buffer: {is_mutable_buffer}")

# If the placeholder has a constant_tag, it is external to the PTE file
# and requires a fqn and location=TensorDataLocation.EXTERNAL
Expand Down Expand Up @@ -1655,7 +1657,7 @@ def placeholder(
spec.storage = real_tensor.untyped_storage()

# User inputs and mutable buffers are not constants, other buffers or parameters are.
spec.const = not (is_user_input or is_mutable_buffer)
spec.const = not is_user_input

evalue = (
self._tensor_spec_to_evalue(spec, constant_tag)
Expand Down
47 changes: 47 additions & 0 deletions exir/passes/init_mutable_buffer_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
from executorch.exir.passes.spec_prop_pass import make_spec

class InitMutableBufferPass(ExportPass):
def __init__(self) -> None:
super().__init__()

def update_placeholder_tensor_specs(
self,
exported_program: torch.export.ExportedProgram,
graph_module: torch.fx.GraphModule,
) -> None:
"""
Update the tensor specs for all placeholder nodes such that
placeholders that are parameters are marked as constant.
"""
for node in graph_module.graph.nodes:
if node.op != "placeholder":
continue
if "spec" not in node.meta:
raise RuntimeError(f"Placeholder node {node} missing meta['spec']")
# print(node)
spec = node.meta["spec"]
if (isinstance(node.target, str) and
node.target in exported_program.graph_signature.inputs_to_buffers and exported_program.graph_signature.inputs_to_buffers[node.target] in exported_program.state_dict):
# print(f"Setting {node.target}.const = True")
# breakpoint()
# print(exported_program.state_dict[exported_program.graph_signature.inputs_to_buffers[node.target]])
spec.const = True

# pyre-ignore
def placeholder(self, name: str, arg, meta):
# print(name)
meta["spec"] = make_spec(arg, const=meta.data['spec'].const)
# if name == "b_kv_cache_cache_pos":
# print("breakpoint")
# breakpoint()

return super().placeholder(name, arg, meta)
4 changes: 2 additions & 2 deletions exir/passes/spec_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@


# pyre-ignore
def make_spec(x):
def make_spec(x, const=False):
if isinstance(x, torch.Tensor):
return TensorSpec.from_tensor(x)
return TensorSpec.from_tensor(x, const)
elif isinstance(x, (int, bool, float)):
return x
else:
Expand Down
4 changes: 4 additions & 0 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from executorch.exir.passes.replace_view_copy_with_view_pass import (
ReplaceViewCopyWithViewPass,
)
from executorch.exir.passes.init_mutable_buffer_pass import InitMutableBufferPass
from executorch.exir.passes.spec_prop_pass import SpecPropPass
from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass
from executorch.exir.print_program import pretty_print, print_program
Expand Down Expand Up @@ -706,6 +707,7 @@ def edge_to_executorch_passes(
passes: List[PassType] = [
*config.passes,
SpecPropPass(),
InitMutableBufferPass(),
# ExecuTorch backend ops are unable to handle unbacked symints. So after
# this pass, passes cannot be Interpreter-based, because it will fail if
# there exists an unbacked symint operation.
Expand Down Expand Up @@ -1352,6 +1354,8 @@ def to_executorch(
gm, new_signature = insert_write_back_for_buffers_pass(program)
new_gm = program.graph_module
for p in edge_to_executorch_passes(config, name):
if isinstance(p, InitMutableBufferPass):
p.update_placeholder_tensor_specs(program, new_gm)
new_gm_res = p(new_gm)
assert new_gm_res is not None
new_gm = new_gm_res.graph_module
Expand Down
1 change: 1 addition & 0 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def to_executorch(self) -> "LLMEdgeManager":
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
)
)
print(self.export_program.to_executorch_program(verbose=True))
logging.info(
"Required memory for activation in bytes: {}".format(
self.export_program._emitter_output.program.execution_plan[
Expand Down
2 changes: 1 addition & 1 deletion extension/llm/modules/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
"v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
)
self.register_buffer(
"cache_pos", torch.arange(0, self.max_seq_len), persistent=False
"cache_pos", torch.arange(0, self.max_seq_len), persistent=True
)
self.batch_size = batch_size

Expand Down
20 changes: 19 additions & 1 deletion extension/llm/modules/test/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,21 @@ def _test_kv_cache(self, et_cache_module: Callable):
prefill_seq_len, self.batch_size, self.num_kv_heads, self.head_dim
)

print()
print("Prefilling...")
print()

et_res = et_cache_module(k_val, v_val)
tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans)
tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2))

print()
print("Final tt kv_cache.cache_pos")
print(self.tt_kv_cache.cache_pos)
print("Final tt kv_cache.k_cache")
print(self.tt_kv_cache.k_cache)
print()

# Check torchtune matches executorch.
assert_close(et_res, tt_res_transposed)

Expand All @@ -89,17 +100,19 @@ def _test_kv_cache(self, et_cache_module: Callable):

et_res = et_cache_module(k_val, v_val)
tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans)
tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2))

# Check torchtune matches executorch.
tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2))
assert_close(tt_res_transposed, et_res)

# All rows should be filled with 1s up to 3 + 1th row.
et_k_cache = et_res[0]
for i in range(prefill_seq_len + 1):
self.assertTrue(et_k_cache[0][i][0][0] == 1)

self.assertTrue(et_k_cache[0][prefill_seq_len + 1][0][0] == 0)


def export_kv_cache(
self,
kv_cache: torch.nn.Module,
Expand Down Expand Up @@ -165,6 +178,10 @@ def test_kv_cache_executorch(self):
),
)
et_program = edge_program.to_executorch()

"""DEBUG the executorch program"""
et_program.dump_executorch_program(verbose=True)

runtime = Runtime.get()
program = runtime.load_program(et_program.buffer)
method = program.load_method("forward")
Expand All @@ -174,3 +191,4 @@ def wrapped_callable(k_val: torch.Tensor, v_val: torch.Tensor) -> torch.Tensor:
return method.execute((k_val, v_val))

self._test_kv_cache(wrapped_callable)

6 changes: 6 additions & 0 deletions runtime/executor/method.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
#include <cinttypes> // @donotremove
#include <cstdint>
#include <cstdio>
#include <iostream>

#include <executorch/extension/evalue_util/print_evalue.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/event_tracer_hooks.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
Expand Down Expand Up @@ -1179,6 +1181,10 @@ Error Method::execute_instruction() {
if (err == Error::Ok) {
step_state_.instr_idx = next_instr_idx;
}

// TODO: Print an EValue.
std::cout << "(" << values_[1] << " ) Printing kv_cache k_cache: " << executorch::extension::evalue_edge_items(9216) << values_[2] << std::endl;

return err;
}

Expand Down

0 comments on commit 917fb0d

Please sign in to comment.