Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make kv cache pos buffer name more specific #7635

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901

additional_passes = []
if args.model in TORCHTUNE_DEFINED_MODELS:
additional_passes = [InitializedMutableBufferPass(["cache_pos"])]
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
if args.generate_etrecord:
if not builder_exported_to_edge.edge_manager:
raise ValueError("Unable to generate etrecord due to missing edge manager.")
Expand Down
6 changes: 3 additions & 3 deletions extension/llm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,13 @@ def calculate_kv(y):

def true_fn(y):
kv_cache = self.kv_cache.clone()
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.kv_cache_pos

def false_fn(y):
k, v = calculate_kv(y)
kv_cache = self.kv_cache.clone()
kv_cache.update(k, v)
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.kv_cache_pos

# If kv cache is None, we expect y to be provided
if self.kv_cache is None:
Expand All @@ -308,7 +308,7 @@ def false_fn(y):
# Update key-value cache
self.kv_cache.k_cache.copy_(k)
self.kv_cache.v_cache.copy_(v)
self.kv_cache.cache_pos.copy_(cache_pos)
self.kv_cache.kv_cache_pos.copy_(cache_pos)

output = self._sdpa(q, k, v, b, s_x, mask=mask)
return self.output_proj(output)
Expand Down
19 changes: 11 additions & 8 deletions extension/llm/modules/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@ def __init__(
self.register_buffer(
"v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
)
# We use "kv_cache_pos" here instead of "cache_pos" since the latter is too generic, and we have
# a InitMutableBuferPass that needs to single out this buffer to initialize (and not others)
# since it takes up space in the pte file.
self.register_buffer(
"cache_pos", torch.arange(0, self.max_seq_len), persistent=False
"kv_cache_pos", torch.arange(0, self.max_seq_len), persistent=False
)
self.batch_size = batch_size

Expand Down Expand Up @@ -105,17 +108,17 @@ def update(
f", but found new key tensors with batch size {k_val.shape[0]}!"
)

assert (self.cache_pos[0] + seq_len) <= self.max_seq_len
assert (self.kv_cache_pos[0] + seq_len) <= self.max_seq_len

k_out = self.k_cache
v_out = self.v_cache

if self.transpose_cache:
k_out[:, :, self.cache_pos[:seq_len]] = k_val
v_out[:, :, self.cache_pos[:seq_len]] = v_val
k_out[:, :, self.kv_cache_pos[:seq_len]] = k_val
v_out[:, :, self.kv_cache_pos[:seq_len]] = v_val
else:
k_out[:, self.cache_pos[:seq_len]] = k_val
v_out[:, self.cache_pos[:seq_len]] = v_val
k_out[:, self.kv_cache_pos[:seq_len]] = k_val
v_out[:, self.kv_cache_pos[:seq_len]] = v_val

# forward cache_pos seq_len positions along
# cache_pos starts at (0, 1, 2, 3, 4, 5, ...)
Expand All @@ -124,7 +127,7 @@ def update(
# this allows us to track the current position in the cache
# after the last update in a compile-friendly way without any dynamism
# e.g. relying on an int size tracker, or re-creating cache_pos every time
self.cache_pos.add_(seq_len)
self.kv_cache_pos.add_(seq_len)

return k_out, v_out

Expand All @@ -144,5 +147,5 @@ def clone(self) -> "KVCache":
)
clone.k_cache.copy_(self.k_cache)
clone.v_cache.copy_(self.v_cache)
clone.cache_pos.copy_(self.cache_pos)
clone.kv_cache_pos.copy_(self.kv_cache_pos)
return clone
4 changes: 2 additions & 2 deletions extension/llm/modules/test/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_attention_executorch(self):
),
).to_executorch(
config=ExecutorchBackendConfig(
passes=[InitializedMutableBufferPass(["cache_pos"])],
passes=[InitializedMutableBufferPass(["kv_cache_pos"])],
)
)

Expand Down Expand Up @@ -330,7 +330,7 @@ def test_attention_torch_cond_executorch(self):
),
).to_executorch(
config=ExecutorchBackendConfig(
passes=[InitializedMutableBufferPass(["cache_pos"])],
passes=[InitializedMutableBufferPass(["kv_cache_pos"])],
)
)

Expand Down
4 changes: 2 additions & 2 deletions extension/llm/modules/test/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_kv_cache_executorch(self):
),
)
et_config = ExecutorchBackendConfig(
passes=[InitializedMutableBufferPass(["cache_pos"])],
passes=[InitializedMutableBufferPass(["kv_cache_pos"])],
)
et_program = edge_program.to_executorch(config=et_config)

Expand All @@ -198,7 +198,7 @@ def test_kv_cache_executorch_from_file(self):
),
)
et_config = ExecutorchBackendConfig(
passes=[InitializedMutableBufferPass(["cache_pos"])],
passes=[InitializedMutableBufferPass(["kv_cache_pos"])],
)
et_program = edge_program.to_executorch(config=et_config)

Expand Down
Loading