From f369ada1fd6b5b5e79d4239bb9fee885f25c3b76 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Mon, 13 Jan 2025 14:25:12 -0800 Subject: [PATCH] Make kv cache pos buffer name more specific --- examples/models/llama/export_llama_lib.py | 2 +- extension/llm/modules/attention.py | 6 +++--- extension/llm/modules/kv_cache.py | 19 +++++++++++-------- extension/llm/modules/test/test_attention.py | 4 ++-- extension/llm/modules/test/test_kv_cache.py | 4 ++-- 5 files changed, 19 insertions(+), 16 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 1c1cf82d19..a562bdf13f 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -778,7 +778,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 additional_passes = [] if args.model in TORCHTUNE_DEFINED_MODELS: - additional_passes = [InitializedMutableBufferPass(["cache_pos"])] + additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] if args.generate_etrecord: if not builder_exported_to_edge.edge_manager: raise ValueError("Unable to generate etrecord due to missing edge manager.") diff --git a/extension/llm/modules/attention.py b/extension/llm/modules/attention.py index 60183801b4..a138585e42 100644 --- a/extension/llm/modules/attention.py +++ b/extension/llm/modules/attention.py @@ -284,13 +284,13 @@ def calculate_kv(y): def true_fn(y): kv_cache = self.kv_cache.clone() - return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos + return kv_cache.k_cache, kv_cache.v_cache, kv_cache.kv_cache_pos def false_fn(y): k, v = calculate_kv(y) kv_cache = self.kv_cache.clone() kv_cache.update(k, v) - return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos + return kv_cache.k_cache, kv_cache.v_cache, kv_cache.kv_cache_pos # If kv cache is None, we expect y to be provided if self.kv_cache is None: @@ -308,7 +308,7 @@ def false_fn(y): # Update key-value cache self.kv_cache.k_cache.copy_(k) self.kv_cache.v_cache.copy_(v) - self.kv_cache.cache_pos.copy_(cache_pos) + self.kv_cache.kv_cache_pos.copy_(cache_pos) output = self._sdpa(q, k, v, b, s_x, mask=mask) return self.output_proj(output) diff --git a/extension/llm/modules/kv_cache.py b/extension/llm/modules/kv_cache.py index db940bca3f..2da11d286c 100644 --- a/extension/llm/modules/kv_cache.py +++ b/extension/llm/modules/kv_cache.py @@ -55,8 +55,11 @@ def __init__( self.register_buffer( "v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False ) + # We use "kv_cache_pos" here instead of "cache_pos" since the latter is too generic, and we have + # a InitMutableBuferPass that needs to single out this buffer to initialize (and not others) + # since it takes up space in the pte file. self.register_buffer( - "cache_pos", torch.arange(0, self.max_seq_len), persistent=False + "kv_cache_pos", torch.arange(0, self.max_seq_len), persistent=False ) self.batch_size = batch_size @@ -105,17 +108,17 @@ def update( f", but found new key tensors with batch size {k_val.shape[0]}!" ) - assert (self.cache_pos[0] + seq_len) <= self.max_seq_len + assert (self.kv_cache_pos[0] + seq_len) <= self.max_seq_len k_out = self.k_cache v_out = self.v_cache if self.transpose_cache: - k_out[:, :, self.cache_pos[:seq_len]] = k_val - v_out[:, :, self.cache_pos[:seq_len]] = v_val + k_out[:, :, self.kv_cache_pos[:seq_len]] = k_val + v_out[:, :, self.kv_cache_pos[:seq_len]] = v_val else: - k_out[:, self.cache_pos[:seq_len]] = k_val - v_out[:, self.cache_pos[:seq_len]] = v_val + k_out[:, self.kv_cache_pos[:seq_len]] = k_val + v_out[:, self.kv_cache_pos[:seq_len]] = v_val # forward cache_pos seq_len positions along # cache_pos starts at (0, 1, 2, 3, 4, 5, ...) @@ -124,7 +127,7 @@ def update( # this allows us to track the current position in the cache # after the last update in a compile-friendly way without any dynamism # e.g. relying on an int size tracker, or re-creating cache_pos every time - self.cache_pos.add_(seq_len) + self.kv_cache_pos.add_(seq_len) return k_out, v_out @@ -144,5 +147,5 @@ def clone(self) -> "KVCache": ) clone.k_cache.copy_(self.k_cache) clone.v_cache.copy_(self.v_cache) - clone.cache_pos.copy_(self.cache_pos) + clone.kv_cache_pos.copy_(self.kv_cache_pos) return clone diff --git a/extension/llm/modules/test/test_attention.py b/extension/llm/modules/test/test_attention.py index 3ecf0b2b4b..296850e0fa 100644 --- a/extension/llm/modules/test/test_attention.py +++ b/extension/llm/modules/test/test_attention.py @@ -219,7 +219,7 @@ def test_attention_executorch(self): ), ).to_executorch( config=ExecutorchBackendConfig( - passes=[InitializedMutableBufferPass(["cache_pos"])], + passes=[InitializedMutableBufferPass(["kv_cache_pos"])], ) ) @@ -330,7 +330,7 @@ def test_attention_torch_cond_executorch(self): ), ).to_executorch( config=ExecutorchBackendConfig( - passes=[InitializedMutableBufferPass(["cache_pos"])], + passes=[InitializedMutableBufferPass(["kv_cache_pos"])], ) ) diff --git a/extension/llm/modules/test/test_kv_cache.py b/extension/llm/modules/test/test_kv_cache.py index 6029a03882..4ed088c58f 100644 --- a/extension/llm/modules/test/test_kv_cache.py +++ b/extension/llm/modules/test/test_kv_cache.py @@ -174,7 +174,7 @@ def test_kv_cache_executorch(self): ), ) et_config = ExecutorchBackendConfig( - passes=[InitializedMutableBufferPass(["cache_pos"])], + passes=[InitializedMutableBufferPass(["kv_cache_pos"])], ) et_program = edge_program.to_executorch(config=et_config) @@ -198,7 +198,7 @@ def test_kv_cache_executorch_from_file(self): ), ) et_config = ExecutorchBackendConfig( - passes=[InitializedMutableBufferPass(["cache_pos"])], + passes=[InitializedMutableBufferPass(["kv_cache_pos"])], ) et_program = edge_program.to_executorch(config=et_config)