Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paged attention #1460

Closed
wants to merge 42 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
1b9c1f4
save current work
Bob-Chen222 Jul 25, 2024
62f1d45
WIP pagemanager
Bob-Chen222 Jul 27, 2024
a28d343
new ckpt
Bob-Chen222 Aug 3, 2024
05e6f73
add workable solution
Bob-Chen222 Aug 3, 2024
eb31ddb
style
Bob-Chen222 Aug 3, 2024
d452929
ckpt
Bob-Chen222 Aug 4, 2024
cf294a2
before debugging
Bob-Chen222 Aug 5, 2024
b599cad
ckpt: done cpu side implementation and seems workable solution //// n…
Bob-Chen222 Aug 5, 2024
bcebb43
ckpt
Bob-Chen222 Aug 7, 2024
76502af
ckpt before adding tmp block to batchconfig
Bob-Chen222 Aug 10, 2024
7424020
ckpt
Bob-Chen222 Aug 15, 2024
6baa2cc
fix compile error
Bob-Chen222 Aug 15, 2024
a3fc691
small modification
Bob-Chen222 Aug 16, 2024
cecb3d4
fix
Bob-Chen222 Aug 16, 2024
040da43
memory access error
Bob-Chen222 Aug 21, 2024
60eb6ff
fix index error on batchfuture, continue debugging
Bob-Chen222 Aug 22, 2024
0b7550d
init spec_num
Bob-Chen222 Aug 23, 2024
8e2b08c
workable solution!
Bob-Chen222 Aug 23, 2024
01b9534
commented printf version
Bob-Chen222 Aug 23, 2024
eaef494
commented out print
Bob-Chen222 Aug 23, 2024
d383de2
ckpt
Bob-Chen222 Aug 23, 2024
869f99a
ckpt
Bob-Chen222 Aug 24, 2024
37c0024
Revert "commented out print"
Bob-Chen222 Aug 26, 2024
bd14e8c
Revert "ckpt"
Bob-Chen222 Aug 26, 2024
222b2b3
Revert "ckpt"
Bob-Chen222 Aug 26, 2024
69f8ec4
fix error in index calculation
Bob-Chen222 Aug 26, 2024
6c61838
fix request manager erase page
Bob-Chen222 Aug 26, 2024
06e8469
erased page will get back
Bob-Chen222 Aug 27, 2024
5f68dbd
ckpt
Bob-Chen222 Aug 28, 2024
0f9a86f
script
Bob-Chen222 Aug 28, 2024
f2cdfcc
fix launch
Bob-Chen222 Aug 29, 2024
bd4f2c9
ckpt
Bob-Chen222 Aug 29, 2024
3e8a355
fix small error
Bob-Chen222 Sep 2, 2024
32a33f4
ckpt for being correct on single prompt
Bob-Chen222 Sep 3, 2024
59951a1
miner cleaning
Bob-Chen222 Sep 5, 2024
e712e6a
ckpt for getting correct code but still have weird error
Bob-Chen222 Sep 6, 2024
4da0ea2
rm some package
Bob-Chen222 Sep 23, 2024
311ed13
documentation attempt
Bob-Chen222 Sep 25, 2024
1297687
some further cleaning
Bob-Chen222 Sep 25, 2024
eb19d8e
cleaning
Bob-Chen222 Sep 25, 2024
3f703ac
cleaning
Bob-Chen222 Sep 25, 2024
17d54ae
cleaning
Bob-Chen222 Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ class BatchConfig {
int first_token_index_in_request = -1;
int first_token_offset_in_batch = -1;
int num_tokens_in_batch = 0;

// page attention: we need some additional attention information here to allocate physical blocks in load_batch_config
int32_t num_kv_pages; //number of kv pages used
int32_t kv_last_page_len; //last page length of kv
RequestGuid request_guid;
};

struct PerTokenInfo {
Expand All @@ -82,6 +87,8 @@ class BatchConfig {
int request_index = -1;
};

std::vector<int32_t> page_indices; //the physical block indices for each page

struct CommittedTokensInfo {
int index_in_kv_cache = -1; // the index in the temporary key-value cache
int request_index = -1; // request index in the batch
Expand Down Expand Up @@ -150,6 +157,7 @@ class BatchConfig {

BitMask causalMask[MAX_NUM_REQUESTS];
PerRequestInfo requestsInfo[MAX_NUM_REQUESTS];
std::vector<int32_t> requestsIndices[MAX_NUM_REQUESTS]; //for kv cache
PerTokenInfo tokensInfo[MAX_NUM_TOKENS];
CommittedTokensInfo committed_tokens[MAX_NUM_TOKENS];
bool request_available[MAX_NUM_REQUESTS];
Expand Down
3 changes: 2 additions & 1 deletion include/flexflow/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ struct FFHandler {
size_t batch_config_metadata_size =
sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) +
sizeof(BatchConfig::request_available) + sizeof(BatchConfig::causalMask) +
sizeof(BatchConfig::committed_tokens);
sizeof(BatchConfig::committed_tokens) + sizeof(int);

void *offload_reserve_space;
size_t offload_reserve_space_size;
DataType quantization_type;
Expand Down
141 changes: 141 additions & 0 deletions include/flexflow/page_manager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/* Copyright 2023 CMU, Stanford, Facebook, LANL
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "flexflow/batch_config.h"
#include "flexflow/inference.h"
#include "flexflow/model.h"
#include "flexflow/config.h"
#include "flexflow/utils/file_loader.h"
#include <future>
#include <mutex>
#include <tokenizers_cpp.h>
#include <deque>

namespace FlexFlow {

using TokenId = BatchConfig::TokenId;

/**
* @class LogicalTokenBlock
* @brief A class to represent a logical block of tokens similar to virtual memory address
*/
class LogicalTokenBlock {
public:
using TokenId = BatchConfig::TokenId;
// Constructor
LogicalTokenBlock(int block_number, uint32_t block_size);

// Method to check if the block is empty
bool is_empty() const;

// Method to get the number of empty slots
int get_num_empty_slots() const;

// Method to get the number of allocated slots
int get_num_alloc_slots();

// Method to check if the block is full
bool is_full() const;

// Method to append tokens
void append_tokens(const std::vector<TokenId>& token_ids_to_append, bool committed);

// Used to clean up the spec tokens in a block since these spec tokens may not be committed after use
void reset_num_spec_tokens();

std::vector<TokenId> get_token_ids() const;

int block_number; // the index of the logical token block
uint32_t block_size; // the size of the block
int num_tokens; // the number of tokens currently stored in the block
int num_commit_tokens; // the number of tokens inside this block that are already committed
int num_spec_tokens; // the number of tokens inside this block that are speculative tokens, which is stored temporarily

std::vector<TokenId> token_ids; //store the token ids in a order that corresponds to the inference sequence
};

/**
* @class PhysicalTokenBlock
* @brief A class to represent a physical block of tokens similar to physical memory address
* It keeps track of the location of the tokens stored on GPU memory
*/
class PhysicalTokenBlock {
public:
// Constructor
PhysicalTokenBlock(int block_number, uint32_t block_size);

int ref_count; // reference count
int block_number; // the index of the physical token block
uint32_t block_size; // the size of the block
};

/**
* @class BlockAllocator
* @brief A Block Manager that is reponsible for maintaining a pool of free blocks
*/
class BlockAllocator {
public:
// Constructor
BlockAllocator(uint32_t block_size, int num_blocks);

// Allocate a block
PhysicalTokenBlock allocate();

// Free a block
void free(PhysicalTokenBlock& block);

// Get the number of free blocks
size_t get_num_free_blocks() const;

private:
uint32_t block_size;
int num_blocks;
std::deque<PhysicalTokenBlock> free_blocks;
};

/*
* @class PageManager
* @brief A wrapper class that manages the kv cache allocation status
* notice that all the layers of model will share the same page manager because the position of kv cache will be the same
*/
class PageManager {
public:
// Get the singleton instance of the PageManager as it will be shared in multiple places
static PageManager *get_page_manager();
using BlockTable = std::vector<PhysicalTokenBlock>;
using RequestGuid = BatchConfig::RequestGuid;
PageManager(uint32_t block_size, int num_total_blocks);

// Prefill the block with the given token ids at the llm prefilling stage
bool prefill(const RequestGuid& request_guid, const std::vector<int>& token_ids);
bool allocate(const RequestGuid& request_guid);
void free(const RequestGuid& request_guid);

size_t get_num_free_blocks() const;
std::vector<int32_t> get_block_table_indices(const RequestGuid& request_guid) const;
int get_num_allocated_blocks(const RequestGuid& request_guid) const;

void erase_last_pages(const RequestGuid& request_guid, int num_pages);

private:
uint32_t block_size; // the size of the block
int num_total_blocks; // the total number of blocks
BlockAllocator block_allocator;
std::unordered_map<int, BlockTable> block_tables;
};

}; // namespace FlexFlow
11 changes: 11 additions & 0 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "flexflow/inference.h"
#include "flexflow/model.h"
#include "flexflow/utils/file_loader.h"
#include "flexflow/page_manager.h"
#include <future>
#include <mutex>
#include <tokenizers_cpp.h>
Expand Down Expand Up @@ -75,6 +76,10 @@ struct Request {
Status status = PENDING;
std::vector<BatchConfig::TokenId> tokens;

// Used for keeping track of the block information
std::vector<LogicalTokenBlock> blocks;
int32_t page_id_commit;

// TokenTree speculative_token_tree;
std::vector<TokenTree> speculative_token_trees;
// To make request manager stateful, we need to store the causal mask here
Expand Down Expand Up @@ -393,6 +398,12 @@ class RequestManager {
double total_request_run_time;
void load_pending_request_to_batch();
void request_complete_clean_up(int batch_index);

/* ---------- Page Attention Helper Functions ---------- */
void _append_logical_block_to_request(Request &request, bool is_commit);
void _append_tokens_to_blocks(Request &request, std::vector<TokenId> const &tokens, bool is_commit, int start = 0, int end = -1);
/* ---------- Page Attention Helper Functions ---------- */

/* ---------- Incremental Decoding Helper Functions ---------- */
bool update_llm_prefill_results(InferenceResult const &result);
bool update_llm_decode_results(InferenceResult const &result);
Expand Down
81 changes: 54 additions & 27 deletions src/ops/tree_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,26 +97,27 @@ using flashinfer::PageStorage;
using flashinfer::PosEncodingMode;
using flashinfer::QKVLayout;

__device__ __forceinline__ size_t get_k_entry_offset(int const req_idx,
int const token_idx,
int const max_num_pages,
// page_idx: the assigned physical block index for this token
// token_idx: can be absolute index in the sequence but in there we just use it as an offset
// hidden_size: the size of the hidden dimension
__device__ __forceinline__ size_t get_k_entry_offset(int const token_idx,
int const page_idx,
int const hidden_size) {
return ((req_idx * max_num_pages + token_idx / kPagesize) * kPagesize * 2 +
token_idx % kPagesize) *
hidden_size;
size_t index = ((page_idx) * kPagesize * 2 + (token_idx % kPagesize)) * hidden_size;
return index;
}

__device__ __forceinline__ size_t get_v_entry_offset(int const req_idx,
int const token_idx,
int const max_num_pages,
__device__ __forceinline__ size_t get_v_entry_offset(int const token_idx,
int const page_idx,
int const hidden_size) {
return ((req_idx * max_num_pages + token_idx / kPagesize) * kPagesize * 2 +
kPagesize + token_idx % kPagesize) *
hidden_size;
size_t index = ((page_idx) * kPagesize * 2 + kPagesize + (token_idx % kPagesize)) * hidden_size;
return index;
}

__global__ void commit_tokens_kernel(
half *kCache_ptr,
int32_t *kv_indptr,
int32_t *kv_page_indices,
BatchConfig::CommittedTokensInfo const *committedTokenInfos,
bool const *request_available,
int num_requests,
Expand All @@ -135,26 +136,28 @@ __global__ void commit_tokens_kernel(
cnt_1++;
}
}

// get the starting index of kv page
int start = kv_indptr[requext_idx_in_batch];
int end = kv_indptr[requext_idx_in_batch + 1] - 1;
for (int i = 0; i < num_committed_tokens; i++) {
if (committedTokenInfos[i].request_index == requext_idx_in_batch) {
int const index_in_kv_cache = committedTokenInfos[i].index_in_kv_cache;
if (index_in_kv_cache == -1) {
continue;
}

int const req_id = committedTokenInfos[i].request_index;
// int const req_id = committedTokenInfos[i].request_index;
int const tok_id = committedTokenInfos[i].token_depth;
int const page_to_idx = committedTokenInfos[i].token_depth / kPagesize;
int const page_from_idx = kv_page_indices[start + (tok_id / kPagesize)];

// page attention: since we cannot store temporary tokens in the cache, we need to figure out another way
size_t from_k_idx = get_k_entry_offset(index_in_kv_cache, page_from_idx, hidden_size),
from_v_idx = get_v_entry_offset(index_in_kv_cache, page_from_idx, hidden_size);

size_t from_k_idx = get_k_entry_offset(
req_id, index_in_kv_cache, max_num_pages, hidden_size),
from_v_idx = get_v_entry_offset(
req_id, index_in_kv_cache, max_num_pages, hidden_size);
size_t to_k_idx =
get_k_entry_offset(req_id, tok_id, max_num_pages, hidden_size),
to_v_idx =
get_v_entry_offset(req_id, tok_id, max_num_pages, hidden_size);
assert(to_k_idx <= from_k_idx);
// page attention: copy the token to the new position
size_t to_k_idx =get_k_entry_offset(tok_id, page_to_idx, hidden_size),
to_v_idx =get_v_entry_offset(tok_id, page_to_idx, hidden_size);

kCache_ptr[to_k_idx + offset] = kCache_ptr[from_k_idx + offset];
kCache_ptr[to_v_idx + offset] = kCache_ptr[from_v_idx + offset];
Expand All @@ -181,6 +184,8 @@ void commit_tokens(TreeIncMultiHeadSelfAttentionMeta const *m,
min(CUDA_NUM_THREADS, parallelism),
0,
stream>>>(static_cast<half *>(m->keyCache),
m->handle.tree_verify_attention_metadata->kv_indptr,
m->handle.tree_verify_attention_metadata->kv_indices,
m->committed_token_infos,
m->request_available,
num_requests,
Expand Down Expand Up @@ -277,6 +282,8 @@ __global__ void
update_qkv_cache_kernel(DT *devQKVProjArray,
half *qTmp_ptr,
half *kCache_ptr,
int32_t *kv_indptr,
int32_t *kv_page_indices,
BatchConfig::PerTokenInfo const *tokenInfos,
BatchConfig::PerRequestInfo *request_infos,
int const max_num_pages,
Expand All @@ -292,11 +299,12 @@ __global__ void
int const req_idx = tokenInfos[token_idx].request_index;
int const token_abs_idx = tokenInfos[token_idx].abs_index_in_request;

// compute the starting index of kv page
int start = kv_indptr[req_idx];
int page_idx = kv_page_indices[start + (token_abs_idx / kPagesize)];
size_t from_idx = token_idx * QKV_WEIGHT_NUM * hidden_size;
size_t to_k_idx = get_k_entry_offset(
req_idx, token_abs_idx, max_num_pages, hidden_size),
to_v_idx = get_v_entry_offset(
req_idx, token_abs_idx, max_num_pages, hidden_size);
size_t to_k_idx = get_k_entry_offset(token_abs_idx, page_idx, hidden_size),
to_v_idx = get_v_entry_offset(token_abs_idx, page_idx, hidden_size);

// key and value cache should be stored interleaved
kCache_ptr[to_k_idx + offset] =
Expand Down Expand Up @@ -324,6 +332,8 @@ void update_qkv_cache(TreeIncMultiHeadSelfAttentionMeta const *m,
stream>>>(static_cast<DT *>(m->devQKVProjArray),
static_cast<half *>(m->queryTmp),
static_cast<half *>(m->keyCache),
m->handle.tree_verify_attention_metadata->kv_indptr,
m->handle.tree_verify_attention_metadata->kv_indices,
m->token_infos,
m->request_infos,
max_num_pages,
Expand Down Expand Up @@ -439,6 +449,23 @@ void tree_verify_attention(TreeIncMultiHeadSelfAttentionMeta const *m,
half *q = static_cast<half *>(m->queryTmp),
*kv = static_cast<half *>(m->keyCache),
*o = static_cast<half *>(m->outputTmp);

static int32_t kv_indices_tmp[BatchConfig::MAX_NUM_REQUESTS * BatchConfig::MAX_NUM_TOKENS];
static int32_t kv_indptr_tmp[BatchConfig::MAX_NUM_REQUESTS + 1];
static int32_t kv_last_page_len_tmp[BatchConfig::MAX_NUM_REQUESTS];
// copy data from device to host
cudaMemcpy(kv_indices_tmp,
m->handle.tree_verify_attention_metadata->kv_indices,
sizeof(int32_t) * BatchConfig::MAX_NUM_REQUESTS * BatchConfig::MAX_NUM_TOKENS,
cudaMemcpyDeviceToHost);
cudaMemcpy(kv_indptr_tmp,
m->handle.tree_verify_attention_metadata->kv_indptr,
sizeof(int32_t) * (BatchConfig::MAX_NUM_REQUESTS + 1),
cudaMemcpyDeviceToHost);
cudaMemcpy(kv_last_page_len_tmp,
m->handle.tree_verify_attention_metadata->kv_last_page_len,
sizeof(int32_t) * BatchConfig::MAX_NUM_REQUESTS,
cudaMemcpyDeviceToHost);
paged_kv_t<PageStorage::kIndices, QKVLayout::kNHD, half, int32_t> paged_kv(
num_kv_heads,
kPagesize,
Expand Down
Loading
Loading