diff --git a/config/config.linux b/config/config.linux index 873f747835..15e9c88214 100755 --- a/config/config.linux +++ b/config/config.linux @@ -111,7 +111,7 @@ function get_build_configs() { BUILD_CONFIGS="FF_CUDA_ARCH=${FF_CUDA_ARCH} FF_HIP_ARCH=${FF_HIP_ARCH} CUDA_DIR=${CUDA_DIR} CUDNN_DIR=${CUDNN_DIR} CUBLAS_DIR=${CUBLAS_DIR} CURAND_DIR=${CURAND_DIR} NCCL_DIR=${NCCL_DIR} FF_USE_PYTHON=${FF_USE_PYTHON} BUILD_LEGION_ONLY=${BUILD_LEGION_ONLY} FF_GASNET_CONDUIT=${FF_GASNET_CONDUIT} UCX_DIR=${UCX_DIR} FF_LEGION_NETWORKS=${FF_LEGION_NETWORKS} FF_BUILD_ALL_EXAMPLES=${FF_BUILD_ALL_EXAMPLES} FF_BUILD_ALL_INFERENCE_EXAMPLES=${FF_BUILD_ALL_INFERENCE_EXAMPLES} FF_BUILD_UNIT_TESTS=${FF_BUILD_UNIT_TESTS} FF_USE_PREBUILT_NCCL=${FF_USE_PREBUILT_NCCL} FF_USE_PREBUILT_LEGION=${FF_USE_PREBUILT_LEGION} FF_USE_ALL_PREBUILT_LIBRARIES=${FF_USE_ALL_PREBUILT_LIBRARIES} FF_USE_AVX2=${FF_USE_AVX2} FF_MAX_DIM=${FF_MAX_DIM} ROCM_PATH=${ROCM_PATH} FF_GPU_BACKEND=${FF_GPU_BACKEND} INSTALL_DIR=${INSTALL_DIR}" } -patch -p0 $(dirname $0)/../deps/raft/cpp/include/raft/matrix/detail/select_radix.cuh $(dirname $0)/../config/raft.patch +patch -p0 --batch $(dirname $0)/../deps/raft/cpp/include/raft/matrix/detail/select_radix.cuh $(dirname $0)/../config/raft.patch if [[ -n "$1" && ( "$1" == "CMAKE_FLAGS" || "$1" == "CUDA_PATH" ) ]]; then . $(dirname $0)/config.inc diff --git a/include/flexflow/attention_config.h b/include/flexflow/attention_config.h index 63b0112e71..7144b7ab3a 100644 --- a/include/flexflow/attention_config.h +++ b/include/flexflow/attention_config.h @@ -20,6 +20,11 @@ namespace FlexFlow { constexpr uint32_t kPagesize = 64; + +inline int round_up_pages(int const num_elements) { + return (num_elements + kPagesize - 1) / kPagesize; +} + #define DISPATCH_HEADDIM(head_dim, HEAD_DIM, ...) \ switch (head_dim) { \ case 64: { \ @@ -93,9 +98,8 @@ class AttentionMetaData { } size_t batch_size = BatchConfig::max_requests_per_batch(); size_t max_num_pages = - (BatchConfig::max_spec_tree_token_num() + - BatchConfig::max_sequence_length() + kPagesize - 1) / - kPagesize; + round_up_pages(BatchConfig::max_spec_tree_token_num() + + BatchConfig::max_sequence_length()); size_t indices_size = std::max( (batch_size + 1) * 4 + max_num_pages * batch_size, 1ul * 1024 * 1024); size_t custom_mask_size = BatchConfig::max_requests_per_batch() * @@ -132,9 +136,8 @@ class AttentionMetaData { "Insufficient memory size for attention metadata"); size_t batch_size = BatchConfig::max_requests_per_batch(); size_t max_num_pages = - (BatchConfig::max_spec_tree_token_num() + - BatchConfig::max_sequence_length() + kPagesize - 1) / - kPagesize; + round_up_pages(BatchConfig::max_spec_tree_token_num() + + BatchConfig::max_sequence_length()); size_t indices_size = std::max( (batch_size + 1) * 4 + max_num_pages * batch_size, 1ul * 1024 * 1024); size_t custom_mask_size = BatchConfig::max_requests_per_batch() * diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index fef5d0b736..d56f4e2455 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -27,6 +27,35 @@ class InferenceResult; using BatchConfigFuture = Legion::Future; using InferenceResultFuture = Legion::Future; +/* + * StreamingCacheInfo is a class that manages the streaming kv cache for + * attention operator (https://arxiv.org/abs/2309.17453), and we use it in the + * draft model. It maintains a fixed-content *sink* cache and a fixed-size + * *window* cache. The *sink* cache is the foremost part of the original kv + * cache, while the *window* cache is the backmost part of the original kv cache + * and is rolling updated. The information is per-request. Note that the + * position encoding of the q&k alters each iteration (relative position), so we + * store the *pre-pos-encoding* kv value in the cache. + */ +class StreamingCacheInfo { +public: + StreamingCacheInfo(); + StreamingCacheInfo(int sink_cache_size, int window_cache_size); + StreamingCacheInfo(StreamingCacheInfo const &other); + + StreamingCacheInfo &operator=(StreamingCacheInfo const &other); + + void commit_cache(int len); + void reset_cache(); + int global_2_cache_index(int global_index); + +public: + int sink_cache_size, window_cache_size; + // the meta info of the window cache, commit_len helps to determine if we fill + // up the window. + int window_back, commit_len; +}; + class BatchConfig { public: using RequestGuid = size_t; @@ -41,6 +70,7 @@ class BatchConfig { static int max_verify_tokens_per_batch(); static int max_spec_tree_token_num(); static int max_sequence_length(); + static int get_max_tree_depth(); friend std::ostream &operator<<(std::ostream &os, BatchConfig const &bc); void print() const; void save_to_file(std::string const &filename) const; @@ -50,7 +80,7 @@ class BatchConfig { // Maximum possible values for different parameters // These maximum values are used for copying BatchConfig // across workers - inline static int const MAX_NUM_REQUESTS = 64; + inline static int const MAX_NUM_REQUESTS = 8; inline static int const MAX_NUM_TOKENS = 1024; inline static int const MAX_SPEC_TREE_TOKEN_NUM = 128; inline static int const MAX_SPECULATIVE_TREE_BRANCHES = 4; @@ -58,6 +88,11 @@ class BatchConfig { inline static int const MAX_TREE_WIDTH = 64; inline static int const MAX_K_LOGITS = 16; + // The Constants for the Streaming KVCache + inline static int const SINK_SIZE = 4; + // size_SINK + size_WINDOW + depth_DRAFT shouldn't exceed this value + inline static int const MAX_STREAMING_POS = 2048; + int num_tokens = 0; int num_available_requests = 0; bool prompt_phase = false; @@ -69,6 +104,7 @@ class BatchConfig { int first_token_index_in_request = -1; int first_token_offset_in_batch = -1; int num_tokens_in_batch = 0; + int padding = 0; // Padding for memory pointer alignment }; struct PerTokenInfo { @@ -150,6 +186,7 @@ class BatchConfig { BitMask causalMask[MAX_NUM_REQUESTS]; PerRequestInfo requestsInfo[MAX_NUM_REQUESTS]; + StreamingCacheInfo streamingCacheInfo[MAX_NUM_REQUESTS]; PerTokenInfo tokensInfo[MAX_NUM_TOKENS]; CommittedTokensInfo committed_tokens[MAX_NUM_TOKENS]; bool request_available[MAX_NUM_REQUESTS]; diff --git a/include/flexflow/config.h b/include/flexflow/config.h index 0e15fc089e..48b0450b61 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -86,6 +86,7 @@ struct FFHandler { size_t batch_config_metadata_size = sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + sizeof(BatchConfig::request_available) + sizeof(BatchConfig::causalMask) + + sizeof(BatchConfig::streamingCacheInfo) + sizeof(BatchConfig::committed_tokens) + sizeof(int); void *offload_reserve_space; diff --git a/include/flexflow/flexflow_c.h b/include/flexflow/flexflow_c.h index ddf9c7e8a5..9bc2c69734 100644 --- a/include/flexflow/flexflow_c.h +++ b/include/flexflow/flexflow_c.h @@ -448,6 +448,7 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention( float scaling_factor, bool qk_prod_scaling, bool position_bias, + bool streaming_cache, char const *name); flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention( @@ -468,6 +469,7 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention( float scaling_factor, bool qk_prod_scaling, bool position_bias, + bool streaming_cache, char const *name); flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention_verify( @@ -509,6 +511,7 @@ flexflow_tensor_t flexflow_model_add_groupquery_self_attention( float scaling_factor, bool qk_prod_scaling, bool position_bias, + bool streaming_cache, char const *name); flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention( @@ -530,6 +533,7 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention( float scaling_factor, bool qk_prod_scaling, bool position_bias, + bool streaming_cache, char const *name); flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention_verify( diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 948feb3645..6618fdaf89 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -724,6 +724,7 @@ class FFModel { float scaling_factor = 1.0f, bool qk_prod_scaling = true, bool position_bias = false, + bool streaming_cache = false, char const *name = NULL); Tensor spec_inc_multihead_self_attention(Tensor const input, @@ -742,6 +743,7 @@ class FFModel { float scaling_factor = 1.0f, bool qk_prod_scaling = true, bool position_bias = false, + bool streaming_cache = false, char const *name = NULL); Tensor inc_multihead_self_attention_verify( Tensor const input, @@ -778,6 +780,7 @@ class FFModel { float scaling_factor = 1.0f, bool qk_prod_scaling = true, bool position_bias = false, + bool streaming_cache = false, char const *name = NULL); Tensor spec_inc_multiquery_self_attention(Tensor const input, @@ -797,6 +800,7 @@ class FFModel { float scaling_factor = 1.0f, bool qk_prod_scaling = true, bool position_bias = false, + bool streaming_cache = false, char const *name = NULL); Tensor inc_multiquery_self_attention_verify( Tensor const input, diff --git a/include/flexflow/ops/inc_multihead_self_attention.h b/include/flexflow/ops/inc_multihead_self_attention.h index 5a90dd61be..8db1c072d4 100644 --- a/include/flexflow/ops/inc_multihead_self_attention.h +++ b/include/flexflow/ops/inc_multihead_self_attention.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_INC_MULTIHEAD_SELF_ATTENTION_H #include "flexflow/accessor.h" +#include "flexflow/batch_config.h" #include "flexflow/device.h" #include "flexflow/fftype.h" #include "flexflow/inference.h" @@ -47,6 +48,7 @@ class IncMultiHeadSelfAttention : public Op { bool allocate_weights, DataType _quantization_type, bool _offload, + bool _streaming_cache, int _tensor_parallelism_degree, char const *name); IncMultiHeadSelfAttention(FFModel &model, @@ -69,6 +71,7 @@ class IncMultiHeadSelfAttention : public Op { bool allocate_weights, DataType _quantization_type, bool _offload, + bool _streaming_cache, int _tensor_parallelism_degree, char const *name); IncMultiHeadSelfAttention(FFModel &model, @@ -131,7 +134,7 @@ class IncMultiHeadSelfAttention : public Op { int hidden_size, qk_dim, v_dim, o_dim; int qoSeqLength, kvSeqLength; DataType quantization_type; - bool offload; + bool offload, streaming_cache; }; class IncMultiHeadSelfAttentionMeta : public OpMeta { @@ -165,7 +168,8 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta { int _num_q_heads, int _num_kv_heads, DataType _quantization_type, - bool _offload); + bool _offload, + bool _streaming_cache); ~IncMultiHeadSelfAttentionMeta(void); public: @@ -184,14 +188,20 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta { bool *position_bias; float scaling_factor; void *weight_ptr, *bias_ptr; // for weight offload - void *devQKVProjArray, *queryTmp, *kvCache; + void *devQKVProjArray, *queryTmp; half *outputTmp; - void *qk_prods, *qk_prods_softmax; + void *kvCache; + bool streaming_cache; + // When enable Streaming cache, we alter relative position each iteration, so + // we need below memory buffer for storing the pre-pos-encoding key value in + // sink and window. + void *streamingPrePosEncBuf; void *attn_heads; char *quantized_weight_ptr; BatchConfig::PerTokenInfo *token_infos; BatchConfig::PerRequestInfo *request_infos; bool *request_available; + StreamingCacheInfo *streaming_cache_infos; DataType quantization_type; bool offload; #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) diff --git a/include/flexflow/ops/inc_multihead_self_attention_params.h b/include/flexflow/ops/inc_multihead_self_attention_params.h index 58681069e2..7c259a0a92 100644 --- a/include/flexflow/ops/inc_multihead_self_attention_params.h +++ b/include/flexflow/ops/inc_multihead_self_attention_params.h @@ -15,7 +15,7 @@ struct IncMultiHeadSelfAttentionParams { bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, qk_prod_scaling, position_bias; DataType quantization_type; - bool offload; + bool offload, streaming_cache; char name[MAX_OPNAME]; bool is_valid(ParallelTensorShape const &) const; }; diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h index 29d2cd1dd3..8f69ad3805 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h @@ -44,6 +44,8 @@ void pre_build_weight(IncMultiHeadSelfAttentionMeta const *m, DataType data_type, ffStream_t stream); +// [For the tokens in batch] +// Compute qkv projection for the tokens in the batch. template void compute_qkv(IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, @@ -54,10 +56,56 @@ void compute_qkv(IncMultiHeadSelfAttentionMeta const *m, DT const *bias_ptr, ffStream_t stream); +// [For the tokens in batch] +// Apply position embedding for qk. +// Note that this is only used for tokens in the current batch. +// For other Key tokens like in streaming cache, we nned other kernel to apply +// the position embedding. template -void update_qkv_cache(IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - cudaStream_t stream); +void apply_pos_encoding_to_tokens_in_batch( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + DT *output_ptr, + cudaStream_t stream); + +// [For the tokens in streaming cache] +// Apply position embedding for k projection in the streaming cache. +// Note that before the position encoding, the projection is moved *in order* to +// the kv memory took by the attention kernel. So our operation is applied where +// kvCache points to. +template +void apply_pos_encoding_to_streaming_proj( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + cudaStream_t stream); + +// [For the tokens in batch] +// Update the kv cache, and compact the q array. +// Source: qkv projeciton array of tokens in the batch. +// Destination: q&kv ptr took by the attention kernel. +// Note that the q&k here are the value after applying with position encoding. +template +void update_qkv_in_batch(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + cudaStream_t stream); + +// [For the tokens in streaming cache] +// Convert the out-of-order cache to in-order relative position. +// Source: pre-pos-encoding kv values in the streaming cache. +// Destination: kv ptr took by the attention kernel. +template +void update_kv_in_streaming_cache(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + cudaStream_t stream); + +// [For the tokens in batch] +// Commit the kv values to the streaming cache. +// Source: qkv projeciton array of tokens in the batch. +// Destination: pre-pos-encoding kv values in the streaming cache. +template +void commit_kv(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + cudaStream_t stream); template void produce_output(IncMultiHeadSelfAttentionMeta const *m, diff --git a/include/flexflow/ops/spec_inc_multihead_self_attention.h b/include/flexflow/ops/spec_inc_multihead_self_attention.h index 617263a051..b08e161c5e 100644 --- a/include/flexflow/ops/spec_inc_multihead_self_attention.h +++ b/include/flexflow/ops/spec_inc_multihead_self_attention.h @@ -42,6 +42,7 @@ class SpecIncMultiHeadSelfAttention : public Op { bool _qk_prod_scaling, bool _position_bias, bool allocate_weights, + bool _streaming_cache, char const *name); SpecIncMultiHeadSelfAttention(FFModel &model, ParallelTensor const _input, @@ -61,6 +62,7 @@ class SpecIncMultiHeadSelfAttention : public Op { bool _qk_prod_scaling, bool _position_bias, bool allocate_weights, + bool _streaming_cache, char const *name); SpecIncMultiHeadSelfAttention(FFModel &model, SpecIncMultiHeadSelfAttention const &other, @@ -124,6 +126,7 @@ class SpecIncMultiHeadSelfAttention : public Op { qk_prod_scaling, position_bias; int hidden_size, qk_dim, v_dim, o_dim; int qoSeqLength, kvSeqLength; + bool streaming_cache; }; class SpecIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta { diff --git a/include/flexflow/ops/spec_inc_multihead_self_attention_params.h b/include/flexflow/ops/spec_inc_multihead_self_attention_params.h index 1461224ba9..2def2a51cb 100644 --- a/include/flexflow/ops/spec_inc_multihead_self_attention_params.h +++ b/include/flexflow/ops/spec_inc_multihead_self_attention_params.h @@ -13,6 +13,7 @@ struct SpecIncMultiHeadSelfAttentionParams { float dropout, scaling_factor; bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, qk_prod_scaling, position_bias; + bool streaming_cache; char name[MAX_OPNAME]; bool is_valid(ParallelTensorShape const &) const; }; diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index 5c6f6b6e08..ac951e6262 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -68,6 +68,8 @@ struct Request { int batch_index = -1; int ssm_cache_size = 0; int llm_cache_size = 0; + int ssm_prefill_len = 0; + int llm_prefill_len = 0; int first_token_offset_in_batch = 0; int num_tokens_in_batch = 0; @@ -120,6 +122,20 @@ struct Request { : from_index(from_index), to_index(to_index), token_id(token_id) {} }; std::vector committed_tokens; + + // Enabling Streaming KVCache means we doesn't store the whole KV sequence of + // the tokens in a request. Instead, we only store the sink cache (a few + // foremost tokens) and the window cache (rolling-updated backmost tokens + // through decoding). Currently, we only use streaming cache in the *draft + // model* calculation. + // - Maintain the streaming cache: During inference, we + // first fill up the sink cache then the window cache. After the window cache + // is full, we move back to the beginning of the window cache and commit the + // tokens in replace there. + // - When to update the streaming cache: + // 1. Prefilling phase + // 2. Committing phase after the target model verification + StreamingCacheInfo streaming_cache_info; }; class TokenTreeNode { @@ -244,6 +260,7 @@ class RequestManager { int get_max_tree_width(); void set_max_tree_width(int max_tree_width); void set_speculative_sampling(bool speculative_sampling); + void set_streaming_cache(bool streaming_cache); int register_ssm_model(FFModel *model); void register_tokenizer(ModelType model_type, int bos_token_id, @@ -318,6 +335,8 @@ class RequestManager { DecodingMode decoding_mode; PrefillModel prefill_model; bool speculative_sampling = false; + // specify if enable streaming cache for incremental decoding or draft model + bool streaming_cache = false; std::unique_ptr tokenizer_; bool verbose; diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index adf51aa309..83f2ba6327 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -48,7 +48,8 @@ void parse_input_args(char **argv, int &max_requests_per_batch, int &max_tokens_per_batch, int &max_sequence_length, - int &sampling_seed) { + int &sampling_seed, + bool &streaming_cache) { for (int i = 1; i < argc; i++) { // llm model type if (!strcmp(argv[i], "-llm-model")) { @@ -110,6 +111,10 @@ void parse_input_args(char **argv, sampling_seed = std::stoi(argv[++i]); continue; } + if (!strcmp(argv[i], "--enable-streaming-cache")) { + streaming_cache = true; + continue; + } } if (paths.cache_folder_path.empty()) { char const *ff_cache_path = std::getenv("FF_CACHE_PATH"); @@ -144,6 +149,7 @@ void FlexFlow::top_level_task(Task const *task, RequestManager::DecodingMode decoding_mode = RequestManager::INCREMENTAL_DECODING; int sampling_seed = 0; + bool streaming_cache = false; InputArgs const &command_args = HighLevelRuntime::get_input_args(); char **argv = command_args.argv; @@ -160,7 +166,8 @@ void FlexFlow::top_level_task(Task const *task, max_requests_per_batch, max_tokens_per_batch, max_sequence_length, - sampling_seed); + sampling_seed, + streaming_cache); assert(ffconfig.data_parallelism_degree * ffconfig.tensor_parallelism_degree * ffconfig.pipeline_parallelism_degree == @@ -226,6 +233,7 @@ void FlexFlow::top_level_task(Task const *task, rm->set_max_tree_depth(8); rm->set_max_tree_width(16); rm->set_verbose(verbose); + rm->set_streaming_cache(streaming_cache); rm->register_tokenizer( model_type, bos_token_id, eos_token_id, tokenizer_filepath); rm->register_output_filepath(file_paths.output_file_path); @@ -237,6 +245,7 @@ void FlexFlow::top_level_task(Task const *task, weights_filepath, INC_DECODING_MODE, generationConfig, + streaming_cache, use_full_precision); } else if (model_type == ModelType::OPT) { OPT::create_opt_model(model, diff --git a/inference/models/falcon.cc b/inference/models/falcon.cc index a9805bf8ec..96e85177cc 100644 --- a/inference/models/falcon.cc +++ b/inference/models/falcon.cc @@ -116,6 +116,7 @@ void FALCON::create_falcon_model(FFModel &ff, 1.0f, /*scaling factor*/ true, /*qk_prod_scaling*/ false, /*position_bias*/ + false, /*streaming_cache*/ std::string("layers_" + std::to_string(i) + "_attention") .c_str() /*name*/ ); @@ -166,6 +167,7 @@ void FALCON::create_falcon_model(FFModel &ff, 1.0f, /*scaling factor*/ true, /*qk_prod_scaling*/ false, /*position_bias*/ + false, /*streaming_cache*/ std::string("layers_" + std::to_string(i) + "_attention") .c_str() /*name*/ ); diff --git a/inference/models/llama.cc b/inference/models/llama.cc index a1f4d370f3..16dc2441ff 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -25,6 +25,7 @@ void LLAMA::create_llama_model(FFModel &ff, std::string const &weight_file_path, InferenceMode mode, GenerationConfig generation_config, + bool streaming_cache, bool use_full_precision) { // do not apply cpu offload in beam search model. LLAMAConfig llama_config(model_config_file_path); @@ -112,6 +113,7 @@ void LLAMA::create_llama_model(FFModel &ff, 1.0f, /*scaling factor*/ true, /*qk_prod_scaling*/ false, /*position_bias*/ + streaming_cache, std::string("layers_" + std::to_string(i) + "_attention") .c_str() /*name*/ ); @@ -149,17 +151,18 @@ void LLAMA::create_llama_model(FFModel &ff, llama_config.num_key_value_heads, llama_config.hidden_size / llama_config.num_attention_heads, llama_config.hidden_size / llama_config.num_attention_heads, - 0.0f, /*dropout*/ - false, /*qkv_bias*/ - false, /*final_bias*/ - false, /*add_zero_attn*/ - DT_NONE, /*data_type*/ - nullptr, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ - false, /*scaling query*/ - 1.0f, /*scaling factor*/ - true, /*qk_prod_scaling*/ - false, /*position_bias*/ + 0.0f, /*dropout*/ + false, /*qkv_bias*/ + false, /*final_bias*/ + false, /*add_zero_attn*/ + DT_NONE, /*data_type*/ + nullptr, /*kernel_initializer*/ + true, /*apply_rotary_embedding*/ + false, /*scaling query*/ + 1.0f, /*scaling factor*/ + true, /*qk_prod_scaling*/ + false, /*position_bias*/ + streaming_cache, /*streaming_cache*/ std::string("layers_" + std::to_string(i) + "_attention") .c_str() /*name*/ ); diff --git a/inference/models/llama.h b/inference/models/llama.h index 1a6a9114e6..a5b2c4a401 100644 --- a/inference/models/llama.h +++ b/inference/models/llama.h @@ -86,6 +86,7 @@ class LLAMA { std::string const &weight_file_path, InferenceMode mode, GenerationConfig generation_config, + bool streaming_cache, bool use_full_precision = false); }; diff --git a/inference/models/starcoder.cc b/inference/models/starcoder.cc index 8251ef71c8..f531fe9884 100644 --- a/inference/models/starcoder.cc +++ b/inference/models/starcoder.cc @@ -124,6 +124,7 @@ void STARCODER::create_starcoder_model( 1.0f, /*scaling factor*/ true, /*qk_prod_scaling*/ false, /*position_bias*/ + false, /*streaming_cache*/ std::string("layers_" + std::to_string(i) + "_attention") .c_str() /*name*/ ); diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 1cdb2e8e94..cc48d9c866 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -69,7 +69,8 @@ void parse_input_args(char **argv, int &expansion_degree, bool &spec_sampling, bool &do_sample, - int &sampling_seed) { + int &sampling_seed, + bool &streaming_cache) { for (int i = 1; i < argc; i++) { // llm model name if (!strcmp(argv[i], "-llm-model")) { @@ -153,6 +154,10 @@ void parse_input_args(char **argv, do_sample = true; continue; } + if (!strcmp(argv[i], "--enable-streaming-cache")) { + streaming_cache = true; + continue; + } } if (paths.cache_folder_path.empty()) { char const *ff_cache_path = std::getenv("FF_CACHE_PATH"); @@ -317,6 +322,7 @@ void FlexFlow::top_level_task(Task const *task, bool spec_sampling = false; bool do_sample = false; int sampling_seed = 0; + bool streaming_cache = false; InputArgs const &command_args = HighLevelRuntime::get_input_args(); char **argv = command_args.argv; @@ -336,7 +342,8 @@ void FlexFlow::top_level_task(Task const *task, expansion_degree, spec_sampling, do_sample, - sampling_seed); + sampling_seed, + streaming_cache); get_model_meta(file_paths, model_metadata, use_full_precision); @@ -356,6 +363,7 @@ void FlexFlow::top_level_task(Task const *task, rm->set_max_tree_depth(max_tree_depth); rm->set_max_tree_width(max_tree_width); rm->set_verbose(verbose); + rm->set_streaming_cache(streaming_cache); rm->register_tokenizer(model_metadata.llm_model_type, model_metadata.bos_token_id, model_metadata.eos_token_id, @@ -371,6 +379,7 @@ void FlexFlow::top_level_task(Task const *task, model_metadata.llm_weights_path, TREE_VERIFY_MODE, generationConfig, + false, use_full_precision); } else if (model_metadata.llm_model_type == ModelType::OPT) { OPT::create_opt_model(tree_model, @@ -418,6 +427,7 @@ void FlexFlow::top_level_task(Task const *task, model_metadata.ssm_model_weights_paths[ssm_id], TREE_SEARCH_MODE, generationConfig, + streaming_cache, use_full_precision); } else if (model_metadata.ssm_model_types[ssm_id] == ModelType::OPT) { OPT::create_opt_model(beam_model, diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index 403f2cba52..dcdda6698f 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -2784,6 +2784,7 @@ def spec_inc_multihead_self_attention( scaling_factor=1.0, qk_prod_scaling=True, position_bias=False, + streaming_cache=False, name=None, ): """Defines the MultiHead Attention operation as described in Attention Is All You Need @@ -2864,6 +2865,7 @@ def spec_inc_multihead_self_attention( scaling_factor, qk_prod_scaling, position_bias, + streaming_cache, c_name, ) self.add_layer(OpType.SPEC_INC_MULTIHEAD_SELF_ATTENTION, name) @@ -2991,6 +2993,7 @@ def groupquery_self_attention( scaling_factor=1.0, qk_prod_scaling=True, position_bias=False, + streaming_cache=False, name=None, ): """Defines the multi-query head attention, which allows a different number of Q and KV heads, @@ -3075,6 +3078,7 @@ def groupquery_self_attention( scaling_factor, qk_prod_scaling, position_bias, + streaming_cache, c_name, ) self.add_layer(OpType.INC_MULTIHEAD_ATTENTION, name) diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index d086d6d16a..a398b54ca9 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -1201,6 +1201,7 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention( float scaling_factor, bool qk_prod_scaling, bool position_bias, + bool streaming_cache, char const *name) { FFModel *handle = FFCObjectWrapper::unwrap(handle_); Tensor input = FFCObjectWrapper::unwrap(input_); @@ -1222,6 +1223,7 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention( scaling_factor, qk_prod_scaling, position_bias, + streaming_cache, name); return FFCObjectWrapper::wrap(tensor); } @@ -1244,6 +1246,7 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention( float scaling_factor, bool qk_prod_scaling, bool position_bias, + bool streaming_cache, char const *name) { FFModel *handle = FFCObjectWrapper::unwrap(handle_); Tensor input = FFCObjectWrapper::unwrap(input_); @@ -1266,6 +1269,7 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention( scaling_factor, qk_prod_scaling, position_bias, + streaming_cache, name); return FFCObjectWrapper::wrap(tensor); } @@ -1333,6 +1337,7 @@ flexflow_tensor_t flexflow_model_add_groupquery_self_attention( float scaling_factor, bool qk_prod_scaling, bool position_bias, + bool streaming_cache, char const *name) { FFModel *handle = FFCObjectWrapper::unwrap(handle_); Tensor input = FFCObjectWrapper::unwrap(input_); @@ -1355,6 +1360,7 @@ flexflow_tensor_t flexflow_model_add_groupquery_self_attention( scaling_factor, qk_prod_scaling, position_bias, + streaming_cache, name); return FFCObjectWrapper::wrap(tensor); } @@ -1378,6 +1384,7 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention( float scaling_factor, bool qk_prod_scaling, bool position_bias, + bool streaming_cache, char const *name) { FFModel *handle = FFCObjectWrapper::unwrap(handle_); Tensor input = FFCObjectWrapper::unwrap(input_); @@ -1401,6 +1408,7 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention( scaling_factor, qk_prod_scaling, position_bias, + streaming_cache, name); return FFCObjectWrapper::wrap(tensor); } diff --git a/src/ops/inc_multihead_self_attention.cc b/src/ops/inc_multihead_self_attention.cc index c35a07a4ec..54d71ea0b8 100644 --- a/src/ops/inc_multihead_self_attention.cc +++ b/src/ops/inc_multihead_self_attention.cc @@ -70,6 +70,7 @@ Tensor FFModel::inc_multihead_self_attention(const Tensor input, float scaling_factor, bool qk_prod_scaling, bool position_bias, + bool streaming_cache, char const *name) { return groupquery_self_attention(input, embed_dim, @@ -88,6 +89,7 @@ Tensor FFModel::inc_multihead_self_attention(const Tensor input, scaling_factor, qk_prod_scaling, position_bias, + streaming_cache, name); } @@ -108,6 +110,7 @@ Tensor FFModel::groupquery_self_attention(const Tensor input, float scaling_factor, bool qk_prod_scaling, bool position_bias, + bool streaming_cache, char const *name) { if (data_type == DT_NONE) { data_type = input->data_type; @@ -204,6 +207,7 @@ Tensor FFModel::groupquery_self_attention(const Tensor input, li->add_int_property("position_bias", position_bias); li->add_int_property("quantization_type", quantization_type); li->add_int_property("offload", offload); + li->add_int_property("streaming_cache", streaming_cache); li->add_int_property("tensor_parallelism_degree", config.tensor_parallelism_degree); layers.push_back(li); @@ -249,6 +253,8 @@ Op *IncMultiHeadSelfAttention::create_operator_from_layer( DataType quantization_type = (DataType)value; layer->get_int_property("offload", value); bool offload = (bool)value; + layer->get_int_property("streaming_cache", value); + bool streaming_cache = (bool)value; layer->get_int_property("tensor_parallelism_degree", value); int tensor_parallelism_degree = (int)value; @@ -272,6 +278,7 @@ Op *IncMultiHeadSelfAttention::create_operator_from_layer( false /*allocate_weights*/, quantization_type, offload, + streaming_cache, tensor_parallelism_degree, layer->name); } @@ -297,6 +304,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( bool allocate_weights, DataType _quantization_type, bool _offload, + bool _streaming_cache, int _tensor_parallelism_degree, char const *name) // Initializer* _bias_initializer) @@ -317,7 +325,8 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( kvSeqLength(_input->dims[1].size), scaling_query(_scaling_query), scaling_factor(_scaling_factor), qk_prod_scaling(_qk_prod_scaling), position_bias(_position_bias), quantization_type(_quantization_type), - offload(_offload), tensor_parallelism_degree(_tensor_parallelism_degree) { + offload(_offload), streaming_cache(_streaming_cache), + tensor_parallelism_degree(_tensor_parallelism_degree) { // overwrite layer_guid layer_guid = _layer_guid; numOutputs = 1; @@ -408,6 +417,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( bool allocate_weights, DataType _quantization_type, bool _offload, + bool _streaming_cache, int _tensor_parallelism_degree, char const *name) // Initializer* _bias_initializer) @@ -429,7 +439,8 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( kvSeqLength(_input->dims[1].size), scaling_query(_scaling_query), scaling_factor(_scaling_factor), qk_prod_scaling(_qk_prod_scaling), position_bias(_position_bias), quantization_type(_quantization_type), - offload(_offload), tensor_parallelism_degree(_tensor_parallelism_degree) + offload(_offload), streaming_cache(_streaming_cache), + tensor_parallelism_degree(_tensor_parallelism_degree) // bias_initializer(_bias_initializer) { numOutputs = 1; @@ -526,6 +537,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( allocate_weights, other.quantization_type, other.offload, + other.streaming_cache, other.tensor_parallelism_degree, other.name) {} @@ -555,6 +567,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( allocate_weights, params.quantization_type, params.offload, + params.streaming_cache, params.tensor_parallelism_degree, params.name) {} @@ -897,7 +910,8 @@ bool operator==(IncMultiHeadSelfAttentionParams const &lhs, lhs.scaling_query == rhs.scaling_query && lhs.scaling_factor == rhs.scaling_factor && lhs.qk_prod_scaling == rhs.qk_prod_scaling && - lhs.position_bias == rhs.position_bias; + lhs.position_bias == rhs.position_bias && + lhs.streaming_cache == rhs.streaming_cache; } IncMultiHeadSelfAttentionParams IncMultiHeadSelfAttention::get_params() const { @@ -919,6 +933,7 @@ IncMultiHeadSelfAttentionParams IncMultiHeadSelfAttention::get_params() const { params.tensor_parallelism_degree = this->tensor_parallelism_degree, params.quantization_type = this->quantization_type; params.offload = this->offload; + params.streaming_cache = this->streaming_cache; params.num_kv_heads = this->num_kv_heads; if (this->name != nullptr) { strcpy(params.name, this->name); @@ -950,6 +965,7 @@ size_t hash::operator()( hash_combine(key, params.position_bias); hash_combine(key, params.quantization_type); hash_combine(key, params.offload); + hash_combine(key, params.streaming_cache); hash_combine(key, params.tensor_parallelism_degree); return key; } diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 83ff630a61..81e4ec3f78 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -246,7 +246,7 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, bias_ptr = static_cast
(m->bias_ptr); } - // phase 1: Implement kernel to compute KQV for input tokens + // phase 1: Compute QKV Projections of the batch compute_qkv(m, bc, shard_id, @@ -255,52 +255,31 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, static_cast
(m->devQKVProjArray), bias_ptr, stream); - // phase 2: Update key/val cache - update_qkv_cache
(m, bc, stream); - // cudaEventRecord(t_end, stream); - // checkCUDA(cudaEventSynchronize(t_end)); - // float elapsed = 0; - // checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); - // cudaEventDestroy(t_start); - // cudaEventDestroy(t_end); - // std::cout << "Prepare attn time: " << elapsed << " ms\n"; + // phase 2: First maintain the streaming cache, because it need + // pre-pos-encoding values + if (m->streaming_cache) { + // Move pre-pos-encoding cache to where took by attention + update_kv_in_streaming_cache
(m, bc, stream); + // Apply pos-encoding to those k values + apply_pos_encoding_to_streaming_proj
(m, bc, stream); + // Commit to the streaming cache + commit_kv
(m, bc, stream); + } - // cudaEventCreate(&t_start); - // cudaEventCreate(&t_end); - // cudaEventRecord(t_start, stream); + // phase 3: Take care of the batch + { + // Apply pos-encoding to the batch + apply_pos_encoding_to_tokens_in_batch( + m, bc, static_cast
(m->devQKVProjArray), stream); + // Move the batch qkv values to where took by attention + update_qkv_in_batch
(m, bc, stream); + } - // phase 3: Compute attention score + // phase 4: Attention computation incr_attention
(m, bc, static_cast
(m->attn_heads), stream); - // cudaEventRecord(t_end, stream); - // checkCUDA(cudaEventSynchronize(t_end)); - // elapsed = 0; - // checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); - // cudaEventDestroy(t_start); - // cudaEventDestroy(t_end); - // std::cout << "Attn time: " << elapsed << " ms\n"; - - // Debug output: - // int size = m->local_hidden_size * BatchConfig::max_tokens_per_batch(); - // float *temp_output = new float[size]; - // cudaDeviceSynchronize(); - // cudaMemcpy( - // temp_output, m->attn_heads, size * sizeof(float), - // cudaMemcpyDeviceToHost); - // printf("Output: "); - // float temp = 0; - // for (int i = 0; i < 1; ++i) { - // for (int j = 0; j < m->local_hidden_size; ++j) { - // temp += temp_output[i * m->local_hidden_size + j]; - // } - // printf("%.6f ", temp); - // } - // printf("\n"); - - // delete[] temp_output; - - // compute output production and bias together for all tokens + // phase 5: Compute output production and bias together for all tokens int num_tokens = bc->num_active_tokens(); compute_o_prod_bias( m, bc, shard_id, output_ptr, weight_ptr, bias_ptr, num_tokens, stream); @@ -409,7 +388,8 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( _num_q_heads, _num_kv_heads, attn->quantization_type, - attn->offload) {} + attn->offload, + attn->streaming_cache) {} IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( FFHandler handler, @@ -434,7 +414,8 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( int _num_q_heads, int _num_kv_heads, DataType _quantization_type, - bool _offload) + bool _offload, + bool _streaming_cache) : OpMeta(handler, attn), weight_ptr(nullptr), bias_ptr(nullptr) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); @@ -447,6 +428,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( size_t size_of_dt = data_type_size(attn->data_type); quantization_type = _quantization_type; offload = _offload; + streaming_cache = _streaming_cache; global_num_q_heads = _global_num_q_heads; global_num_kv_heads = _global_num_kv_heads; @@ -498,16 +480,15 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( size_t qkv_max_proj_size = max_tokens_per_batch * (qk_dim * num_q_heads + qk_dim * num_q_heads + v_dim * num_q_heads); - size_t query_tmp_size = 0, key_cache_size = 0, value_cache_size = 0, - qk_prod_size = 0; + size_t query_tmp_size = 0, key_cache_size = 0, value_cache_size = 0; + size_t streaming_pre_pos_enc_size = 0; // assert((BatchConfig::max_sequence_length() + // BatchConfig::max_spec_tree_token_num()) % // kPagesize == // 0); size_t max_num_pages = - (BatchConfig::max_sequence_length() + - BatchConfig::max_spec_tree_token_num() + kPagesize - 1) / - kPagesize; + round_up_pages(BatchConfig::max_sequence_length() + + BatchConfig::max_spec_tree_token_num()); switch (infer_mode) { case INC_DECODING_MODE: case TREE_SEARCH_MODE: @@ -515,14 +496,31 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( query_tmp_size = num_q_heads * qk_dim * BatchConfig::max_tokens_per_batch(); // a K-ary tree max node is (k^n - 1) / 2 - key_cache_size = num_q_heads * qk_dim * + key_cache_size = num_kv_heads * qk_dim * BatchConfig::max_requests_per_batch() * max_num_pages * kPagesize; - value_cache_size = num_q_heads * v_dim * + value_cache_size = num_kv_heads * v_dim * BatchConfig::max_requests_per_batch() * max_num_pages * kPagesize; - qk_prod_size = BatchConfig::max_sequence_length() * max_num_pages * - kPagesize * num_q_heads; + if (streaming_cache) { + size_t max_post_pos_enc_pages = + round_up_pages(BatchConfig::MAX_STREAMING_POS - + BatchConfig::get_max_tree_depth() + + max(BatchConfig::max_tokens_per_batch(), + BatchConfig::max_spec_tree_token_num())); + key_cache_size = num_kv_heads * qk_dim * + BatchConfig::max_requests_per_batch() * + max_post_pos_enc_pages * kPagesize; + value_cache_size = num_kv_heads * v_dim * + BatchConfig::max_requests_per_batch() * + max_post_pos_enc_pages * kPagesize; + streaming_pre_pos_enc_size = + num_kv_heads * (qk_dim + v_dim) * + BatchConfig::max_requests_per_batch() * + round_up_pages(BatchConfig::MAX_STREAMING_POS - + BatchConfig::get_max_tree_depth()) * + kPagesize; + } break; } default: @@ -535,7 +533,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( 2; size_t totalSize = (qkv_max_proj_size + query_tmp_size + key_cache_size + - value_cache_size + 2 * qk_prod_size + attn_heads_size) * + value_cache_size + streaming_pre_pos_enc_size + attn_heads_size) * size_of_dt + output_tmp_size * data_type_size(DT_HALF) + complex_size * sizeof(cuFloatComplex); // more components will @@ -544,19 +542,21 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( // assert that we have enough reserved work space left size_t totalSharedSize = infer_mode == TREE_VERIFY_MODE - ? totalSize - (query_tmp_size + key_cache_size + - value_cache_size + qkv_max_proj_size) * - size_of_dt - : totalSize - - (query_tmp_size + key_cache_size + value_cache_size) * - size_of_dt; + ? totalSize - + (query_tmp_size + key_cache_size + value_cache_size + + streaming_pre_pos_enc_size + qkv_max_proj_size) * + size_of_dt + : totalSize - (query_tmp_size + key_cache_size + + value_cache_size + streaming_pre_pos_enc_size) * + size_of_dt; size_t instance_size = size_of_dt * (infer_mode == TREE_VERIFY_MODE ? query_tmp_size + key_cache_size + value_cache_size + - qkv_max_proj_size - : query_tmp_size + key_cache_size + value_cache_size); + streaming_pre_pos_enc_size + qkv_max_proj_size + : query_tmp_size + key_cache_size + value_cache_size + + streaming_pre_pos_enc_size); if (quantization_type != DT_NONE) { totalSharedSize += quantized_weightSize; @@ -585,6 +585,10 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( } kvCache = gpu_mem_allocator.allocate_instance_untyped( (key_cache_size + value_cache_size) * size_of_dt); + if (streaming_pre_pos_enc_size > 0) { + streamingPrePosEncBuf = gpu_mem_allocator.allocate_instance_untyped( + streaming_pre_pos_enc_size * size_of_dt); + } outputTmp = gpu_mem_allocator.allocate_instance(output_tmp_size); token_infos = @@ -595,18 +599,17 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( request_available = reinterpret_cast( reinterpret_cast(handler.batch_config_metadata) + sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo)); + streaming_cache_infos = reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + + sizeof(BatchConfig::request_available) + + sizeof(BatchConfig::causalMask)); if (offload) { // token_infos = // gpu_mem_allocator.allocate_reserved( // tokeninfo_size); // offset += sizeof(BatchConfig::PerTokenInfo) * tokeninfo_size; - qk_prods = gpu_mem_allocator.allocate_reserved_untyped(qk_prod_size * - size_of_dt); - // offset += qk_prod_size * size_of_dt; - qk_prods_softmax = gpu_mem_allocator.allocate_reserved_untyped( - qk_prod_size * size_of_dt); - // offset += qk_prod_size * size_of_dt; attn_heads = gpu_mem_allocator.allocate_reserved_untyped(attn_heads_size * size_of_dt); // offset += attn_heads_size * size_of_dt; @@ -620,10 +623,6 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( // token_infos = // gpu_mem_allocator.allocate_instance( // tokeninfo_size); - qk_prods = gpu_mem_allocator.allocate_instance_untyped(qk_prod_size * - size_of_dt); - qk_prods_softmax = gpu_mem_allocator.allocate_instance_untyped( - qk_prod_size * size_of_dt); attn_heads = gpu_mem_allocator.allocate_instance_untyped(attn_heads_size * size_of_dt); complex_input = diff --git a/src/ops/kernels/inc_multihead_self_attention_kernels.cu b/src/ops/kernels/inc_multihead_self_attention_kernels.cu index 57a02e6f8f..e65f2c0609 100644 --- a/src/ops/kernels/inc_multihead_self_attention_kernels.cu +++ b/src/ops/kernels/inc_multihead_self_attention_kernels.cu @@ -12,9 +12,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "flexflow/batch_config.h" +#include #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) #include "cuComplex.h" #endif +#include "flashinfer/pos_enc.cuh" +#include "flexflow/attention_config.h" #include "flexflow/ffconst_utils.h" #include "flexflow/ops/inc_multihead_self_attention.h" #include "flexflow/ops/kernels/decompress_kernels.h" @@ -28,6 +32,9 @@ namespace FlexFlow { using Legion::coord_t; using Legion::Memory; +using flashinfer::BatchQKApplyLlama31Rotary; +using flashinfer::BatchQKApplyRotary; + #define WARP_SIZE 32 namespace Kernels { @@ -175,59 +182,6 @@ __global__ void } } -template -__global__ void - apply_rotary_embedding_hf(DT *input_ptr, - cuFloatComplex *complex_input, - BatchConfig::PerTokenInfo const *tokenInfos, - int qk_dim, - int num_tokens, - size_t q_array_size, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { - // create complex number - bool q_tensor = i < (q_array_size / 2); - int proj_size = q_tensor ? qk_dim : qk_dim; - int real_i = q_tensor ? i : i - q_array_size / 2; - - int token_idx = real_i / (hidden_size / 2); - int idx = real_i % (proj_size / 2); - int head_idx = (real_i - (token_idx * (hidden_size / 2))) / (proj_size / 2); - - int real_part_index = idx + head_idx * proj_size + - token_idx * hidden_size * QKV_WEIGHT_NUM + - hidden_size * (q_tensor ? 0 : 1); - int complex_part_index = real_part_index + (proj_size / 2); - - // complex_input[i] = {input_ptr[real_part_index], - // input_ptr[complex_part_index]}; - cuFloatComplex cii = {input_ptr[real_part_index], - input_ptr[complex_part_index]}; - - // get the freq_cis: shape 1 * (qk_dim/2) = 1 * 64 - // apply a Cartesian coordinate transformation - // multiple with input & /copy back to q/k - - // get position of token - - // size_t pos = id_map[token_idx].token_position; - size_t pos = tokenInfos[token_idx].abs_depth_in_request; - - // float before_real = complex_input[i].x, before_complex = - int pos_i = real_i % (proj_size / 2); - float freq = pos * (1.0 / pow(10000.0, (float)2 * pos_i / proj_size)); - cuFloatComplex complex_pos = {cos(freq), sin(freq)}; - - // complex_input[i] = cuCmulf(complex_input[i], complex_pos); - // input_ptr[real_part_index] = complex_input[i].x; - // input_ptr[complex_part_index] = complex_input[i].y; - - cii = cuCmulf(cii, complex_pos); - input_ptr[real_part_index] = cii.x; - input_ptr[complex_part_index] = cii.y; - } -} - template void compute_qkv(IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, @@ -311,8 +265,10 @@ void compute_qkv(IncMultiHeadSelfAttentionMeta const *m, // } int num_tokens = bc->num_active_tokens(); + if (num_tokens == 0) { + return; + } int parallelism = m->qk_dim * num_tokens * m->num_q_heads; - size_t q_array_size = m->qk_dim * num_tokens * m->num_q_heads; // Step 2: apply bias for QKV, or scale the query if (*m->qkv_bias) { @@ -341,49 +297,174 @@ void compute_qkv(IncMultiHeadSelfAttentionMeta const *m, m->scaling_factor, m->local_hidden_size); } +} - // checkCUDA(cudaEventCreate(&t_start)); - // checkCUDA(cudaEventCreate(&t_end)); - // checkCUDA(cudaEventRecord(t_start, stream)); +template +__global__ void apply_pos_encoding_to_tokens_in_batch_kernel( + DT *input_ptr, + BatchConfig::PerTokenInfo const *tokenInfos, + int qk_dim, + int num_tokens, + size_t q_array_size, + int hidden_size) { + CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + // create complex number + bool q_tensor = i < (q_array_size / 2); + int proj_size = q_tensor ? qk_dim : qk_dim; + int real_i = q_tensor ? i : i - q_array_size / 2; + + int token_idx = real_i / (hidden_size / 2); + int idx = real_i % (proj_size / 2); + int head_idx = (real_i - (token_idx * (hidden_size / 2))) / (proj_size / 2); + + int real_part_index = idx + head_idx * proj_size + + token_idx * hidden_size * QKV_WEIGHT_NUM + + hidden_size * (q_tensor ? 0 : 1); + int complex_part_index = real_part_index + (proj_size / 2); - // Step 3: apply rotary embedding if needed - if (*m->apply_rotary_embedding) { - /*q&k*/ - parallelism = num_tokens * m->local_hidden_size; - apply_rotary_embedding_hf<<>>(output_ptr, - m->complex_input, - m->token_infos, - m->qk_dim, - num_tokens, - q_array_size, - m->local_hidden_size); + cuFloatComplex cii = {input_ptr[real_part_index], + input_ptr[complex_part_index]}; + + // get the freq_cis: shape 1 * (qk_dim/2) = 1 * 64 + // apply a Cartesian coordinate transformation + // multiple with input & /copy back to q/k + + // get position of token + + size_t pos = tokenInfos[token_idx].abs_depth_in_request; + + float freq = pos * (1.0 / pow(10000.0, (float)2 * idx / proj_size)); + cuFloatComplex complex_pos = {cos(freq), sin(freq)}; + + cii = cuCmulf(cii, complex_pos); + input_ptr[real_part_index] = cii.x; + input_ptr[complex_part_index] = cii.y; } - // checkCUDA(cudaEventRecord(t_end, stream)); - // checkCUDA(cudaEventSynchronize(t_end)); - // elapsed = 0; - // checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); - // cudaEventDestroy(t_start); - // cudaEventDestroy(t_end); - // if (bc->inference_mode == TREE_VERIFY_MODE and device == 0) { - // std::cout << "Rotary time: " << elapsed << " ms\n"; - // } +} + +template +void apply_pos_encoding_to_tokens_in_batch( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + DT *output_ptr, + cudaStream_t stream) { + // apply rotary embedding if needed + if (!*m->apply_rotary_embedding) { + return; + } + int num_tokens = bc->num_active_tokens(); + if (num_tokens == 0) { + return; + } + int parallelism = num_tokens * m->local_hidden_size; + size_t q_array_size = m->qk_dim * num_tokens * m->num_q_heads; + apply_pos_encoding_to_tokens_in_batch_kernel<<>>( + output_ptr, + m->token_infos, + m->qk_dim, + num_tokens, + q_array_size, + m->local_hidden_size); +} + +__global__ void apply_pos_encoding_to_streaming_proj_kernel( + half *kv_cache, + BatchConfig::PerRequestInfo const *requestInfos, + bool const *request_available, + int const max_num_pages, + int num_kv_heads, + int head_dim, + StreamingCacheInfo const *streaming_cache_infos, + uint32_t const max_num_requests) { + int const kv_hidden_size = num_kv_heads * head_dim; + int const thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int token_idx = thread_idx / (kv_hidden_size / 2); + // Each complex is consist of (i, i + head_dim / 2) wuthin the same head. + int const head_idx = (thread_idx % (kv_hidden_size / 2)) / (head_dim / 2); + int const offset_in_head = thread_idx % (head_dim / 2); + // Get the corresponding request index and token index in the request. + int request_idx = 0; + while (token_idx >= 0 && request_idx < max_num_requests) { + if (request_available[request_idx]) { + token_idx -= streaming_cache_infos[request_idx].commit_len; + } + request_idx++; + } + if (token_idx >= 0) { + return; + } + request_idx--; + token_idx += streaming_cache_infos[request_idx].commit_len; + + // Get the real and complex part index for the current complex. + int const real_part_idx = + get_k_entry_offset( + request_idx, token_idx, max_num_pages, num_kv_heads, head_dim) + + head_idx * head_dim + offset_in_head; + int const complex_part_idx = real_part_idx + head_dim / 2; + + // Apply the rotary position encoding. + cuFloatComplex cii = {kv_cache[real_part_idx], kv_cache[complex_part_idx]}; + size_t pos = token_idx; + float freq = pos * (1.0 / pow(10000.0, (float)2 * offset_in_head / head_dim)); + cuFloatComplex complex_pos = {cos(freq), sin(freq)}; + cii = cuCmulf(cii, complex_pos); + kv_cache[real_part_idx] = cii.x; + kv_cache[complex_part_idx] = cii.y; +} + +template +void apply_pos_encoding_to_streaming_proj( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + cudaStream_t stream) { + assert(m->streaming_cache); + int const kv_hidden_size = m->num_kv_heads * m->qk_dim; + int num_tokens = 0; + for (int req_idx = 0; req_idx < BatchConfig::max_requests_per_batch(); + req_idx++) { + if (!bc->request_available[req_idx]) { + continue; + } + num_tokens += bc->streamingCacheInfo[req_idx].commit_len; + } + if (num_tokens == 0) { + return; + } + int parallelism = num_tokens * kv_hidden_size / 2; + int const max_num_pages = round_up_pages( + BatchConfig::MAX_STREAMING_POS - BatchConfig::get_max_tree_depth() + + BatchConfig::max_spec_tree_token_num()); + apply_pos_encoding_to_streaming_proj_kernel<<>>( + static_cast(m->kvCache), + m->request_infos, + m->request_available, + max_num_pages, + m->num_kv_heads, + m->qk_dim, + m->streaming_cache_infos, + bc->max_requests_per_batch()); } template __global__ void - update_qkv_cache_kernel(DT *devQKVProjArray, - half *qTmp_ptr, - half *kCache_ptr, - BatchConfig::PerTokenInfo const *tokenInfos, - BatchConfig::PerRequestInfo *request_infos, - int const max_num_pages, - int num_q_heads, - int num_kv_heads, - int head_dim, - int num_new_tokens) { + update_qkv_in_batch_kernel(DT *qkv_proj_array, + half *qTmp_ptr, + half *kvCache_ptr, + BatchConfig::PerTokenInfo const *tokenInfos, + int const max_num_pages, + int num_q_heads, + int num_kv_heads, + int head_dim, + int num_new_tokens) { int const q_hidden_size = num_q_heads * head_dim; int const temp_kv_hidden_size = num_q_heads * head_dim; // temporary hard code int const kv_hidden_size = num_kv_heads * head_dim; @@ -395,11 +476,11 @@ __global__ void } int const req_idx = tokenInfos[token_idx].request_index; - int const token_abs_idx = tokenInfos[token_idx].abs_index_in_request; + int token_abs_idx = tokenInfos[token_idx].abs_index_in_request; size_t from_idx = token_idx * (q_hidden_size + temp_kv_hidden_size * 2); qTmp_ptr[token_idx * q_hidden_size + offset] = - static_cast(devQKVProjArray[from_idx + offset]); + static_cast(qkv_proj_array[from_idx + offset]); if (offset < kv_hidden_size) { size_t to_k_idx = get_k_entry_offset( @@ -410,38 +491,236 @@ __global__ void int const stride = num_q_heads / num_kv_heads; int const kv_offset = offset / head_dim * stride * head_dim + offset % head_dim; - kCache_ptr[to_k_idx + offset] = static_cast( - devQKVProjArray[from_idx + q_hidden_size + kv_offset]); - kCache_ptr[to_v_idx + offset] = - static_cast(devQKVProjArray[from_idx + q_hidden_size + - temp_kv_hidden_size + kv_offset]); + kvCache_ptr[to_k_idx + offset] = + static_cast(qkv_proj_array[from_idx + q_hidden_size + kv_offset]); + kvCache_ptr[to_v_idx + offset] = + static_cast(qkv_proj_array[from_idx + q_hidden_size + + temp_kv_hidden_size + kv_offset]); } } template -void update_qkv_cache(IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - cudaStream_t stream) { - // update the kv cache, compact the q array +void update_qkv_in_batch(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + cudaStream_t stream) { int num_new_tokens = bc->num_active_tokens(); + if (num_new_tokens == 0) { + return; + } int parallelism = m->local_hidden_size * num_new_tokens; int const max_num_pages = - (BatchConfig::max_sequence_length() + - BatchConfig::max_spec_tree_token_num() + kPagesize - 1) / - kPagesize; - update_qkv_cache_kernel<<>>(static_cast
(m->devQKVProjArray), - static_cast(m->queryTmp), - static_cast(m->kvCache), - m->token_infos, - m->request_infos, - max_num_pages, - m->num_q_heads, - m->num_kv_heads, - m->qk_dim, - num_new_tokens); + round_up_pages(BatchConfig::max_sequence_length() + + BatchConfig::max_spec_tree_token_num()); + update_qkv_in_batch_kernel<<>>(static_cast
(m->devQKVProjArray), + static_cast(m->queryTmp), + static_cast(m->kvCache), + m->token_infos, + max_num_pages, + m->num_q_heads, + m->num_kv_heads, + m->qk_dim, + num_new_tokens); +} + +__global__ void update_kv_in_streaming_cache_kernel( + half *pre_pos_enc_buf, + half *kv_cache, + BatchConfig::PerRequestInfo const *requestInfos, + bool const *request_available, + int const max_num_pages_pre_pos_enc_buf, + int const max_num_pages_kv_cache, + int num_kv_heads, + int head_dim, + StreamingCacheInfo const *streaming_cache_infos, + uint32_t const max_num_requests) { + int const kv_hidden_size = num_kv_heads * head_dim; + int const thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int token_idx = thread_idx / kv_hidden_size; + int const offset = thread_idx % kv_hidden_size; + int request_idx = 0; + while (token_idx >= 0 && request_idx < max_num_requests) { + if (request_available[request_idx]) { + token_idx -= streaming_cache_infos[request_idx].commit_len; + } + request_idx++; + } + if (token_idx >= 0) { + return; + } + request_idx--; + token_idx += streaming_cache_infos[request_idx].commit_len; + + size_t from_k_idx = get_k_entry_offset(request_idx, + token_idx, + max_num_pages_pre_pos_enc_buf, + num_kv_heads, + head_dim), + from_v_idx = get_v_entry_offset(request_idx, + token_idx, + max_num_pages_pre_pos_enc_buf, + num_kv_heads, + head_dim); + + // to_idx should consider the rolling property of the window cache + int to_idx = token_idx; + StreamingCacheInfo const &info = streaming_cache_infos[request_idx]; + if (info.commit_len >= info.sink_cache_size + info.window_cache_size && + to_idx >= info.sink_cache_size) { + to_idx -= info.sink_cache_size; + to_idx = (to_idx + info.window_cache_size - info.window_back) % + info.window_cache_size; + to_idx += info.sink_cache_size; + } + + size_t to_k_idx = get_k_entry_offset(request_idx, + to_idx, + max_num_pages_kv_cache, + num_kv_heads, + head_dim), + to_v_idx = get_v_entry_offset(request_idx, + to_idx, + max_num_pages_kv_cache, + num_kv_heads, + head_dim); + + kv_cache[to_k_idx + offset] = pre_pos_enc_buf[from_k_idx + offset]; + kv_cache[to_v_idx + offset] = pre_pos_enc_buf[from_v_idx + offset]; +} + +template +void update_kv_in_streaming_cache(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + cudaStream_t stream) { + assert(m->streaming_cache); + int const kv_hidden_size = m->num_kv_heads * m->qk_dim; + int num_tokens = 0; + for (int req_idx = 0; req_idx < BatchConfig::max_requests_per_batch(); + req_idx++) { + if (!bc->request_available[req_idx]) { + continue; + } + num_tokens += bc->streamingCacheInfo[req_idx].commit_len; + } + if (num_tokens == 0) { + return; + } + int parallelism = kv_hidden_size * num_tokens; + int const max_num_pages_pre_pos_enc_buf = round_up_pages( + BatchConfig::MAX_STREAMING_POS - BatchConfig::get_max_tree_depth()); + int const max_num_pages_kv_cache = round_up_pages( + BatchConfig::MAX_STREAMING_POS - BatchConfig::get_max_tree_depth() + + BatchConfig::max_spec_tree_token_num()); + + update_kv_in_streaming_cache_kernel<<>>( + static_cast(m->streamingPrePosEncBuf), + static_cast(m->kvCache), + m->request_infos, + m->request_available, + max_num_pages_pre_pos_enc_buf, + max_num_pages_kv_cache, + m->num_kv_heads, + m->qk_dim, + m->streaming_cache_infos, + bc->max_requests_per_batch()); +} + +template +__global__ void + commit_kv_kernel(DT const *qkv_proj_array, + half *pre_pos_enc_buf, + BatchConfig::PerTokenInfo const *tokenInfos, + BatchConfig::PerRequestInfo const *requestInfos, + int const max_num_pages, + int num_q_heads, + int num_kv_heads, + int head_dim, + StreamingCacheInfo const *streaming_cache_infos, + int num_new_tokens) { + int const q_hidden_size = num_q_heads * head_dim; + int const temp_kv_hidden_size = num_q_heads * head_dim; // temporary hard code + int const kv_hidden_size = num_kv_heads * head_dim; + int const thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + int const token_idx = thread_idx / kv_hidden_size; + int const offset = thread_idx % kv_hidden_size; + if (token_idx >= num_new_tokens) { + return; + } + int const request_idx = tokenInfos[token_idx].request_index; + + StreamingCacheInfo const &info = streaming_cache_infos[request_idx]; + int to_idx = tokenInfos[token_idx].abs_index_in_request; + // cases that get over the boundary: + // 1. commit_len < sink_cache_size: commit to sink, window, window_back is + // after commit_len. + // 2. sink_cache_size <= commit_len < sink_cache_size + window_cache_size: + // commit to window, window_back + sink_cache_size = commit_len, pointing to + // the same position. + // 3. commit_len >= sink_cache_size + window_cache_size: commit to window, + // window is full before this commit, window_back is pointing to the real + // position. + if (to_idx >= info.sink_cache_size + info.window_cache_size) { + to_idx = to_idx - info.commit_len + info.window_back; + if (info.commit_len < info.sink_cache_size) { + // For case 1, compensating for sink offset, because window_back is + // someway back from commit_len. + to_idx -= info.sink_cache_size - info.commit_len; + } + to_idx = info.sink_cache_size + to_idx % info.window_cache_size; + } + // TODO: For now don't consider the case that the commit tokens roll over the + // for more than once. In this case, we should only count the last tokens in + // the same window position. + + size_t from_idx = token_idx * (q_hidden_size + temp_kv_hidden_size * 2); + size_t to_k_idx = get_k_entry_offset( + request_idx, to_idx, max_num_pages, num_kv_heads, head_dim), + to_v_idx = get_v_entry_offset( + request_idx, to_idx, max_num_pages, num_kv_heads, head_dim); + + int const stride = num_q_heads / num_kv_heads; + int const kv_offset = + offset / head_dim * stride * head_dim + offset % head_dim; + + pre_pos_enc_buf[to_k_idx + offset] = + static_cast(qkv_proj_array[from_idx + q_hidden_size + kv_offset]); + pre_pos_enc_buf[to_v_idx + offset] = + static_cast(qkv_proj_array[from_idx + q_hidden_size + + temp_kv_hidden_size + kv_offset]); +} + +template +void commit_kv(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + cudaStream_t stream) { + assert(m->streaming_cache); + int const kv_hidden_size = m->num_kv_heads * m->qk_dim; + int const num_new_tokens = bc->num_active_tokens(); + if (num_new_tokens == 0) { + return; + } + int parallelism = kv_hidden_size * num_new_tokens; + int const max_num_pages = round_up_pages(BatchConfig::MAX_STREAMING_POS - + BatchConfig::get_max_tree_depth()); + + commit_kv_kernel<<>>(static_cast
(m->devQKVProjArray), + static_cast(m->streamingPrePosEncBuf), + m->token_infos, + m->request_infos, + max_num_pages, + m->num_q_heads, + m->num_kv_heads, + m->qk_dim, + m->streaming_cache_infos, + num_new_tokens); } template @@ -458,7 +737,11 @@ void produce_output(IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, DT *output_ptr, cudaStream_t stream) { - int parallelism = m->v_dim * m->num_q_heads * bc->num_active_tokens(); + int const num_tokens = bc->num_active_tokens(); + if (num_tokens == 0) { + return; + } + int parallelism = m->v_dim * m->num_q_heads * num_tokens; produce_output_kernel<<output_type[0]); assert(data_type_size(m->output_type[0]) == sizeof(DT)); #if CUDA_VERSION >= 11000 - // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance + // TODO: currently set the default to CUBLAS_COMPUTE_16F for best + // performance cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; #else cudaDataType_t compute_type = cublas_data_type; @@ -635,12 +919,60 @@ template void Kernels::IncMultiHeadAttention::compute_qkv( half const *bias_ptr, cudaStream_t stream); -template void Kernels::IncMultiHeadAttention::update_qkv_cache( +template void + Kernels::IncMultiHeadAttention::apply_pos_encoding_to_tokens_in_batch< + float>(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + float *output_ptr, + cudaStream_t stream); + +template void + Kernels::IncMultiHeadAttention::apply_pos_encoding_to_tokens_in_batch( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + half *output_ptr, + cudaStream_t stream); + +template void + Kernels::IncMultiHeadAttention::apply_pos_encoding_to_streaming_proj( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + cudaStream_t stream); + +template void + Kernels::IncMultiHeadAttention::apply_pos_encoding_to_streaming_proj( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + cudaStream_t stream); + +template void Kernels::IncMultiHeadAttention::update_qkv_in_batch( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + cudaStream_t stream); + +template void Kernels::IncMultiHeadAttention::update_qkv_in_batch( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + cudaStream_t stream); + +template void + Kernels::IncMultiHeadAttention::update_kv_in_streaming_cache( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + cudaStream_t stream); + +template void + Kernels::IncMultiHeadAttention::update_kv_in_streaming_cache( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + cudaStream_t stream); + +template void Kernels::IncMultiHeadAttention::commit_kv( IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, cudaStream_t stream); -template void Kernels::IncMultiHeadAttention::update_qkv_cache( +template void Kernels::IncMultiHeadAttention::commit_kv( IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, cudaStream_t stream); diff --git a/src/ops/spec_inc_multihead_self_attention.cc b/src/ops/spec_inc_multihead_self_attention.cc index cd937f1651..cfcf783e93 100644 --- a/src/ops/spec_inc_multihead_self_attention.cc +++ b/src/ops/spec_inc_multihead_self_attention.cc @@ -69,6 +69,7 @@ Tensor float scaling_factor, bool qk_prod_scaling, bool position_bias, + bool streaming_cache, char const *name) { return spec_inc_multiquery_self_attention(input, embed_dim, @@ -87,6 +88,7 @@ Tensor scaling_factor, qk_prod_scaling, position_bias, + streaming_cache, name); } @@ -108,6 +110,7 @@ Tensor float scaling_factor, bool qk_prod_scaling, bool position_bias, + bool streaming_cache, char const *name) { if (data_type == DT_NONE) { data_type = input->data_type; @@ -190,6 +193,7 @@ Tensor li->add_float_property("scaling_factor", scaling_factor); li->add_int_property("qk_prod_scaling", qk_prod_scaling); li->add_int_property("position_bias", position_bias); + li->add_int_property("streaming_cache", streaming_cache); layers.push_back(li); return li->outputs[0]; } @@ -229,6 +233,8 @@ Op *SpecIncMultiHeadSelfAttention::create_operator_from_layer( bool qk_prod_scaling = (bool)value; layer->get_int_property("position_bias", value); bool position_bias = (bool)value; + layer->get_int_property("streaming_cache", value); + bool streaming_cache = (bool)value; return new SpecIncMultiHeadSelfAttention(model, layer->layer_guid, @@ -248,6 +254,7 @@ Op *SpecIncMultiHeadSelfAttention::create_operator_from_layer( qk_prod_scaling, position_bias, false /*allocate_weights*/, + streaming_cache, layer->name); } @@ -270,6 +277,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( bool _qk_prod_scaling, bool _position_bias, bool allocate_weights, + bool _streaming_cache, char const *name) // Initializer* _bias_initializer) : Op(model, @@ -288,7 +296,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( o_dim(_embed_dim), qoSeqLength(_input->dims[1].size), kvSeqLength(_input->dims[1].size), scaling_query(_scaling_query), scaling_factor(_scaling_factor), qk_prod_scaling(_qk_prod_scaling), - position_bias(_position_bias) { + position_bias(_position_bias), streaming_cache(_streaming_cache) { // overwrite layer_guid layer_guid = _layer_guid; @@ -370,6 +378,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( bool _qk_prod_scaling, bool _position_bias, bool allocate_weights, + bool _streaming_cache, char const *name) // Initializer* _bias_initializer) : Op(model, @@ -389,7 +398,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( o_dim(_embed_dim), qoSeqLength(_input->dims[1].size), kvSeqLength(_input->dims[1].size), scaling_query(_scaling_query), scaling_factor(_scaling_factor), qk_prod_scaling(_qk_prod_scaling), - position_bias(_position_bias) + position_bias(_position_bias), streaming_cache(_streaming_cache) // bias_initializer(_bias_initializer) { numOutputs = 1; @@ -478,6 +487,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( other.qk_prod_scaling, other.position_bias, allocate_weights, + other.streaming_cache, other.name) {} SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( @@ -504,6 +514,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( params.qk_prod_scaling, params.position_bias, allocate_weights, + params.streaming_cache, params.name) {} void SpecIncMultiHeadSelfAttention::init_inference( @@ -825,7 +836,8 @@ bool operator==(SpecIncMultiHeadSelfAttentionParams const &lhs, lhs.scaling_query == rhs.scaling_query && lhs.scaling_factor == rhs.scaling_factor && lhs.qk_prod_scaling == rhs.qk_prod_scaling && - lhs.position_bias == rhs.position_bias; + lhs.position_bias == rhs.position_bias && + lhs.streaming_cache == rhs.streaming_cache; } SpecIncMultiHeadSelfAttentionParams @@ -846,6 +858,7 @@ SpecIncMultiHeadSelfAttentionParams params.scaling_factor = this->scaling_factor; params.qk_prod_scaling = this->qk_prod_scaling; params.position_bias = this->position_bias; + params.streaming_cache = this->streaming_cache; if (this->name != nullptr) { strcpy(params.name, this->name); } @@ -874,6 +887,7 @@ size_t hash::operator()( hash_combine(key, params.scaling_factor); hash_combine(key, params.qk_prod_scaling); hash_combine(key, params.position_bias); + hash_combine(key, params.streaming_cache); return key; } }; // namespace std diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 5010851b26..16dbe74767 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -239,9 +239,7 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta *m, DT *output_ptr, DT const *bias_ptr, cudaStream_t stream) { - // phase 1: Implement kernel to compute KQV for input tokens - - // long long time_1 = Realm::Clock::current_time_in_microseconds(), time_2; + // phase 1: Compute QKV Projections of the batch compute_qkv(m, bc, shard_id, @@ -250,11 +248,30 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta *m, static_cast
(m->devQKVProjArray), bias_ptr, stream); - // phase 2: Update key/val cache - update_qkv_cache
(m, bc, stream); - // phase 3: Compute attention score - // 3 kernels for pahse 3: matmul1 - softmax - matmal2 + // phase 2: First maintain the streaming cache, because it need + // pre-pos-encoding values + if (m->streaming_cache) { + // Move pre-pos-encoding cache to where took by attention + update_kv_in_streaming_cache
(m, bc, stream); + // Apply pos-encoding to those k values + apply_pos_encoding_to_streaming_proj
(m, bc, stream); + // Commit to the streaming cache + if (bc->prompt_phase) { + commit_kv
(m, bc, stream); + } + } + + // phase 3: Take care of the batch + { + // Apply pos-encoding to the batch + apply_pos_encoding_to_tokens_in_batch( + m, bc, static_cast
(m->devQKVProjArray), stream); + // Move the batch qkv values to where took by attention + update_qkv_in_batch
(m, bc, stream); + } + + // phase 4: Attention computation tree_search_attention
(m, bc, static_cast
(m->attn_heads), stream); // Debug output: @@ -277,14 +294,10 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta *m, // delete[] temp_output; - // compute output production and bias together for all tokens + // phase 5: Compute output production and bias together for all tokens int num_tokens = bc->num_active_tokens(); - compute_o_prod_bias( m, bc, shard_id, output_ptr, weight_ptr, bias_ptr, num_tokens, stream); - // time_2 = Realm::Clock::current_time_in_microseconds(); - // std::cout << "SpecIncMultiHeadSelfAttention kernel time: " - // << (time_2 - time_1) << "us" << std::endl; } } // namespace SpecIncMultiHeadSelfAttention @@ -388,7 +401,8 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( _num_q_heads, _num_kv_heads, DT_NONE, - false) { + false, + attn->streaming_cache) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); checkCUDNN(cudnnSetStream(handler.dnn, stream)); diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index cb545ec845..8c384c1b05 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -111,9 +111,8 @@ void commit_tokens(TreeIncMultiHeadSelfAttentionMeta const *m, // cudaEventRecord(t_start, stream); int const max_num_pages = - (BatchConfig::max_sequence_length() + - BatchConfig::max_spec_tree_token_num() + kPagesize - 1) / - kPagesize; + round_up_pages(BatchConfig::max_sequence_length() + + BatchConfig::max_spec_tree_token_num()); int const num_requests = bc->num_active_requests(); int parallelism = m->num_kv_heads * m->qk_dim * num_requests; commit_tokens_kernel<<(m->devQKVProjArray), stream); + // cudaEventRecord(t_end, stream); // checkCUDA(cudaEventSynchronize(t_end)); // elapsed = 0; @@ -416,7 +418,7 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, // cudaEventRecord(t_start, stream); // Update key-val cache, compact q array - update_qkv_cache
(m, bc, stream); + update_qkv_in_batch
(m, bc, stream); // cudaEventRecord(t_end, stream); // checkCUDA(cudaEventSynchronize(t_end)); @@ -613,7 +615,8 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( _num_q_heads, _num_kv_heads, attn->quantization_type, - attn->offload), + attn->offload, + false), num_active_tokens(0) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); @@ -633,7 +636,8 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + sizeof(BatchConfig::request_available) + - sizeof(BatchConfig::causalMask)); + sizeof(BatchConfig::causalMask) + + sizeof(BatchConfig::streamingCacheInfo)); num_tokens_to_commit = reinterpret_cast( reinterpret_cast(committed_token_infos) + sizeof(BatchConfig::committed_tokens)); diff --git a/src/parallel_ops/kernels/allreduce_kernels.cu b/src/parallel_ops/kernels/allreduce_kernels.cu index 02fb760fd5..60a1afaefa 100644 --- a/src/parallel_ops/kernels/allreduce_kernels.cu +++ b/src/parallel_ops/kernels/allreduce_kernels.cu @@ -166,7 +166,7 @@ void inference_kernel_wrapper(AllReduceMeta *m, ncclComm, const_cast(input.ptr), stream); - params.barrier_flag = (*comm_buffer->barrier_flag)++; + params.barrier_flag = ++(*comm_buffer->barrier_flag); for (int i = 0; i < num_devices; ++i) { params.peer_comm_buffer_ptrs[i] = comm_buffer->comm_ptrs[i]; } diff --git a/src/runtime/batch_config.cc b/src/runtime/batch_config.cc index d74f8084c3..308f468f53 100644 --- a/src/runtime/batch_config.cc +++ b/src/runtime/batch_config.cc @@ -16,6 +16,7 @@ #include "flexflow/batch_config.h" #include "flexflow/request_manager.h" #include "legion.h" +#include #include #include @@ -48,6 +49,7 @@ BatchConfig::BatchConfig(BatchConfig const &rhs) { if (rhs.request_available[request_idx]) { request_available[request_idx] = true; requestsInfo[request_idx] = rhs.requestsInfo[request_idx]; + streamingCacheInfo[request_idx] = rhs.streamingCacheInfo[request_idx]; causalMask[request_idx] = rhs.causalMask[request_idx]; } } @@ -101,6 +103,10 @@ int BatchConfig::max_spec_tree_token_num() { return RequestManager::get_request_manager()->get_max_spec_tree_token_num(); } +int BatchConfig::get_max_tree_depth() { + return RequestManager::get_request_manager()->get_max_tree_depth(); +} + // Overloading the << operator for the Bitset class std::ostream &operator<<(std::ostream &os, BatchConfig::BitMask::Bitset const &bitset) { @@ -155,6 +161,22 @@ std::ostream &operator<<(std::ostream &os, BatchConfig const &bc) { } } + // Streaming cache info + os << "Streaming cache info:\n"; + for (int i = 0; i < bc.max_requests_per_batch(); i++) { + if (bc.request_available[i]) { + os << " Request " << i << ":\n"; + os << " Sink cache size: " << bc.streamingCacheInfo[i].sink_cache_size + << std::endl; + os << " Window cache size: " + << bc.streamingCacheInfo[i].window_cache_size << std::endl; + os << " Window back: " << bc.streamingCacheInfo[i].window_back + << std::endl; + os << " Commit len: " << bc.streamingCacheInfo[i].commit_len + << std::endl; + } + } + // Per-token info os << "Per-token info:\n"; for (int i = 0; i < bc.num_tokens; i++) { @@ -232,4 +254,50 @@ InferenceResult::InferenceResult(InferenceResult const &other) { gumbel_logits); } +StreamingCacheInfo::StreamingCacheInfo() : StreamingCacheInfo(0, 0) {} + +StreamingCacheInfo::StreamingCacheInfo(int sink_cache_size, + int window_cache_size) + : sink_cache_size(sink_cache_size), window_cache_size(window_cache_size), + window_back(0), commit_len(0) {} + +StreamingCacheInfo::StreamingCacheInfo(StreamingCacheInfo const &other) + : sink_cache_size(other.sink_cache_size), + window_cache_size(other.window_cache_size), + window_back(other.window_back), commit_len(other.commit_len) {} + +StreamingCacheInfo & + StreamingCacheInfo::operator=(StreamingCacheInfo const &other) { + sink_cache_size = other.sink_cache_size; + window_cache_size = other.window_cache_size; + window_back = other.window_back; + commit_len = other.commit_len; + return *this; +} + +// For draft model, we only update the cache when prefill or +// commit the verified result from target model; +// For incremental decoding, we update the cache both in prefill and decoding +void StreamingCacheInfo::commit_cache(int len) { + commit_len += len; + if (commit_len <= sink_cache_size + window_cache_size) { + window_back = std::max(0, commit_len - sink_cache_size); + } else { + commit_len = sink_cache_size + window_cache_size; + window_back = (window_back + len - 1) % window_cache_size + 1; + } +} + +void StreamingCacheInfo::reset_cache() { + window_back = 0; + commit_len = 0; +} + +int StreamingCacheInfo::global_2_cache_index(int global_index) { + if (global_index < sink_cache_size) { + return global_index; + } + return (global_index - sink_cache_size) % window_cache_size + sink_cache_size; +} + }; // namespace FlexFlow diff --git a/src/runtime/graph.cc b/src/runtime/graph.cc index 8cae8e0592..ca8e51d40f 100644 --- a/src/runtime/graph.cc +++ b/src/runtime/graph.cc @@ -2342,6 +2342,7 @@ GraphOptimalViewSerialized sez.serialize(attn->position_bias); sez.serialize(attn->quantization_type); sez.serialize(attn->offload); + sez.serialize(attn->streaming_cache); sez.serialize(attn->num_kv_heads); sez.serialize(attn->tensor_parallelism_degree); sez.serialize(strlen(attn->name)); @@ -2367,6 +2368,7 @@ GraphOptimalViewSerialized sez.serialize(attn->scaling_factor); sez.serialize(attn->qk_prod_scaling); sez.serialize(attn->position_bias); + sez.serialize(attn->streaming_cache); sez.serialize(attn->num_kv_heads); sez.serialize(strlen(attn->name)); sez.serialize(attn->name, strlen(attn->name)); @@ -2807,7 +2809,8 @@ void FFModel::deserialize_graph_optimal_view( tensor_parallelism_degree; float dropout, scaling_factor; bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, offload, position_bias; + scaling_query, qk_prod_scaling, offload, streaming_cache, + position_bias; DataType quantization_type; size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); @@ -2829,6 +2832,7 @@ void FFModel::deserialize_graph_optimal_view( dez.deserialize(position_bias); dez.deserialize(quantization_type); dez.deserialize(offload); + dez.deserialize(streaming_cache); dez.deserialize(num_kv_heads); dez.deserialize(tensor_parallelism_degree); size_t name_len; @@ -2853,6 +2857,7 @@ void FFModel::deserialize_graph_optimal_view( params.position_bias = position_bias; params.quantization_type = quantization_type; params.offload = offload; + params.streaming_cache = streaming_cache; params.num_kv_heads = num_kv_heads; params.tensor_parallelism_degree = tensor_parallelism_degree; strcpy(params.name, name); @@ -2864,7 +2869,7 @@ void FFModel::deserialize_graph_optimal_view( int embed_dim, num_q_heads, k_dim, v_dim, num_kv_heads; float dropout, scaling_factor; bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, position_bias; + scaling_query, qk_prod_scaling, position_bias, streaming_cache; size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); @@ -2883,6 +2888,7 @@ void FFModel::deserialize_graph_optimal_view( dez.deserialize(scaling_factor); dez.deserialize(qk_prod_scaling); dez.deserialize(position_bias); + dez.deserialize(streaming_cache); dez.deserialize(num_kv_heads); size_t name_len; char name[MAX_OPNAME] = {0}; @@ -2904,6 +2910,7 @@ void FFModel::deserialize_graph_optimal_view( params.scaling_factor = scaling_factor; params.qk_prod_scaling = qk_prod_scaling; params.position_bias = position_bias; + params.streaming_cache = streaming_cache; params.num_kv_heads = num_kv_heads; strcpy(params.name, name); node = get_or_create_node(inputs[0], diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 3a6619b37f..eaddb255c9 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -206,6 +206,10 @@ void RequestManager::set_speculative_sampling(bool speculative_sampling_) { speculative_sampling = speculative_sampling_; } +void RequestManager::set_streaming_cache(bool streaming_cache_) { + streaming_cache = streaming_cache_; +} + void RequestManager::register_tokenizer(ModelType type, int bos_token_id, int eos_token_id, @@ -303,6 +307,11 @@ RequestManager::RequestGuid init_token_tree(request.guid); } + request.streaming_cache_info = StreamingCacheInfo( + BatchConfig::SINK_SIZE, + BatchConfig::MAX_STREAMING_POS - BatchConfig::SINK_SIZE - + BatchConfig::get_max_tree_depth()); + pending_request_queue.push(request); all_requests[request.guid] = request; { @@ -362,6 +371,11 @@ RequestManager::RequestGuid init_token_tree(request.guid); } + request.streaming_cache_info = StreamingCacheInfo( + BatchConfig::SINK_SIZE, + BatchConfig::MAX_STREAMING_POS - BatchConfig::SINK_SIZE - + BatchConfig::get_max_tree_depth()); + pending_request_queue.push(request); all_requests[request.guid] = request; { @@ -723,9 +737,17 @@ void RequestManager::update_inference_results(InferenceResult const &result) { bool RequestManager::update_llm_prefill_results(InferenceResult const &result) { bool prefill_completed = false; - prefill_request->llm_cache_size += prefill_request->num_tokens_in_batch; + if (decoding_mode == INCREMENTAL_DECODING && streaming_cache) { + prefill_request->streaming_cache_info.commit_cache( + prefill_request->num_tokens_in_batch); + prefill_request->llm_cache_size = + prefill_request->streaming_cache_info.commit_len; + } else { + prefill_request->llm_cache_size += prefill_request->num_tokens_in_batch; + } + prefill_request->llm_prefill_len += prefill_request->num_tokens_in_batch; - if (prefill_request->llm_cache_size == prefill_request->tokens.size()) { + if (prefill_request->llm_prefill_len == prefill_request->tokens.size()) { // Indicates that the LLM prefilling phase finishes prefill_request->tokens.push_back( result.token_ids[prefill_request->num_tokens_in_batch - 1]); @@ -767,7 +789,12 @@ bool RequestManager::update_llm_decode_results(InferenceResult const &result) { int guid = guid_of_requests[request_index]; Request &request = all_requests[guid]; assert(request.status == Request::RUNNING); - request.llm_cache_size++; + if (streaming_cache) { + request.streaming_cache_info.commit_cache(1); + request.llm_cache_size = request.streaming_cache_info.commit_len; + } else { + request.llm_cache_size++; + } request.tokens.push_back( result.token_ids[request.first_token_offset_in_batch]); @@ -799,7 +826,15 @@ void RequestManager::update_ssm_prefill_results( // This function is called by update_inference_results when the // request_manager_status is PREFILLING and the prefill_model is SSM. // There's no results to update, but we should update ssm_cache_size. - prefill_request->ssm_cache_size += prefill_request->num_tokens_in_batch; + if (streaming_cache) { + prefill_request->streaming_cache_info.commit_cache( + prefill_request->num_tokens_in_batch); + prefill_request->ssm_cache_size = + prefill_request->streaming_cache_info.commit_len; + } else { + prefill_request->ssm_cache_size += prefill_request->num_tokens_in_batch; + } + prefill_request->ssm_prefill_len += prefill_request->num_tokens_in_batch; profiling_requests[prefill_request->guid].ssm_prefilling_steps++; } @@ -877,25 +912,27 @@ BatchConfig RequestManager::prepare_llm_prefilling_batch() { bc.requestsInfo[request_index].first_token_offset_in_batch = 0; bc.requestsInfo[request_index].first_token_index_in_request = prefill_request->llm_cache_size; - bc.requestsInfo[request_index].num_tokens_in_batch = std::min( - get_max_tokens_per_batch(), - (int)prefill_request->tokens.size() - prefill_request->llm_cache_size); + int num_tokens_in_batch = std::min(get_max_tokens_per_batch(), + (int)prefill_request->tokens.size() - + prefill_request->llm_prefill_len); + bc.requestsInfo[request_index].num_tokens_in_batch = num_tokens_in_batch; + + // Copy the streaming cache info + bc.streamingCacheInfo[request_index] = prefill_request->streaming_cache_info; prefill_request->first_token_offset_in_batch = 0; - prefill_request->num_tokens_in_batch = - bc.requestsInfo[request_index].num_tokens_in_batch; + prefill_request->num_tokens_in_batch = num_tokens_in_batch; // Token Info - for (int token_idx = 0; - token_idx < bc.requestsInfo[request_index].num_tokens_in_batch; - token_idx++) { + for (int token_idx = 0; token_idx < num_tokens_in_batch; token_idx++) { int abs_idx = prefill_request->llm_cache_size + token_idx; assert(abs_idx < prefill_request->tokens.size()); bc.tokensInfo[token_idx].request_index = request_index; bc.tokensInfo[token_idx].abs_index_in_request = abs_idx; bc.tokensInfo[token_idx].abs_depth_in_request = abs_idx; - bc.tokensInfo[token_idx].token_id = prefill_request->tokens[abs_idx]; + bc.tokensInfo[token_idx].token_id = + prefill_request->tokens[prefill_request->llm_prefill_len + token_idx]; bc.num_tokens++; } @@ -931,25 +968,27 @@ BatchConfig RequestManager::prepare_ssm_prefilling_batch() { bc.requestsInfo[request_index].first_token_offset_in_batch = 0; bc.requestsInfo[request_index].first_token_index_in_request = prefill_request->ssm_cache_size; - bc.requestsInfo[request_index].num_tokens_in_batch = std::min( - get_max_tokens_per_batch(), - (int)prefill_request->tokens.size() - prefill_request->ssm_cache_size); + int num_tokens_in_batch = std::min(get_max_tokens_per_batch(), + (int)prefill_request->tokens.size() - + prefill_request->ssm_prefill_len); + bc.requestsInfo[request_index].num_tokens_in_batch = num_tokens_in_batch; + + // Copy the streaming cache info + bc.streamingCacheInfo[request_index] = prefill_request->streaming_cache_info; prefill_request->first_token_offset_in_batch = 0; - prefill_request->num_tokens_in_batch = - bc.requestsInfo[request_index].num_tokens_in_batch; + prefill_request->num_tokens_in_batch = num_tokens_in_batch; // Token Info - for (int token_idx = 0; - token_idx < bc.requestsInfo[request_index].num_tokens_in_batch; - token_idx++) { + for (int token_idx = 0; token_idx < num_tokens_in_batch; token_idx++) { int abs_idx = prefill_request->ssm_cache_size + token_idx; assert(abs_idx < prefill_request->tokens.size()); bc.tokensInfo[token_idx].request_index = request_index; bc.tokensInfo[token_idx].abs_index_in_request = abs_idx; bc.tokensInfo[token_idx].abs_depth_in_request = abs_idx; - bc.tokensInfo[token_idx].token_id = prefill_request->tokens[abs_idx]; + bc.tokensInfo[token_idx].token_id = + prefill_request->tokens[prefill_request->ssm_prefill_len + token_idx]; bc.num_tokens++; } @@ -992,6 +1031,9 @@ BatchConfig RequestManager::prepare_decoding_batch() { bc.requestsInfo[request_index].first_token_offset_in_batch = bc.num_tokens; bc.requestsInfo[request_index].num_tokens_in_batch = 1; + // Copy the streaming cache info + bc.streamingCacheInfo[request_index] = request.streaming_cache_info; + request.first_token_offset_in_batch = bc.num_tokens; request.num_tokens_in_batch = 1; @@ -1064,13 +1106,22 @@ BatchConfig RequestManager::prepare_first_spec_batch_config() { if (num_committed_tokens == 1) { new_bc.requestsInfo[request_index].num_tokens_in_batch = 1; // The case where the prefilling is just finished. Although the last - // token's kv cache is already there, the we need to decode the last token - // because it's the root of the token tree. + // token's kv cache is already there, the we need to decode the last + // token because it's the root of the token tree. new_bc.tokensInfo[new_bc.num_tokens].request_index = request_index; - new_bc.tokensInfo[new_bc.num_tokens].abs_index_in_request = - committed_tokens[0].to_index; - new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = - committed_tokens[0].to_index; + if (streaming_cache) { + new_bc.tokensInfo[new_bc.num_tokens].abs_index_in_request = + request.streaming_cache_info.global_2_cache_index( + committed_tokens[0].to_index); + new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = + request.streaming_cache_info.global_2_cache_index( + committed_tokens[0].to_index); + } else { + new_bc.tokensInfo[new_bc.num_tokens].abs_index_in_request = + committed_tokens[0].to_index; + new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = + committed_tokens[0].to_index; + } new_bc.tokensInfo[new_bc.num_tokens].token_id = committed_tokens[0].token_id; new_bc.num_tokens++; @@ -1079,10 +1130,19 @@ BatchConfig RequestManager::prepare_first_spec_batch_config() { committed_token_index < committed_tokens.size(); committed_token_index++) { new_bc.tokensInfo[new_bc.num_tokens].request_index = request_index; - new_bc.tokensInfo[new_bc.num_tokens].abs_index_in_request = - committed_tokens[committed_token_index].to_index; - new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = - committed_tokens[committed_token_index].to_index; + if (streaming_cache) { + new_bc.tokensInfo[new_bc.num_tokens].abs_index_in_request = + request.streaming_cache_info.global_2_cache_index( + committed_tokens[committed_token_index].to_index); + new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = + request.streaming_cache_info.global_2_cache_index( + committed_tokens[committed_token_index].to_index); + } else { + new_bc.tokensInfo[new_bc.num_tokens].abs_index_in_request = + committed_tokens[committed_token_index].to_index; + new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = + committed_tokens[committed_token_index].to_index; + } new_bc.tokensInfo[new_bc.num_tokens].token_id = committed_tokens[committed_token_index].token_id; new_bc.num_tokens++; @@ -1099,6 +1159,13 @@ BatchConfig RequestManager::prepare_first_spec_batch_config() { // Copy the causal mask, it should already been updated in // update_llm_verify_results new_bc.causalMask[request_index] = request.causal_mask; + if (streaming_cache) { + new_bc.causalMask[request_index].non_tree_cache_size = + request.ssm_cache_size - 1; + } + + // Copy the streaming cache info + new_bc.streamingCacheInfo[request_index] = request.streaming_cache_info; if (profiling_requests[guid].ssm_decoding_steps == 0) { profiling_requests[guid].start_decoding_time = @@ -1149,9 +1216,9 @@ BatchConfig RequestManager::prepare_next_spec_batch_config() { // This request has no token to decode in this and the following small // model inference steps new_bc.requestsInfo[request_index].num_tokens_in_batch = 0; + // non_tree_cache_size = ssm_cache_size - 1 new_bc.requestsInfo[request_index].first_token_index_in_request = - request.causal_mask.non_tree_cache_size + - request.causal_mask.tree_or_prompt_size - + request.ssm_cache_size - 1 + request.causal_mask.tree_or_prompt_size - request.causal_mask.current_layer_size; request.num_tokens_in_batch = 0; request.first_token_offset_in_batch = new_bc.num_tokens; @@ -1161,9 +1228,9 @@ BatchConfig RequestManager::prepare_next_spec_batch_config() { token_tree.tree_layers.back(); // Exclude the current layer from the token tree, because we want the // start index + // non_tree_cache_size = ssm_cache_size - 1 new_bc.requestsInfo[request_index].first_token_index_in_request = - request.causal_mask.non_tree_cache_size + - request.causal_mask.tree_or_prompt_size - + request.ssm_cache_size - 1 + request.causal_mask.tree_or_prompt_size - request.causal_mask.current_layer_size; new_bc.requestsInfo[request_index].num_tokens_in_batch = request.causal_mask.current_layer_size; @@ -1179,7 +1246,7 @@ BatchConfig RequestManager::prepare_next_spec_batch_config() { new_bc.requestsInfo[request_index].first_token_index_in_request + child_index; new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = - request.tokens.size() - 1 + current_ssm_step; + request.ssm_cache_size - 1 + current_ssm_step; new_bc.tokensInfo[new_bc.num_tokens].token_id = node_ptr->id; new_bc.num_tokens++; @@ -1190,6 +1257,13 @@ BatchConfig RequestManager::prepare_next_spec_batch_config() { // Copy the causal mask, it should already been updated by // update_ssm_inference_results new_bc.causalMask[request_index] = request.causal_mask; + if (streaming_cache) { + new_bc.causalMask[request_index].non_tree_cache_size = + request.ssm_cache_size - 1; + } + + // Copy the streaming cache info + new_bc.streamingCacheInfo[request_index] = request.streaming_cache_info; } if (verbose) { @@ -1292,6 +1366,9 @@ BatchConfig RequestManager::prepare_verify_batch_config() { // Create the causal mask for the large model based on the small model // causal mask. new_bc.causalMask[request_index] = create_llm_bitmask(guid); + + // Copy the streaming cache info + new_bc.streamingCacheInfo[request_index] = request.streaming_cache_info; } if (verbose) { @@ -1429,7 +1506,12 @@ bool RequestManager::update_ssm_inference_results( assert(request.status == Request::RUNNING); if (current_ssm_step == 1) { - request.ssm_cache_size = request.tokens.size(); + if (streaming_cache) { + request.streaming_cache_info.commit_cache(request.num_tokens_in_batch); + request.ssm_cache_size = request.streaming_cache_info.commit_len; + } else { + request.ssm_cache_size = request.tokens.size(); + } } if (current_ssm_step == 1) { diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index bb027d5862..48d79ea5c4 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -78,7 +78,17 @@ void RequestManager::load_tokens_task( } } -// NOTE: qk_indptr is accumulative `ceil(qk_len / 8)` +// q_indptr: the start offset of q in the batch for each request, +// the length is `num_requests + 1`: [0, num_q_0, num_q_0 + num_q_1, +// ..., num_q_0 + num_q_1 + ... + num_q_{num_requests - 1}] +// kv_indptr: the start offset of kv page_indices for each request, +// the length is `num_requests + 1`. +// kv_indices: the page indices for kv, the length is `num_kv_pages`. +// kv_last_page_len: the cache length in the last page for each request, +// the length is `num_requests`. +// qk_indptr: the start offset of custom_mask in the flattened mask for each +// request, the length is `num_requests + 1`. It can be calculated as +// accumulative `ceil(qk_len / 8)`. __global__ void prepare_inference_params_kernel(int const num_requests, BatchConfig::PerRequestInfo *request_infos, @@ -130,6 +140,11 @@ __global__ void #define test_bit_orig(bit_mask, idx, pos) \ (((bit_mask)[idx].bits[(pos) / 64] & (1ULL << ((pos) % 64))) != 0) +// Passing the CPU-side causalMask, then output the bit-packed custom_mask for +// attention forward. +// Layout of causalMask: [num_requests][tree_size][tree_size] +// Layout of custom_mask: [num_requests][q_length][kv_length] (bit-packed) +// Note that for spec-decoding, q_length == last_layer_length != tree_size __global__ void update_custom_mask_kernel(uint8_t *custom_mask, int32_t const *qk_indptr, @@ -160,21 +175,25 @@ __global__ void } } + BatchConfig::BitMask &causal_mask = causalMask[requext_idx_in_batch]; + int const q_length = request_infos[requext_idx_in_batch].num_tokens_in_batch, q_start = request_infos[requext_idx_in_batch] - .first_token_index_in_request; + .first_token_index_in_request - + causal_mask.non_tree_cache_size, + non_tree_cache_size = causal_mask.non_tree_cache_size; uint8_t packed_bits = 0; for (int bit_idx = 0; bit_idx < 8; bit_idx++) { int const bit_offset = byte_idx * 8 + bit_idx, - q_idx = bit_offset / (q_start + q_length), - kv_idx = bit_offset % (q_start + q_length); - if (kv_idx < q_start || q_idx >= q_length) { + q_idx = bit_offset / (non_tree_cache_size + q_start + q_length), + kv_idx = bit_offset % (non_tree_cache_size + q_start + q_length); + if (kv_idx < non_tree_cache_size || q_idx >= q_length) { packed_bits |= 1 << bit_idx; } else { - if (test_bit_orig(causalMask[requext_idx_in_batch].bit_mask, - q_idx, - kv_idx - q_start)) { + if (test_bit_orig(causal_mask.bit_mask, + q_start + q_idx, + kv_idx - non_tree_cache_size)) { packed_bits |= 1 << bit_idx; } } @@ -231,6 +250,47 @@ void RequestManager::load_batch_config_task( stream)); total_copy_size += sizeof(BatchConfig::request_available); + for (int request_idx = 0; request_idx < BatchConfig::max_requests_per_batch(); + request_idx++) { + if (batch_config->request_available[request_idx]) { + checkCUDA(cudaMemcpyAsync( + static_cast(handle.batch_config_metadata) + total_copy_size + + request_idx * sizeof(BatchConfig::BitMask), + &(batch_config->causalMask[request_idx]), + sizeof(BatchConfig::BitMask), + cudaMemcpyHostToDevice, + stream)); + } + } + total_copy_size += sizeof(BatchConfig::causalMask); + + checkCUDA(cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(batch_config->streamingCacheInfo), + sizeof(BatchConfig::streamingCacheInfo), + cudaMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(BatchConfig::streamingCacheInfo); + + if (batch_config->num_tokens_to_commit > 0) { + checkCUDA(cudaMemcpyAsync( + static_cast(handle.batch_config_metadata) + total_copy_size, + &(batch_config->committed_tokens), + batch_config->num_tokens_to_commit * + sizeof(BatchConfig::CommittedTokensInfo), + cudaMemcpyHostToDevice, + stream)); + } + total_copy_size += sizeof(BatchConfig::committed_tokens); + + checkCUDA(cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(batch_config->num_tokens_to_commit), + sizeof(int), + cudaMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(int); + // load attention metadata if (batch_config->get_mode() == INC_DECODING_MODE) { if (handle.incr_attention_metadata->enabled()) { @@ -246,9 +306,8 @@ void RequestManager::load_batch_config_task( sizeof(BatchConfig::requestsInfo)); int batch_size = batch_config->num_active_requests(); uint32_t const max_num_pages = - (BatchConfig::max_sequence_length() + - BatchConfig::max_spec_tree_token_num() + kPagesize - 1) / - kPagesize; + round_up_pages(BatchConfig::max_sequence_length() + + BatchConfig::max_spec_tree_token_num()); int parallelism = batch_size; prepare_inference_params_kernel<<get_mode() == TREE_SEARCH_MODE) { if (handle.tree_search_attention_metadata->enabled()) { - for (int request_idx = 0; - request_idx < BatchConfig::max_requests_per_batch(); - request_idx++) { - if (batch_config->request_available[request_idx]) { - checkCUDA(cudaMemcpyAsync( - static_cast(handle.batch_config_metadata) + - total_copy_size + request_idx * sizeof(BatchConfig::BitMask), - &(batch_config->causalMask[request_idx]), - sizeof(BatchConfig::BitMask), - cudaMemcpyHostToDevice, - stream)); - } - } - total_copy_size += sizeof(BatchConfig::causalMask); - // calculate the attention meta data { BatchConfig::PerRequestInfo *request_infos = @@ -392,9 +436,8 @@ void RequestManager::load_batch_config_task( sizeof(BatchConfig::request_available)); int batch_size = batch_config->num_active_requests(); uint32_t const max_num_pages = - (BatchConfig::max_sequence_length() + - BatchConfig::max_spec_tree_token_num() + kPagesize - 1) / - kPagesize; + round_up_pages(BatchConfig::max_sequence_length() + + BatchConfig::max_spec_tree_token_num()); int parallelism = batch_size; prepare_inference_params_kernel<<get_mode() == TREE_VERIFY_MODE) { if (handle.tree_verify_attention_metadata->enabled()) { - for (int request_idx = 0; - request_idx < BatchConfig::max_requests_per_batch(); - request_idx++) { - if (batch_config->request_available[request_idx]) { - checkCUDA(cudaMemcpyAsync( - static_cast(handle.batch_config_metadata) + - total_copy_size + request_idx * sizeof(BatchConfig::BitMask), - &(batch_config->causalMask[request_idx]), - sizeof(BatchConfig::BitMask), - cudaMemcpyHostToDevice, - stream)); - } - } - total_copy_size += sizeof(BatchConfig::causalMask); - - if (batch_config->num_tokens_to_commit > 0) { - checkCUDA(cudaMemcpyAsync( - static_cast(handle.batch_config_metadata) + total_copy_size, - &(batch_config->committed_tokens), - batch_config->num_tokens_to_commit * - sizeof(BatchConfig::CommittedTokensInfo), - cudaMemcpyHostToDevice, - stream)); - } - total_copy_size += sizeof(BatchConfig::committed_tokens); - - checkCUDA(cudaMemcpyAsync( - static_cast(handle.batch_config_metadata) + total_copy_size, - &(batch_config->num_tokens_to_commit), - sizeof(int), - cudaMemcpyHostToDevice, - stream)); - total_copy_size += sizeof(int); - // calculate the attention meta data { BatchConfig::PerRequestInfo *request_infos = @@ -558,9 +567,8 @@ void RequestManager::load_batch_config_task( sizeof(BatchConfig::request_available)); int batch_size = batch_config->num_active_requests(); uint32_t const max_num_pages = - (BatchConfig::max_sequence_length() + - BatchConfig::max_spec_tree_token_num() + kPagesize - 1) / - kPagesize; + round_up_pages(BatchConfig::max_sequence_length() + + BatchConfig::max_spec_tree_token_num()); int parallelism = batch_size; prepare_inference_params_kernel<<