From 22aebb3c393052eb3482977fa214229cc5e62333 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Sun, 29 Sep 2024 06:28:22 +0000 Subject: [PATCH] llama3.1 support --- .gitignore | 2 + include/flexflow/flexflow_c.h | 36 ++++ include/flexflow/inference.h | 39 +++- include/flexflow/layer.h | 3 + include/flexflow/model.h | 150 +++++++------- include/flexflow/operator.h | 8 +- .../ops/inc_multihead_self_attention.h | 12 +- .../ops/inc_multihead_self_attention_params.h | 6 +- .../ops/spec_inc_multihead_self_attention.h | 8 +- ...spec_inc_multihead_self_attention_params.h | 5 +- .../ops/tree_inc_multihead_self_attention.h | 8 +- ...tree_inc_multihead_self_attention_params.h | 5 +- inference/models/falcon.cc | 30 +-- inference/models/falcon.h | 29 ++- inference/models/llama.cc | 30 +-- inference/models/llama.h | 29 ++- inference/models/mpt.cc | 6 +- inference/models/mpt.h | 2 + inference/models/opt.cc | 12 +- inference/models/opt.h | 9 +- inference/models/starcoder.cc | 22 +-- inference/models/starcoder.h | 4 +- python/flexflow/core/flexflow_cffi.py | 101 +++++++--- python/flexflow/serve/models/falcon.py | 22 ++- python/flexflow/serve/models/llama.py | 22 ++- python/flexflow/serve/models/mpt.py | 12 +- python/flexflow/serve/models/opt.py | 12 +- python/flexflow/serve/models/starcoder.py | 10 +- src/c/flexflow_c.cc | 90 ++++++++- src/ops/inc_multihead_self_attention.cc | 137 ++++++++----- src/ops/inc_multihead_self_attention.cpp | 184 ++++++++++-------- src/ops/inc_multihead_self_attention.cu | 164 +++++++++------- src/ops/spec_inc_multihead_self_attention.cc | 139 ++++++++----- src/ops/spec_inc_multihead_self_attention.cpp | 2 +- src/ops/spec_inc_multihead_self_attention.cu | 6 +- src/ops/tree_inc_multihead_self_attention.cc | 71 +++++-- src/ops/tree_inc_multihead_self_attention.cpp | 2 +- src/ops/tree_inc_multihead_self_attention.cu | 4 +- src/runtime/graph.cc | 90 +++++++-- src/runtime/layer.cc | 17 ++ tests/fine_grained_alignment_test.sh | 31 ++- 41 files changed, 1042 insertions(+), 529 deletions(-) diff --git a/.gitignore b/.gitignore index cc34c1a7b6..27264b8fbf 100644 --- a/.gitignore +++ b/.gitignore @@ -193,3 +193,5 @@ lora_training_logs Untitled-1.ipynb Untitled-2.ipynb tests/inference/python_test_configs/*.json + +core.* diff --git a/include/flexflow/flexflow_c.h b/include/flexflow/flexflow_c.h index 52b4b3d362..afe6bc4573 100644 --- a/include/flexflow/flexflow_c.h +++ b/include/flexflow/flexflow_c.h @@ -451,6 +451,12 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -471,6 +477,12 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -491,6 +503,12 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention_verify( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -512,6 +530,12 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -533,6 +557,12 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -554,6 +584,12 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention_verify( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, diff --git a/include/flexflow/inference.h b/include/flexflow/inference.h index ba4101c173..755df9f5cb 100644 --- a/include/flexflow/inference.h +++ b/include/flexflow/inference.h @@ -43,8 +43,43 @@ struct GenerationResult { std::vector finetuning_losses; }; -#include -#include +struct RotaryEmbeddingMeta { + bool apply_rotary_embedding = false; + float rope_theta = 10000.0f; + std::string rope_type = "default"; + float factor = 8.0f; + float low_freq_factor = 1.0f; + float high_freq_factor = 4.0f; + int original_max_position_embeddings = 8192; + + RotaryEmbeddingMeta(bool apply_rotary_embedding_ = false, + float rope_theta_ = 10000.0f, + std::string rope_type_ = "default", + float factor_ = 8.0f, + float low_freq_factor_ = 1.0f, + float high_freq_factor_ = 4.0f, + int original_max_position_embeddings_ = 8192) + : apply_rotary_embedding(apply_rotary_embedding_), + rope_theta(rope_theta_), rope_type(rope_type_), factor(factor_), + low_freq_factor(low_freq_factor_), high_freq_factor(high_freq_factor_), + original_max_position_embeddings(original_max_position_embeddings_) {} + + friend std::ostream &operator<<(std::ostream &os, + RotaryEmbeddingMeta const &meta) { + os << std::boolalpha // To print bool as true/false instead of 1/0 + << "RotaryEmbeddingMeta {\n" + << " apply_rotary_embedding: " << meta.apply_rotary_embedding << ",\n" + << " rope_theta: " << meta.rope_theta << ",\n" + << " rope_type: \"" << meta.rope_type << "\",\n" + << " factor: " << meta.factor << ",\n" + << " low_freq_factor: " << meta.low_freq_factor << ",\n" + << " high_freq_factor: " << meta.high_freq_factor << ",\n" + << " original_max_position_embeddings: " + << meta.original_max_position_embeddings << "\n" + << "}"; + return os; + } +}; std::string join_path(std::vector const &paths); diff --git a/include/flexflow/layer.h b/include/flexflow/layer.h index c3dbcac422..e18bad3982 100644 --- a/include/flexflow/layer.h +++ b/include/flexflow/layer.h @@ -32,11 +32,13 @@ class Layer { void add_float_property(std::string const &key, float value); void add_int_vector_property(std::string const &key, std::vector const &value); + void add_string_property(std::string const &key, std::string const &value); void add_initializer(std::string const &key, Initializer *initializer); bool get_int_property(std::string const &key, long long &value) const; bool get_float_property(std::string const &key, float &value) const; bool get_int_vector_property(std::string const &key, std::vector &value) const; + bool get_string_property(std::string const &key, std::string &value) const; bool get_initializer(std::string const &key, Initializer *&initializer) const; Tensor get_parameter(int index); void print(); @@ -59,6 +61,7 @@ class Layer { std::unordered_map float_properties; std::unordered_map initializers; std::unordered_map> int_vector_properties; + std::unordered_map string_properties; }; }; // namespace FlexFlow diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 4ad735ef7d..a42d3ab36d 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -733,41 +733,42 @@ class FFModel { DataType data_type = DT_NONE, Initializer *kernel_initializer = NULL, char const *name = NULL); - Tensor inc_multihead_self_attention(const Tensor input, - int embed_dim, - int num_heads, - int kdim = 0, - int vdim = 0, - float dropout = 0.0f, - bool bias = false, - bool add_bias_kv = false, - bool add_zero_attn = false, - DataType data_type = DT_NONE, - Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, - bool scaling_query = false, - float scaling_factor = 1.0f, - bool qk_prod_scaling = true, - bool position_bias = false, - char const *name = NULL); - Tensor - spec_inc_multihead_self_attention(const Tensor input, - int embed_dim, - int num_heads, - int kdim = 0, - int vdim = 0, - float dropout = 0.0f, - bool bias = false, - bool add_bias_kv = false, - bool add_zero_attn = false, - DataType data_type = DT_NONE, - Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, - bool scaling_query = false, - float scaling_factor = 1.0f, - bool qk_prod_scaling = true, - bool position_bias = false, - char const *name = NULL); + Tensor inc_multihead_self_attention( + const Tensor input, + int embed_dim, + int num_heads, + int kdim = 0, + int vdim = 0, + float dropout = 0.0f, + bool bias = false, + bool add_bias_kv = false, + bool add_zero_attn = false, + DataType data_type = DT_NONE, + Initializer *kernel_initializer = NULL, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), + bool scaling_query = false, + float scaling_factor = 1.0f, + bool qk_prod_scaling = true, + bool position_bias = false, + char const *name = NULL); + Tensor spec_inc_multihead_self_attention( + const Tensor input, + int embed_dim, + int num_heads, + int kdim = 0, + int vdim = 0, + float dropout = 0.0f, + bool bias = false, + bool add_bias_kv = false, + bool add_zero_attn = false, + DataType data_type = DT_NONE, + Initializer *kernel_initializer = NULL, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), + bool scaling_query = false, + float scaling_factor = 1.0f, + bool qk_prod_scaling = true, + bool position_bias = false, + char const *name = NULL); Tensor inc_multihead_self_attention_verify( const Tensor input, int embed_dim, @@ -780,49 +781,50 @@ class FFModel { bool add_zero_attn = false, DataType data_type = DT_NONE, Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), + bool scaling_query = false, + float scaling_factor = 1.0f, + bool qk_prod_scaling = true, + bool position_bias = false, + char const *name = NULL); + Tensor inc_multiquery_self_attention( + const Tensor input, + int embed_dim, + int num_q_heads, + int num_kv_heads, + int kdim = 0, + int vdim = 0, + float dropout = 0.0f, + bool bias = false, + bool add_bias_kv = false, + bool add_zero_attn = false, + DataType data_type = DT_NONE, + Initializer *kernel_initializer = NULL, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), + bool scaling_query = false, + float scaling_factor = 1.0f, + bool qk_prod_scaling = true, + bool position_bias = false, + char const *name = NULL); + Tensor spec_inc_multiquery_self_attention( + const Tensor input, + int embed_dim, + int num_q_heads, + int num_kv_heads, + int kdim = 0, + int vdim = 0, + float dropout = 0.0f, + bool bias = false, + bool add_bias_kv = false, + bool add_zero_attn = false, + DataType data_type = DT_NONE, + Initializer *kernel_initializer = NULL, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), bool scaling_query = false, float scaling_factor = 1.0f, bool qk_prod_scaling = true, bool position_bias = false, char const *name = NULL); - Tensor inc_multiquery_self_attention(const Tensor input, - int embed_dim, - int num_q_heads, - int num_kv_heads, - int kdim = 0, - int vdim = 0, - float dropout = 0.0f, - bool bias = false, - bool add_bias_kv = false, - bool add_zero_attn = false, - DataType data_type = DT_NONE, - Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, - bool scaling_query = false, - float scaling_factor = 1.0f, - bool qk_prod_scaling = true, - bool position_bias = false, - char const *name = NULL); - Tensor - spec_inc_multiquery_self_attention(const Tensor input, - int embed_dim, - int num_q_heads, - int num_kv_heads, - int kdim = 0, - int vdim = 0, - float dropout = 0.0f, - bool bias = false, - bool add_bias_kv = false, - bool add_zero_attn = false, - DataType data_type = DT_NONE, - Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, - bool scaling_query = false, - float scaling_factor = 1.0f, - bool qk_prod_scaling = true, - bool position_bias = false, - char const *name = NULL); Tensor inc_multiquery_self_attention_verify( const Tensor input, int embed_dim, @@ -836,7 +838,7 @@ class FFModel { bool add_zero_attn = false, DataType data_type = DT_NONE, Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), bool scaling_query = false, float scaling_factor = 1.0f, bool qk_prod_scaling = true, diff --git a/include/flexflow/operator.h b/include/flexflow/operator.h index 1a5af67b36..007314797a 100644 --- a/include/flexflow/operator.h +++ b/include/flexflow/operator.h @@ -335,7 +335,13 @@ class Op { // only dump the weights in the forward pass, at the first step // note that we do not save the weight gradients, since we only support // finetuning LoRA weights, which are not FF tensors. - if (fwd_pass && m->decoding_step == 0) { + // Set FF_DEBG_NO_WEIGHTS=1 or to FF_DEBG_NO_WEIGHTS=true to disable saving + // weights + bool do_not_save_weights = + (std::getenv("FF_DEBG_NO_WEIGHTS") && + (std::string(std::getenv("FF_DEBG_NO_WEIGHTS")) == "1" || + std::string(std::getenv("FF_DEBG_NO_WEIGHTS")) == "true")); + if (fwd_pass && m->decoding_step == 0 && !do_not_save_weights) { fs::path dst_filepath_weights = get_dst_folder("weights", m->decoding_step, shard_id, before_kernel) / layername; diff --git a/include/flexflow/ops/inc_multihead_self_attention.h b/include/flexflow/ops/inc_multihead_self_attention.h index 5d639623fe..a361909d8d 100644 --- a/include/flexflow/ops/inc_multihead_self_attention.h +++ b/include/flexflow/ops/inc_multihead_self_attention.h @@ -39,7 +39,7 @@ class IncMultiHeadSelfAttention : public Op { bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -61,7 +61,7 @@ class IncMultiHeadSelfAttention : public Op { bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -138,8 +138,8 @@ class IncMultiHeadSelfAttention : public Op { int num_q_heads, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; bool qkv_bias; - bool final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, - qk_prod_scaling, position_bias; + bool final_bias, add_zero_attn, scaling_query, qk_prod_scaling, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; int qSize, kSize, vSize, qProjSize, kProjSize, vProjSize, oProjSize; int qoSeqLength, kvSeqLength; DataType quantization_type; @@ -165,7 +165,7 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta { int _kProjSize, int _vProjSize, int _oProjSize, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _qkv_bias, bool _scaling_query, bool _qk_prod_scaling, @@ -191,7 +191,7 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta { int global_num_q_heads, global_num_kv_heads, num_q_heads, num_kv_heads, hidden_size; bool *has_load_weights; - bool *apply_rotary_embedding; + RotaryEmbeddingMeta *rotary_embedding_meta; bool *qkv_bias; bool *final_bias; bool *scaling_query; diff --git a/include/flexflow/ops/inc_multihead_self_attention_params.h b/include/flexflow/ops/inc_multihead_self_attention_params.h index 58681069e2..6ce32e0779 100644 --- a/include/flexflow/ops/inc_multihead_self_attention_params.h +++ b/include/flexflow/ops/inc_multihead_self_attention_params.h @@ -3,6 +3,7 @@ #include "flexflow/ffconst.h" #include "flexflow/fftype.h" +#include "flexflow/inference.h" #include "flexflow/parallel_tensor.h" namespace FlexFlow { @@ -12,8 +13,9 @@ struct IncMultiHeadSelfAttentionParams { int embed_dim, num_q_heads, kdim, vdim, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, position_bias; + bool qkv_bias, final_bias, add_zero_attn, scaling_query, qk_prod_scaling, + position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; DataType quantization_type; bool offload; char name[MAX_OPNAME]; diff --git a/include/flexflow/ops/spec_inc_multihead_self_attention.h b/include/flexflow/ops/spec_inc_multihead_self_attention.h index 85279860cf..58be153458 100644 --- a/include/flexflow/ops/spec_inc_multihead_self_attention.h +++ b/include/flexflow/ops/spec_inc_multihead_self_attention.h @@ -36,7 +36,7 @@ class SpecIncMultiHeadSelfAttention : public Op { bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -55,7 +55,7 @@ class SpecIncMultiHeadSelfAttention : public Op { bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -119,8 +119,8 @@ class SpecIncMultiHeadSelfAttention : public Op { int num_q_heads, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; bool qkv_bias; - bool final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, - qk_prod_scaling, position_bias; + bool final_bias, add_zero_attn, scaling_query, qk_prod_scaling, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; int qSize, kSize, vSize, qProjSize, kProjSize, vProjSize, oProjSize; int qoSeqLength, kvSeqLength; }; diff --git a/include/flexflow/ops/spec_inc_multihead_self_attention_params.h b/include/flexflow/ops/spec_inc_multihead_self_attention_params.h index 1461224ba9..3f173dfcf7 100644 --- a/include/flexflow/ops/spec_inc_multihead_self_attention_params.h +++ b/include/flexflow/ops/spec_inc_multihead_self_attention_params.h @@ -11,8 +11,9 @@ struct SpecIncMultiHeadSelfAttentionParams { LayerID layer_guid; int embed_dim, num_q_heads, num_kv_heads, kdim, vdim; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, position_bias; + bool qkv_bias, final_bias, add_zero_attn, scaling_query, qk_prod_scaling, + position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; char name[MAX_OPNAME]; bool is_valid(ParallelTensorShape const &) const; }; diff --git a/include/flexflow/ops/tree_inc_multihead_self_attention.h b/include/flexflow/ops/tree_inc_multihead_self_attention.h index b4eb339201..120e63053a 100644 --- a/include/flexflow/ops/tree_inc_multihead_self_attention.h +++ b/include/flexflow/ops/tree_inc_multihead_self_attention.h @@ -36,7 +36,7 @@ class TreeIncMultiHeadSelfAttention : public Op { bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -58,7 +58,7 @@ class TreeIncMultiHeadSelfAttention : public Op { bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -121,8 +121,8 @@ class TreeIncMultiHeadSelfAttention : public Op { int num_q_heads, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; bool qkv_bias; - bool final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, - qk_prod_scaling, position_bias; + bool final_bias, add_zero_attn, scaling_query, qk_prod_scaling, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; int qSize, kSize, vSize, qProjSize, kProjSize, vProjSize, oProjSize; int qoSeqLength, kvSeqLength; DataType quantization_type; diff --git a/include/flexflow/ops/tree_inc_multihead_self_attention_params.h b/include/flexflow/ops/tree_inc_multihead_self_attention_params.h index d1a51b8b8f..3906210d40 100644 --- a/include/flexflow/ops/tree_inc_multihead_self_attention_params.h +++ b/include/flexflow/ops/tree_inc_multihead_self_attention_params.h @@ -12,8 +12,9 @@ struct TreeIncMultiHeadSelfAttentionParams { int embed_dim, num_q_heads, kdim, vdim, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, position_bias; + bool qkv_bias, final_bias, add_zero_attn, scaling_query, qk_prod_scaling, + position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; DataType quantization_type; bool offload; char name[MAX_OPNAME]; diff --git a/inference/models/falcon.cc b/inference/models/falcon.cc index e6eb72701e..46a55c6559 100644 --- a/inference/models/falcon.cc +++ b/inference/models/falcon.cc @@ -130,11 +130,11 @@ void FALCON::create_falcon_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ NULL, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ - false, /*scaling query*/ - 1.0f, /*scaling factor*/ - true, /*qk_prod_scaling*/ - false, /*position_bias*/ + falcon_config.rotary_embedding_meta, + false, /*scaling query*/ + 1.0f, /*scaling factor*/ + true, /*qk_prod_scaling*/ + false, /*position_bias*/ std::string("layers." + std::to_string(i) + ".self_attention") .c_str() /*name*/ ); @@ -155,11 +155,11 @@ void FALCON::create_falcon_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ nullptr, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ - false, /*scaling query*/ - 1.0f, /*scaling factor*/ - true, /*qk_prod_scaling*/ - false, /*position_bias*/ + falcon_config.rotary_embedding_meta, + false, /*scaling query*/ + 1.0f, /*scaling factor*/ + true, /*qk_prod_scaling*/ + false, /*position_bias*/ std::string("layers." + std::to_string(i) + ".self_attention") .c_str() /*name*/ ); @@ -180,11 +180,11 @@ void FALCON::create_falcon_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ nullptr, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ - false, /*scaling query*/ - 1.0f, /*scaling factor*/ - true, /*qk_prod_scaling*/ - false, /*position_bias*/ + falcon_config.rotary_embedding_meta, + false, /*scaling query*/ + 1.0f, /*scaling factor*/ + true, /*qk_prod_scaling*/ + false, /*position_bias*/ std::string("layers." + std::to_string(i) + ".self_attention") .c_str() /*name*/ ); diff --git a/inference/models/falcon.h b/inference/models/falcon.h index fce2dade3f..565d7e5419 100644 --- a/inference/models/falcon.h +++ b/inference/models/falcon.h @@ -50,6 +50,26 @@ class FALCON { : model_config["num_hidden_layers"]; parallel_attn = model_config["parallel_attn"]; vocab_size = model_config["vocab_size"]; + rotary_embedding_meta.apply_rotary_embedding = true; + if (model_config.find("rope_theta") != model_config.end()) { + rotary_embedding_meta.rope_theta = model_config["rope_theta"]; + } else { + rotary_embedding_meta.rope_theta = 10000.0f; + } + if (model_config.find("scaling_factor") != model_config.end() && + !model_config["scaling_factor"].is_null()) { + rotary_embedding_meta.rope_type = + model_config["scaling_factor"]["rope_type"]; + rotary_embedding_meta.factor = + model_config["scaling_factor"]["factor"]; + rotary_embedding_meta.low_freq_factor = + model_config["scaling_factor"]["low_freq_factor"]; + rotary_embedding_meta.high_freq_factor = + model_config["scaling_factor"]["high_freq_factor"]; + rotary_embedding_meta.original_max_position_embeddings = + model_config["scaling_factor"] + ["original_max_position_embeddings"]; + } } catch (json::exception const &e) { std::cerr << "Error parsing JSON file: " << e.what() << std::endl; assert(false); @@ -59,8 +79,6 @@ class FALCON { << std::endl; assert(false); } - // max_seq_len = BatchConfig::MAX_SEQ_LENGTH; - // max_num_tokens = BatchConfig::MAX_NUM_TOKENS; max_beam_width = BeamSearchBatchConfig::MAX_BEAM_WIDTH; max_beam_depth = BeamSearchBatchConfig::MAX_BEAM_DEPTH; } @@ -76,9 +94,8 @@ class FALCON { std::cout << "\tn_layer: " << n_layer << std::endl; std::cout << "\tparallel_attn: " << parallel_attn << std::endl; std::cout << "\tvocab_size: " << vocab_size << std::endl; - - // std::cout << "\tmax_seq_len: " << max_seq_len << std::endl; - // std::cout << "\tmax_num_tokens: " << max_num_tokens << std::endl; + std::cout << "\trotary_embedding_meta: " << rotary_embedding_meta + << std::endl; std::cout << "\tmax_beam_width: " << max_beam_width << std::endl; std::cout << "\tmax_beam_depth: " << max_beam_depth << std::endl; } @@ -86,8 +103,8 @@ class FALCON { bool bias, multi_query, parallel_attn; int hidden_size, n_head, n_head_kv, n_layer, vocab_size; float layer_norm_epsilon; - // int max_seq_len, max_num_tokens; int max_beam_width, max_beam_depth; + RotaryEmbeddingMeta rotary_embedding_meta; }; static void create_falcon_model(FFModel &ff, diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 48f319d409..c157ac4ed1 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -123,11 +123,11 @@ void LLAMA::create_llama_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ NULL, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ - false, /*scaling query*/ - 1.0f, /*scaling factor*/ - true, /*qk_prod_scaling*/ - false, /*position_bias*/ + llama_config.rotary_embedding_meta, + false, /*scaling query*/ + 1.0f, /*scaling factor*/ + true, /*qk_prod_scaling*/ + false, /*position_bias*/ std::string("layers." + std::to_string(i) + ".self_attn") .c_str() /*name*/ ); @@ -147,11 +147,11 @@ void LLAMA::create_llama_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ nullptr, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ - false, /*scaling query*/ - 1.0f, /*scaling factor*/ - true, /*qk_prod_scaling*/ - false, /*position_bias*/ + llama_config.rotary_embedding_meta, + false, /*scaling query*/ + 1.0f, /*scaling factor*/ + true, /*qk_prod_scaling*/ + false, /*position_bias*/ std::string("layers." + std::to_string(i) + ".self_attn") .c_str() /*name*/ ); @@ -171,11 +171,11 @@ void LLAMA::create_llama_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ nullptr, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ - false, /*scaling query*/ - 1.0f, /*scaling factor*/ - true, /*qk_prod_scaling*/ - false, /*position_bias*/ + llama_config.rotary_embedding_meta, + false, /*scaling query*/ + 1.0f, /*scaling factor*/ + true, /*qk_prod_scaling*/ + false, /*position_bias*/ std::string("layers." + std::to_string(i) + ".self_attn") .c_str() /*name*/ ); diff --git a/inference/models/llama.h b/inference/models/llama.h index edb78f1300..853a51a999 100644 --- a/inference/models/llama.h +++ b/inference/models/llama.h @@ -44,6 +44,26 @@ class LLAMA { hidden_size = model_config["hidden_size"]; rms_norm_eps = model_config["rms_norm_eps"]; intermediate_size = model_config["intermediate_size"]; + rotary_embedding_meta.apply_rotary_embedding = true; + if (model_config.find("rope_theta") != model_config.end()) { + rotary_embedding_meta.rope_theta = model_config["rope_theta"]; + } else { + rotary_embedding_meta.rope_theta = 10000.0f; + } + if (model_config.find("scaling_factor") != model_config.end() && + !model_config["scaling_factor"].is_null()) { + rotary_embedding_meta.rope_type = + model_config["scaling_factor"]["rope_type"]; + rotary_embedding_meta.factor = + model_config["scaling_factor"]["factor"]; + rotary_embedding_meta.low_freq_factor = + model_config["scaling_factor"]["low_freq_factor"]; + rotary_embedding_meta.high_freq_factor = + model_config["scaling_factor"]["high_freq_factor"]; + rotary_embedding_meta.original_max_position_embeddings = + model_config["scaling_factor"] + ["original_max_position_embeddings"]; + } } catch (json::exception const &e) { std::cerr << "Error parsing LLAMA config from JSON file: " << e.what() << std::endl; @@ -54,8 +74,6 @@ class LLAMA { << std::endl; assert(false); } - // max_seq_len = BatchConfig::MAX_SEQ_LENGTH; - // max_num_tokens = BatchConfig::MAX_NUM_TOKENS; max_beam_width = BeamSearchBatchConfig::MAX_BEAM_WIDTH; max_beam_depth = BeamSearchBatchConfig::MAX_BEAM_DEPTH; } @@ -71,18 +89,17 @@ class LLAMA { std::cout << "\thidden_size: " << hidden_size << std::endl; std::cout << "\trms_norm_eps: " << rms_norm_eps << std::endl; std::cout << "\tintermediate_size: " << intermediate_size << std::endl; - - // std::cout << "\tmax_seq_len: " << max_seq_len << std::endl; - // std::cout << "\tmax_num_tokens: " << max_num_tokens << std::endl; + std::cout << "\trotary_embedding_meta: " << rotary_embedding_meta + << std::endl; std::cout << "\tmax_beam_width: " << max_beam_width << std::endl; std::cout << "\tmax_beam_depth: " << max_beam_depth << std::endl; } - // int max_seq_len, max_num_tokens; int max_beam_width, max_beam_depth; int num_hidden_layers, vocab_size, num_attention_heads, num_key_value_heads, hidden_size, intermediate_size; float rms_norm_eps; + RotaryEmbeddingMeta rotary_embedding_meta; }; static void create_llama_model(FFModel &ff, diff --git a/inference/models/mpt.cc b/inference/models/mpt.cc index 64e5924753..f984551f38 100644 --- a/inference/models/mpt.cc +++ b/inference/models/mpt.cc @@ -123,7 +123,7 @@ void MPT::create_mpt_model(FFModel &ff, false, DT_NONE, /*data_type*/ NULL, - false, + mpt_config.rotary_embedding_meta, /*scaling query*/ true, /*scaling factor*/ pow((mpt_config.hidden_size / mpt_config.n_heads), -0.5), @@ -147,7 +147,7 @@ void MPT::create_mpt_model(FFModel &ff, false, DT_NONE, /*data_type*/ NULL, - false, + mpt_config.rotary_embedding_meta, /*scaling query*/ true, /*scaling factor*/ pow((mpt_config.hidden_size / mpt_config.n_heads), -0.5), @@ -171,7 +171,7 @@ void MPT::create_mpt_model(FFModel &ff, false, DT_NONE, /*data_type*/ NULL, - false, + mpt_config.rotary_embedding_meta, /*scaling query*/ true, /*scaling factor*/ pow((mpt_config.hidden_size / mpt_config.n_heads), -0.5), diff --git a/inference/models/mpt.h b/inference/models/mpt.h index 08597e1d75..3001420ad0 100644 --- a/inference/models/mpt.h +++ b/inference/models/mpt.h @@ -37,6 +37,7 @@ class MPT { n_heads = model_config["n_heads"]; n_layers = model_config["n_layers"]; vocab_size = model_config["vocab_size"]; + rotary_embedding_meta.apply_rotary_embedding = false; } catch (json::exception const &e) { std::cerr << "Error parsing JSON file: " << e.what() << std::endl; assert(false); @@ -63,6 +64,7 @@ class MPT { // int max_seq_len, max_num_tokens; int max_beam_width, max_beam_depth; int hidden_size, n_heads, n_layers, vocab_size; + RotaryEmbeddingMeta rotary_embedding_meta; }; static void create_mpt_model(FFModel &ff, diff --git a/inference/models/opt.cc b/inference/models/opt.cc index 4aea36d3d7..d84410980f 100644 --- a/inference/models/opt.cc +++ b/inference/models/opt.cc @@ -132,8 +132,8 @@ void OPT::create_opt_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ NULL, /*kernel_initializer*/ - false, /*apply_rotary_embedding*/ - true, /*scaling query*/ + opt_config.rotary_embedding_meta, + true, /*scaling query*/ pow((opt_config.hidden_size / opt_config.num_attention_heads), -0.5), /*scaling factor*/ false, /*qk_prod_scaling*/ @@ -156,8 +156,8 @@ void OPT::create_opt_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ NULL, /*kernel_initializer*/ - false, /*apply_rotary_embedding*/ - true, /*scaling query*/ + opt_config.rotary_embedding_meta, + true, /*scaling query*/ pow((opt_config.hidden_size / opt_config.num_attention_heads), -0.5), /*scaling factor*/ false, /*qk_prod_scaling*/ @@ -180,8 +180,8 @@ void OPT::create_opt_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ NULL, /*kernel_initializer*/ - false, /*apply_rotary_embedding*/ - true, /*scaling query*/ + opt_config.rotary_embedding_meta, + true, /*scaling query*/ pow((opt_config.hidden_size / opt_config.num_attention_heads), -0.5), /*scaling factor*/ false, /*qk_prod_scaling*/ diff --git a/inference/models/opt.h b/inference/models/opt.h index 7c736a26d1..8b85f81aa6 100644 --- a/inference/models/opt.h +++ b/inference/models/opt.h @@ -45,6 +45,7 @@ class OPT { num_hidden_layers = model_config["num_hidden_layers"]; vocab_size = model_config["vocab_size"]; word_embed_proj_dim = model_config["word_embed_proj_dim"]; + rotary_embedding_meta.apply_rotary_embedding = false; } catch (json::exception const &e) { std::cerr << "Error parsing JSON file: " << e.what() << std::endl; assert(false); @@ -54,8 +55,6 @@ class OPT { << std::endl; assert(false); } - // max_seq_len = BatchConfig::MAX_SEQ_LENGTH; - // max_num_tokens = BatchConfig::MAX_NUM_TOKENS; max_beam_width = BeamSearchBatchConfig::MAX_BEAM_WIDTH; max_beam_depth = BeamSearchBatchConfig::MAX_BEAM_DEPTH; } @@ -78,9 +77,8 @@ class OPT { std::cout << "\tvocab_size: " << vocab_size << std::endl; std::cout << "\tword_embed_proj_dim: " << word_embed_proj_dim << std::endl; - - // std::cout << "\tmax_seq_len: " << max_seq_len << std::endl; - // std::cout << "\tmax_num_tokens: " << max_num_tokens << std::endl; + std::cout << "\trotary_embedding_meta: " << rotary_embedding_meta + << std::endl; std::cout << "\tmax_beam_width: " << max_beam_width << std::endl; std::cout << "\tmax_beam_depth: " << max_beam_depth << std::endl; } @@ -91,6 +89,7 @@ class OPT { float dropout; int ffn_dim, hidden_size, max_position_embeddings, num_attention_heads, num_hidden_layers, vocab_size, word_embed_proj_dim; + RotaryEmbeddingMeta rotary_embedding_meta; }; static void create_opt_model(FFModel &ff, diff --git a/inference/models/starcoder.cc b/inference/models/starcoder.cc index 887696ff31..47dd6b2030 100644 --- a/inference/models/starcoder.cc +++ b/inference/models/starcoder.cc @@ -131,17 +131,17 @@ void STARCODER::create_starcoder_model( startcoder_config.num_attention_heads, startcoder_config.hidden_size / startcoder_config.num_attention_heads, - startcoder_config.dropout_p, /*dropout*/ - true, /*bias*/ - false, /*add_bias_kv*/ - false, /*add_zero_attn*/ - DT_NONE, /*data_type*/ - nullptr, /*kernel_initializer*/ - false, /*apply_rotary_embedding*/ - false, /*scaling query*/ - 1.0f, /*scaling factor*/ - true, /*qk_prod_scaling*/ - false, /*position_bias*/ + startcoder_config.dropout_p, /*dropout*/ + true, /*bias*/ + false, /*add_bias_kv*/ + false, /*add_zero_attn*/ + DT_NONE, /*data_type*/ + nullptr, /*kernel_initializer*/ + startcoder_config.rotary_embedding_meta, /*apply_rotary_embedding*/ + false, /*scaling query*/ + 1.0f, /*scaling factor*/ + true, /*qk_prod_scaling*/ + false, /*position_bias*/ std::string("layers." + std::to_string(i) + ".attn.c_attn") .c_str() /*name*/ ); diff --git a/inference/models/starcoder.h b/inference/models/starcoder.h index 0e9577d569..7ff6f33770 100644 --- a/inference/models/starcoder.h +++ b/inference/models/starcoder.h @@ -41,6 +41,7 @@ class STARCODER { intermediate_size = model_config["n_inner"]; dropout_p = model_config["attn_pdrop"]; max_position_embeddings = model_config["n_positions"]; + rotary_embedding_meta.apply_rotary_embedding = false; } catch (json::exception const &e) { std::cerr << "Error parsing STARCODER config from JSON file: " << e.what() << std::endl; @@ -51,8 +52,6 @@ class STARCODER { << std::endl; assert(false); } - // max_seq_len = BatchConfig::MAX_SEQ_LENGTH; - // max_num_tokens = BatchConfig::MAX_NUM_TOKENS; max_beam_width = BeamSearchBatchConfig::MAX_BEAM_WIDTH; max_beam_depth = BeamSearchBatchConfig::MAX_BEAM_DEPTH; } @@ -64,6 +63,7 @@ class STARCODER { int num_hidden_layers, vocab_size, num_attention_heads, hidden_size, intermediate_size, max_position_embeddings; float layer_norm_epsilon, dropout_p; + RotaryEmbeddingMeta rotary_embedding_meta; }; static void create_starcoder_model(FFModel &ff, diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index 7692ccb88f..5e429fd08b 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -41,6 +41,7 @@ from typing import Union, List from peft import LoraConfig import json +from dataclasses import dataclass def ffc(): @@ -2070,6 +2071,22 @@ def __init__( self.max_training_steps = max_training_steps +# ----------------------------------------------------------------------- +# RotaryEmbeddingMeta +# ----------------------------------------------------------------------- + + +@dataclass +class RotaryEmbeddingMeta: + apply_rotary_embedding: bool = False + rope_theta: float = 10000.0 + rope_type: str = "default" + factor: float = 8.0 + low_freq_factor: float = 1.0 + high_freq_factor: float = 4.0 + original_max_position_embeddings: int = 8192 + + # ----------------------------------------------------------------------- # FFModel # ----------------------------------------------------------------------- @@ -3514,7 +3531,7 @@ def inc_multihead_self_attention( add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -3558,8 +3575,8 @@ def inc_multihead_self_attention( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -3594,7 +3611,13 @@ def inc_multihead_self_attention( add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, @@ -3617,7 +3640,7 @@ def spec_inc_multihead_self_attention( add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -3661,8 +3684,8 @@ def spec_inc_multihead_self_attention( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -3697,7 +3720,13 @@ def spec_inc_multihead_self_attention( add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, @@ -3720,7 +3749,7 @@ def inc_multihead_self_attention_verify( add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -3764,8 +3793,8 @@ def inc_multihead_self_attention_verify( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -3800,7 +3829,13 @@ def inc_multihead_self_attention_verify( add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, @@ -3824,7 +3859,7 @@ def inc_multiquery_self_attention( add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -3871,8 +3906,8 @@ def inc_multiquery_self_attention( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -3908,7 +3943,13 @@ def inc_multiquery_self_attention( add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, @@ -3932,7 +3973,7 @@ def spec_inc_multiquery_self_attention( add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -3979,8 +4020,8 @@ def spec_inc_multiquery_self_attention( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -4016,7 +4057,13 @@ def spec_inc_multiquery_self_attention( add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, @@ -4040,7 +4087,7 @@ def inc_multiquery_self_attention_verify( add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -4087,8 +4134,8 @@ def inc_multiquery_self_attention_verify( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -4124,7 +4171,13 @@ def inc_multiquery_self_attention_verify( add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, diff --git a/python/flexflow/serve/models/falcon.py b/python/flexflow/serve/models/falcon.py index e2d1f56224..c98f9454c4 100644 --- a/python/flexflow/serve/models/falcon.py +++ b/python/flexflow/serve/models/falcon.py @@ -41,6 +41,17 @@ def __init__(self, hf_config): ) self.parallel_attn = hf_config.parallel_attn self.vocab_size = hf_config.vocab_size + self.rotary_embedding_meta = RotaryEmbeddingMeta( + apply_rotary_embedding=True, + rope_theta=hf_config.rope_theta if "rope_theta" in hf_config.__dict__ else 10000.0, + ) + if "rope_scaling" in hf_config.__dict__: + if hf_config.rope_scaling is not None: + self.rotary_embedding_meta.rope_type = hf_config.rope_scaling["rope_type"] + self.rotary_embedding_meta.factor = hf_config.rope_scaling["factor"] + self.rotary_embedding_meta.low_freq_factor = hf_config.rope_scaling["low_freq_factor"] + self.rotary_embedding_meta.high_freq_factor = hf_config.rope_scaling["high_freq_factor"] + self.rotary_embedding_meta.original_max_position_embeddings = hf_config.rope_scaling["original_max_position_embeddings"] # Standardized FlexFlow num heads fields below self.num_attention_heads = self.n_head self.num_key_value_heads = self.n_head_kv @@ -54,8 +65,6 @@ def __init__( ffconfig, hf_config, data_type, - # max_batch_size=1, - # max_seq_length=256, max_tokens_per_batch, weights_filepath="", tokenizer_filepath="", @@ -63,11 +72,8 @@ def __init__( self.mode = mode self.generation_config = generation_config self.ffconfig = ffconfig - # self.max_batch_size = max_batch_size self.data_type = data_type self.falcon_config = FalconConfig(hf_config) - # self.falcon_config.max_seq_length = max_seq_length - # self.falcon_config.max_num_tokens = max_tokens_per_batch self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 @@ -160,7 +166,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.falcon_config.rotary_embedding_meta, name=f"layers.{i}.self_attention", ) elif self.mode == InferenceMode.TREE_VERIFY_MODE: @@ -177,7 +183,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.falcon_config.rotary_embedding_meta, name=f"layers.{i}.self_attention", ) elif self.mode == InferenceMode.INC_DECODING_MODE: @@ -194,7 +200,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.falcon_config.rotary_embedding_meta, name=f"layers.{i}.self_attention", ) else: diff --git a/python/flexflow/serve/models/llama.py b/python/flexflow/serve/models/llama.py index 47071a746e..53209298a5 100644 --- a/python/flexflow/serve/models/llama.py +++ b/python/flexflow/serve/models/llama.py @@ -19,8 +19,6 @@ class LLAMAConfig: def __init__(self, hf_config): - # self.max_seq_len = 256 - # self.max_num_tokens = 64 self.max_beam_width = 1 self.max_beam_depth = 8 self.max_spec_tree_token_num = 20 @@ -29,6 +27,17 @@ def __init__(self, hf_config): self.hidden_size = hf_config.hidden_size self.rms_norm_eps = hf_config.rms_norm_eps self.intermediate_size = hf_config.intermediate_size + self.rotary_embedding_meta = RotaryEmbeddingMeta( + apply_rotary_embedding=True, + rope_theta=hf_config.rope_theta if "rope_theta" in hf_config.__dict__ else 10000.0, + ) + if "rope_scaling" in hf_config.__dict__: + if hf_config.rope_scaling is not None: + self.rotary_embedding_meta.rope_type = hf_config.rope_scaling["rope_type"] + self.rotary_embedding_meta.factor = hf_config.rope_scaling["factor"] + self.rotary_embedding_meta.low_freq_factor = hf_config.rope_scaling["low_freq_factor"] + self.rotary_embedding_meta.high_freq_factor = hf_config.rope_scaling["high_freq_factor"] + self.rotary_embedding_meta.original_max_position_embeddings = hf_config.rope_scaling["original_max_position_embeddings"] # Standardized FlexFlow num heads fields below self.num_attention_heads = hf_config.num_attention_heads self.num_key_value_heads = ( @@ -55,11 +64,8 @@ def __init__( self.mode = mode self.generation_config = generation_config self.ffconfig = ffconfig - # self.max_batch_size = max_batch_size self.data_type = data_type self.llama_config = LLAMAConfig(hf_config) - # self.llama_config.max_seq_length = max_seq_length - # self.llama_config.max_num_tokens = max_tokens_per_batch self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2 ** 31 - 1 @@ -152,7 +158,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.llama_config.rotary_embedding_meta, name=f"layers.{i}.self_attn", ) elif self.mode == InferenceMode.TREE_VERIFY_MODE: @@ -171,7 +177,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.llama_config.rotary_embedding_meta, name=f"layers.{i}.self_attn", ) elif self.mode == InferenceMode.INC_DECODING_MODE: @@ -190,7 +196,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.llama_config.rotary_embedding_meta, name=f"layers.{i}.self_attn", ) else: diff --git a/python/flexflow/serve/models/mpt.py b/python/flexflow/serve/models/mpt.py index 1f012e405d..2dc3257807 100644 --- a/python/flexflow/serve/models/mpt.py +++ b/python/flexflow/serve/models/mpt.py @@ -19,8 +19,6 @@ class MPTConfig: def __init__(self, hf_config): - # self.max_seq_len = 256 - # self.max_num_tokens = 64 self.max_beam_width = 1 self.max_beam_depth = 8 self.max_spec_tree_token_num = 20 @@ -28,6 +26,7 @@ def __init__(self, hf_config): self.n_heads = hf_config.n_heads self.n_layers = hf_config.n_layers self.vocab_size = hf_config.vocab_size + self.rotary_embedding_meta = RotaryEmbeddingMeta(apply_rotary_embedding=False) # Standardized FlexFlow num heads fields below self.num_attention_heads = hf_config.n_heads self.num_key_value_heads = hf_config.n_heads @@ -50,11 +49,8 @@ def __init__( self.mode = mode self.generation_config = generation_config self.ffconfig = ffconfig - # self.max_batch_size = max_batch_size self.data_type = data_type self.mpt_config = MPTConfig(hf_config) - # self.mpt_config.max_seq_length = max_seq_length - # self.mpt_config.max_num_tokens = max_tokens_per_batch self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 @@ -150,7 +146,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.mpt_config.rotary_embedding_meta, True, # scaling_query (self.mpt_config.hidden_size / self.mpt_config.n_heads) ** (-0.5), # scaling_factor @@ -171,7 +167,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.mpt_config.rotary_embedding_meta, True, # scaling_query (self.mpt_config.hidden_size / self.mpt_config.n_heads) ** (-0.5), # scaling_factor @@ -192,7 +188,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.mpt_config.rotary_embedding_meta, True, # scaling_query (self.mpt_config.hidden_size / self.mpt_config.n_heads) ** (-0.5), # scaling_factor diff --git a/python/flexflow/serve/models/opt.py b/python/flexflow/serve/models/opt.py index d30b1fcd23..54c82bc491 100644 --- a/python/flexflow/serve/models/opt.py +++ b/python/flexflow/serve/models/opt.py @@ -34,6 +34,7 @@ def __init__(self, hf_config): self.num_hidden_layers = hf_config.num_hidden_layers self.vocab_size = hf_config.vocab_size self.word_embed_proj_dim = hf_config.word_embed_proj_dim + self.rotary_embedding_meta = RotaryEmbeddingMeta(apply_rotary_embedding=False) # Standardized FlexFlow num heads fields below self.num_attention_heads = hf_config.num_attention_heads self.num_key_value_heads = hf_config.num_attention_heads @@ -47,8 +48,6 @@ def __init__( ffconfig, hf_config, data_type, - # max_batch_size=1, - # max_seq_length=256, max_tokens_per_batch, weights_filepath="", tokenizer_filepath="", @@ -56,11 +55,8 @@ def __init__( self.mode = mode self.generation_config = generation_config self.ffconfig = ffconfig - # self.max_batch_size = max_batch_size self.data_type = data_type self.opt_config = OPTConfig(hf_config) - # self.opt_config.max_seq_length = max_seq_length - # self.opt_config.max_num_tokens = max_tokens_per_batch self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 @@ -166,7 +162,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.opt_config.rotary_embedding_meta, True, # scaling_query (self.opt_config.hidden_size / self.opt_config.num_attention_heads) ** (-0.5), # scaling_factor @@ -186,7 +182,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.opt_config.rotary_embedding_meta, True, # scaling_query (self.opt_config.hidden_size / self.opt_config.num_attention_heads) ** (-0.5), # scaling_factor @@ -206,7 +202,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.opt_config.rotary_embedding_meta, True, # scaling_query (self.opt_config.hidden_size / self.opt_config.num_attention_heads) ** (-0.5), # scaling_factor diff --git a/python/flexflow/serve/models/starcoder.py b/python/flexflow/serve/models/starcoder.py index 83d29a55e1..10b882357d 100644 --- a/python/flexflow/serve/models/starcoder.py +++ b/python/flexflow/serve/models/starcoder.py @@ -19,8 +19,6 @@ class STARCODERConfig: def __init__(self, hf_config): - # self.max_seq_len = 256 - # self.max_num_tokens = 64 self.max_beam_width = 1 self.max_beam_depth = 8 self.max_spec_tree_token_num = 20 @@ -32,6 +30,7 @@ def __init__(self, hf_config): self.vocab_size = hf_config.vocab_size self.intermediate_size = hf_config.n_inner self.n_head_kv = 1 if hf_config.multi_query else hf_config.n_head + self.rotary_embedding_meta = RotaryEmbeddingMeta(apply_rotary_embedding=False) # Standardized FlexFlow num heads fields below self.num_attention_heads = hf_config.n_head self.num_key_value_heads = self.n_head_kv @@ -45,8 +44,6 @@ def __init__( ffconfig, hf_config, data_type, - # max_batch_size=1, - # max_seq_length=256, max_tokens_per_batch, weights_filepath="", tokenizer_filepath="", @@ -54,11 +51,8 @@ def __init__( self.mode = mode self.generation_config = generation_config self.ffconfig = ffconfig - # self.max_batch_size = max_batch_size self.data_type = data_type self.starcoder_config = STARCODERConfig(hf_config) - # self.starcoder_config.max_seq_length = max_seq_length - # self.starcoder_config.max_num_tokens = max_tokens_per_batch self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 @@ -166,7 +160,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.starcoder_config.rotary_embedding_meta, name=f"layers.{i}.attn.c_attn", ) diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index e39cb29037..5ae32b6516 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -1211,6 +1211,12 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1220,6 +1226,13 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->inc_multihead_self_attention(input, embed_dim, num_heads, @@ -1231,7 +1244,7 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention( add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -1254,6 +1267,12 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1263,6 +1282,13 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->spec_inc_multihead_self_attention(input, embed_dim, @@ -1275,7 +1301,7 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention( add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -1298,6 +1324,12 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention_verify( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1307,6 +1339,13 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention_verify( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->inc_multihead_self_attention_verify(input, embed_dim, @@ -1319,7 +1358,7 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention_verify( add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -1343,6 +1382,12 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1352,6 +1397,13 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->inc_multiquery_self_attention(input, embed_dim, num_q_heads, @@ -1364,7 +1416,7 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention( add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -1388,6 +1440,12 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1397,6 +1455,13 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->spec_inc_multiquery_self_attention(input, embed_dim, @@ -1410,7 +1475,7 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention( add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -1434,6 +1499,12 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention_verify( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1443,6 +1514,13 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention_verify( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->inc_multiquery_self_attention_verify(input, embed_dim, @@ -1456,7 +1534,7 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention_verify( add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, diff --git a/src/ops/inc_multihead_self_attention.cc b/src/ops/inc_multihead_self_attention.cc index 1bea204601..b9a16d0177 100644 --- a/src/ops/inc_multihead_self_attention.cc +++ b/src/ops/inc_multihead_self_attention.cc @@ -54,23 +54,24 @@ bool IncMultiHeadSelfAttentionParams::is_valid( return is_valid; } -Tensor FFModel::inc_multihead_self_attention(const Tensor input, - int embed_dim, - int num_heads, - int kdim, - int vdim, - float dropout, - bool qkv_bias, - bool final_bias, - bool add_zero_attn, - DataType data_type, - Initializer *kernel_initializer, - bool apply_rotary_embedding, - bool scaling_query, - float scaling_factor, - bool qk_prod_scaling, - bool position_bias, - char const *name) { +Tensor FFModel::inc_multihead_self_attention( + const Tensor input, + int embed_dim, + int num_heads, + int kdim, + int vdim, + float dropout, + bool qkv_bias, + bool final_bias, + bool add_zero_attn, + DataType data_type, + Initializer *kernel_initializer, + RotaryEmbeddingMeta rotary_embedding_meta, + bool scaling_query, + float scaling_factor, + bool qk_prod_scaling, + bool position_bias, + char const *name) { return inc_multiquery_self_attention(input, embed_dim, num_heads, @@ -83,7 +84,7 @@ Tensor FFModel::inc_multihead_self_attention(const Tensor input, add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -91,24 +92,25 @@ Tensor FFModel::inc_multihead_self_attention(const Tensor input, name); } -Tensor FFModel::inc_multiquery_self_attention(const Tensor input, - int embed_dim, - int num_q_heads, - int num_kv_heads, - int kdim, - int vdim, - float dropout, - bool qkv_bias, - bool final_bias, - bool add_zero_attn, - DataType data_type, - Initializer *kernel_initializer, - bool apply_rotary_embedding, - bool scaling_query, - float scaling_factor, - bool qk_prod_scaling, - bool position_bias, - char const *name) { +Tensor FFModel::inc_multiquery_self_attention( + const Tensor input, + int embed_dim, + int num_q_heads, + int num_kv_heads, + int kdim, + int vdim, + float dropout, + bool qkv_bias, + bool final_bias, + bool add_zero_attn, + DataType data_type, + Initializer *kernel_initializer, + RotaryEmbeddingMeta rotary_embedding_meta, + bool scaling_query, + float scaling_factor, + bool qk_prod_scaling, + bool position_bias, + char const *name) { if (data_type == DT_NONE) { data_type = input->data_type; } @@ -170,7 +172,17 @@ Tensor FFModel::inc_multiquery_self_attention(const Tensor input, li->add_int_property("final_bias", final_bias); li->add_int_property("add_zero_attn", add_zero_attn); li->add_float_property("dropout", dropout); - li->add_int_property("apply_rotary_embedding", apply_rotary_embedding); + li->add_int_property("apply_rotary_embedding", + rotary_embedding_meta.apply_rotary_embedding); + li->add_float_property("rope_theta", rotary_embedding_meta.rope_theta); + li->add_string_property("rope_type", rotary_embedding_meta.rope_type); + li->add_float_property("factor", rotary_embedding_meta.factor); + li->add_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + li->add_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + li->add_int_property("original_max_position_embeddings", + rotary_embedding_meta.original_max_position_embeddings); li->add_int_property("scaling_query", scaling_query); li->add_float_property("scaling_factor", scaling_factor); li->add_int_property("qk_prod_scaling", qk_prod_scaling); @@ -207,8 +219,18 @@ Op *IncMultiHeadSelfAttention::create_operator_from_layer( bool final_bias = (bool)value; layer->get_int_property("add_zero_attn", value); bool add_zero_attn = (bool)value; + RotaryEmbeddingMeta rotary_embedding_meta; layer->get_int_property("apply_rotary_embedding", value); - bool apply_rotary_embedding = (bool)value; + rotary_embedding_meta.apply_rotary_embedding = (bool)value; + layer->get_float_property("rope_theta", rotary_embedding_meta.rope_theta); + layer->get_string_property("rope_type", rotary_embedding_meta.rope_type); + layer->get_float_property("factor", rotary_embedding_meta.factor); + layer->get_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + layer->get_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + layer->get_int_property("original_max_position_embeddings", value); + rotary_embedding_meta.original_max_position_embeddings = (int)value; layer->get_int_property("scaling_query", value); bool scaling_query = (bool)value; float scaling_factor; @@ -237,7 +259,7 @@ Op *IncMultiHeadSelfAttention::create_operator_from_layer( qkv_bias, final_bias, add_zero_attn, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -262,7 +284,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -284,7 +306,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), qSize(_input->dims[0].size), kSize(_input->dims[0].size), vSize(_input->dims[0].size), qProjSize(_kdim), kProjSize(_kdim), vProjSize(_vdim), oProjSize(_embed_dim), @@ -353,7 +375,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -376,7 +398,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), qSize(_input->dims[0].size), kSize(_input->dims[0].size), vSize(_input->dims[0].size), qProjSize(_kdim), kProjSize(_kdim), vProjSize(_vdim), oProjSize(_embed_dim), @@ -451,7 +473,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( other.qkv_bias, other.final_bias, other.add_zero_attn, - other.apply_rotary_embedding, + other.rotary_embedding_meta, other.scaling_query, other.scaling_factor, other.qk_prod_scaling, @@ -480,7 +502,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( params.qkv_bias, params.final_bias, params.add_zero_attn, - params.apply_rotary_embedding, + params.rotary_embedding_meta, params.scaling_query, params.scaling_factor, params.qk_prod_scaling, @@ -846,7 +868,19 @@ bool operator==(IncMultiHeadSelfAttentionParams const &lhs, lhs.vdim == rhs.vdim && lhs.dropout == rhs.dropout && lhs.qkv_bias == rhs.qkv_bias && lhs.final_bias == rhs.final_bias && lhs.add_zero_attn == rhs.add_zero_attn && - lhs.apply_rotary_embedding == rhs.apply_rotary_embedding && + lhs.rotary_embedding_meta.apply_rotary_embedding == + rhs.rotary_embedding_meta.apply_rotary_embedding && + lhs.rotary_embedding_meta.rope_theta == + rhs.rotary_embedding_meta.rope_theta && + lhs.rotary_embedding_meta.rope_type == + rhs.rotary_embedding_meta.rope_type && + lhs.rotary_embedding_meta.factor == rhs.rotary_embedding_meta.factor && + lhs.rotary_embedding_meta.low_freq_factor == + rhs.rotary_embedding_meta.low_freq_factor && + lhs.rotary_embedding_meta.high_freq_factor == + rhs.rotary_embedding_meta.high_freq_factor && + lhs.rotary_embedding_meta.original_max_position_embeddings == + rhs.rotary_embedding_meta.original_max_position_embeddings && lhs.scaling_query == rhs.scaling_query && lhs.scaling_factor == rhs.scaling_factor && lhs.qk_prod_scaling == rhs.qk_prod_scaling && @@ -864,7 +898,7 @@ IncMultiHeadSelfAttentionParams IncMultiHeadSelfAttention::get_params() const { params.qkv_bias = this->qkv_bias; params.final_bias = this->final_bias; params.add_zero_attn = this->add_zero_attn; - params.apply_rotary_embedding = this->apply_rotary_embedding; + params.rotary_embedding_meta = this->rotary_embedding_meta; params.scaling_query = this->scaling_query; params.scaling_factor = this->scaling_factor; params.qk_prod_scaling = this->qk_prod_scaling; @@ -896,7 +930,14 @@ size_t hash::operator()( hash_combine(key, params.qkv_bias); hash_combine(key, params.final_bias); hash_combine(key, params.add_zero_attn); - hash_combine(key, params.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.rope_theta); + hash_combine(key, params.rotary_embedding_meta.rope_type); + hash_combine(key, params.rotary_embedding_meta.factor); + hash_combine(key, params.rotary_embedding_meta.low_freq_factor); + hash_combine(key, params.rotary_embedding_meta.high_freq_factor); + hash_combine(key, + params.rotary_embedding_meta.original_max_position_embeddings); hash_combine(key, params.scaling_query); hash_combine(key, params.scaling_factor); hash_combine(key, params.qk_prod_scaling); diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index 81a3401da3..01a64a983f 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -20,6 +20,7 @@ #include "flexflow/utils/hip_helper.h" #include "hip/hip_complex.h" #include +#include namespace FlexFlow { @@ -405,60 +406,17 @@ __global__ void scaling_query_kernel(DT *input_ptr, } } -template -__global__ void - apply_rotary_embedding_native(DT *input_ptr, - hipFloatComplex *complex_input, - BatchConfig::PerTokenInfo const *tokenInfos, - int qProjSize, - int kProjSize, - int num_q_heads, - int num_tokens, - int num_kv_heads, - int q_block_size, - int k_block_size, - int q_array_size) { - CUDA_KERNEL_LOOP( - i, - num_tokens * (qProjSize * num_q_heads + kProjSize * num_kv_heads) / 2) { - // create complex number - bool q_tensor = i < (q_array_size / 2); - int proj_size = q_tensor ? qProjSize : kProjSize; - int real_i = q_tensor ? i : i - q_array_size / 2; - - int head_idx = real_i / (num_tokens * proj_size / 2); - int idx = real_i % (num_tokens * proj_size / 2); - int real_part_index = idx * 2 + - head_idx * (q_tensor ? q_block_size : k_block_size) + - (q_tensor ? 0 : q_array_size); - - int complex_part_index = real_part_index + 1; - - complex_input[i] = {input_ptr[real_part_index], - input_ptr[complex_part_index]}; - - int token_idx = - (real_i - head_idx * (num_tokens * proj_size / 2)) / (proj_size / 2); - size_t pos = tokenInfos[token_idx].abs_depth_in_request; - - // float before_real = complex_input[i].x, before_complex = - // complex_input[i].y; - - int pos_i = real_i % (proj_size / 2); - float freq = pos * (1.0 / pow(10000.0, (float)2 * pos_i / proj_size)); - hipFloatComplex complex_pos = {cos(freq), sin(freq)}; - - complex_input[i] = hipCmulf(complex_input[i], complex_pos); - input_ptr[real_part_index] = complex_input[i].x; - input_ptr[complex_part_index] = complex_input[i].y; - } -} - template __global__ void apply_rotary_embedding_hf(DT *input_ptr, hipFloatComplex *complex_input, BatchConfig::PerTokenInfo const *tokenInfos, + float rope_theta, + bool llama3_rope, + float factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, int qProjSize, int kProjSize, int num_tokens, @@ -493,7 +451,29 @@ __global__ void // float before_real = complex_input[i].x, before_complex = int pos_i = real_i % (proj_size / 2); - float freq = pos * (1.0 / pow(10000.0, (float)2 * pos_i / proj_size)); + + float freq = + pos * (1.0 / pow(rope_theta, (float)2 * pos_i / proj_size)); // θ_i + + if (llama3_rope) { + float pi = CUDART_PI_F; + float wavelen = 2 * pi / freq; + float low_freq_wavelen = + original_max_position_embeddings / low_freq_factor; + float high_freq_wavelen = + original_max_position_embeddings / high_freq_factor; + if (wavelen < high_freq_wavelen) { + } else if (wavelen > low_freq_wavelen) { + freq = freq / factor; + } else { + assert(low_freq_wavelen != high_freq_wavelen); + float smooth = + (original_max_position_embeddings / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor); + freq = ((1 - smooth) * freq / factor + smooth * freq); + } + } + hipFloatComplex complex_pos = {cos(freq), sin(freq)}; complex_input[i] = hipCmulf(complex_input[i], complex_pos); @@ -507,6 +487,12 @@ __global__ void apply_rotary_embedding_bwd(DT *input_ptr, hipFloatComplex *complex_input, BatchConfig::PerTokenInfo const *tokenInfos, + float rope_theta, + bool llama3_rope, + float factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, int proj_size, int num_tokens, int hidden_size) { @@ -533,7 +519,28 @@ __global__ void size_t pos = tokenInfos[token_idx].abs_depth_in_request; - float freq = pos * (1.0 / pow(10000.0, (float)2 * idx / proj_size)); + float freq = + pos * (1.0 / pow(rope_theta, (float)2 * idx / proj_size)); // θ_i + + if (llama3_rope) { + float pi = CUDART_PI_F; + float wavelen = 2 * pi / freq; + float low_freq_wavelen = + original_max_position_embeddings / low_freq_factor; + float high_freq_wavelen = + original_max_position_embeddings / high_freq_factor; + if (wavelen < high_freq_wavelen) { + } else if (wavelen > low_freq_wavelen) { + freq = freq / factor; + } else { + assert(low_freq_wavelen != high_freq_wavelen); + float smooth = + (original_max_position_embeddings / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor); + freq = ((1 - smooth) * freq / factor + smooth * freq); + } + } + hipFloatComplex complex_pos = {cos(freq), sin(freq)}; complex_input[i] = hipCmulf(complex_input[i], complex_pos); @@ -664,22 +671,29 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, } // Step 3: apply rotary embedding if needed - if (*m->apply_rotary_embedding) { + if (m->rotary_embedding_meta->apply_rotary_embedding) { /*q&k*/ parallelism = num_tokens * m->hidden_size; - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_rotary_embedding_hf), - GET_BLOCKS(parallelism), - min(CUDA_NUM_THREADS, parallelism), - 0, - stream, - output_ptr, - m->complex_input, - m->token_infos, - m->qProjSize, - m->kProjSize, - num_tokens, - q_array_size, - m->hidden_size); + hipLaunchKernelGGL( + HIP_KERNEL_NAME(apply_rotary_embedding_hf), + GET_BLOCKS(parallelism), + min(CUDA_NUM_THREADS, parallelism), + 0, + stream, + output_ptr, + m->complex_input, + m->token_infos, + m->rotary_embedding_meta->rope_theta, + (m->rotary_embedding_meta->rope_type == "llama3"), + m->rotary_embedding_meta->factor, + m->rotary_embedding_meta->low_freq_factor, + m->rotary_embedding_meta->high_freq_factor, + m->rotary_embedding_meta->original_max_position_embeddings, + m->qProjSize, + m->kProjSize, + num_tokens, + q_array_size, + m->hidden_size); } } @@ -1365,23 +1379,30 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, // Step 7: perform rotary position embeddings (RoPE) bwd { - if (*m->apply_rotary_embedding) { + if (m->rotary_embedding_meta->apply_rotary_embedding) { assert(m->hidden_size == m->qProjSize * m->num_q_heads); assert(m->qProjSize == m->kProjSize); /*q&k*/ int parallelism = num_tokens * m->hidden_size; DT *A = static_cast
(m->devQKVProjArray); - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_rotary_embedding_bwd), - GET_BLOCKS(parallelism), - min(CUDA_NUM_THREADS, parallelism), - 0, - stream, - A, - m->complex_input, - m->token_infos, - m->qProjSize, - num_tokens, - m->hidden_size); + hipLaunchKernelGGL( + HIP_KERNEL_NAME(apply_rotary_embedding_bwd), + GET_BLOCKS(parallelism), + min(CUDA_NUM_THREADS, parallelism), + 0, + stream, + A, + m->complex_input, + m->token_infos, + m->rotary_embedding_meta->rope_theta, + (m->rotary_embedding_meta->rope_type == "llama3"), + m->rotary_embedding_meta->factor, + m->rotary_embedding_meta->low_freq_factor, + m->rotary_embedding_meta->high_freq_factor, + m->rotary_embedding_meta->original_max_position_embeddings, + m->qProjSize, + num_tokens, + m->hidden_size); DT *C = static_cast
(m->devQKVProjArray); if (m->inference_debugging) { std::string filename = @@ -1900,7 +1921,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( attn->kProjSize, attn->vProjSize, attn->oProjSize, - attn->apply_rotary_embedding, + attn->rotary_embedding_meta, attn->qkv_bias, attn->scaling_query, attn->qk_prod_scaling, @@ -1928,7 +1949,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( int _kProjSize, int _vProjSize, int _oProjSize, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _qkv_bias, bool _scaling_query, bool _qk_prod_scaling, @@ -1989,8 +2010,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( // has_load_weights = (bool *)calloc(1, sizeof(bool)); //*has_load_weights = false; - apply_rotary_embedding = (bool *)calloc(1, sizeof(bool)); - *apply_rotary_embedding = _apply_rotary_embedding; + rotary_embedding_meta = + (RotaryEmbeddingMeta *)calloc(1, sizeof(RotaryEmbeddingMeta)); + *rotary_embedding_meta = _rotary_embedding_meta; qkv_bias = (bool *)calloc(1, sizeof(bool)); *qkv_bias = _qkv_bias; scaling_query = (bool *)calloc(1, sizeof(bool)); diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 0ac8653b4a..43864b437b 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -19,6 +19,7 @@ #include "flexflow/ops/kernels/inc_multihead_self_attention_kernels.h" #include "flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh" #include "flexflow/utils/cuda_helper.h" +#include namespace FlexFlow { @@ -384,60 +385,17 @@ __global__ void scaling_query_kernel(DT *input_ptr, } } -template -__global__ void - apply_rotary_embedding_native(DT *input_ptr, - cuFloatComplex *complex_input, - BatchConfig::PerTokenInfo const *tokenInfos, - int qProjSize, - int kProjSize, - int num_q_heads, - int num_tokens, - int num_kv_heads, - int q_block_size, - int k_block_size, - int q_array_size) { - CUDA_KERNEL_LOOP( - i, - num_tokens * (qProjSize * num_q_heads + kProjSize * num_kv_heads) / 2) { - // create complex number - bool q_tensor = i < (q_array_size / 2); - int proj_size = q_tensor ? qProjSize : kProjSize; - int real_i = q_tensor ? i : i - q_array_size / 2; - - int head_idx = real_i / (num_tokens * proj_size / 2); - int idx = real_i % (num_tokens * proj_size / 2); - int real_part_index = idx * 2 + - head_idx * (q_tensor ? q_block_size : k_block_size) + - (q_tensor ? 0 : q_array_size); - - int complex_part_index = real_part_index + 1; - - complex_input[i] = {input_ptr[real_part_index], - input_ptr[complex_part_index]}; - - int token_idx = - (real_i - head_idx * (num_tokens * proj_size / 2)) / (proj_size / 2); - size_t pos = tokenInfos[token_idx].abs_depth_in_request; - - // float before_real = complex_input[i].x, before_complex = - // complex_input[i].y; - - int pos_i = real_i % (proj_size / 2); - float freq = pos * (1.0 / pow(10000.0, (float)2 * pos_i / proj_size)); - cuFloatComplex complex_pos = {cos(freq), sin(freq)}; - - complex_input[i] = cuCmulf(complex_input[i], complex_pos); - input_ptr[real_part_index] = complex_input[i].x; - input_ptr[complex_part_index] = complex_input[i].y; - } -} - template __global__ void apply_rotary_embedding_hf(DT *input_ptr, cuFloatComplex *complex_input, BatchConfig::PerTokenInfo const *tokenInfos, + float rope_theta, + bool llama3_rope, + float factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, int qProjSize, int kProjSize, int num_tokens, @@ -472,7 +430,29 @@ __global__ void // float before_real = complex_input[i].x, before_complex = int pos_i = real_i % (proj_size / 2); - float freq = pos * (1.0 / pow(10000.0, (float)2 * pos_i / proj_size)); + + float freq = + pos * (1.0 / pow(rope_theta, (float)2 * pos_i / proj_size)); // θ_i + + if (llama3_rope) { + float pi = CUDART_PI_F; + float wavelen = 2 * pi / freq; + float low_freq_wavelen = + original_max_position_embeddings / low_freq_factor; + float high_freq_wavelen = + original_max_position_embeddings / high_freq_factor; + if (wavelen < high_freq_wavelen) { + } else if (wavelen > low_freq_wavelen) { + freq = freq / factor; + } else { + assert(low_freq_wavelen != high_freq_wavelen); + float smooth = + (original_max_position_embeddings / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor); + freq = ((1 - smooth) * freq / factor + smooth * freq); + } + } + cuFloatComplex complex_pos = {cos(freq), sin(freq)}; complex_input[i] = cuCmulf(complex_input[i], complex_pos); @@ -486,6 +466,12 @@ __global__ void apply_rotary_embedding_bwd(DT *input_ptr, cuFloatComplex *complex_input, BatchConfig::PerTokenInfo const *tokenInfos, + float rope_theta, + bool llama3_rope, + float factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, int proj_size, int num_tokens, int hidden_size) { @@ -512,7 +498,28 @@ __global__ void size_t pos = tokenInfos[token_idx].abs_depth_in_request; - float freq = pos * (1.0 / pow(10000.0, (float)2 * idx / proj_size)); + float freq = + pos * (1.0 / pow(rope_theta, (float)2 * idx / proj_size)); // θ_i + + if (llama3_rope) { + float pi = CUDART_PI_F; + float wavelen = 2 * pi / freq; + float low_freq_wavelen = + original_max_position_embeddings / low_freq_factor; + float high_freq_wavelen = + original_max_position_embeddings / high_freq_factor; + if (wavelen < high_freq_wavelen) { + } else if (wavelen > low_freq_wavelen) { + freq = freq / factor; + } else { + assert(low_freq_wavelen != high_freq_wavelen); + float smooth = + (original_max_position_embeddings / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor); + freq = ((1 - smooth) * freq / factor + smooth * freq); + } + } + cuFloatComplex complex_pos = {cos(freq), sin(freq)}; complex_input[i] = cuCmulf(complex_input[i], complex_pos); @@ -578,20 +585,27 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, } // Step 3: apply rotary embedding if needed - if (*m->apply_rotary_embedding) { + if (m->rotary_embedding_meta->apply_rotary_embedding) { /*q&k*/ parallelism = num_tokens * m->hidden_size; apply_rotary_embedding_hf<<>>(output_ptr, - m->complex_input, - m->token_infos, - m->qProjSize, - m->kProjSize, - num_tokens, - q_array_size, - m->hidden_size); + stream>>>( + output_ptr, + m->complex_input, + m->token_infos, + m->rotary_embedding_meta->rope_theta, + (m->rotary_embedding_meta->rope_type == "llama3"), + m->rotary_embedding_meta->factor, + m->rotary_embedding_meta->low_freq_factor, + m->rotary_embedding_meta->high_freq_factor, + m->rotary_embedding_meta->original_max_position_embeddings, + m->qProjSize, + m->kProjSize, + num_tokens, + q_array_size, + m->hidden_size); } } @@ -1292,7 +1306,7 @@ void peft_bwd_kernel( // Step 7: perform rotary position embeddings (RoPE) bwd { - if (*m->apply_rotary_embedding) { + if (m->rotary_embedding_meta->apply_rotary_embedding) { assert(m->hidden_size == m->qProjSize * m->num_q_heads); assert(m->qProjSize == m->kProjSize); /*q&k*/ @@ -1301,12 +1315,19 @@ void peft_bwd_kernel( apply_rotary_embedding_bwd<<>>(A, - m->complex_input, - m->token_infos, - m->qProjSize, - num_tokens, - m->hidden_size); + stream>>>( + A, + m->complex_input, + m->token_infos, + m->rotary_embedding_meta->rope_theta, + (m->rotary_embedding_meta->rope_type == "llama3"), + m->rotary_embedding_meta->factor, + m->rotary_embedding_meta->low_freq_factor, + m->rotary_embedding_meta->high_freq_factor, + m->rotary_embedding_meta->original_max_position_embeddings, + m->qProjSize, + num_tokens, + m->hidden_size); DT *C = static_cast
(m->devQKVProjArray); if (m->inference_debugging) { std::string filename = @@ -1811,7 +1832,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( attn->kProjSize, attn->vProjSize, attn->oProjSize, - attn->apply_rotary_embedding, + attn->rotary_embedding_meta, attn->qkv_bias, attn->scaling_query, attn->qk_prod_scaling, @@ -1839,7 +1860,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( int _kProjSize, int _vProjSize, int _oProjSize, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _qkv_bias, bool _scaling_query, bool _qk_prod_scaling, @@ -1900,8 +1921,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( // has_load_weights = (bool *)calloc(1, sizeof(bool)); //*has_load_weights = false; - apply_rotary_embedding = (bool *)calloc(1, sizeof(bool)); - *apply_rotary_embedding = _apply_rotary_embedding; + rotary_embedding_meta = + (RotaryEmbeddingMeta *)calloc(1, sizeof(RotaryEmbeddingMeta)); + *rotary_embedding_meta = _rotary_embedding_meta; qkv_bias = (bool *)calloc(1, sizeof(bool)); *qkv_bias = _qkv_bias; scaling_query = (bool *)calloc(1, sizeof(bool)); diff --git a/src/ops/spec_inc_multihead_self_attention.cc b/src/ops/spec_inc_multihead_self_attention.cc index 954c28ad40..5a70b1baee 100644 --- a/src/ops/spec_inc_multihead_self_attention.cc +++ b/src/ops/spec_inc_multihead_self_attention.cc @@ -52,24 +52,24 @@ bool SpecIncMultiHeadSelfAttentionParams::is_valid( return is_valid; } -Tensor - FFModel::spec_inc_multihead_self_attention(Tensor const input, - int embed_dim, - int num_heads, - int kdim, - int vdim, - float dropout, - bool qkv_bias, - bool final_bias, - bool add_zero_attn, - DataType data_type, - Initializer *kernel_initializer, - bool apply_rotary_embedding, - bool scaling_query, - float scaling_factor, - bool qk_prod_scaling, - bool position_bias, - char const *name) { +Tensor FFModel::spec_inc_multihead_self_attention( + Tensor const input, + int embed_dim, + int num_heads, + int kdim, + int vdim, + float dropout, + bool qkv_bias, + bool final_bias, + bool add_zero_attn, + DataType data_type, + Initializer *kernel_initializer, + RotaryEmbeddingMeta rotary_embedding_meta, + bool scaling_query, + float scaling_factor, + bool qk_prod_scaling, + bool position_bias, + char const *name) { return spec_inc_multiquery_self_attention(input, embed_dim, num_heads, @@ -82,7 +82,7 @@ Tensor add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -90,25 +90,25 @@ Tensor name); } -Tensor - FFModel::spec_inc_multiquery_self_attention(Tensor const input, - int embed_dim, - int num_q_heads, - int num_kv_heads, - int kdim, - int vdim, - float dropout, - bool qkv_bias, - bool final_bias, - bool add_zero_attn, - DataType data_type, - Initializer *kernel_initializer, - bool apply_rotary_embedding, - bool scaling_query, - float scaling_factor, - bool qk_prod_scaling, - bool position_bias, - char const *name) { +Tensor FFModel::spec_inc_multiquery_self_attention( + Tensor const input, + int embed_dim, + int num_q_heads, + int num_kv_heads, + int kdim, + int vdim, + float dropout, + bool qkv_bias, + bool final_bias, + bool add_zero_attn, + DataType data_type, + Initializer *kernel_initializer, + RotaryEmbeddingMeta rotary_embedding_meta, + bool scaling_query, + float scaling_factor, + bool qk_prod_scaling, + bool position_bias, + char const *name) { if (data_type == DT_NONE) { data_type = input->data_type; } @@ -165,7 +165,17 @@ Tensor li->add_int_property("final_bias", final_bias); li->add_int_property("add_zero_attn", add_zero_attn); li->add_float_property("dropout", dropout); - li->add_int_property("apply_rotary_embedding", apply_rotary_embedding); + li->add_int_property("apply_rotary_embedding", + rotary_embedding_meta.apply_rotary_embedding); + li->add_float_property("rope_theta", rotary_embedding_meta.rope_theta); + li->add_string_property("rope_type", rotary_embedding_meta.rope_type); + li->add_float_property("factor", rotary_embedding_meta.factor); + li->add_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + li->add_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + li->add_int_property("original_max_position_embeddings", + rotary_embedding_meta.original_max_position_embeddings); li->add_int_property("scaling_query", scaling_query); li->add_float_property("scaling_factor", scaling_factor); li->add_int_property("qk_prod_scaling", qk_prod_scaling); @@ -199,8 +209,18 @@ Op *SpecIncMultiHeadSelfAttention::create_operator_from_layer( bool final_bias = (bool)value; layer->get_int_property("add_zero_attn", value); bool add_zero_attn = (bool)value; + RotaryEmbeddingMeta rotary_embedding_meta; layer->get_int_property("apply_rotary_embedding", value); - bool apply_rotary_embedding = (bool)value; + rotary_embedding_meta.apply_rotary_embedding = (bool)value; + layer->get_float_property("rope_theta", rotary_embedding_meta.rope_theta); + layer->get_string_property("rope_type", rotary_embedding_meta.rope_type); + layer->get_float_property("factor", rotary_embedding_meta.factor); + layer->get_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + layer->get_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + layer->get_int_property("original_max_position_embeddings", value); + rotary_embedding_meta.original_max_position_embeddings = (int)value; layer->get_int_property("scaling_query", value); bool scaling_query = (bool)value; float scaling_factor; @@ -222,7 +242,7 @@ Op *SpecIncMultiHeadSelfAttention::create_operator_from_layer( qkv_bias, final_bias, add_zero_attn, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -244,7 +264,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -263,7 +283,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), qSize(_input->dims[0].size), kSize(_input->dims[0].size), vSize(_input->dims[0].size), qProjSize(_kdim), kProjSize(_kdim), vProjSize(_vdim), oProjSize(_embed_dim), @@ -319,7 +339,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -339,7 +359,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), qSize(_input->dims[0].size), kSize(_input->dims[0].size), vSize(_input->dims[0].size), qProjSize(_kdim), kProjSize(_kdim), vProjSize(_vdim), oProjSize(_embed_dim), @@ -399,7 +419,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( other.qkv_bias, other.final_bias, other.add_zero_attn, - other.apply_rotary_embedding, + other.rotary_embedding_meta, other.scaling_query, other.scaling_factor, other.qk_prod_scaling, @@ -425,7 +445,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( params.qkv_bias, params.final_bias, params.add_zero_attn, - params.apply_rotary_embedding, + params.rotary_embedding_meta, params.scaling_query, params.scaling_factor, params.qk_prod_scaling, @@ -688,7 +708,19 @@ bool operator==(SpecIncMultiHeadSelfAttentionParams const &lhs, lhs.vdim == rhs.vdim && lhs.dropout == rhs.dropout && lhs.qkv_bias == rhs.qkv_bias && lhs.final_bias == rhs.final_bias && lhs.add_zero_attn == rhs.add_zero_attn && - lhs.apply_rotary_embedding == rhs.apply_rotary_embedding && + lhs.rotary_embedding_meta.apply_rotary_embedding == + rhs.rotary_embedding_meta.apply_rotary_embedding && + lhs.rotary_embedding_meta.rope_theta == + rhs.rotary_embedding_meta.rope_theta && + lhs.rotary_embedding_meta.rope_type == + rhs.rotary_embedding_meta.rope_type && + lhs.rotary_embedding_meta.factor == rhs.rotary_embedding_meta.factor && + lhs.rotary_embedding_meta.low_freq_factor == + rhs.rotary_embedding_meta.low_freq_factor && + lhs.rotary_embedding_meta.high_freq_factor == + rhs.rotary_embedding_meta.high_freq_factor && + lhs.rotary_embedding_meta.original_max_position_embeddings == + rhs.rotary_embedding_meta.original_max_position_embeddings && lhs.scaling_query == rhs.scaling_query && lhs.scaling_factor == rhs.scaling_factor && lhs.qk_prod_scaling == rhs.qk_prod_scaling && @@ -708,7 +740,7 @@ SpecIncMultiHeadSelfAttentionParams params.qkv_bias = this->qkv_bias; params.final_bias = this->final_bias; params.add_zero_attn = this->add_zero_attn; - params.apply_rotary_embedding = this->apply_rotary_embedding; + params.rotary_embedding_meta = this->rotary_embedding_meta; params.scaling_query = this->scaling_query; params.scaling_factor = this->scaling_factor; params.qk_prod_scaling = this->qk_prod_scaling; @@ -736,7 +768,14 @@ size_t hash::operator()( hash_combine(key, params.qkv_bias); hash_combine(key, params.final_bias); hash_combine(key, params.add_zero_attn); - hash_combine(key, params.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.rope_theta); + hash_combine(key, params.rotary_embedding_meta.rope_type); + hash_combine(key, params.rotary_embedding_meta.factor); + hash_combine(key, params.rotary_embedding_meta.low_freq_factor); + hash_combine(key, params.rotary_embedding_meta.high_freq_factor); + hash_combine(key, + params.rotary_embedding_meta.original_max_position_embeddings); hash_combine(key, params.scaling_query); hash_combine(key, params.scaling_factor); hash_combine(key, params.qk_prod_scaling); diff --git a/src/ops/spec_inc_multihead_self_attention.cpp b/src/ops/spec_inc_multihead_self_attention.cpp index 0bf2b3346e..aa123d9451 100644 --- a/src/ops/spec_inc_multihead_self_attention.cpp +++ b/src/ops/spec_inc_multihead_self_attention.cpp @@ -614,7 +614,7 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( attn->kProjSize, attn->vProjSize, attn->oProjSize, - attn->apply_rotary_embedding, + attn->rotary_embedding_meta, attn->qkv_bias, attn->scaling_query, attn->qk_prod_scaling, diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 4c65a8baa8..4d391ef0b8 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -749,7 +749,7 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( GenericTensorAccessorW const &output) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - bool use_bias = *m->qkv_bias || *m->final_bias; + // bool use_bias = *m->qkv_bias || *m->final_bias; cudaEvent_t t_start, t_end; if (m->profiling) { @@ -761,7 +761,7 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( assert(input.data_type == output.data_type); if (input.data_type == DT_HALF) { - half const *bias_ptr = static_cast(nullptr); + // half const *bias_ptr = static_cast(nullptr); Kernels::SpecIncMultiHeadSelfAttention::inference_kernel( m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream); } else if (input.data_type == DT_FLOAT) { @@ -803,7 +803,7 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( attn->kProjSize, attn->vProjSize, attn->oProjSize, - attn->apply_rotary_embedding, + attn->rotary_embedding_meta, attn->qkv_bias, attn->scaling_query, attn->qk_prod_scaling, diff --git a/src/ops/tree_inc_multihead_self_attention.cc b/src/ops/tree_inc_multihead_self_attention.cc index c2187b1ca2..13779e7c33 100644 --- a/src/ops/tree_inc_multihead_self_attention.cc +++ b/src/ops/tree_inc_multihead_self_attention.cc @@ -66,7 +66,7 @@ Tensor FFModel::inc_multihead_self_attention_verify( bool add_zero_attn, DataType data_type, Initializer *kernel_initializer, - bool apply_rotary_embedding, + RotaryEmbeddingMeta rotary_embedding_meta, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -84,7 +84,7 @@ Tensor FFModel::inc_multihead_self_attention_verify( add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -105,7 +105,7 @@ Tensor FFModel::inc_multiquery_self_attention_verify( bool add_zero_attn, DataType data_type, Initializer *kernel_initializer, - bool apply_rotary_embedding, + RotaryEmbeddingMeta rotary_embedding_meta, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -170,10 +170,19 @@ Tensor FFModel::inc_multiquery_self_attention_verify( li->add_int_property("final_bias", final_bias); li->add_int_property("add_zero_attn", add_zero_attn); li->add_float_property("dropout", dropout); - li->add_int_property("apply_rotary_embedding", apply_rotary_embedding); + li->add_int_property("apply_rotary_embedding", + rotary_embedding_meta.apply_rotary_embedding); + li->add_float_property("rope_theta", rotary_embedding_meta.rope_theta); + li->add_string_property("rope_type", rotary_embedding_meta.rope_type); + li->add_float_property("factor", rotary_embedding_meta.factor); + li->add_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + li->add_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + li->add_int_property("original_max_position_embeddings", + rotary_embedding_meta.original_max_position_embeddings); li->add_int_property("scaling_query", scaling_query); li->add_float_property("scaling_factor", scaling_factor); - li->add_int_property("qk_prod_scaling", qk_prod_scaling); li->add_int_property("position_bias", position_bias); li->add_int_property("quantization_type", quantization_type); li->add_int_property("offload", offload); @@ -206,9 +215,18 @@ Op *TreeIncMultiHeadSelfAttention::create_operator_from_layer( bool final_bias = (bool)value; layer->get_int_property("add_zero_attn", value); bool add_zero_attn = (bool)value; + RotaryEmbeddingMeta rotary_embedding_meta; layer->get_int_property("apply_rotary_embedding", value); - bool apply_rotary_embedding = (bool)value; - layer->get_int_property("scaling_query", value); + rotary_embedding_meta.apply_rotary_embedding = (bool)value; + layer->get_float_property("rope_theta", rotary_embedding_meta.rope_theta); + layer->get_string_property("rope_type", rotary_embedding_meta.rope_type); + layer->get_float_property("factor", rotary_embedding_meta.factor); + layer->get_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + layer->get_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + layer->get_int_property("original_max_position_embeddings", value); + rotary_embedding_meta.original_max_position_embeddings = (int)value; bool scaling_query = (bool)value; float scaling_factor; layer->get_float_property("scaling_factor", scaling_factor); @@ -234,7 +252,7 @@ Op *TreeIncMultiHeadSelfAttention::create_operator_from_layer( qkv_bias, final_bias, add_zero_attn, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -259,7 +277,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -281,7 +299,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), qSize(_input->dims[0].size), kSize(_input->dims[0].size), vSize(_input->dims[0].size), qProjSize(_kdim), kProjSize(_kdim), vProjSize(_vdim), oProjSize(_embed_dim), @@ -351,7 +369,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -374,7 +392,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), qSize(_input->dims[0].size), kSize(_input->dims[0].size), vSize(_input->dims[0].size), qProjSize(_kdim), kProjSize(_kdim), vProjSize(_vdim), oProjSize(_embed_dim), @@ -449,7 +467,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( other.qkv_bias, other.final_bias, other.add_zero_attn, - other.apply_rotary_embedding, + other.rotary_embedding_meta, other.scaling_query, other.scaling_factor, other.qk_prod_scaling, @@ -478,7 +496,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( params.qkv_bias, params.final_bias, params.add_zero_attn, - params.apply_rotary_embedding, + params.rotary_embedding_meta, params.scaling_query, params.scaling_factor, params.qk_prod_scaling, @@ -754,7 +772,19 @@ bool operator==(TreeIncMultiHeadSelfAttentionParams const &lhs, lhs.vdim == rhs.vdim && lhs.dropout == rhs.dropout && lhs.qkv_bias == rhs.qkv_bias && lhs.final_bias == rhs.final_bias && lhs.add_zero_attn == rhs.add_zero_attn && - lhs.apply_rotary_embedding == rhs.apply_rotary_embedding && + lhs.rotary_embedding_meta.apply_rotary_embedding == + rhs.rotary_embedding_meta.apply_rotary_embedding && + lhs.rotary_embedding_meta.rope_theta == + rhs.rotary_embedding_meta.rope_theta && + lhs.rotary_embedding_meta.rope_type == + rhs.rotary_embedding_meta.rope_type && + lhs.rotary_embedding_meta.factor == rhs.rotary_embedding_meta.factor && + lhs.rotary_embedding_meta.low_freq_factor == + rhs.rotary_embedding_meta.low_freq_factor && + lhs.rotary_embedding_meta.high_freq_factor == + rhs.rotary_embedding_meta.high_freq_factor && + lhs.rotary_embedding_meta.original_max_position_embeddings == + rhs.rotary_embedding_meta.original_max_position_embeddings && lhs.scaling_query == rhs.scaling_query && lhs.scaling_factor == rhs.scaling_factor && lhs.qk_prod_scaling == rhs.qk_prod_scaling && @@ -774,7 +804,7 @@ TreeIncMultiHeadSelfAttentionParams params.qkv_bias = this->qkv_bias; params.final_bias = this->final_bias; params.add_zero_attn = this->add_zero_attn; - params.apply_rotary_embedding = this->apply_rotary_embedding; + params.rotary_embedding_meta = this->rotary_embedding_meta; params.scaling_query = this->scaling_query; params.scaling_factor = this->scaling_factor; params.qk_prod_scaling = this->qk_prod_scaling; @@ -802,7 +832,14 @@ size_t hash::operator()( hash_combine(key, params.qkv_bias); hash_combine(key, params.final_bias); hash_combine(key, params.add_zero_attn); - hash_combine(key, params.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.rope_theta); + hash_combine(key, params.rotary_embedding_meta.rope_type); + hash_combine(key, params.rotary_embedding_meta.factor); + hash_combine(key, params.rotary_embedding_meta.low_freq_factor); + hash_combine(key, params.rotary_embedding_meta.high_freq_factor); + hash_combine(key, + params.rotary_embedding_meta.original_max_position_embeddings); hash_combine(key, params.scaling_query); hash_combine(key, params.scaling_factor); hash_combine(key, params.qk_prod_scaling); diff --git a/src/ops/tree_inc_multihead_self_attention.cpp b/src/ops/tree_inc_multihead_self_attention.cpp index ff592ddccb..8a4c0f3b68 100644 --- a/src/ops/tree_inc_multihead_self_attention.cpp +++ b/src/ops/tree_inc_multihead_self_attention.cpp @@ -1062,7 +1062,7 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( attn->kProjSize, attn->vProjSize, attn->oProjSize, - attn->apply_rotary_embedding, + attn->rotary_embedding_meta, attn->qkv_bias, attn->scaling_query, attn->qk_prod_scaling, diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 43e8e46d49..a1d8c7000a 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -958,7 +958,7 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( GenericTensorAccessorW const &output) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - bool use_bias = *m->qkv_bias || *m->final_bias; + // bool use_bias = *m->qkv_bias || *m->final_bias; cudaEvent_t t_start, t_end; if (m->profiling) { @@ -1020,7 +1020,7 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( attn->kProjSize, attn->vProjSize, attn->oProjSize, - attn->apply_rotary_embedding, + attn->rotary_embedding_meta, attn->qkv_bias, attn->scaling_query, attn->qk_prod_scaling, diff --git a/src/runtime/graph.cc b/src/runtime/graph.cc index 1a38782e81..6a74979172 100644 --- a/src/runtime/graph.cc +++ b/src/runtime/graph.cc @@ -2334,7 +2334,16 @@ GraphOptimalViewSerialized sez.serialize(attn->qkv_bias); sez.serialize(attn->final_bias); sez.serialize(attn->add_zero_attn); - sez.serialize(attn->apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.rope_theta); + sez.serialize(attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.rope_type.c_str(), + attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.factor); + sez.serialize(attn->rotary_embedding_meta.low_freq_factor); + sez.serialize(attn->rotary_embedding_meta.high_freq_factor); + sez.serialize( + attn->rotary_embedding_meta.original_max_position_embeddings); sez.serialize(attn->scaling_query); sez.serialize(attn->scaling_factor); sez.serialize(attn->qk_prod_scaling); @@ -2361,7 +2370,16 @@ GraphOptimalViewSerialized sez.serialize(attn->qkv_bias); sez.serialize(attn->final_bias); sez.serialize(attn->add_zero_attn); - sez.serialize(attn->apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.rope_theta); + sez.serialize(attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.rope_type.c_str(), + attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.factor); + sez.serialize(attn->rotary_embedding_meta.low_freq_factor); + sez.serialize(attn->rotary_embedding_meta.high_freq_factor); + sez.serialize( + attn->rotary_embedding_meta.original_max_position_embeddings); sez.serialize(attn->scaling_query); sez.serialize(attn->scaling_factor); sez.serialize(attn->qk_prod_scaling); @@ -2385,7 +2403,16 @@ GraphOptimalViewSerialized sez.serialize(attn->qkv_bias); sez.serialize(attn->final_bias); sez.serialize(attn->add_zero_attn); - sez.serialize(attn->apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.rope_theta); + sez.serialize(attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.rope_type.c_str(), + attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.factor); + sez.serialize(attn->rotary_embedding_meta.low_freq_factor); + sez.serialize(attn->rotary_embedding_meta.high_freq_factor); + sez.serialize( + attn->rotary_embedding_meta.original_max_position_embeddings); sez.serialize(attn->scaling_query); sez.serialize(attn->scaling_factor); sez.serialize(attn->qk_prod_scaling); @@ -2817,8 +2844,9 @@ void FFModel::deserialize_graph_optimal_view( int embed_dim, num_q_heads, k_dim, v_dim, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, offload, position_bias; + bool qkv_bias, final_bias, add_zero_attn, scaling_query, + qk_prod_scaling, offload, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; DataType quantization_type; size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); @@ -2833,7 +2861,17 @@ void FFModel::deserialize_graph_optimal_view( dez.deserialize(qkv_bias); dez.deserialize(final_bias); dez.deserialize(add_zero_attn); - dez.deserialize(apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.rope_theta); + size_t rope_type_len; + char rope_type[1024] = {0}; + dez.deserialize(rope_type_len); + dez.deserialize(rope_type, rope_type_len); + rotary_embedding_meta.rope_type = std::string(rope_type); + dez.deserialize(rotary_embedding_meta.factor); + dez.deserialize(rotary_embedding_meta.low_freq_factor); + dez.deserialize(rotary_embedding_meta.high_freq_factor); + dez.deserialize(rotary_embedding_meta.original_max_position_embeddings); dez.deserialize(scaling_query); dez.deserialize(scaling_factor); dez.deserialize(qk_prod_scaling); @@ -2857,7 +2895,7 @@ void FFModel::deserialize_graph_optimal_view( params.final_bias = final_bias; params.add_zero_attn = add_zero_attn; params.layer_guid = layer_guid; - params.apply_rotary_embedding = apply_rotary_embedding; + params.rotary_embedding_meta = rotary_embedding_meta; params.scaling_query = scaling_query; params.scaling_factor = scaling_factor; params.qk_prod_scaling = qk_prod_scaling; @@ -2874,8 +2912,9 @@ void FFModel::deserialize_graph_optimal_view( assert(num_inputs == 1); int embed_dim, num_q_heads, k_dim, v_dim, num_kv_heads; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, position_bias; + bool qkv_bias, final_bias, add_zero_attn, scaling_query, + qk_prod_scaling, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); @@ -2889,7 +2928,17 @@ void FFModel::deserialize_graph_optimal_view( dez.deserialize(qkv_bias); dez.deserialize(final_bias); dez.deserialize(add_zero_attn); - dez.deserialize(apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.rope_theta); + size_t rope_type_len; + char rope_type[1024] = {0}; + dez.deserialize(rope_type_len); + dez.deserialize(rope_type, rope_type_len); + rotary_embedding_meta.rope_type = std::string(rope_type); + dez.deserialize(rotary_embedding_meta.factor); + dez.deserialize(rotary_embedding_meta.low_freq_factor); + dez.deserialize(rotary_embedding_meta.high_freq_factor); + dez.deserialize(rotary_embedding_meta.original_max_position_embeddings); dez.deserialize(scaling_query); dez.deserialize(scaling_factor); dez.deserialize(qk_prod_scaling); @@ -2910,7 +2959,7 @@ void FFModel::deserialize_graph_optimal_view( params.final_bias = final_bias; params.add_zero_attn = add_zero_attn; params.layer_guid = layer_guid; - params.apply_rotary_embedding = apply_rotary_embedding; + params.rotary_embedding_meta = rotary_embedding_meta; params.scaling_query = scaling_query; params.scaling_factor = scaling_factor; params.qk_prod_scaling = qk_prod_scaling; @@ -2926,8 +2975,9 @@ void FFModel::deserialize_graph_optimal_view( int embed_dim, num_q_heads, k_dim, v_dim, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, offload, position_bias; + bool qkv_bias, final_bias, add_zero_attn, scaling_query, + qk_prod_scaling, offload, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; DataType quantization_type; size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); @@ -2942,7 +2992,17 @@ void FFModel::deserialize_graph_optimal_view( dez.deserialize(qkv_bias); dez.deserialize(final_bias); dez.deserialize(add_zero_attn); - dez.deserialize(apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.rope_theta); + size_t rope_type_len; + char rope_type[1024] = {0}; + dez.deserialize(rope_type_len); + dez.deserialize(rope_type, rope_type_len); + rotary_embedding_meta.rope_type = std::string(rope_type); + dez.deserialize(rotary_embedding_meta.factor); + dez.deserialize(rotary_embedding_meta.low_freq_factor); + dez.deserialize(rotary_embedding_meta.high_freq_factor); + dez.deserialize(rotary_embedding_meta.original_max_position_embeddings); dez.deserialize(scaling_query); dez.deserialize(scaling_factor); dez.deserialize(qk_prod_scaling); @@ -2966,7 +3026,7 @@ void FFModel::deserialize_graph_optimal_view( params.final_bias = final_bias; params.add_zero_attn = add_zero_attn; params.layer_guid = layer_guid; - params.apply_rotary_embedding = apply_rotary_embedding; + params.rotary_embedding_meta = rotary_embedding_meta; params.scaling_query = scaling_query; params.scaling_factor = scaling_factor; params.qk_prod_scaling = qk_prod_scaling; diff --git a/src/runtime/layer.cc b/src/runtime/layer.cc index 8f33f6db87..72e71688c1 100644 --- a/src/runtime/layer.cc +++ b/src/runtime/layer.cc @@ -87,6 +87,11 @@ void Layer::add_int_vector_property(std::string const &key, int_vector_properties[key] = value; } +void Layer::add_string_property(std::string const &key, + std::string const &value) { + string_properties[key] = value; +} + void Layer::add_initializer(std::string const &key, Initializer *initializer) { initializers[key] = initializer; } @@ -125,6 +130,18 @@ bool Layer::get_int_vector_property(std::string const &key, } } +bool Layer::get_string_property(std::string const &key, + std::string &value) const { + auto const &it = string_properties.find(key); + if (it == string_properties.end()) { + assert(false); + return false; + } else { + value = it->second; + return true; + } +} + bool Layer::get_initializer(std::string const &key, Initializer *&initializer) const { auto const &it = initializers.find(key); diff --git a/tests/fine_grained_alignment_test.sh b/tests/fine_grained_alignment_test.sh index 681a015600..a0ed718d25 100755 --- a/tests/fine_grained_alignment_test.sh +++ b/tests/fine_grained_alignment_test.sh @@ -6,6 +6,7 @@ MODEL_NAME=${MODEL_NAME:-"JackFram/llama-160m"} MEMORY_PER_GPU=${MEMORY_PER_GPU:-14000} ZCOPY_MEMORY=${ZCOPY_MEMORY:-40000} CACHE_PATH=${FF_CACHE_PATH:-"~/.cache/flexflow"} +NUM_STEPS=${NUM_STEPS:-2} cleanup() { rm -rf ${CACHE_PATH}/debug ./fine_grained_alignment_config.json ./inference/output/fine_grained_alignment_test_ff.txt ./inference/output/fine_grained_alignment_test_hf.txt @@ -26,8 +27,30 @@ mkdir -p ./inference/output # Enable backtrace in case we run into a segfault or assertion failure export LEGION_BACKTRACE=1 +export FF_DEBG_NO_WEIGHTS=1 -python ./tests/inference/huggingface_inference.py --model-name $MODEL_NAME --max-length 10 --prompt-file ../../inference/prompt/test.json --output-file ../../inference/output/fine_grained_alignment_test_hf.txt --use-full-precision --inference-debugging +PROMPT_LENGTH=$(python -c " +from transformers import AutoTokenizer +import os +tokenizer = AutoTokenizer.from_pretrained(\"$MODEL_NAME\") +tokens = tokenizer.tokenize('Three tips for staying healthy are: ') +print(len(tokens)) +") +# Check if the Python code executed successfully +if [ $? -ne 0 ]; then + echo "Error: Failed to execute Python code" + exit 1 +fi + +MAX_LENGTH=$((PROMPT_LENGTH + NUM_STEPS + 1)) + +python ./tests/inference/huggingface_inference.py \ + --model-name $MODEL_NAME \ + --max-length $MAX_LENGTH \ + --prompt-file ../../inference/prompt/test.json \ + --output-file ../../inference/output/fine_grained_alignment_test_hf.txt \ + --use-full-precision \ + --inference-debugging json_config=$(cat <<-END { @@ -46,7 +69,7 @@ json_config=$(cat <<-END "cache_path": "${CACHE_PATH}", "full_precision": true, "prompt": "./inference/prompt/test.json", - "max_length": 10, + "max_length": $MAX_LENGTH, "output_file": "./inference/output/fine_grained_alignment_test_ff.txt" } END @@ -67,11 +90,11 @@ python ./inference/python/incr_decoding.py -config-file ./fine_grained_alignment # --inference-debugging # Check alignment -python ./tests/inference/inference_alignment_test.py -m $MODEL_NAME -tp 2 -n 2 +python ./tests/inference/inference_alignment_test.py -m $MODEL_NAME -tp 2 -n $NUM_STEPS # Print succeess message echo "" -echo "Inference alignment tests passed!" +echo "Inference alignment tests passed (model ${MODEL_NAME})!" echo "" # Cleanup after the test