From 96628b360efb6a0299dd9a3a652a91249b722231 Mon Sep 17 00:00:00 2001 From: Yingcheng <135535812+yingchen21@users.noreply.github.com> Date: Thu, 10 Oct 2024 06:27:49 +0800 Subject: [PATCH] Attention projections (QKV, O) disaggregation (#1436) * merged attn-qkv-proj into peft. commented out some alignment test, but should be equivalent to the oriinal test. * restored and passed the alignement test * linting * rebased onto inference * Bug fixes, uploaded missing cpp implmentation * Code cleanup * clean up * fixed problem with mpt. * update * llama3.1 support * fix * support llama3.2 * fix opt bias? * opt alignment test stub * fix bias * update * fix non-fusion opt * update * fix * cleanup * delete file * cleanup * shellcheck * hip cleanup * fix * hip fixes --------- Co-authored-by: Gabriele Oliaro Co-authored-by: zhihao Co-authored-by: Gabriele Oliaro --- .gitignore | 3 + .../ops/inc_multihead_self_attention.py | 6 - .../inc_multihead_self_attention_verify.py | 6 - .../ops/inc_multiquery_self_attention.py | 6 - .../inc_multiquery_self_attention_verify.py | 6 - .../ops/spec_inc_multihead_self_attention.py | 6 - .../ops/spec_inc_multiquery_self_attention.py | 6 - include/flexflow/flexflow_c.h | 48 +- include/flexflow/inference.h | 39 +- include/flexflow/layer.h | 3 + include/flexflow/model.h | 146 +- include/flexflow/operator.h | 8 +- .../ops/inc_multihead_self_attention.h | 54 +- .../ops/inc_multihead_self_attention_params.h | 5 +- .../inc_multihead_self_attention_kernels.h | 49 +- .../ops/spec_inc_multihead_self_attention.h | 25 +- ...spec_inc_multihead_self_attention_params.h | 4 +- .../ops/tree_inc_multihead_self_attention.h | 26 +- ...tree_inc_multihead_self_attention_params.h | 4 +- inference/models/falcon.cc | 81 +- inference/models/falcon.h | 29 +- inference/models/llama.cc | 72 +- inference/models/llama.h | 29 +- inference/models/mpt.cc | 54 +- inference/models/mpt.h | 2 + inference/models/opt.cc | 62 +- inference/models/opt.h | 9 +- inference/models/starcoder.cc | 55 +- inference/models/starcoder.h | 4 +- inference/python/incr_decoding.py | 10 +- python/flexflow/core/flexflow_cffi.py | 161 +- python/flexflow/serve/models/falcon.py | 56 +- python/flexflow/serve/models/llama.py | 56 +- python/flexflow/serve/models/mpt.py | 46 +- python/flexflow/serve/models/opt.py | 45 +- python/flexflow/serve/models/starcoder.py | 32 +- src/c/flexflow_c.cc | 114 +- src/ops/add_bias_residual_layer_norm.cc | 14 +- src/ops/fused.cpp | 48 +- src/ops/fused.cu | 55 +- src/ops/inc_multihead_self_attention.cc | 496 +-- src/ops/inc_multihead_self_attention.cpp | 1646 ++++----- src/ops/inc_multihead_self_attention.cu | 2972 ++++++++--------- src/ops/kernels/linear_kernels.cu | 1 + src/ops/linear.cc | 6 +- src/ops/residual_layer_norm.cc | 17 +- src/ops/spec_inc_multihead_self_attention.cc | 415 +-- src/ops/spec_inc_multihead_self_attention.cpp | 1056 +++--- src/ops/spec_inc_multihead_self_attention.cu | 101 +- src/ops/tree_inc_multihead_self_attention.cc | 385 +-- src/ops/tree_inc_multihead_self_attention.cpp | 411 +-- src/ops/tree_inc_multihead_self_attention.cu | 409 +-- src/parallel_ops/allreduce.cc | 2 +- src/runtime/file_loader.cc | 406 ++- src/runtime/graph.cc | 107 +- src/runtime/inference_manager.cc | 1 + src/runtime/layer.cc | 17 + src/runtime/model.cc | 51 +- src/runtime/operator.cc | 12 + src/runtime/substitution.cc | 5 +- tests/fine_grained_alignment_test.sh | 106 + tests/inference/huggingface_inference.py | 49 +- tests/inference/inference_alignment_test.py | 817 +++++ tests/peft/alignment/align_test_utils.py | 13 +- tests/peft/hf_finetune.py | 2 +- tests/peft/hf_utils.py | 15 +- tests/peft/peft_alignment_test.py | 39 +- 67 files changed, 5146 insertions(+), 5895 deletions(-) create mode 100755 tests/fine_grained_alignment_test.sh create mode 100644 tests/inference/inference_alignment_test.py diff --git a/.gitignore b/.gitignore index cc34c1a7b6..c1e22fcaba 100644 --- a/.gitignore +++ b/.gitignore @@ -193,3 +193,6 @@ lora_training_logs Untitled-1.ipynb Untitled-2.ipynb tests/inference/python_test_configs/*.json + +core.* +fine_grained_alignment_config.json diff --git a/examples/python/native/ops/inc_multihead_self_attention.py b/examples/python/native/ops/inc_multihead_self_attention.py index dce7bd565d..ab80a5893c 100644 --- a/examples/python/native/ops/inc_multihead_self_attention.py +++ b/examples/python/native/ops/inc_multihead_self_attention.py @@ -11,8 +11,6 @@ def test_inc_multihead_self_attention( kdim: int = 0, vdim: int = 0, dropout: float = 0.0, - bias: bool = True, - add_bias_kv: bool = False, add_zero_attn: bool = False, data_type: DataType = DataType.DT_NONE, kernel_initializer=None, @@ -34,8 +32,6 @@ def test_inc_multihead_self_attention( kdim=kdim, vdim=vdim, dropout=dropout, - bias=bias, - add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, data_type=data_type, kernel_initializer=kernel_initializer, @@ -85,8 +81,6 @@ def test_inc_multihead_self_attention( kdim=0, # Example value for kdim vdim=0, # Example value for vdim dropout=0.1, # Example value for dropout - bias=True, - add_bias_kv=False, add_zero_attn=False, data_type=DataType.DT_FLOAT, kernel_initializer=None, # Example value for kernel_initializer diff --git a/examples/python/native/ops/inc_multihead_self_attention_verify.py b/examples/python/native/ops/inc_multihead_self_attention_verify.py index f6dc8e3933..bc2ba5e977 100644 --- a/examples/python/native/ops/inc_multihead_self_attention_verify.py +++ b/examples/python/native/ops/inc_multihead_self_attention_verify.py @@ -11,8 +11,6 @@ def test_inc_multihead_self_attention_verify( kdim: int = 0, vdim: int = 0, dropout: float = 0.0, - bias: bool = True, - add_bias_kv: bool = False, add_zero_attn: bool = False, data_type: DataType = DataType.DT_NONE, kernel_initializer=None, @@ -34,8 +32,6 @@ def test_inc_multihead_self_attention_verify( kdim=kdim, vdim=vdim, dropout=dropout, - bias=bias, - add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, data_type=data_type, kernel_initializer=kernel_initializer, @@ -85,8 +81,6 @@ def test_inc_multihead_self_attention_verify( kdim=0, # Example value for kdim vdim=0, # Example value for vdim dropout=0.1, # Example value for dropout - bias=True, - add_bias_kv=False, add_zero_attn=False, data_type=DataType.DT_FLOAT, kernel_initializer=None, # Example value for kernel_initializer diff --git a/examples/python/native/ops/inc_multiquery_self_attention.py b/examples/python/native/ops/inc_multiquery_self_attention.py index 33390ab1f6..424b46b0f4 100644 --- a/examples/python/native/ops/inc_multiquery_self_attention.py +++ b/examples/python/native/ops/inc_multiquery_self_attention.py @@ -12,8 +12,6 @@ def test_inc_multiquery_self_attention( kdim: int = 0, vdim: int = 0, dropout: float = 0.0, - bias: bool = True, - add_bias_kv: bool = False, add_zero_attn: bool = False, data_type: DataType = DataType.DT_NONE, kernel_initializer=None, @@ -36,8 +34,6 @@ def test_inc_multiquery_self_attention( kdim=kdim, vdim=vdim, dropout=dropout, - bias=bias, - add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, data_type=data_type, kernel_initializer=kernel_initializer, @@ -89,8 +85,6 @@ def test_inc_multiquery_self_attention( kdim=0, # Example value for kdim vdim=0, # Example value for vdim dropout=0.1, # Example value for dropout - bias=True, - add_bias_kv=False, add_zero_attn=False, data_type=DataType.DT_FLOAT, kernel_initializer=None, # Example value for kernel_initializer diff --git a/examples/python/native/ops/inc_multiquery_self_attention_verify.py b/examples/python/native/ops/inc_multiquery_self_attention_verify.py index 69a76f68bf..b2c0e7dcf5 100644 --- a/examples/python/native/ops/inc_multiquery_self_attention_verify.py +++ b/examples/python/native/ops/inc_multiquery_self_attention_verify.py @@ -12,8 +12,6 @@ def test_inc_multiquery_self_attention_verify( kdim: int = 0, vdim: int = 0, dropout: float = 0.0, - bias: bool = True, - add_bias_kv: bool = False, add_zero_attn: bool = False, data_type: DataType = DataType.DT_NONE, kernel_initializer=None, @@ -36,8 +34,6 @@ def test_inc_multiquery_self_attention_verify( kdim=kdim, vdim=vdim, dropout=dropout, - bias=bias, - add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, data_type=data_type, kernel_initializer=kernel_initializer, @@ -89,8 +85,6 @@ def test_inc_multiquery_self_attention_verify( kdim=0, # Example value for kdim vdim=0, # Example value for vdim dropout=0.1, # Example value for dropout - bias=True, - add_bias_kv=False, add_zero_attn=False, data_type=DataType.DT_FLOAT, kernel_initializer=None, # Example value for kernel_initializer diff --git a/examples/python/native/ops/spec_inc_multihead_self_attention.py b/examples/python/native/ops/spec_inc_multihead_self_attention.py index bd1aaa189b..d0fa5f7689 100644 --- a/examples/python/native/ops/spec_inc_multihead_self_attention.py +++ b/examples/python/native/ops/spec_inc_multihead_self_attention.py @@ -11,8 +11,6 @@ def test_spec_inc_multihead_self_attention( kdim: int = 0, vdim: int = 0, dropout: float = 0.0, - bias: bool = True, - add_bias_kv: bool = False, add_zero_attn: bool = False, data_type: DataType = DataType.DT_NONE, kernel_initializer=None, @@ -34,8 +32,6 @@ def test_spec_inc_multihead_self_attention( kdim=kdim, vdim=vdim, dropout=dropout, - bias=bias, - add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, data_type=data_type, kernel_initializer=kernel_initializer, @@ -85,8 +81,6 @@ def test_spec_inc_multihead_self_attention( kdim=0, # Example value for kdim vdim=0, # Example value for vdim dropout=0.1, # Example value for dropout - bias=True, - add_bias_kv=False, add_zero_attn=False, data_type=DataType.DT_FLOAT, kernel_initializer=None, # Example value for kernel_initializer diff --git a/examples/python/native/ops/spec_inc_multiquery_self_attention.py b/examples/python/native/ops/spec_inc_multiquery_self_attention.py index 0b731c99e0..0d04f639c9 100644 --- a/examples/python/native/ops/spec_inc_multiquery_self_attention.py +++ b/examples/python/native/ops/spec_inc_multiquery_self_attention.py @@ -12,8 +12,6 @@ def test_spec_inc_multiquery_self_attention( kdim: int = 0, vdim: int = 0, dropout: float = 0.0, - bias: bool = True, - add_bias_kv: bool = False, add_zero_attn: bool = False, data_type: DataType = DataType.DT_NONE, kernel_initializer=None, @@ -36,8 +34,6 @@ def test_spec_inc_multiquery_self_attention( kdim=kdim, vdim=vdim, dropout=dropout, - bias=bias, - add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, data_type=data_type, kernel_initializer=kernel_initializer, @@ -89,8 +85,6 @@ def test_spec_inc_multiquery_self_attention( kdim=0, # Example value for kdim vdim=0, # Example value for vdim dropout=0.1, # Example value for dropout - bias=True, - add_bias_kv=False, add_zero_attn=False, data_type=DataType.DT_FLOAT, kernel_initializer=None, # Example value for kernel_initializer diff --git a/include/flexflow/flexflow_c.h b/include/flexflow/flexflow_c.h index 52b4b3d362..c1e18e660b 100644 --- a/include/flexflow/flexflow_c.h +++ b/include/flexflow/flexflow_c.h @@ -445,12 +445,16 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention( int kdim, int vdim, float dropout, - bool bias, - bool add_bias_kv, bool add_zero_attn, enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -465,12 +469,16 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention( int kdim, int vdim, float dropout, - bool bias, - bool add_bias_kv, bool add_zero_attn, enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -485,12 +493,16 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention_verify( int kdim, int vdim, float dropout, - bool bias, - bool add_bias_kv, bool add_zero_attn, enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -506,12 +518,16 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention( int kdim, int vdim, float dropout, - bool bias, - bool add_bias_kv, bool add_zero_attn, enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -527,12 +543,16 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention( int kdim, int vdim, float dropout, - bool bias, - bool add_bias_kv, bool add_zero_attn, enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -548,12 +568,16 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention_verify( int kdim, int vdim, float dropout, - bool bias, - bool add_bias_kv, bool add_zero_attn, enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, diff --git a/include/flexflow/inference.h b/include/flexflow/inference.h index ba4101c173..755df9f5cb 100644 --- a/include/flexflow/inference.h +++ b/include/flexflow/inference.h @@ -43,8 +43,43 @@ struct GenerationResult { std::vector finetuning_losses; }; -#include -#include +struct RotaryEmbeddingMeta { + bool apply_rotary_embedding = false; + float rope_theta = 10000.0f; + std::string rope_type = "default"; + float factor = 8.0f; + float low_freq_factor = 1.0f; + float high_freq_factor = 4.0f; + int original_max_position_embeddings = 8192; + + RotaryEmbeddingMeta(bool apply_rotary_embedding_ = false, + float rope_theta_ = 10000.0f, + std::string rope_type_ = "default", + float factor_ = 8.0f, + float low_freq_factor_ = 1.0f, + float high_freq_factor_ = 4.0f, + int original_max_position_embeddings_ = 8192) + : apply_rotary_embedding(apply_rotary_embedding_), + rope_theta(rope_theta_), rope_type(rope_type_), factor(factor_), + low_freq_factor(low_freq_factor_), high_freq_factor(high_freq_factor_), + original_max_position_embeddings(original_max_position_embeddings_) {} + + friend std::ostream &operator<<(std::ostream &os, + RotaryEmbeddingMeta const &meta) { + os << std::boolalpha // To print bool as true/false instead of 1/0 + << "RotaryEmbeddingMeta {\n" + << " apply_rotary_embedding: " << meta.apply_rotary_embedding << ",\n" + << " rope_theta: " << meta.rope_theta << ",\n" + << " rope_type: \"" << meta.rope_type << "\",\n" + << " factor: " << meta.factor << ",\n" + << " low_freq_factor: " << meta.low_freq_factor << ",\n" + << " high_freq_factor: " << meta.high_freq_factor << ",\n" + << " original_max_position_embeddings: " + << meta.original_max_position_embeddings << "\n" + << "}"; + return os; + } +}; std::string join_path(std::vector const &paths); diff --git a/include/flexflow/layer.h b/include/flexflow/layer.h index c3dbcac422..e18bad3982 100644 --- a/include/flexflow/layer.h +++ b/include/flexflow/layer.h @@ -32,11 +32,13 @@ class Layer { void add_float_property(std::string const &key, float value); void add_int_vector_property(std::string const &key, std::vector const &value); + void add_string_property(std::string const &key, std::string const &value); void add_initializer(std::string const &key, Initializer *initializer); bool get_int_property(std::string const &key, long long &value) const; bool get_float_property(std::string const &key, float &value) const; bool get_int_vector_property(std::string const &key, std::vector &value) const; + bool get_string_property(std::string const &key, std::string &value) const; bool get_initializer(std::string const &key, Initializer *&initializer) const; Tensor get_parameter(int index); void print(); @@ -59,6 +61,7 @@ class Layer { std::unordered_map float_properties; std::unordered_map initializers; std::unordered_map> int_vector_properties; + std::unordered_map string_properties; }; }; // namespace FlexFlow diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 4ad735ef7d..51b7950db8 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -733,41 +733,38 @@ class FFModel { DataType data_type = DT_NONE, Initializer *kernel_initializer = NULL, char const *name = NULL); - Tensor inc_multihead_self_attention(const Tensor input, - int embed_dim, - int num_heads, - int kdim = 0, - int vdim = 0, - float dropout = 0.0f, - bool bias = false, - bool add_bias_kv = false, - bool add_zero_attn = false, - DataType data_type = DT_NONE, - Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, - bool scaling_query = false, - float scaling_factor = 1.0f, - bool qk_prod_scaling = true, - bool position_bias = false, - char const *name = NULL); - Tensor - spec_inc_multihead_self_attention(const Tensor input, - int embed_dim, - int num_heads, - int kdim = 0, - int vdim = 0, - float dropout = 0.0f, - bool bias = false, - bool add_bias_kv = false, - bool add_zero_attn = false, - DataType data_type = DT_NONE, - Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, - bool scaling_query = false, - float scaling_factor = 1.0f, - bool qk_prod_scaling = true, - bool position_bias = false, - char const *name = NULL); + Tensor inc_multihead_self_attention( + const Tensor input, + int embed_dim, + int num_heads, + int kdim = 0, + int vdim = 0, + float dropout = 0.0f, + bool add_zero_attn = false, + DataType data_type = DT_NONE, + Initializer *kernel_initializer = NULL, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), + bool scaling_query = false, + float scaling_factor = 1.0f, + bool qk_prod_scaling = true, + bool position_bias = false, + char const *name = NULL); + Tensor spec_inc_multihead_self_attention( + const Tensor input, + int embed_dim, + int num_heads, + int kdim = 0, + int vdim = 0, + float dropout = 0.0f, + bool add_zero_attn = false, + DataType data_type = DT_NONE, + Initializer *kernel_initializer = NULL, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), + bool scaling_query = false, + float scaling_factor = 1.0f, + bool qk_prod_scaling = true, + bool position_bias = false, + char const *name = NULL); Tensor inc_multihead_self_attention_verify( const Tensor input, int embed_dim, @@ -775,54 +772,49 @@ class FFModel { int kdim = 0, int vdim = 0, float dropout = 0.0f, - bool bias = false, - bool add_bias_kv = false, bool add_zero_attn = false, DataType data_type = DT_NONE, Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), + bool scaling_query = false, + float scaling_factor = 1.0f, + bool qk_prod_scaling = true, + bool position_bias = false, + char const *name = NULL); + Tensor inc_multiquery_self_attention( + const Tensor input, + int embed_dim, + int num_q_heads, + int num_kv_heads, + int kdim = 0, + int vdim = 0, + float dropout = 0.0f, + bool add_zero_attn = false, + DataType data_type = DT_NONE, + Initializer *kernel_initializer = NULL, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), + bool scaling_query = false, + float scaling_factor = 1.0f, + bool qk_prod_scaling = true, + bool position_bias = false, + char const *name = NULL); + Tensor spec_inc_multiquery_self_attention( + const Tensor input, + int embed_dim, + int num_q_heads, + int num_kv_heads, + int kdim = 0, + int vdim = 0, + float dropout = 0.0f, + bool add_zero_attn = false, + DataType data_type = DT_NONE, + Initializer *kernel_initializer = NULL, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), bool scaling_query = false, float scaling_factor = 1.0f, bool qk_prod_scaling = true, bool position_bias = false, char const *name = NULL); - Tensor inc_multiquery_self_attention(const Tensor input, - int embed_dim, - int num_q_heads, - int num_kv_heads, - int kdim = 0, - int vdim = 0, - float dropout = 0.0f, - bool bias = false, - bool add_bias_kv = false, - bool add_zero_attn = false, - DataType data_type = DT_NONE, - Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, - bool scaling_query = false, - float scaling_factor = 1.0f, - bool qk_prod_scaling = true, - bool position_bias = false, - char const *name = NULL); - Tensor - spec_inc_multiquery_self_attention(const Tensor input, - int embed_dim, - int num_q_heads, - int num_kv_heads, - int kdim = 0, - int vdim = 0, - float dropout = 0.0f, - bool bias = false, - bool add_bias_kv = false, - bool add_zero_attn = false, - DataType data_type = DT_NONE, - Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, - bool scaling_query = false, - float scaling_factor = 1.0f, - bool qk_prod_scaling = true, - bool position_bias = false, - char const *name = NULL); Tensor inc_multiquery_self_attention_verify( const Tensor input, int embed_dim, @@ -831,12 +823,10 @@ class FFModel { int kdim = 0, int vdim = 0, float dropout = 0.0f, - bool bias = false, - bool add_bias_kv = false, bool add_zero_attn = false, DataType data_type = DT_NONE, Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), bool scaling_query = false, float scaling_factor = 1.0f, bool qk_prod_scaling = true, diff --git a/include/flexflow/operator.h b/include/flexflow/operator.h index 1a5af67b36..007314797a 100644 --- a/include/flexflow/operator.h +++ b/include/flexflow/operator.h @@ -335,7 +335,13 @@ class Op { // only dump the weights in the forward pass, at the first step // note that we do not save the weight gradients, since we only support // finetuning LoRA weights, which are not FF tensors. - if (fwd_pass && m->decoding_step == 0) { + // Set FF_DEBG_NO_WEIGHTS=1 or to FF_DEBG_NO_WEIGHTS=true to disable saving + // weights + bool do_not_save_weights = + (std::getenv("FF_DEBG_NO_WEIGHTS") && + (std::string(std::getenv("FF_DEBG_NO_WEIGHTS")) == "1" || + std::string(std::getenv("FF_DEBG_NO_WEIGHTS")) == "true")); + if (fwd_pass && m->decoding_step == 0 && !do_not_save_weights) { fs::path dst_filepath_weights = get_dst_folder("weights", m->decoding_step, shard_id, before_kernel) / layername; diff --git a/include/flexflow/ops/inc_multihead_self_attention.h b/include/flexflow/ops/inc_multihead_self_attention.h index f77df7c456..4519cf8215 100644 --- a/include/flexflow/ops/inc_multihead_self_attention.h +++ b/include/flexflow/ops/inc_multihead_self_attention.h @@ -36,49 +36,40 @@ class IncMultiHeadSelfAttention : public Op { int _kdim, int _vdim, float _dropout, - bool _qkv_bias, - bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, bool _position_bias, - bool allocate_weights, DataType _quantization_type, bool _offload, int _tensor_parallelism_degree, char const *name); IncMultiHeadSelfAttention(FFModel &model, ParallelTensor const _input, - ParallelTensor const _weight, int _embed_dim, int _num_q_heads, int _num_kv_heads, int _kdim, int _vdim, float _dropout, - bool _qkv_bias, - bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, bool _position_bias, - bool allocate_weights, DataType _quantization_type, bool _offload, int _tensor_parallelism_degree, char const *name); IncMultiHeadSelfAttention(FFModel &model, IncMultiHeadSelfAttention const &other, - ParallelTensor const input, - bool allocate_weights); + ParallelTensor const input); IncMultiHeadSelfAttention(FFModel &model, Params const ¶ms, Input const &inputs, - bool allocate_weights = false, char const *name = nullptr); static Op * create_operator_from_layer(FFModel &model, @@ -125,24 +116,20 @@ class IncMultiHeadSelfAttention : public Op { BatchConfig const *bc, int shard_id, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &weight, - GenericTensorAccessorW const &output, - GenericTensorAccessorR const &bias); - static void peft_bwd_kernel_wrapper(IncMultiHeadSelfAttentionMeta *m, - BatchConfig const *bc, - int shard_id, - GenericTensorAccessorW const &input_grad, - GenericTensorAccessorR const &weight, - GenericTensorAccessorR const &output_grad, - GenericTensorAccessorR const &bias); + GenericTensorAccessorW const &output); + static void + peft_bwd_kernel_wrapper(IncMultiHeadSelfAttentionMeta *m, + BatchConfig const *bc, + int shard_id, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output_grad); Params get_params() const; public: int num_q_heads, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; - bool qkv_bias; - bool final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, - qk_prod_scaling, position_bias; + bool add_zero_attn, scaling_query, qk_prod_scaling, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; int qSize, kSize, vSize, qProjSize, kProjSize, vProjSize, oProjSize; int qoSeqLength, kvSeqLength; DataType quantization_type; @@ -153,7 +140,6 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta { public: IncMultiHeadSelfAttentionMeta(FFHandler handler, IncMultiHeadSelfAttention const *attn, - GenericTensorAccessorR const &weight, MemoryAllocator &gpu_mem_allocator, int num_samples, int _num_q_heads, @@ -168,14 +154,11 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta { int _kProjSize, int _vProjSize, int _oProjSize, - bool _apply_rotary_embedding, - bool _qkv_bias, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, bool _qk_prod_scaling, bool _position_bias, - bool _final_bias, float _scaling_factor, - GenericTensorAccessorR const &weight, MemoryAllocator &gpu_mem_allocator, int num_samples, int _global_num_q_heads, @@ -188,30 +171,23 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta { public: Realm::RegionInstance reserveInst; - size_t weights_params, weightSize, biasSize, reserveSpaceSize, - quantized_weightSize; + size_t reserveSpaceSize; int qSize, kSize, vSize, qProjSize, kProjSize, vProjSize, oProjSize; int global_num_q_heads, global_num_kv_heads, num_q_heads, num_kv_heads, hidden_size; - bool *has_load_weights; - bool *apply_rotary_embedding; - bool *qkv_bias; - bool *final_bias; + RotaryEmbeddingMeta *rotary_embedding_meta; bool *scaling_query; bool *qk_prod_scaling; bool *position_bias; float scaling_factor; - void *weight_ptr, *bias_ptr; // for weight offload void *devQKVProjArray, *keyCache, *valueCache; void *qk_prods, *qk_prods_softmax; void *attn_heads; - char *quantized_weight_ptr; BatchConfig::PerTokenInfo *token_infos; BatchConfig::PerRequestInfo *request_infos; DataType quantization_type; bool offload; #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) - // cudaStream_t task_local_stream; cudnnTensorDescriptor_t qk_tensor; cuFloatComplex *complex_input; #elif defined(FF_USE_HIP_ROCM) diff --git a/include/flexflow/ops/inc_multihead_self_attention_params.h b/include/flexflow/ops/inc_multihead_self_attention_params.h index 58681069e2..9b0a26e5d7 100644 --- a/include/flexflow/ops/inc_multihead_self_attention_params.h +++ b/include/flexflow/ops/inc_multihead_self_attention_params.h @@ -3,6 +3,7 @@ #include "flexflow/ffconst.h" #include "flexflow/fftype.h" +#include "flexflow/inference.h" #include "flexflow/parallel_tensor.h" namespace FlexFlow { @@ -12,8 +13,8 @@ struct IncMultiHeadSelfAttentionParams { int embed_dim, num_q_heads, kdim, vdim, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, position_bias; + bool add_zero_attn, scaling_query, qk_prod_scaling, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; DataType quantization_type; bool offload; char name[MAX_OPNAME]; diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h index 26dcf12425..16d5915381 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h @@ -14,6 +14,11 @@ namespace FlexFlow { namespace Kernels { namespace IncMultiHeadAttention { +template +void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, + BatchConfig const *bc, + int shard_id, + ffStream_t stream); template void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, @@ -21,14 +26,11 @@ void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m, ffStream_t stream); template -void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - int shard_id, - DT *output_ptr, - DT const *weight_ptr, - DT const *bias_ptr, - int num_tokens, - ffStream_t stream); +void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + int shard_id, + DT *output_ptr, + ffStream_t stream); template __global__ void apply_position_bias_qkprd(DT *input_ptr, @@ -38,27 +40,6 @@ __global__ void apply_position_bias_qkprd(DT *input_ptr, int global_num_q_heads, int shard_id); -template -__global__ void apply_proj_bias_w(DT *input_ptr, - DT const *bias_ptr, - int num_tokens, - int qkv_weight_size, - int oProjSize); - -template -__global__ void apply_proj_bias_qkv(DT *input_ptr, - DT const *bias_ptr, - int shard_id, - int num_tokens, - int qProjSize, - int kProjSize, - int vProjSize, - int num_heads, - int num_kv_heads, - bool scaling_query, - float scaling_factor, - int hidden_size); - #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) template __global__ void @@ -91,16 +72,6 @@ __global__ void bool q_tensor); #endif -template -void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - int shard_id, - DT const *input_ptr, - DT const *weight_ptr, - DT *output_ptr, - DT const *bias_ptr, - ffStream_t stream); - template void pre_build_weight_kernel(IncMultiHeadSelfAttentionMeta const *m, GenericTensorAccessorR const weight, diff --git a/include/flexflow/ops/spec_inc_multihead_self_attention.h b/include/flexflow/ops/spec_inc_multihead_self_attention.h index a0d01092bf..155132a7fe 100644 --- a/include/flexflow/ops/spec_inc_multihead_self_attention.h +++ b/include/flexflow/ops/spec_inc_multihead_self_attention.h @@ -33,43 +33,34 @@ class SpecIncMultiHeadSelfAttention : public Op { int _kdim, int _vdim, float _dropout, - bool _qkv_bias, - bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, bool _position_bias, - bool allocate_weights, char const *name); SpecIncMultiHeadSelfAttention(FFModel &model, const ParallelTensor _input, - const ParallelTensor _weight, int _embed_dim, int _num_q_heads, int _num_kv_heads, int _kdim, int _vdim, float _dropout, - bool _qkv_bias, - bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, bool _position_bias, - bool allocate_weights, char const *name); SpecIncMultiHeadSelfAttention(FFModel &model, SpecIncMultiHeadSelfAttention const &other, - const ParallelTensor input, - bool allocate_weights); + const ParallelTensor input); SpecIncMultiHeadSelfAttention(FFModel &model, Params const ¶ms, Input const &inputs, - bool allocate_weights = false, char const *name = nullptr); static Op * create_operator_from_layer(FFModel &model, @@ -112,17 +103,14 @@ class SpecIncMultiHeadSelfAttention : public Op { BeamSearchBatchConfig const *bc, int shard_id, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &weight, - GenericTensorAccessorW const &output, - GenericTensorAccessorR const &bias); + GenericTensorAccessorW const &output); Params get_params() const; public: int num_q_heads, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; - bool qkv_bias; - bool final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, - qk_prod_scaling, position_bias; + bool add_zero_attn, scaling_query, qk_prod_scaling, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; int qSize, kSize, vSize, qProjSize, kProjSize, vProjSize, oProjSize; int qoSeqLength, kvSeqLength; }; @@ -131,7 +119,6 @@ class SpecIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta { public: SpecIncMultiHeadSelfAttentionMeta(FFHandler handler, SpecIncMultiHeadSelfAttention const *attn, - GenericTensorAccessorR const &weight, MemoryAllocator &gpu_mem_allocator, int num_samples, int _num_q_heads, diff --git a/include/flexflow/ops/spec_inc_multihead_self_attention_params.h b/include/flexflow/ops/spec_inc_multihead_self_attention_params.h index 1461224ba9..a0ae3fc4f2 100644 --- a/include/flexflow/ops/spec_inc_multihead_self_attention_params.h +++ b/include/flexflow/ops/spec_inc_multihead_self_attention_params.h @@ -11,8 +11,8 @@ struct SpecIncMultiHeadSelfAttentionParams { LayerID layer_guid; int embed_dim, num_q_heads, num_kv_heads, kdim, vdim; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, position_bias; + bool add_zero_attn, scaling_query, qk_prod_scaling, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; char name[MAX_OPNAME]; bool is_valid(ParallelTensorShape const &) const; }; diff --git a/include/flexflow/ops/tree_inc_multihead_self_attention.h b/include/flexflow/ops/tree_inc_multihead_self_attention.h index 168ad5f618..9755e62d42 100644 --- a/include/flexflow/ops/tree_inc_multihead_self_attention.h +++ b/include/flexflow/ops/tree_inc_multihead_self_attention.h @@ -33,49 +33,40 @@ class TreeIncMultiHeadSelfAttention : public Op { int _kdim, int _vdim, float _dropout, - bool _qkv_bias, - bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, bool _position_bias, - bool allocate_weights, DataType _quantization_type, bool _offload, int _tensor_parallelism_degree, char const *name); TreeIncMultiHeadSelfAttention(FFModel &model, const ParallelTensor _input, - const ParallelTensor _weight, int _embed_dim, int _num_q_heads, int _num_kv_heads, int _kdim, int _vdim, float _dropout, - bool _qkv_bias, - bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, bool _position_bias, - bool allocate_weights, DataType _quantization_type, bool _offload, int _tensor_parallelism_degree, char const *name); TreeIncMultiHeadSelfAttention(FFModel &model, TreeIncMultiHeadSelfAttention const &other, - const ParallelTensor input, - bool allocate_weights); + const ParallelTensor input); TreeIncMultiHeadSelfAttention(FFModel &model, Params const ¶ms, Input const &inputs, - bool allocate_weights = false, char const *name = nullptr); static Op * create_operator_from_layer(FFModel &model, @@ -114,18 +105,14 @@ class TreeIncMultiHeadSelfAttention : public Op { TreeVerifyBatchConfig const *bc, int shard_id, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &weight, - GenericTensorAccessorW const &output, - GenericTensorAccessorR const &bias); - + GenericTensorAccessorW const &output); Params get_params() const; public: int num_q_heads, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; - bool qkv_bias; - bool final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, - qk_prod_scaling, position_bias; + bool add_zero_attn, scaling_query, qk_prod_scaling, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; int qSize, kSize, vSize, qProjSize, kProjSize, vProjSize, oProjSize; int qoSeqLength, kvSeqLength; DataType quantization_type; @@ -136,7 +123,6 @@ class TreeIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta { public: TreeIncMultiHeadSelfAttentionMeta(FFHandler handler, TreeIncMultiHeadSelfAttention const *attn, - GenericTensorAccessorR const &weight, MemoryAllocator &gpu_mem_allocator, int num_samples, int _num_q_heads, diff --git a/include/flexflow/ops/tree_inc_multihead_self_attention_params.h b/include/flexflow/ops/tree_inc_multihead_self_attention_params.h index d1a51b8b8f..b49db2c10d 100644 --- a/include/flexflow/ops/tree_inc_multihead_self_attention_params.h +++ b/include/flexflow/ops/tree_inc_multihead_self_attention_params.h @@ -12,8 +12,8 @@ struct TreeIncMultiHeadSelfAttentionParams { int embed_dim, num_q_heads, kdim, vdim, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, position_bias; + bool add_zero_attn, scaling_query, qk_prod_scaling, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; DataType quantization_type; bool offload; char name[MAX_OPNAME]; diff --git a/inference/models/falcon.cc b/inference/models/falcon.cc index 195d6ba7e3..fd4da87b99 100644 --- a/inference/models/falcon.cc +++ b/inference/models/falcon.cc @@ -60,6 +60,7 @@ void FALCON::create_falcon_model(FFModel &ff, "word_embeddings"); Tensor mha = nullptr, mlp_output = nullptr; + Tensor qkv_proj = nullptr, o_proj = nullptr; Tensor res_ln_outputs[2] = {nullptr, nullptr}; for (int i = 0; i < falcon_config.n_layer; i++) { @@ -97,26 +98,41 @@ void FALCON::create_falcon_model(FFModel &ff, att_norm = res_ln_outputs[1]; } + qkv_proj = ff.dense( + att_norm, + falcon_config.hidden_size * + 3, // q, k, v. need to change if want to remove replication. + // (q_heads + 2 * kv_heads) * proj_size + AC_MODE_NONE, + false, // seems like it does not use bias + DT_NONE, // what is this + nullptr, // ? + nullptr, // ? + nullptr, // ? + REG_MODE_NONE, // no regularization + 0.0f, // no dropout + std::string("layers." + std::to_string(i) + ".self_attention.qkv_proj") + .c_str()); + qkv_proj->print("qkv_proj"); + switch (mode) { case BEAM_SEARCH_MODE: { - mha = ff.spec_inc_multiquery_self_attention( - att_norm, + o_proj = ff.spec_inc_multiquery_self_attention( + qkv_proj, falcon_config.hidden_size, falcon_config.n_head, falcon_config.n_head_kv, falcon_config.hidden_size / falcon_config.n_head, falcon_config.hidden_size / falcon_config.n_head, 0.0f, /*dropout*/ - false, /*qkv_bias*/ - false, /*final_bias*/ false, /*add_zero_attn*/ DT_NONE, /*data_type*/ NULL, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ - false, /*scaling query*/ - 1.0f, /*scaling factor*/ - true, /*qk_prod_scaling*/ - false, /*position_bias*/ + falcon_config.rotary_embedding_meta, + false, /*scaling query*/ + 1.0f, /*scaling factor*/ + true, /*qk_prod_scaling*/ + false, /*position_bias*/ std::string("layers." + std::to_string(i) + ".self_attention") .c_str() /*name*/ ); @@ -124,24 +140,22 @@ void FALCON::create_falcon_model(FFModel &ff, } case TREE_VERIFY_MODE: { - mha = ff.inc_multiquery_self_attention_verify( - att_norm, + o_proj = ff.inc_multiquery_self_attention_verify( + qkv_proj, falcon_config.hidden_size, falcon_config.n_head, falcon_config.n_head_kv, falcon_config.hidden_size / falcon_config.n_head, falcon_config.hidden_size / falcon_config.n_head, 0.0f, /*dropout*/ - false, /*qkv_bias*/ - false, /*final_bias*/ false, /*add_zero_attn*/ DT_NONE, /*data_type*/ nullptr, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ - false, /*scaling query*/ - 1.0f, /*scaling factor*/ - true, /*qk_prod_scaling*/ - false, /*position_bias*/ + falcon_config.rotary_embedding_meta, + false, /*scaling query*/ + 1.0f, /*scaling factor*/ + true, /*qk_prod_scaling*/ + false, /*position_bias*/ std::string("layers." + std::to_string(i) + ".self_attention") .c_str() /*name*/ ); @@ -149,24 +163,22 @@ void FALCON::create_falcon_model(FFModel &ff, } case INC_DECODING_MODE: { - mha = ff.inc_multiquery_self_attention( - att_norm, + o_proj = ff.inc_multiquery_self_attention( + qkv_proj, falcon_config.hidden_size, falcon_config.n_head, falcon_config.n_head_kv, falcon_config.hidden_size / falcon_config.n_head, falcon_config.hidden_size / falcon_config.n_head, 0.0f, /*dropout*/ - false, /*qkv_bias*/ - false, /*final_bias*/ false, /*add_zero_attn*/ DT_NONE, /*data_type*/ nullptr, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ - false, /*scaling query*/ - 1.0f, /*scaling factor*/ - true, /*qk_prod_scaling*/ - false, /*position_bias*/ + falcon_config.rotary_embedding_meta, + false, /*scaling query*/ + 1.0f, /*scaling factor*/ + true, /*qk_prod_scaling*/ + false, /*position_bias*/ std::string("layers." + std::to_string(i) + ".self_attention") .c_str() /*name*/ ); @@ -177,6 +189,21 @@ void FALCON::create_falcon_model(FFModel &ff, } } + mha = ff.dense( + o_proj, + falcon_config.hidden_size, + AC_MODE_NONE, + false, + DT_NONE, + nullptr, + nullptr, + nullptr, + REG_MODE_NONE, + 0.0f, + std::string("layers." + std::to_string(i) + ".self_attention.o_proj") + .c_str()); + mha->print("mha"); + Tensor dense_h_to_4h = ff.dense( att_norm, falcon_config.hidden_size * 4, diff --git a/inference/models/falcon.h b/inference/models/falcon.h index fce2dade3f..565d7e5419 100644 --- a/inference/models/falcon.h +++ b/inference/models/falcon.h @@ -50,6 +50,26 @@ class FALCON { : model_config["num_hidden_layers"]; parallel_attn = model_config["parallel_attn"]; vocab_size = model_config["vocab_size"]; + rotary_embedding_meta.apply_rotary_embedding = true; + if (model_config.find("rope_theta") != model_config.end()) { + rotary_embedding_meta.rope_theta = model_config["rope_theta"]; + } else { + rotary_embedding_meta.rope_theta = 10000.0f; + } + if (model_config.find("scaling_factor") != model_config.end() && + !model_config["scaling_factor"].is_null()) { + rotary_embedding_meta.rope_type = + model_config["scaling_factor"]["rope_type"]; + rotary_embedding_meta.factor = + model_config["scaling_factor"]["factor"]; + rotary_embedding_meta.low_freq_factor = + model_config["scaling_factor"]["low_freq_factor"]; + rotary_embedding_meta.high_freq_factor = + model_config["scaling_factor"]["high_freq_factor"]; + rotary_embedding_meta.original_max_position_embeddings = + model_config["scaling_factor"] + ["original_max_position_embeddings"]; + } } catch (json::exception const &e) { std::cerr << "Error parsing JSON file: " << e.what() << std::endl; assert(false); @@ -59,8 +79,6 @@ class FALCON { << std::endl; assert(false); } - // max_seq_len = BatchConfig::MAX_SEQ_LENGTH; - // max_num_tokens = BatchConfig::MAX_NUM_TOKENS; max_beam_width = BeamSearchBatchConfig::MAX_BEAM_WIDTH; max_beam_depth = BeamSearchBatchConfig::MAX_BEAM_DEPTH; } @@ -76,9 +94,8 @@ class FALCON { std::cout << "\tn_layer: " << n_layer << std::endl; std::cout << "\tparallel_attn: " << parallel_attn << std::endl; std::cout << "\tvocab_size: " << vocab_size << std::endl; - - // std::cout << "\tmax_seq_len: " << max_seq_len << std::endl; - // std::cout << "\tmax_num_tokens: " << max_num_tokens << std::endl; + std::cout << "\trotary_embedding_meta: " << rotary_embedding_meta + << std::endl; std::cout << "\tmax_beam_width: " << max_beam_width << std::endl; std::cout << "\tmax_beam_depth: " << max_beam_depth << std::endl; } @@ -86,8 +103,8 @@ class FALCON { bool bias, multi_query, parallel_attn; int hidden_size, n_head, n_head_kv, n_layer, vocab_size; float layer_norm_epsilon; - // int max_seq_len, max_num_tokens; int max_beam_width, max_beam_depth; + RotaryEmbeddingMeta rotary_embedding_meta; }; static void create_falcon_model(FFModel &ff, diff --git a/inference/models/llama.cc b/inference/models/llama.cc index cf26194597..bd5243bd4b 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -91,28 +91,41 @@ void LLAMA::create_llama_model(FFModel &ff, token = token_att_norm[0]; att_norm = token_att_norm[1]; } + Tensor qkv_proj = ff.dense( + att_norm, + llama_config.hidden_size * + 3, // q, k, v. need to change if want to remove replication. + // (q_heads + 2 * kv_heads) * proj_size + AC_MODE_NONE, + false, // seems like llama does not use bias + DT_NONE, // what is this + nullptr, // ? + nullptr, // ? + nullptr, // ? + REG_MODE_NONE, // no regularization + 0.0f, // no dropout + std::string("layers." + std::to_string(i) + ".self_attn.qkv_proj") + .c_str()); Tensor mha; switch (mode) { case BEAM_SEARCH_MODE: { mha = ff.spec_inc_multiquery_self_attention( - att_norm, + qkv_proj, llama_config.hidden_size, llama_config.num_attention_heads, llama_config.num_key_value_heads, llama_config.hidden_size / llama_config.num_attention_heads, llama_config.hidden_size / llama_config.num_attention_heads, 0.0f, /*dropout*/ - false, /*qkv_bias*/ - false, /*final_bias*/ false, /*add_zero_attn*/ DT_NONE, /*data_type*/ NULL, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ - false, /*scaling query*/ - 1.0f, /*scaling factor*/ - true, /*qk_prod_scaling*/ - false, /*position_bias*/ + llama_config.rotary_embedding_meta, + false, /*scaling query*/ + 1.0f, /*scaling factor*/ + true, /*qk_prod_scaling*/ + false, /*position_bias*/ std::string("layers." + std::to_string(i) + ".self_attn") .c_str() /*name*/ ); @@ -120,23 +133,21 @@ void LLAMA::create_llama_model(FFModel &ff, } case TREE_VERIFY_MODE: { mha = ff.inc_multiquery_self_attention_verify( - att_norm, + qkv_proj, llama_config.hidden_size, llama_config.num_attention_heads, llama_config.num_key_value_heads, llama_config.hidden_size / llama_config.num_attention_heads, llama_config.hidden_size / llama_config.num_attention_heads, 0.0f, /*dropout*/ - false, /*qkv_bias*/ - false, /*final_bias*/ false, /*add_zero_attn*/ DT_NONE, /*data_type*/ nullptr, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ - false, /*scaling query*/ - 1.0f, /*scaling factor*/ - true, /*qk_prod_scaling*/ - false, /*position_bias*/ + llama_config.rotary_embedding_meta, + false, /*scaling query*/ + 1.0f, /*scaling factor*/ + true, /*qk_prod_scaling*/ + false, /*position_bias*/ std::string("layers." + std::to_string(i) + ".self_attn") .c_str() /*name*/ ); @@ -144,23 +155,21 @@ void LLAMA::create_llama_model(FFModel &ff, } case INC_DECODING_MODE: { mha = ff.inc_multiquery_self_attention( - att_norm, + qkv_proj, llama_config.hidden_size, llama_config.num_attention_heads, llama_config.num_key_value_heads, llama_config.hidden_size / llama_config.num_attention_heads, llama_config.hidden_size / llama_config.num_attention_heads, 0.0f, /*dropout*/ - false, /*qkv_bias*/ - false, /*final_bias*/ false, /*add_zero_attn*/ DT_NONE, /*data_type*/ nullptr, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ - false, /*scaling query*/ - 1.0f, /*scaling factor*/ - true, /*qk_prod_scaling*/ - false, /*position_bias*/ + llama_config.rotary_embedding_meta, + false, /*scaling query*/ + 1.0f, /*scaling factor*/ + true, /*qk_prod_scaling*/ + false, /*position_bias*/ std::string("layers." + std::to_string(i) + ".self_attn") .c_str() /*name*/ ); @@ -171,6 +180,21 @@ void LLAMA::create_llama_model(FFModel &ff, } } + Tensor mha_input = mha; + mha = ff.dense( + mha_input, + llama_config.hidden_size, + AC_MODE_NONE, + false, + DT_NONE, + nullptr, + nullptr, + nullptr, + REG_MODE_NONE, + 0.0f, + std::string("layers." + std::to_string(i) + ".self_attn.o_proj") + .c_str()); + // step 2: SILU activaion Tensor token_ff_norm[2] = {nullptr, nullptr}; ff.residual_rms_norm( diff --git a/inference/models/llama.h b/inference/models/llama.h index edb78f1300..853a51a999 100644 --- a/inference/models/llama.h +++ b/inference/models/llama.h @@ -44,6 +44,26 @@ class LLAMA { hidden_size = model_config["hidden_size"]; rms_norm_eps = model_config["rms_norm_eps"]; intermediate_size = model_config["intermediate_size"]; + rotary_embedding_meta.apply_rotary_embedding = true; + if (model_config.find("rope_theta") != model_config.end()) { + rotary_embedding_meta.rope_theta = model_config["rope_theta"]; + } else { + rotary_embedding_meta.rope_theta = 10000.0f; + } + if (model_config.find("scaling_factor") != model_config.end() && + !model_config["scaling_factor"].is_null()) { + rotary_embedding_meta.rope_type = + model_config["scaling_factor"]["rope_type"]; + rotary_embedding_meta.factor = + model_config["scaling_factor"]["factor"]; + rotary_embedding_meta.low_freq_factor = + model_config["scaling_factor"]["low_freq_factor"]; + rotary_embedding_meta.high_freq_factor = + model_config["scaling_factor"]["high_freq_factor"]; + rotary_embedding_meta.original_max_position_embeddings = + model_config["scaling_factor"] + ["original_max_position_embeddings"]; + } } catch (json::exception const &e) { std::cerr << "Error parsing LLAMA config from JSON file: " << e.what() << std::endl; @@ -54,8 +74,6 @@ class LLAMA { << std::endl; assert(false); } - // max_seq_len = BatchConfig::MAX_SEQ_LENGTH; - // max_num_tokens = BatchConfig::MAX_NUM_TOKENS; max_beam_width = BeamSearchBatchConfig::MAX_BEAM_WIDTH; max_beam_depth = BeamSearchBatchConfig::MAX_BEAM_DEPTH; } @@ -71,18 +89,17 @@ class LLAMA { std::cout << "\thidden_size: " << hidden_size << std::endl; std::cout << "\trms_norm_eps: " << rms_norm_eps << std::endl; std::cout << "\tintermediate_size: " << intermediate_size << std::endl; - - // std::cout << "\tmax_seq_len: " << max_seq_len << std::endl; - // std::cout << "\tmax_num_tokens: " << max_num_tokens << std::endl; + std::cout << "\trotary_embedding_meta: " << rotary_embedding_meta + << std::endl; std::cout << "\tmax_beam_width: " << max_beam_width << std::endl; std::cout << "\tmax_beam_depth: " << max_beam_depth << std::endl; } - // int max_seq_len, max_num_tokens; int max_beam_width, max_beam_depth; int num_hidden_layers, vocab_size, num_attention_heads, num_key_value_heads, hidden_size, intermediate_size; float rms_norm_eps; + RotaryEmbeddingMeta rotary_embedding_meta; }; static void create_llama_model(FFModel &ff, diff --git a/inference/models/mpt.cc b/inference/models/mpt.cc index e4a7e0056d..d02c0f3b82 100644 --- a/inference/models/mpt.cc +++ b/inference/models/mpt.cc @@ -93,22 +93,35 @@ void MPT::create_mpt_model(FFModel &ff, layernorm_output = res_ln_outputs[1]; } - Tensor attn_outputs; + Tensor qkv_proj = ff.dense( + layernorm_output, + mpt_config.hidden_size * + 3, // q, k, v. need to change if want to remove replication. + // (q_heads + 2 * kv_heads) * proj_size + AC_MODE_NONE, + false, // seems like it does not use bias + DT_NONE, // what is this + nullptr, // ? + nullptr, // ? + nullptr, // ? + REG_MODE_NONE, // no regularization + 0.0f, // no dropout + std::string("layers." + std::to_string(i) + ".attn.qkv_proj").c_str()); + + Tensor o_proj; switch (mode) { case BEAM_SEARCH_MODE: { - attn_outputs = ff.spec_inc_multihead_self_attention( - layernorm_output, + o_proj = ff.spec_inc_multihead_self_attention( + qkv_proj, mpt_config.hidden_size, mpt_config.n_heads, mpt_config.hidden_size / mpt_config.n_heads, mpt_config.hidden_size / mpt_config.n_heads, 0.0f, false, - false, - false, DT_NONE, /*data_type*/ NULL, - false, + mpt_config.rotary_embedding_meta, /*scaling query*/ true, /*scaling factor*/ pow((mpt_config.hidden_size / mpt_config.n_heads), -0.5), @@ -120,19 +133,17 @@ void MPT::create_mpt_model(FFModel &ff, break; } case TREE_VERIFY_MODE: { - attn_outputs = ff.inc_multihead_self_attention_verify( - layernorm_output, + o_proj = ff.inc_multihead_self_attention_verify( + qkv_proj, mpt_config.hidden_size, mpt_config.n_heads, mpt_config.hidden_size / mpt_config.n_heads, mpt_config.hidden_size / mpt_config.n_heads, 0.0f, false, - false, - false, DT_NONE, /*data_type*/ NULL, - false, + mpt_config.rotary_embedding_meta, /*scaling query*/ true, /*scaling factor*/ pow((mpt_config.hidden_size / mpt_config.n_heads), -0.5), @@ -144,19 +155,17 @@ void MPT::create_mpt_model(FFModel &ff, break; } case INC_DECODING_MODE: { - attn_outputs = ff.inc_multihead_self_attention( - layernorm_output, + o_proj = ff.inc_multihead_self_attention( + qkv_proj, mpt_config.hidden_size, mpt_config.n_heads, mpt_config.hidden_size / mpt_config.n_heads, mpt_config.hidden_size / mpt_config.n_heads, 0.0f, false, - false, - false, DT_NONE, /*data_type*/ NULL, - false, + mpt_config.rotary_embedding_meta, /*scaling query*/ true, /*scaling factor*/ pow((mpt_config.hidden_size / mpt_config.n_heads), -0.5), @@ -172,6 +181,19 @@ void MPT::create_mpt_model(FFModel &ff, } } + Tensor attn_outputs = ff.dense( + o_proj, + mpt_config.hidden_size, + AC_MODE_NONE, + false, + DT_NONE, + nullptr, + nullptr, + nullptr, + REG_MODE_NONE, + 0.0f, + std::string("layers." + std::to_string(i) + ".attn.o_proj").c_str()); + ff.residual_layer_norm( attn_outputs, hidden_states, diff --git a/inference/models/mpt.h b/inference/models/mpt.h index 08597e1d75..3001420ad0 100644 --- a/inference/models/mpt.h +++ b/inference/models/mpt.h @@ -37,6 +37,7 @@ class MPT { n_heads = model_config["n_heads"]; n_layers = model_config["n_layers"]; vocab_size = model_config["vocab_size"]; + rotary_embedding_meta.apply_rotary_embedding = false; } catch (json::exception const &e) { std::cerr << "Error parsing JSON file: " << e.what() << std::endl; assert(false); @@ -63,6 +64,7 @@ class MPT { // int max_seq_len, max_num_tokens; int max_beam_width, max_beam_depth; int hidden_size, n_heads, n_layers, vocab_size; + RotaryEmbeddingMeta rotary_embedding_meta; }; static void create_mpt_model(FFModel &ff, diff --git a/inference/models/opt.cc b/inference/models/opt.cc index b3f2ef4e17..34a6bb0f02 100644 --- a/inference/models/opt.cc +++ b/inference/models/opt.cc @@ -101,23 +101,37 @@ void OPT::create_opt_model(FFModel &ff, Tensor residual = res_ln_outputs[0]; Tensor hidden_states = res_ln_outputs[1]; - Tensor mha; + Tensor qkv_proj = ff.dense( + hidden_states, + opt_config.hidden_size * + 3, // q, k, v. need to change if want to remove replication. + // (q_heads + 2 * kv_heads) * proj_size + AC_MODE_NONE, + true, // seems like it does not use bias + DT_NONE, // what is this + nullptr, // ? + nullptr, // ? + nullptr, // ? + REG_MODE_NONE, // no regularization + 0.0f, // no dropout + std::string("layers." + std::to_string(i) + ".self_attn.qkv_proj") + .c_str()); + + Tensor o_proj; switch (mode) { case BEAM_SEARCH_MODE: { - mha = ff.spec_inc_multihead_self_attention( - hidden_states, + o_proj = ff.spec_inc_multihead_self_attention( + qkv_proj, opt_config.hidden_size, opt_config.num_attention_heads, opt_config.hidden_size / opt_config.num_attention_heads, opt_config.hidden_size / opt_config.num_attention_heads, 0.0f, /*dropout*/ - true, /*qkv_bias*/ - false, /*final_bias*/ false, /*add_zero_attn*/ DT_NONE, /*data_type*/ NULL, /*kernel_initializer*/ - false, /*apply_rotary_embedding*/ - true, /*scaling query*/ + opt_config.rotary_embedding_meta, + true, /*scaling query*/ pow((opt_config.hidden_size / opt_config.num_attention_heads), -0.5), /*scaling factor*/ false, /*qk_prod_scaling*/ @@ -128,20 +142,18 @@ void OPT::create_opt_model(FFModel &ff, break; } case TREE_VERIFY_MODE: { - mha = ff.inc_multihead_self_attention_verify( - hidden_states, + o_proj = ff.inc_multihead_self_attention_verify( + qkv_proj, opt_config.hidden_size, opt_config.num_attention_heads, opt_config.hidden_size / opt_config.num_attention_heads, opt_config.hidden_size / opt_config.num_attention_heads, 0.0f, /*dropout*/ - true, /*qkv_bias*/ - false, /*final_bias*/ false, /*add_zero_attn*/ DT_NONE, /*data_type*/ NULL, /*kernel_initializer*/ - false, /*apply_rotary_embedding*/ - true, /*scaling query*/ + opt_config.rotary_embedding_meta, + true, /*scaling query*/ pow((opt_config.hidden_size / opt_config.num_attention_heads), -0.5), /*scaling factor*/ false, /*qk_prod_scaling*/ @@ -152,20 +164,18 @@ void OPT::create_opt_model(FFModel &ff, break; } case INC_DECODING_MODE: { - mha = ff.inc_multihead_self_attention( - hidden_states, + o_proj = ff.inc_multihead_self_attention( + qkv_proj, opt_config.hidden_size, opt_config.num_attention_heads, opt_config.hidden_size / opt_config.num_attention_heads, opt_config.hidden_size / opt_config.num_attention_heads, 0.0f, /*dropout*/ - true, /*qkv_bias*/ - false, /*final_bias*/ false, /*add_zero_attn*/ DT_NONE, /*data_type*/ NULL, /*kernel_initializer*/ - false, /*apply_rotary_embedding*/ - true, /*scaling query*/ + opt_config.rotary_embedding_meta, + true, /*scaling query*/ pow((opt_config.hidden_size / opt_config.num_attention_heads), -0.5), /*scaling factor*/ false, /*qk_prod_scaling*/ @@ -180,6 +190,20 @@ void OPT::create_opt_model(FFModel &ff, } } + Tensor mha = ff.dense( + o_proj, + opt_config.hidden_size, + AC_MODE_NONE, + false, + DT_NONE, + nullptr, + nullptr, + nullptr, + REG_MODE_NONE, + 0.0f, + std::string("layers." + std::to_string(i) + ".self_attn.o_proj") + .c_str()); + ff.add_bias_residual_layer_norm(mha, residual, res_ln_outputs, diff --git a/inference/models/opt.h b/inference/models/opt.h index 7c736a26d1..8b85f81aa6 100644 --- a/inference/models/opt.h +++ b/inference/models/opt.h @@ -45,6 +45,7 @@ class OPT { num_hidden_layers = model_config["num_hidden_layers"]; vocab_size = model_config["vocab_size"]; word_embed_proj_dim = model_config["word_embed_proj_dim"]; + rotary_embedding_meta.apply_rotary_embedding = false; } catch (json::exception const &e) { std::cerr << "Error parsing JSON file: " << e.what() << std::endl; assert(false); @@ -54,8 +55,6 @@ class OPT { << std::endl; assert(false); } - // max_seq_len = BatchConfig::MAX_SEQ_LENGTH; - // max_num_tokens = BatchConfig::MAX_NUM_TOKENS; max_beam_width = BeamSearchBatchConfig::MAX_BEAM_WIDTH; max_beam_depth = BeamSearchBatchConfig::MAX_BEAM_DEPTH; } @@ -78,9 +77,8 @@ class OPT { std::cout << "\tvocab_size: " << vocab_size << std::endl; std::cout << "\tword_embed_proj_dim: " << word_embed_proj_dim << std::endl; - - // std::cout << "\tmax_seq_len: " << max_seq_len << std::endl; - // std::cout << "\tmax_num_tokens: " << max_num_tokens << std::endl; + std::cout << "\trotary_embedding_meta: " << rotary_embedding_meta + << std::endl; std::cout << "\tmax_beam_width: " << max_beam_width << std::endl; std::cout << "\tmax_beam_depth: " << max_beam_depth << std::endl; } @@ -91,6 +89,7 @@ class OPT { float dropout; int ffn_dim, hidden_size, max_position_embeddings, num_attention_heads, num_hidden_layers, vocab_size, word_embed_proj_dim; + RotaryEmbeddingMeta rotary_embedding_meta; }; static void create_opt_model(FFModel &ff, diff --git a/inference/models/starcoder.cc b/inference/models/starcoder.cc index cd8bf3a9a7..2429b1ec1b 100644 --- a/inference/models/starcoder.cc +++ b/inference/models/starcoder.cc @@ -102,11 +102,28 @@ void STARCODER::create_starcoder_model( Tensor hidden_states = res_ln_outputs[0]; Tensor ln_1 = res_ln_outputs[1]; + Tensor qkv_proj = ff.dense( + ln_1, + startcoder_config.hidden_size * + 3, // q, k, v. need to change if want to remove replication. + // (q_heads + 2 * kv_heads) * proj_size + AC_MODE_NONE, + false, // seems like it does not use bias + DT_NONE, // what is this + nullptr, // ? + nullptr, // ? + nullptr, // ? + REG_MODE_NONE, // no regularization + 0.0f, // no dropout + std::string("layers." + std::to_string(i) + ".self_attention.qkv_proj") + .c_str()); + Tensor mha; + Tensor o_proj; switch (mode) { case INC_DECODING_MODE: { - mha = ff.inc_multiquery_self_attention( - ln_1, + o_proj = ff.inc_multiquery_self_attention( + qkv_proj, startcoder_config.hidden_size, startcoder_config.num_attention_heads, 1, @@ -114,17 +131,15 @@ void STARCODER::create_starcoder_model( startcoder_config.num_attention_heads, startcoder_config.hidden_size / startcoder_config.num_attention_heads, - startcoder_config.dropout_p, /*dropout*/ - true, /*bias*/ - false, /*add_bias_kv*/ - false, /*add_zero_attn*/ - DT_NONE, /*data_type*/ - nullptr, /*kernel_initializer*/ - false, /*apply_rotary_embedding*/ - false, /*scaling query*/ - 1.0f, /*scaling factor*/ - true, /*qk_prod_scaling*/ - false, /*position_bias*/ + startcoder_config.dropout_p, /*dropout*/ + false, /*add_zero_attn*/ + DT_NONE, /*data_type*/ + nullptr, /*kernel_initializer*/ + startcoder_config.rotary_embedding_meta, /*apply_rotary_embedding*/ + false, /*scaling query*/ + 1.0f, /*scaling factor*/ + true, /*qk_prod_scaling*/ + false, /*position_bias*/ std::string("layers." + std::to_string(i) + ".attn.c_attn") .c_str() /*name*/ ); @@ -135,6 +150,20 @@ void STARCODER::create_starcoder_model( } } + mha = ff.dense( + o_proj, + startcoder_config.hidden_size, + AC_MODE_NONE, + true, + DT_NONE, + nullptr, + nullptr, + nullptr, + REG_MODE_NONE, + 0.0f, + std::string("layers." + std::to_string(i) + ".self_attn.o_proj") + .c_str()); + ff.residual_layer_norm( hidden_states, mha, diff --git a/inference/models/starcoder.h b/inference/models/starcoder.h index 0e9577d569..7ff6f33770 100644 --- a/inference/models/starcoder.h +++ b/inference/models/starcoder.h @@ -41,6 +41,7 @@ class STARCODER { intermediate_size = model_config["n_inner"]; dropout_p = model_config["attn_pdrop"]; max_position_embeddings = model_config["n_positions"]; + rotary_embedding_meta.apply_rotary_embedding = false; } catch (json::exception const &e) { std::cerr << "Error parsing STARCODER config from JSON file: " << e.what() << std::endl; @@ -51,8 +52,6 @@ class STARCODER { << std::endl; assert(false); } - // max_seq_len = BatchConfig::MAX_SEQ_LENGTH; - // max_num_tokens = BatchConfig::MAX_NUM_TOKENS; max_beam_width = BeamSearchBatchConfig::MAX_BEAM_WIDTH; max_beam_depth = BeamSearchBatchConfig::MAX_BEAM_DEPTH; } @@ -64,6 +63,7 @@ class STARCODER { int num_hidden_layers, vocab_size, num_attention_heads, hidden_size, intermediate_size, max_position_embeddings; float layer_norm_epsilon, dropout_p; + RotaryEmbeddingMeta rotary_embedding_meta; }; static void create_starcoder_model(FFModel &ff, diff --git a/inference/python/incr_decoding.py b/inference/python/incr_decoding.py index f888982f2c..1df5a05a8f 100644 --- a/inference/python/incr_decoding.py +++ b/inference/python/incr_decoding.py @@ -111,9 +111,15 @@ def main(): if len(configs.prompt) > 0: prompts = [s for s in json.load(open(configs.prompt))] - results = llm.generate(prompts) + if "max_length" not in configs_dict: + results = llm.generate(prompts) + else: + results = llm.generate(prompts, max_length=configs.max_length) else: - result = llm.generate("Three tips for staying healthy are: ") + if "max_length" not in configs_dict: + result = llm.generate("Three tips for staying healthy are: ") + else: + result = llm.generate("Three tips for staying healthy are: ", max_length=configs.max_length) llm.stop_server() diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index 7692ccb88f..a5aadc270e 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -41,6 +41,7 @@ from typing import Union, List from peft import LoraConfig import json +from dataclasses import dataclass def ffc(): @@ -2070,6 +2071,22 @@ def __init__( self.max_training_steps = max_training_steps +# ----------------------------------------------------------------------- +# RotaryEmbeddingMeta +# ----------------------------------------------------------------------- + + +@dataclass +class RotaryEmbeddingMeta: + apply_rotary_embedding: bool = False + rope_theta: float = 10000.0 + rope_type: str = "default" + factor: float = 8.0 + low_freq_factor: float = 1.0 + high_freq_factor: float = 4.0 + original_max_position_embeddings: int = 8192 + + # ----------------------------------------------------------------------- # FFModel # ----------------------------------------------------------------------- @@ -3509,12 +3526,10 @@ def inc_multihead_self_attention( kdim=0, vdim=0, dropout=0.0, - bias=True, - add_bias_kv=False, add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -3543,12 +3558,6 @@ def inc_multihead_self_attention( :param dropout: a Dropout layer on attn_output_weights. Default is 0.0 :type dropout: float(0-1) - :param bias: Whether the dense layers use bias vectors. Default is True. - :type bias: bool - - :param add_bias_kv: add bias to the key and value sequences at dim=0. Default is False. - :type add_bias_kv: bool - :param add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. Default is False. :type add_zero_attn: bool @@ -3558,8 +3567,8 @@ def inc_multihead_self_attention( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -3589,12 +3598,16 @@ def inc_multihead_self_attention( kdim, vdim, dropout, - bias, - add_bias_kv, add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, @@ -3612,12 +3625,10 @@ def spec_inc_multihead_self_attention( kdim=0, vdim=0, dropout=0.0, - bias=True, - add_bias_kv=False, add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -3646,12 +3657,6 @@ def spec_inc_multihead_self_attention( :param dropout: a Dropout layer on attn_output_weights. Default is 0.0 :type dropout: float(0-1) - :param bias: Whether the dense layers use bias vectors. Default is True. - :type bias: bool - - :param add_bias_kv: add bias to the key and value sequences at dim=0. Default is False. - :type add_bias_kv: bool - :param add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. Default is False. :type add_zero_attn: bool @@ -3661,8 +3666,8 @@ def spec_inc_multihead_self_attention( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -3692,12 +3697,16 @@ def spec_inc_multihead_self_attention( kdim, vdim, dropout, - bias, - add_bias_kv, add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, @@ -3715,12 +3724,10 @@ def inc_multihead_self_attention_verify( kdim=0, vdim=0, dropout=0.0, - bias=True, - add_bias_kv=False, add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -3749,12 +3756,6 @@ def inc_multihead_self_attention_verify( :param dropout: a Dropout layer on attn_output_weights. Default is 0.0 :type dropout: float(0-1) - :param bias: Whether the dense layers use bias vectors. Default is True. - :type bias: bool - - :param add_bias_kv: add bias to the key and value sequences at dim=0. Default is False. - :type add_bias_kv: bool - :param add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. Default is False. :type add_zero_attn: bool @@ -3764,8 +3765,8 @@ def inc_multihead_self_attention_verify( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -3795,12 +3796,16 @@ def inc_multihead_self_attention_verify( kdim, vdim, dropout, - bias, - add_bias_kv, add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, @@ -3819,12 +3824,10 @@ def inc_multiquery_self_attention( kdim=0, vdim=0, dropout=0.0, - bias=True, - add_bias_kv=False, add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -3856,12 +3859,6 @@ def inc_multiquery_self_attention( :param dropout: a Dropout layer on attn_output_weights. Default is 0.0 :type dropout: float(0-1) - :param bias: Whether the dense layers use bias vectors. Default is True. - :type bias: bool - - :param add_bias_kv: add bias to the key and value sequences at dim=0. Default is False. - :type add_bias_kv: bool - :param add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. Default is False. :type add_zero_attn: bool @@ -3871,8 +3868,8 @@ def inc_multiquery_self_attention( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -3903,12 +3900,16 @@ def inc_multiquery_self_attention( kdim, vdim, dropout, - bias, - add_bias_kv, add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, @@ -3927,12 +3928,10 @@ def spec_inc_multiquery_self_attention( kdim=0, vdim=0, dropout=0.0, - bias=True, - add_bias_kv=False, add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -3964,12 +3963,6 @@ def spec_inc_multiquery_self_attention( :param dropout: a Dropout layer on attn_output_weights. Default is 0.0 :type dropout: float(0-1) - :param bias: Whether the dense layers use bias vectors. Default is True. - :type bias: bool - - :param add_bias_kv: add bias to the key and value sequences at dim=0. Default is False. - :type add_bias_kv: bool - :param add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. Default is False. :type add_zero_attn: bool @@ -3979,8 +3972,8 @@ def spec_inc_multiquery_self_attention( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -4011,12 +4004,16 @@ def spec_inc_multiquery_self_attention( kdim, vdim, dropout, - bias, - add_bias_kv, add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, @@ -4035,12 +4032,10 @@ def inc_multiquery_self_attention_verify( kdim=0, vdim=0, dropout=0.0, - bias=True, - add_bias_kv=False, add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -4072,12 +4067,6 @@ def inc_multiquery_self_attention_verify( :param dropout: a Dropout layer on attn_output_weights. Default is 0.0 :type dropout: float(0-1) - :param bias: Whether the dense layers use bias vectors. Default is True. - :type bias: bool - - :param add_bias_kv: add bias to the key and value sequences at dim=0. Default is False. - :type add_bias_kv: bool - :param add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. Default is False. :type add_zero_attn: bool @@ -4087,8 +4076,8 @@ def inc_multiquery_self_attention_verify( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -4119,12 +4108,16 @@ def inc_multiquery_self_attention_verify( kdim, vdim, dropout, - bias, - add_bias_kv, add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, diff --git a/python/flexflow/serve/models/falcon.py b/python/flexflow/serve/models/falcon.py index 0e8fbcbd7d..0c6102406f 100644 --- a/python/flexflow/serve/models/falcon.py +++ b/python/flexflow/serve/models/falcon.py @@ -41,6 +41,17 @@ def __init__(self, hf_config): ) self.parallel_attn = hf_config.parallel_attn self.vocab_size = hf_config.vocab_size + self.rotary_embedding_meta = RotaryEmbeddingMeta( + apply_rotary_embedding=True, + rope_theta=hf_config.rope_theta if "rope_theta" in hf_config.__dict__ else 10000.0, + ) + if "rope_scaling" in hf_config.__dict__: + if hf_config.rope_scaling is not None: + self.rotary_embedding_meta.rope_type = hf_config.rope_scaling["rope_type"] + self.rotary_embedding_meta.factor = hf_config.rope_scaling["factor"] + self.rotary_embedding_meta.low_freq_factor = hf_config.rope_scaling["low_freq_factor"] + self.rotary_embedding_meta.high_freq_factor = hf_config.rope_scaling["high_freq_factor"] + self.rotary_embedding_meta.original_max_position_embeddings = hf_config.rope_scaling["original_max_position_embeddings"] # Standardized FlexFlow num heads fields below self.num_attention_heads = self.n_head self.num_key_value_heads = self.n_head_kv @@ -54,8 +65,6 @@ def __init__( ffconfig, hf_config, data_type, - # max_batch_size=1, - # max_seq_length=256, max_tokens_per_batch, weights_filepath="", tokenizer_filepath="", @@ -63,11 +72,8 @@ def __init__( self.mode = mode self.generation_config = generation_config self.ffconfig = ffconfig - # self.max_batch_size = max_batch_size self.data_type = data_type self.falcon_config = FalconConfig(hf_config) - # self.falcon_config.max_seq_length = max_seq_length - # self.falcon_config.max_num_tokens = max_tokens_per_batch self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 @@ -138,60 +144,70 @@ def build_model(self, max_tokens_per_batch): name=f"layers.{i}.input_layernorm", ) + qkv_proj = ffmodel.dense( + att_norm, + 3 * self.falcon_config.hidden_size, + ActiMode.AC_MODE_NONE, + False, + name=f"layers.{i}.self_attention.qkv_proj", + ) + if self.mode == InferenceMode.BEAM_SEARCH_MODE: - mha = ffmodel.spec_inc_multiquery_self_attention( - att_norm, + o_proj = ffmodel.spec_inc_multiquery_self_attention( + qkv_proj, self.falcon_config.hidden_size, self.falcon_config.n_head, self.falcon_config.n_head_kv, self.falcon_config.hidden_size // self.falcon_config.n_head, self.falcon_config.hidden_size // self.falcon_config.n_head, 0.0, # dropout - False, # qkv_bias - False, # final_bias False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.falcon_config.rotary_embedding_meta, name=f"layers.{i}.self_attention", ) elif self.mode == InferenceMode.TREE_VERIFY_MODE: - mha = ffmodel.inc_multiquery_self_attention_verify( - att_norm, + o_proj = ffmodel.inc_multiquery_self_attention_verify( + qkv_proj, self.falcon_config.hidden_size, self.falcon_config.n_head, self.falcon_config.n_head_kv, self.falcon_config.hidden_size // self.falcon_config.n_head, self.falcon_config.hidden_size // self.falcon_config.n_head, 0.0, # dropout - False, # qkv_bias - False, # final_bias False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.falcon_config.rotary_embedding_meta, name=f"layers.{i}.self_attention", ) elif self.mode == InferenceMode.INC_DECODING_MODE: - mha = ffmodel.inc_multiquery_self_attention( - att_norm, + o_proj = ffmodel.inc_multiquery_self_attention( + qkv_proj, self.falcon_config.hidden_size, self.falcon_config.n_head, self.falcon_config.n_head_kv, self.falcon_config.hidden_size // self.falcon_config.n_head, self.falcon_config.hidden_size // self.falcon_config.n_head, 0.0, # dropout - False, # qkv_bias - False, # final_bias False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.falcon_config.rotary_embedding_meta, name=f"layers.{i}.self_attention", ) else: assert False + mha = ffmodel.dense( + o_proj, + self.falcon_config.hidden_size, + ActiMode.AC_MODE_NONE, + False, + name=f"layers.{i}.self_attention.o_proj" + ) + dense_h_to_4h = ffmodel.dense( att_norm, self.falcon_config.hidden_size * 4, diff --git a/python/flexflow/serve/models/llama.py b/python/flexflow/serve/models/llama.py index 96f0258572..e149834603 100644 --- a/python/flexflow/serve/models/llama.py +++ b/python/flexflow/serve/models/llama.py @@ -19,8 +19,6 @@ class LLAMAConfig: def __init__(self, hf_config): - # self.max_seq_len = 256 - # self.max_num_tokens = 64 self.max_beam_width = 1 self.max_beam_depth = 8 self.max_spec_tree_token_num = 20 @@ -29,6 +27,17 @@ def __init__(self, hf_config): self.hidden_size = hf_config.hidden_size self.rms_norm_eps = hf_config.rms_norm_eps self.intermediate_size = hf_config.intermediate_size + self.rotary_embedding_meta = RotaryEmbeddingMeta( + apply_rotary_embedding=True, + rope_theta=hf_config.rope_theta if "rope_theta" in hf_config.__dict__ else 10000.0, + ) + if "rope_scaling" in hf_config.__dict__: + if hf_config.rope_scaling is not None: + self.rotary_embedding_meta.rope_type = hf_config.rope_scaling["rope_type"] + self.rotary_embedding_meta.factor = hf_config.rope_scaling["factor"] + self.rotary_embedding_meta.low_freq_factor = hf_config.rope_scaling["low_freq_factor"] + self.rotary_embedding_meta.high_freq_factor = hf_config.rope_scaling["high_freq_factor"] + self.rotary_embedding_meta.original_max_position_embeddings = hf_config.rope_scaling["original_max_position_embeddings"] # Standardized FlexFlow num heads fields below self.num_attention_heads = hf_config.num_attention_heads self.num_key_value_heads = ( @@ -55,11 +64,8 @@ def __init__( self.mode = mode self.generation_config = generation_config self.ffconfig = ffconfig - # self.max_batch_size = max_batch_size self.data_type = data_type self.llama_config = LLAMAConfig(hf_config) - # self.llama_config.max_seq_length = max_seq_length - # self.llama_config.max_num_tokens = max_tokens_per_batch self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2 ** 31 - 1 @@ -128,9 +134,17 @@ def build_model(self, max_tokens_per_batch): name=f"layers.{i}.input_layernorm", ) + qkv_proj = ffmodel.dense( + attn_norm, + 3 * self.llama_config.hidden_size, + ActiMode.AC_MODE_NONE, + False, + name=f"layers.{i}.self_attn.qkv_proj", + ) + if self.mode == InferenceMode.BEAM_SEARCH_MODE: mha = ffmodel.spec_inc_multiquery_self_attention( - attn_norm, + qkv_proj, self.llama_config.hidden_size, self.llama_config.num_attention_heads, self.llama_config.num_key_value_heads, @@ -139,17 +153,15 @@ def build_model(self, max_tokens_per_batch): self.llama_config.hidden_size // self.llama_config.num_attention_heads, 0.0, # dropout - False, # qkv_bias - False, # final_bias False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.llama_config.rotary_embedding_meta, name=f"layers.{i}.self_attn", ) elif self.mode == InferenceMode.TREE_VERIFY_MODE: mha = ffmodel.inc_multiquery_self_attention_verify( - attn_norm, + qkv_proj, self.llama_config.hidden_size, self.llama_config.num_attention_heads, self.llama_config.num_key_value_heads, @@ -158,17 +170,15 @@ def build_model(self, max_tokens_per_batch): self.llama_config.hidden_size // self.llama_config.num_attention_heads, 0.0, # dropout - False, # qkv_bias - False, # final_bias False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.llama_config.rotary_embedding_meta, name=f"layers.{i}.self_attn", ) elif self.mode == InferenceMode.INC_DECODING_MODE: mha = ffmodel.inc_multiquery_self_attention( - attn_norm, + qkv_proj, self.llama_config.hidden_size, self.llama_config.num_attention_heads, self.llama_config.num_key_value_heads, @@ -177,20 +187,26 @@ def build_model(self, max_tokens_per_batch): self.llama_config.hidden_size // self.llama_config.num_attention_heads, 0.0, # dropout - False, # qkv_bias - False, # final_bias False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.llama_config.rotary_embedding_meta, name=f"layers.{i}.self_attn", ) else: assert False + o_proj = ffmodel.dense( + mha, + self.llama_config.hidden_size, + ActiMode.AC_MODE_NONE, + False, + name=f"layers.{i}.self_attn.o_proj" + ) + token, ff_norm = ffmodel.residual_rms_norm( token, - mha, + o_proj, self.llama_config.rms_norm_eps, self.llama_config.hidden_size, name=f"layers.{i}.post_attention_layernorm", @@ -259,3 +275,7 @@ def convert_hf_model(model, dst_folder): for name, params in model.named_parameters(): name = FlexFlowLLAMA.convert_hf_weight_name(name) params.detach().cpu().numpy().tofile(f"{dst_folder}/{name}") + # LM head weight + model.lm_head.weight.detach().cpu().numpy().tofile( + os.path.join(dst_folder, "lm_head.weight") + ) diff --git a/python/flexflow/serve/models/mpt.py b/python/flexflow/serve/models/mpt.py index b350ae106d..a0e70b381a 100644 --- a/python/flexflow/serve/models/mpt.py +++ b/python/flexflow/serve/models/mpt.py @@ -19,8 +19,6 @@ class MPTConfig: def __init__(self, hf_config): - # self.max_seq_len = 256 - # self.max_num_tokens = 64 self.max_beam_width = 1 self.max_beam_depth = 8 self.max_spec_tree_token_num = 20 @@ -28,6 +26,7 @@ def __init__(self, hf_config): self.n_heads = hf_config.n_heads self.n_layers = hf_config.n_layers self.vocab_size = hf_config.vocab_size + self.rotary_embedding_meta = RotaryEmbeddingMeta(apply_rotary_embedding=False) # Standardized FlexFlow num heads fields below self.num_attention_heads = hf_config.n_heads self.num_key_value_heads = hf_config.n_heads @@ -50,11 +49,8 @@ def __init__( self.mode = mode self.generation_config = generation_config self.ffconfig = ffconfig - # self.max_batch_size = max_batch_size self.data_type = data_type self.mpt_config = MPTConfig(hf_config) - # self.mpt_config.max_seq_length = max_seq_length - # self.mpt_config.max_num_tokens = max_tokens_per_batch self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 @@ -129,20 +125,26 @@ def build_model(self, max_tokens_per_batch): name=f"layers.{i}.norm_1", ) + qkv_proj = ffmodel.dense( + layernorm_output, + 3 * self.mpt_config.hidden_size, + ActiMode.AC_MODE_NONE, + False, + name=f"layers.{i}.attn.qkv_proj", + ) + if self.mode == InferenceMode.BEAM_SEARCH_MODE: - attn_outputs = ffmodel.spec_inc_multihead_self_attention( - layernorm_output, + o_proj = ffmodel.spec_inc_multihead_self_attention( + qkv_proj, self.mpt_config.hidden_size, self.mpt_config.n_heads, self.mpt_config.hidden_size // self.mpt_config.n_heads, self.mpt_config.hidden_size // self.mpt_config.n_heads, 0.0, # dropout - False, # qkv_bias - False, # final_bias False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.mpt_config.rotary_embedding_meta, True, # scaling_query (self.mpt_config.hidden_size / self.mpt_config.n_heads) ** (-0.5), # scaling_factor @@ -151,19 +153,17 @@ def build_model(self, max_tokens_per_batch): name=f"layers.{i}.attn", ) elif self.mode == InferenceMode.TREE_VERIFY_MODE: - attn_outputs = ffmodel.inc_multihead_self_attention_verify( - layernorm_output, + o_proj = ffmodel.inc_multihead_self_attention_verify( + qkv_proj, self.mpt_config.hidden_size, self.mpt_config.n_heads, self.mpt_config.hidden_size // self.mpt_config.n_heads, self.mpt_config.hidden_size // self.mpt_config.n_heads, 0.0, # dropout - False, # qkv_bias - False, # final_bias False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.mpt_config.rotary_embedding_meta, True, # scaling_query (self.mpt_config.hidden_size / self.mpt_config.n_heads) ** (-0.5), # scaling_factor @@ -172,19 +172,17 @@ def build_model(self, max_tokens_per_batch): name=f"layers.{i}.attn", ) elif self.mode == InferenceMode.INC_DECODING_MODE: - attn_outputs = ffmodel.inc_multihead_self_attention( - layernorm_output, + o_proj = ffmodel.inc_multihead_self_attention( + qkv_proj, self.mpt_config.hidden_size, self.mpt_config.n_heads, self.mpt_config.hidden_size // self.mpt_config.n_heads, self.mpt_config.hidden_size // self.mpt_config.n_heads, 0.0, # dropout - False, # qkv_bias - False, # final_bias False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.mpt_config.rotary_embedding_meta, True, # scaling_query (self.mpt_config.hidden_size / self.mpt_config.n_heads) ** (-0.5), # scaling_factor @@ -195,6 +193,14 @@ def build_model(self, max_tokens_per_batch): else: assert False + attn_outputs = ffmodel.dense( + o_proj, + self.mpt_config.hidden_size, + ActiMode.AC_MODE_NONE, + False, + name=f"layers.{i}.attn.o_proj" + ) + hidden_states, layernorm_output = ffmodel.residual_layer_norm( attn_outputs, hidden_states, diff --git a/python/flexflow/serve/models/opt.py b/python/flexflow/serve/models/opt.py index 02668abf59..ba2e21b690 100644 --- a/python/flexflow/serve/models/opt.py +++ b/python/flexflow/serve/models/opt.py @@ -34,6 +34,7 @@ def __init__(self, hf_config): self.num_hidden_layers = hf_config.num_hidden_layers self.vocab_size = hf_config.vocab_size self.word_embed_proj_dim = hf_config.word_embed_proj_dim + self.rotary_embedding_meta = RotaryEmbeddingMeta(apply_rotary_embedding=False) # Standardized FlexFlow num heads fields below self.num_attention_heads = hf_config.num_attention_heads self.num_key_value_heads = hf_config.num_attention_heads @@ -47,8 +48,6 @@ def __init__( ffconfig, hf_config, data_type, - # max_batch_size=1, - # max_seq_length=256, max_tokens_per_batch, weights_filepath="", tokenizer_filepath="", @@ -56,11 +55,8 @@ def __init__( self.mode = mode self.generation_config = generation_config self.ffconfig = ffconfig - # self.max_batch_size = max_batch_size self.data_type = data_type self.opt_config = OPTConfig(hf_config) - # self.opt_config.max_seq_length = max_seq_length - # self.opt_config.max_num_tokens = max_tokens_per_batch self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 @@ -145,20 +141,26 @@ def build_model(self, max_tokens_per_batch): hidden_states = ffmodel.add(token, positional_embedding) residual = hidden_states + qkv_proj = ffmodel.dense( + hidden_states, + 3 * self.opt_config.hidden_size, + ActiMode.AC_MODE_NONE, + True, + name=f"layers.{i}.self_attn.qkv_proj", + ) + if self.mode == InferenceMode.BEAM_SEARCH_MODE: - mha = ffmodel.spec_inc_multihead_self_attention( - hidden_states, + o_proj = ffmodel.spec_inc_multihead_self_attention( + qkv_proj, self.opt_config.hidden_size, self.opt_config.num_attention_heads, self.opt_config.hidden_size // self.opt_config.num_attention_heads, self.opt_config.hidden_size // self.opt_config.num_attention_heads, 0.0, # dropout - True, # qkv_bias - False, # final_bias False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.opt_config.rotary_embedding_meta, True, # scaling_query (self.opt_config.hidden_size / self.opt_config.num_attention_heads) ** (-0.5), # scaling_factor @@ -166,19 +168,17 @@ def build_model(self, max_tokens_per_batch): name=f"layers.{i}.self_attn", ) elif self.mode == InferenceMode.TREE_VERIFY_MODE: - mha = ffmodel.inc_multihead_self_attention_verify( - hidden_states, + o_proj = ffmodel.inc_multihead_self_attention_verify( + qkv_proj, self.opt_config.hidden_size, self.opt_config.num_attention_heads, self.opt_config.hidden_size // self.opt_config.num_attention_heads, self.opt_config.hidden_size // self.opt_config.num_attention_heads, 0.0, # dropout - True, # qkv_bias - False, # final_bias False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.opt_config.rotary_embedding_meta, True, # scaling_query (self.opt_config.hidden_size / self.opt_config.num_attention_heads) ** (-0.5), # scaling_factor @@ -186,19 +186,17 @@ def build_model(self, max_tokens_per_batch): name=f"layers.{i}.self_attn", ) elif self.mode == InferenceMode.INC_DECODING_MODE: - mha = ffmodel.inc_multihead_self_attention( - hidden_states, + o_proj = ffmodel.inc_multihead_self_attention( + qkv_proj, self.opt_config.hidden_size, self.opt_config.num_attention_heads, self.opt_config.hidden_size // self.opt_config.num_attention_heads, self.opt_config.hidden_size // self.opt_config.num_attention_heads, 0.0, # dropout - True, # qkv_bias - False, # final_bias False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.opt_config.rotary_embedding_meta, True, # scaling_query (self.opt_config.hidden_size / self.opt_config.num_attention_heads) ** (-0.5), # scaling_factor @@ -208,6 +206,13 @@ def build_model(self, max_tokens_per_batch): else: assert False + mha = ffmodel.dense( + o_proj, + self.opt_config.hidden_size, + ActiMode.AC_MODE_NONE, + False, + name=f"layers.{i}.self_attn.o_proj" + ) # This is either a before or after attention LayerNorm. In both cases, we need to compute the LN here. residual, ff_norm = ffmodel.add_bias_residual_layer_norm( mha, diff --git a/python/flexflow/serve/models/starcoder.py b/python/flexflow/serve/models/starcoder.py index 2d4471201f..dc5faf175f 100644 --- a/python/flexflow/serve/models/starcoder.py +++ b/python/flexflow/serve/models/starcoder.py @@ -19,8 +19,6 @@ class STARCODERConfig: def __init__(self, hf_config): - # self.max_seq_len = 256 - # self.max_num_tokens = 64 self.max_beam_width = 1 self.max_beam_depth = 8 self.max_spec_tree_token_num = 20 @@ -32,6 +30,7 @@ def __init__(self, hf_config): self.vocab_size = hf_config.vocab_size self.intermediate_size = hf_config.n_inner self.n_head_kv = 1 if hf_config.multi_query else hf_config.n_head + self.rotary_embedding_meta = RotaryEmbeddingMeta(apply_rotary_embedding=False) # Standardized FlexFlow num heads fields below self.num_attention_heads = hf_config.n_head self.num_key_value_heads = self.n_head_kv @@ -45,8 +44,6 @@ def __init__( ffconfig, hf_config, data_type, - # max_batch_size=1, - # max_seq_length=256, max_tokens_per_batch, weights_filepath="", tokenizer_filepath="", @@ -54,11 +51,8 @@ def __init__( self.mode = mode self.generation_config = generation_config self.ffconfig = ffconfig - # self.max_batch_size = max_batch_size self.data_type = data_type self.starcoder_config = STARCODERConfig(hf_config) - # self.starcoder_config.max_seq_length = max_seq_length - # self.starcoder_config.max_num_tokens = max_tokens_per_batch self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 @@ -142,9 +136,17 @@ def build_model(self, max_tokens_per_batch): name=f"layers.{i}.ln_1", ) - assert self.mode == InferenceMode.INC_DECODING_MODE - mha = ffmodel.inc_multiquery_self_attention( + qkv_proj = ffmodel.dense( ln_1, + 3 * self.starcoder_config.hidden_size, + ActiMode.AC_MODE_NONE, + True, + name=f"layers.{i}.self_attn.qkv_proj", + ) + + assert self.mode == InferenceMode.INC_DECODING_MODE + o_proj = ffmodel.inc_multiquery_self_attention( + qkv_proj, self.starcoder_config.hidden_size, self.starcoder_config.num_attention_heads, self.starcoder_config.n_head_kv, @@ -153,15 +155,21 @@ def build_model(self, max_tokens_per_batch): self.starcoder_config.hidden_size // self.starcoder_config.num_attention_heads, 0.0, # dropout - True, # qkv_bias - False, # final_bias False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.starcoder_config.rotary_embedding_meta, name=f"layers.{i}.attn.c_attn", ) + mha = ffmodel.dense( + o_proj, + self.starcoder_config.hidden_size, + ActiMode.AC_MODE_NONE, + False, + name=f"layers.{i}.self_attn.o_proj" + ) + residual, l2_norm = ffmodel.residual_layer_norm( hidden_states, mha, diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index 532dd00198..c6cf656ac0 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -1209,12 +1209,16 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention( int kdim, int vdim, float dropout, - bool bias, - bool add_bias_kv, bool add_zero_attn, enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1224,18 +1228,23 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->inc_multihead_self_attention(input, embed_dim, num_heads, kdim, vdim, dropout, - bias, - add_bias_kv, add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -1252,12 +1261,16 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention( int kdim, int vdim, float dropout, - bool bias, - bool add_bias_kv, bool add_zero_attn, enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1267,6 +1280,13 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->spec_inc_multihead_self_attention(input, embed_dim, @@ -1274,12 +1294,10 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention( kdim, vdim, dropout, - bias, - add_bias_kv, add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -1296,12 +1314,16 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention_verify( int kdim, int vdim, float dropout, - bool bias, - bool add_bias_kv, bool add_zero_attn, enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1311,6 +1333,13 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention_verify( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->inc_multihead_self_attention_verify(input, embed_dim, @@ -1318,12 +1347,10 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention_verify( kdim, vdim, dropout, - bias, - add_bias_kv, add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -1341,12 +1368,16 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention( int kdim, int vdim, float dropout, - bool bias, - bool add_bias_kv, bool add_zero_attn, enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1356,6 +1387,13 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->inc_multiquery_self_attention(input, embed_dim, num_q_heads, @@ -1363,12 +1401,10 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention( kdim, vdim, dropout, - bias, - add_bias_kv, add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -1386,12 +1422,16 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention( int kdim, int vdim, float dropout, - bool bias, - bool add_bias_kv, bool add_zero_attn, enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1401,6 +1441,13 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->spec_inc_multiquery_self_attention(input, embed_dim, @@ -1409,12 +1456,10 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention( kdim, vdim, dropout, - bias, - add_bias_kv, add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -1432,12 +1477,16 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention_verify( int kdim, int vdim, float dropout, - bool bias, - bool add_bias_kv, bool add_zero_attn, enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1447,6 +1496,13 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention_verify( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->inc_multiquery_self_attention_verify(input, embed_dim, @@ -1455,12 +1511,10 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention_verify( kdim, vdim, dropout, - bias, - add_bias_kv, add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, diff --git a/src/ops/add_bias_residual_layer_norm.cc b/src/ops/add_bias_residual_layer_norm.cc index 7a1da2e974..7bfbe31aad 100644 --- a/src/ops/add_bias_residual_layer_norm.cc +++ b/src/ops/add_bias_residual_layer_norm.cc @@ -670,8 +670,18 @@ void AddBiasResidualLayerNorm::inference_task( AddBiasResidualLayerNormMeta *m = *((AddBiasResidualLayerNormMeta **)task->local_args); - assert(regions.size() == - 4 + (m->elementwise_affine ? (m->use_bias ? 2 : 1) : 0)); + int expected_regions = + 5; // input, attn_bias, residual (input), added_output, output + if (m->inplace_residual) { + expected_regions--; // input == added_output + } + if (m->elementwise_affine) { + expected_regions += 1; // gamma + if (m->use_bias) { + expected_regions += 1; // beta + } + } + assert(regions.size() == expected_regions); int rid = 0, tid = 0, did = 0; GenericTensorAccessorR input = diff --git a/src/ops/fused.cpp b/src/ops/fused.cpp index 9f826cd611..2cede662f3 100644 --- a/src/ops/fused.cpp +++ b/src/ops/fused.cpp @@ -439,21 +439,13 @@ __host__ void assert(fused->op_num_outputs[op] == 1); IncMultiHeadSelfAttentionMeta *m = (IncMultiHeadSelfAttentionMeta *)metas->meta[op]; - assert(fused->op_num_weights[op] == - (1 + (int)(*m->qkv_bias || *m->final_bias))); - GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - assert(fused->op_num_weights[op] == 2); - biases = my_weight_accessor[1]; - } + assert(fused->op_num_weights[op] == 0); IncMultiHeadSelfAttention::inference_kernel_wrapper( m, bc, task->index_point.point_data[0], my_input_accessor[0], - my_weight_accessor[0], - my_output_accessor[0], - biases); + my_output_accessor[0]); break; } case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION: { @@ -463,21 +455,13 @@ __host__ void (TreeIncMultiHeadSelfAttentionMeta *)metas->meta[op]; TreeVerifyBatchConfig const &tree_bc = Future(task->futures[0]).get_result(); - assert(fused->op_num_weights[op] == - (1 + (int)(*m->qkv_bias || *m->final_bias))); - GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - assert(fused->op_num_weights[op] == 2); - biases = my_weight_accessor[1]; - } + assert(fused->op_num_weights[op] == 0); TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( m, &tree_bc, task->index_point.point_data[0], my_input_accessor[0], - my_weight_accessor[0], - my_output_accessor[0], - biases); + my_output_accessor[0]); break; } case OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION: { @@ -489,21 +473,13 @@ __host__ void // (BeamSearchBatchConfig *)task->args; BeamSearchBatchConfig const &beam_bc = Future(task->futures[0]).get_result(); - assert(fused->op_num_weights[op] == - (1 + (int)(*m->qkv_bias || *m->final_bias))); - GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - assert(fused->op_num_weights[op] == 2); - biases = my_weight_accessor[1]; - } + assert(fused->op_num_weights[op] == 0); SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( m, &beam_bc, task->index_point.point_data[0], my_input_accessor[0], - my_weight_accessor[0], - my_output_accessor[0], - biases); + my_output_accessor[0]); break; } case OP_LAYERNORM: { @@ -1025,21 +1001,13 @@ __host__ void FusedOp::peft_bwd_task(Task const *task, assert(fused->op_num_outputs[op] == 1); IncMultiHeadSelfAttentionMeta *m = (IncMultiHeadSelfAttentionMeta *)metas->meta[op]; - assert(fused->op_num_weights[op] == - (1 + (int)(*m->qkv_bias || *m->final_bias))); - GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - assert(fused->op_num_weights[op] == 2); - biases = my_weight_accessor[1]; - } + assert(fused->op_num_weights[op] == 0); IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( m, bc, task->index_point.point_data[0], my_input_grad_accessor[0], - my_weight_accessor[0], - my_output_grad_accessor[0], - biases); + my_output_grad_accessor[0]); break; } case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION: diff --git a/src/ops/fused.cu b/src/ops/fused.cu index 8f1212beb4..5aed2cd69a 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -448,73 +448,49 @@ __host__ void case OP_INC_MULTIHEAD_SELF_ATTENTION: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_outputs[op] == 1); + assert(fused->op_num_weights[op] == 0); IncMultiHeadSelfAttentionMeta *m = (IncMultiHeadSelfAttentionMeta *)metas->meta[op]; - assert(fused->op_num_weights[op] == - (1 + (int)(*m->qkv_bias || *m->final_bias))); - GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - assert(fused->op_num_weights[op] == 2); - biases = my_weight_accessor[1]; - } IncMultiHeadSelfAttention::inference_kernel_wrapper( m, bc, task->index_point.point_data[0], my_input_accessor[0], - my_weight_accessor[0], - my_output_accessor[0], - biases); + my_output_accessor[0]); break; } case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_outputs[op] == 1); + assert(fused->op_num_weights[op] == 0); TreeIncMultiHeadSelfAttentionMeta *m = (TreeIncMultiHeadSelfAttentionMeta *)metas->meta[op]; TreeVerifyBatchConfig const &tree_bc = Future(task->futures[0]).get_result(); - assert(fused->op_num_weights[op] == - (1 + (int)(*m->qkv_bias || *m->final_bias))); - GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - assert(fused->op_num_weights[op] == 2); - biases = my_weight_accessor[1]; - } TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( m, &tree_bc, task->index_point.point_data[0], my_input_accessor[0], - my_weight_accessor[0], - my_output_accessor[0], - biases); + my_output_accessor[0]); break; } case OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_outputs[op] == 1); + assert(fused->op_num_weights[op] == 0); SpecIncMultiHeadSelfAttentionMeta const *m = (SpecIncMultiHeadSelfAttentionMeta *)metas->meta[op]; // BeamSearchBatchConfig const *beam_bc = // (BeamSearchBatchConfig *)task->args; BeamSearchBatchConfig const &beam_bc = Future(task->futures[0]).get_result(); - assert(fused->op_num_weights[op] == - (1 + (int)(*m->qkv_bias || *m->final_bias))); - GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - assert(fused->op_num_weights[op] == 2); - biases = my_weight_accessor[1]; - } SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( m, &beam_bc, task->index_point.point_data[0], my_input_accessor[0], - my_weight_accessor[0], - my_output_accessor[0], - biases); + my_output_accessor[0]); break; } case OP_LAYERNORM: { @@ -666,12 +642,7 @@ __host__ void assert(false && "Fusion currently does not support type"); } } - if (metas->meta[op]->inference_debugging && - !(fused->op_op_type[op] == OP_ALLREDUCE || - fused->op_op_type[op] == OP_PARALLEL_IDENTITY || - fused->op_op_type[op] == OP_REPLICATE || - fused->op_op_type[op] == OP_REPARTITION || - fused->op_op_type[op] == OP_COMBINE)) { + if (metas->meta[op]->inference_debugging) { std::vector input_accessors_to_save; std::vector weight_accessors_to_save; std::vector output_accessors_to_save; @@ -1048,21 +1019,15 @@ __host__ void FusedOp::peft_bwd_task(Task const *task, assert(fused->op_num_outputs[op] == 1); IncMultiHeadSelfAttentionMeta *m = (IncMultiHeadSelfAttentionMeta *)metas->meta[op]; - assert(fused->op_num_weights[op] == - (1 + (int)(*m->qkv_bias || *m->final_bias))); + assert(fused->op_num_weights[op] == 0); GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - assert(fused->op_num_weights[op] == 2); - biases = my_weight_accessor[1]; - } IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( m, bc, task->index_point.point_data[0], my_input_grad_accessor[0], - my_weight_accessor[0], - my_output_grad_accessor[0], - biases); + my_output_grad_accessor[0]); + // biases); break; } case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION: diff --git a/src/ops/inc_multihead_self_attention.cc b/src/ops/inc_multihead_self_attention.cc index 8219cf9e1f..8dbce00ebc 100644 --- a/src/ops/inc_multihead_self_attention.cc +++ b/src/ops/inc_multihead_self_attention.cc @@ -54,23 +54,22 @@ bool IncMultiHeadSelfAttentionParams::is_valid( return is_valid; } -Tensor FFModel::inc_multihead_self_attention(const Tensor input, - int embed_dim, - int num_heads, - int kdim, - int vdim, - float dropout, - bool qkv_bias, - bool final_bias, - bool add_zero_attn, - DataType data_type, - Initializer *kernel_initializer, - bool apply_rotary_embedding, - bool scaling_query, - float scaling_factor, - bool qk_prod_scaling, - bool position_bias, - char const *name) { +Tensor FFModel::inc_multihead_self_attention( + const Tensor input, + int embed_dim, + int num_heads, + int kdim, + int vdim, + float dropout, + bool add_zero_attn, + DataType data_type, + Initializer *kernel_initializer, + RotaryEmbeddingMeta rotary_embedding_meta, + bool scaling_query, + float scaling_factor, + bool qk_prod_scaling, + bool position_bias, + char const *name) { return inc_multiquery_self_attention(input, embed_dim, num_heads, @@ -78,12 +77,10 @@ Tensor FFModel::inc_multihead_self_attention(const Tensor input, kdim, vdim, dropout, - qkv_bias, - final_bias, add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -91,31 +88,29 @@ Tensor FFModel::inc_multihead_self_attention(const Tensor input, name); } -Tensor FFModel::inc_multiquery_self_attention(const Tensor input, - int embed_dim, - int num_q_heads, - int num_kv_heads, - int kdim, - int vdim, - float dropout, - bool qkv_bias, - bool final_bias, - bool add_zero_attn, - DataType data_type, - Initializer *kernel_initializer, - bool apply_rotary_embedding, - bool scaling_query, - float scaling_factor, - bool qk_prod_scaling, - bool position_bias, - char const *name) { +Tensor FFModel::inc_multiquery_self_attention( + const Tensor input, + int embed_dim, + int num_q_heads, + int num_kv_heads, + int kdim, + int vdim, + float dropout, + bool add_zero_attn, + DataType data_type, + Initializer *kernel_initializer, + RotaryEmbeddingMeta rotary_embedding_meta, + bool scaling_query, + float scaling_factor, + bool qk_prod_scaling, + bool position_bias, + char const *name) { if (data_type == DT_NONE) { data_type = input->data_type; } DataType quantization_type = cpu_offload ? config.quantization_type : DT_NONE; bool offload = cpu_offload; Layer *li = nullptr; - int weight_num = (qkv_bias || final_bias) ? 2 : 1; if (data_type != input->data_type) { Tensor casted_input = cast(input, data_type, "type cast for IncMHA"); li = new Layer(this, @@ -123,7 +118,7 @@ Tensor FFModel::inc_multiquery_self_attention(const Tensor input, data_type, name, 1 /*inputs*/, - weight_num /*weights*/, + 0, 1 /*outputs*/, casted_input); } else { @@ -132,7 +127,7 @@ Tensor FFModel::inc_multiquery_self_attention(const Tensor input, data_type, name, 1 /*inputs*/, - weight_num /*weights*/, + 0, 1 /*outputs*/, input); } @@ -142,65 +137,30 @@ Tensor FFModel::inc_multiquery_self_attention(const Tensor input, for (int i = 0; i < numdims; i++) { dims[i] = input->dims[i]; } - dims[0] = embed_dim; + dims[0] = vdim * num_q_heads; // we now output o_proj_dim * o_heads li->outputs[0] = create_tensor_legion_ordering( numdims, dims, data_type, li, 0, true /*create_grad*/); } - // Compute weight size - int qProjSize = kdim, kProjSize = kdim, vProjSize = kdim, - oProjSize = embed_dim; - int qSize = input->dims[0], kSize = input->dims[0], vSize = input->dims[0]; - int qParas = qProjSize * qSize; - int kParas = kProjSize * kSize; - int vParas = vProjSize * vSize; - int oParas = oProjSize * (vProjSize > 0 ? vProjSize : vSize); - // allocate num_q_heads for key, value for replication - int weight_size = qParas * num_q_heads + kParas * num_q_heads + - vParas * num_q_heads + oParas * num_q_heads; - int one_head_size = qParas + kParas + vParas + oParas; - - { - // compress the weight size if quantization. - if (quantization_type != DT_NONE) { - one_head_size = get_quantization_to_byte_size( - data_type, quantization_type, one_head_size); - } - int dims[1] = {weight_size}; - li->weights[0] = create_weight_legion_ordering( - 1, - dims, - quantization_type == DT_NONE ? data_type : quantization_type, - li, - true /*create_grad*/, - kernel_initializer, - CHOSEN_SYNC_TYPE); - } - if (qkv_bias || final_bias) { - // q, k, v, o - int qkv_bias_size = - qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; - int dims[1] = {(qkv_bias ? qkv_bias_size : 0) + - (final_bias ? oProjSize : 0)}; - li->weights[1] = create_weight_legion_ordering(1, - dims, - data_type, - li, - true /*create_grad*/, - kernel_initializer, - CHOSEN_SYNC_TYPE); - } li->data_type = data_type; li->add_int_property("embed_dim", embed_dim); li->add_int_property("num_q_heads", num_q_heads); li->add_int_property("num_kv_heads", num_kv_heads); li->add_int_property("kdim", kdim); li->add_int_property("vdim", vdim); - li->add_int_property("qkv_bias", qkv_bias); - li->add_int_property("final_bias", final_bias); li->add_int_property("add_zero_attn", add_zero_attn); li->add_float_property("dropout", dropout); - li->add_int_property("apply_rotary_embedding", apply_rotary_embedding); + li->add_int_property("apply_rotary_embedding", + rotary_embedding_meta.apply_rotary_embedding); + li->add_float_property("rope_theta", rotary_embedding_meta.rope_theta); + li->add_string_property("rope_type", rotary_embedding_meta.rope_type); + li->add_float_property("factor", rotary_embedding_meta.factor); + li->add_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + li->add_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + li->add_int_property("original_max_position_embeddings", + rotary_embedding_meta.original_max_position_embeddings); li->add_int_property("scaling_query", scaling_query); li->add_float_property("scaling_factor", scaling_factor); li->add_int_property("qk_prod_scaling", qk_prod_scaling); @@ -231,14 +191,20 @@ Op *IncMultiHeadSelfAttention::create_operator_from_layer( int vdim = value; float dropout; layer->get_float_property("dropout", dropout); - layer->get_int_property("qkv_bias", value); - bool qkv_bias = (bool)value; - layer->get_int_property("final_bias", value); - bool final_bias = (bool)value; layer->get_int_property("add_zero_attn", value); bool add_zero_attn = (bool)value; + RotaryEmbeddingMeta rotary_embedding_meta; layer->get_int_property("apply_rotary_embedding", value); - bool apply_rotary_embedding = (bool)value; + rotary_embedding_meta.apply_rotary_embedding = (bool)value; + layer->get_float_property("rope_theta", rotary_embedding_meta.rope_theta); + layer->get_string_property("rope_type", rotary_embedding_meta.rope_type); + layer->get_float_property("factor", rotary_embedding_meta.factor); + layer->get_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + layer->get_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + layer->get_int_property("original_max_position_embeddings", value); + rotary_embedding_meta.original_max_position_embeddings = (int)value; layer->get_int_property("scaling_query", value); bool scaling_query = (bool)value; float scaling_factor; @@ -264,15 +230,12 @@ Op *IncMultiHeadSelfAttention::create_operator_from_layer( kdim, vdim, dropout, - qkv_bias, - final_bias, add_zero_attn, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, position_bias, - false /*allocate_weights*/, quantization_type, offload, tensor_parallelism_degree, @@ -289,15 +252,12 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( int _kdim, int _vdim, float _dropout, - bool _qkv_bias, - bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, bool _position_bias, - bool allocate_weights, DataType _quantization_type, bool _offload, int _tensor_parallelism_degree, @@ -308,13 +268,12 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( _input->data_type, name, 1 /*inputs*/, - (_qkv_bias || _final_bias ? 2 : 1), /*weights*/ + 0, 1 /*outputs*/, _input), num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), - qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), qSize(_input->dims[0].size), kSize(_input->dims[0].size), vSize(_input->dims[0].size), qProjSize(_kdim), kProjSize(_kdim), vProjSize(_vdim), oProjSize(_embed_dim), @@ -334,86 +293,29 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( x *= _input->dims[i].size; } dims[0].size = _embed_dim; - // Currently require no parallelism along this dim - assert(dims[0].degree == 1); - if (allocate_weights) { - // Create weight tensor - int num_dims = inputs[0]->num_dims; - // Compute weight size - int qParas = this->qProjSize * this->qSize; - int kParas = this->kProjSize * this->kSize; - int vParas = this->vProjSize * this->vSize; - int oParas = - this->oProjSize * (this->vProjSize > 0 ? this->vProjSize : this->vSize); - ParallelDim dims[2]; - dims[0] = inputs[0]->dims[num_dims - 2]; - dims[0].size = dims[0].degree; - dims[1] = inputs[0]->dims[num_dims - 1]; - dims[1].size = this->num_q_heads * (qParas + oParas) + - this->num_q_heads * (kParas + vParas); - dims[1].is_replica_dim = false; - - if (quantization_type != DT_NONE) { - dims[1].size = get_quantization_to_byte_size( - data_type, quantization_type, (qParas + kParas + vParas + oParas)); - } - int seed = std::rand(); - Initializer *initializer = new GlorotUniform(seed); - weights[0] = model.create_parallel_weight<2>( - dims, - quantization_type == DT_NONE ? this->data_type : quantization_type, - nullptr /*owner_op*/, - model.config.computationMode == COMP_MODE_INFERENCE - ? false - : true /*create_grad*/, - initializer, - CHOSEN_SYNC_TYPE); - if (qkv_bias || final_bias) { - ParallelTensorShape bias_shape = _input->get_shape(); - int qkv_bias_size = - qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; - bias_shape.dims[0].size = - (qkv_bias ? qkv_bias_size : 0) + (final_bias ? oProjSize : 0); - bias_shape.dims[1].size = bias_shape.dims[2].size = 1; - weights[1] = - model.create_parallel_weight_legion_ordering(bias_shape.num_dims, - bias_shape.dims, - this->data_type, - nullptr /*owner_op*/, - true /*create_grad*/, - initializer, - CHOSEN_SYNC_TYPE); - } - } + // Removed restriction that no parallelism along this dim + // assert(dims[0].degree == 1); outputs[0] = model.create_parallel_tensor_legion_ordering( _input->num_dims, dims, this->data_type, this); - /* for (int i = 0; i < numdim; i++) { */ - /* register_output_input_parallel_dims(outputs[0], i, inputs[0], i); */ - /* } */ - /* // Check correctness */ /* assert(check_output_input_weight_parallel_dims()); */ } IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( FFModel &model, const ParallelTensor _input, - const ParallelTensor _weight, int _embed_dim, int _num_q_heads, int _num_kv_heads, int _kdim, int _vdim, float _dropout, - bool _qkv_bias, - bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, bool _position_bias, - bool allocate_weights, DataType _quantization_type, bool _offload, int _tensor_parallelism_degree, @@ -424,14 +326,12 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( _input->data_type, name, 1 /*inputs*/, - (_qkv_bias || _final_bias ? 2 : 1), /*weights*/ + 0, 1 /*outputs*/, - _input, - _weight), + _input), num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), - qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), qSize(_input->dims[0].size), kSize(_input->dims[0].size), vSize(_input->dims[0].size), qProjSize(_kdim), kProjSize(_kdim), vProjSize(_vdim), oProjSize(_embed_dim), @@ -439,9 +339,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( scaling_query(_scaling_query), scaling_factor(_scaling_factor), qk_prod_scaling(_qk_prod_scaling), position_bias(_position_bias), quantization_type(_quantization_type), offload(_offload), - tensor_parallelism_degree(_tensor_parallelism_degree) -// bias_initializer(_bias_initializer) -{ + tensor_parallelism_degree(_tensor_parallelism_degree) { numOutputs = 1; int numdim = _input->num_dims; ParallelDim dims[MAX_TENSOR_DIM]; @@ -451,63 +349,10 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( dims[0].size = _embed_dim; // Currently require no parallelism along this dim assert(dims[0].degree == 1); - if (allocate_weights) { - // Create weight tensor - int num_dims = inputs[0]->num_dims; - // Compute weight size - int qParas = this->qProjSize * this->qSize; - int kParas = this->kProjSize * this->kSize; - int vParas = this->vProjSize * this->vSize; - int oParas = - this->oProjSize * (this->vProjSize > 0 ? this->vProjSize : this->vSize); - ParallelDim dims[2]; - dims[0] = inputs[0]->dims[num_dims - 2]; - dims[0].size = dims[0].degree; - dims[1] = inputs[0]->dims[num_dims - 1]; - dims[1].size = this->num_q_heads * (qParas + oParas) + - this->num_q_heads * (kParas + vParas); - dims[1].is_replica_dim = false; - // dims[2].size = this->num_q_heads * (qParas + oParas) + this->num_kv_heads - // * (kParas + vParas); - if (quantization_type != DT_NONE) { - dims[1].size = get_quantization_to_byte_size( - data_type, quantization_type, (qParas + kParas + vParas + oParas)); - } - int seed = std::rand(); - Initializer *initializer = new GlorotUniform(seed); - weights[0] = model.create_parallel_weight<2>( - dims, - quantization_type == DT_NONE ? this->data_type : quantization_type, - NULL /*owner_op*/, - true /*create_grad*/, - initializer, - CHOSEN_SYNC_TYPE); - if (qkv_bias || final_bias) { - ParallelTensorShape bias_shape = _input->get_shape(); - int qkv_bias_size = - qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; - bias_shape.dims[0].size = - (qkv_bias ? qkv_bias_size : 0) + (final_bias ? oProjSize : 0); - bias_shape.dims[1].size = bias_shape.dims[2].size = 1; - weights[1] = - model.create_parallel_weight_legion_ordering(bias_shape.num_dims, - bias_shape.dims, - this->data_type, - nullptr /*owner_op*/, - true /*create_grad*/, - initializer, - CHOSEN_SYNC_TYPE); - } - } outputs[0] = model.create_parallel_tensor_legion_ordering( _input->num_dims, dims, this->data_type, this); - /* for (int i = 0; i < numdim; i++) { */ - /* register_output_input_parallel_dims(outputs[0], i, inputs[0], i); */ - /* } */ - /* register_output_weight_parallel_dims(outputs[0], numdim-1, _weight, 1); */ - /* register_output_weight_parallel_dims(outputs[0], numdim-2, _weight, 2); */ // Check correctness /* assert(check_output_input_weight_parallel_dims()); */ } @@ -515,8 +360,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( FFModel &model, IncMultiHeadSelfAttention const &other, - const ParallelTensor input, - bool allocate_weights) + const ParallelTensor input) : IncMultiHeadSelfAttention(model, other.layer_guid, input, @@ -526,15 +370,12 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( other.qProjSize, other.vProjSize, other.dropout, - other.qkv_bias, - other.final_bias, other.add_zero_attn, - other.apply_rotary_embedding, + other.rotary_embedding_meta, other.scaling_query, other.scaling_factor, other.qk_prod_scaling, other.position_bias, - allocate_weights, other.quantization_type, other.offload, other.tensor_parallelism_degree, @@ -544,7 +385,6 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( FFModel &model, IncMultiHeadSelfAttentionParams const ¶ms, ParallelTensor const &input, - bool allocate_weights, char const *name) : IncMultiHeadSelfAttention(model, params.layer_guid, @@ -555,15 +395,12 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( params.kdim, params.vdim, params.dropout, - params.qkv_bias, - params.final_bias, params.add_zero_attn, - params.apply_rotary_embedding, + params.rotary_embedding_meta, params.scaling_query, params.scaling_factor, params.qk_prod_scaling, params.position_bias, - allocate_weights, params.quantization_type, params.offload, params.tensor_parallelism_degree, @@ -596,20 +433,12 @@ void IncMultiHeadSelfAttention::init_inference( EXCLUSIVE, batch_inputs[0]->region)); launcher.add_field(0, FID_DATA); - launcher.add_region_requirement( - RegionRequirement(weights[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region, - ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0)); - launcher.add_field(1, FID_DATA); launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, batch_outputs[0]->region)); - launcher.add_field(2, FID_DATA); + launcher.add_field(1, FID_DATA); FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); set_opmeta_from_futuremap_inference(ff, fm, batch_outputs[0]); @@ -636,18 +465,12 @@ void IncMultiHeadSelfAttention::init(FFModel const &ff) { EXCLUSIVE, inputs[0]->region)); launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(weights[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region)); - launcher.add_field(1, FID_DATA); launcher.add_region_requirement(RegionRequirement(outputs[0]->part, 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, outputs[0]->region)); - launcher.add_field(2, FID_DATA); + launcher.add_field(1, FID_DATA); FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); set_opmeta_from_futuremap(ff, fm); @@ -655,8 +478,7 @@ void IncMultiHeadSelfAttention::init(FFModel const &ff) { /* regions[0](I): input - regions[1](I): weight - regions[2](O): output + regions[1](O): output */ OpMeta *IncMultiHeadSelfAttention::init_task( Task const *task, @@ -675,17 +497,10 @@ OpMeta *IncMultiHeadSelfAttention::init_task( FID_DATA, ctx, runtime); - GenericTensorAccessorR weight = - helperGetGenericTensorAccessorRO(attn->weights[0]->data_type, - regions[1], - task->regions[1], - FID_DATA, - ctx, - runtime); GenericTensorAccessorW output = helperGetGenericTensorAccessorWO(attn->outputs[0]->data_type, - regions[2], - task->regions[2], + regions[1], + task->regions[1], FID_DATA, ctx, runtime); @@ -698,8 +513,6 @@ OpMeta *IncMultiHeadSelfAttention::init_task( attn->num_kv_heads / attn->tensor_parallelism_degree + (attn->num_kv_heads % attn->tensor_parallelism_degree != 0); - assert(attn->oProjSize == output.domain.hi()[0] - output.domain.lo()[0] + 1); - Memory gpu_mem = get_proc_mem(Machine::get_machine(), task->target_proc); MemoryAllocator gpu_mem_allocator(gpu_mem); if (attn->offload) { @@ -708,14 +521,8 @@ OpMeta *IncMultiHeadSelfAttention::init_task( gpu_mem_allocator.register_reserved_work_space( handle.offload_reserve_space, handle.offload_reserve_space_size); } - IncMultiHeadSelfAttentionMeta *m = - new IncMultiHeadSelfAttentionMeta(handle, - attn, - weight, - gpu_mem_allocator, - num_samples, - num_q_heads, - num_kv_heads); + IncMultiHeadSelfAttentionMeta *m = new IncMultiHeadSelfAttentionMeta( + handle, attn, gpu_mem_allocator, num_samples, num_q_heads, num_kv_heads); if (handle.offload_reserve_space == nullptr) { // assert that we didn't over allocate memory assert(gpu_mem_allocator.reserved_allocated_size == @@ -725,10 +532,6 @@ OpMeta *IncMultiHeadSelfAttention::init_task( m->inference_debugging = attn->inference_debugging; std::strcpy(m->op_name, attn->name); m->layer_guid = attn->layer_guid; - if (attn->quantization_type == DT_NONE) { - assert(weight.domain.get_volume() * data_type_size(weight.data_type) == - m->weightSize); - } return m; } @@ -770,14 +573,6 @@ FutureMap IncMultiHeadSelfAttention::inference( EXCLUSIVE, batch_inputs[0]->region)); launcher.add_field(idx++, FID_DATA); - launcher.add_region_requirement( - RegionRequirement(weights[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region, - ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0)); - launcher.add_field(idx++, FID_DATA); launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, 0 /*projection id*/, WRITE_ONLY, @@ -785,23 +580,12 @@ FutureMap IncMultiHeadSelfAttention::inference( batch_outputs[0]->region)); launcher.add_field(idx++, FID_DATA); - if (qkv_bias || final_bias) { - launcher.add_region_requirement( - RegionRequirement(weights[1]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[1]->region, - ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0)); - launcher.add_field(idx++, FID_DATA); - } return runtime->execute_index_space(ctx, launcher); } /* regions[0](I): input - regions[3](I): weight - regions[4](O): output + regions[1](O): output */ void IncMultiHeadSelfAttention::inference_task( Task const *task, @@ -822,54 +606,31 @@ void IncMultiHeadSelfAttention::inference_task( IncMultiHeadSelfAttentionMeta *m = *((IncMultiHeadSelfAttentionMeta **)task->local_args); - assert(((*m->qkv_bias || *m->final_bias) ? regions.size() == 4 - : regions.size() == 3)); + assert(regions.size() == 2); // input and output GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); - GenericTensorAccessorR weight = helperGetGenericTensorAccessorRO( - m->weight_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( - m->output_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); - GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - biases = helperGetGenericTensorAccessorRO(m->weight_type[1], - regions[3], - task->regions[3], - FID_DATA, - ctx, - runtime); - Domain bias_domain = runtime->get_index_space_domain( - ctx, task->regions[3].region.get_index_space()); - assert(bias_domain.get_dim() == 4); - } + m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); Domain input_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); - Domain weight_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); + ctx, task->regions[1].region.get_index_space()); assert(input_domain.get_dim() == 4); - assert(weight_domain.get_dim() == 2); assert(output_domain.get_dim() == 4); assert(task->index_point.get_dim() == 1); IncMultiHeadSelfAttention::inference_kernel_wrapper( - m, bc, task->index_point.point_data[0], input, weight, output, biases); + m, bc, task->index_point.point_data[0], input, output); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; - std::vector weights_accessors; - weights_accessors.push_back(weight); - if (*m->qkv_bias || *m->final_bias) { - weights_accessors.push_back(biases); - } IncMultiHeadSelfAttention::save_inference_tensors_to_file( - m, shard_id, bc, {input}, weights_accessors, {output}); + m, shard_id, bc, {input}, {}, {output}); } } @@ -903,14 +664,6 @@ FutureMap IncMultiHeadSelfAttention::peft_bwd( EXCLUSIVE, batch_inputs[0]->region_grad)); launcher.add_field(idx++, FID_DATA); - launcher.add_region_requirement( - RegionRequirement(weights[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region, - ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0)); - launcher.add_field(idx++, FID_DATA); launcher.add_region_requirement( RegionRequirement(batch_outputs[0]->part_grad, 0 /*projection id*/, @@ -918,23 +671,12 @@ FutureMap IncMultiHeadSelfAttention::peft_bwd( EXCLUSIVE, batch_outputs[0]->region_grad)); launcher.add_field(idx++, FID_DATA); - if (qkv_bias || final_bias) { - launcher.add_region_requirement( - RegionRequirement(weights[1]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[1]->region, - ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0)); - launcher.add_field(idx++, FID_DATA); - } return runtime->execute_index_space(ctx, launcher); } /* regions[0](I): input - regions[3](I): weight - regions[4](O): output + regions[1](O): output */ void IncMultiHeadSelfAttention::peft_bwd_task( Task const *task, @@ -954,55 +696,31 @@ void IncMultiHeadSelfAttention::peft_bwd_task( IncMultiHeadSelfAttentionMeta *m = *((IncMultiHeadSelfAttentionMeta **)task->local_args); - assert(((*m->qkv_bias || *m->final_bias) ? regions.size() == 4 - : regions.size() == 3)); + assert(regions.size() == 2); // input grad, output grad GenericTensorAccessorW input_grad = helperGetGenericTensorAccessorRW( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); - GenericTensorAccessorR weight = helperGetGenericTensorAccessorRO( - m->weight_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); GenericTensorAccessorW output_grad = helperGetGenericTensorAccessorRW( - m->output_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); - GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - biases = helperGetGenericTensorAccessorRO(m->weight_type[1], - regions[3], - task->regions[3], - FID_DATA, - ctx, - runtime); - Domain bias_domain = runtime->get_index_space_domain( - ctx, task->regions[3].region.get_index_space()); - assert(bias_domain.get_dim() == 4); - } + m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); Domain input_grad_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); - Domain weight_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); Domain output_grad_domain = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); + ctx, task->regions[1].region.get_index_space()); assert(input_grad_domain.get_dim() == 4); - assert(weight_domain.get_dim() == 2); assert(output_grad_domain.get_dim() == 4); assert(task->index_point.get_dim() == 1); IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( - m, - bc, - task->index_point.point_data[0], - input_grad, - weight, - output_grad, - biases); + m, bc, task->index_point.point_data[0], input_grad, output_grad); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; IncMultiHeadSelfAttention::save_inference_tensors_to_file( - m, shard_id, bc, {input_grad}, {weight}, {output_grad}, false); + m, shard_id, bc, {input_grad}, {}, {output_grad}, false); } } @@ -1032,9 +750,20 @@ bool operator==(IncMultiHeadSelfAttentionParams const &lhs, return lhs.layer_guid == rhs.layer_guid && lhs.embed_dim == rhs.embed_dim && lhs.num_q_heads == rhs.num_q_heads && lhs.kdim == rhs.kdim && lhs.vdim == rhs.vdim && lhs.dropout == rhs.dropout && - lhs.qkv_bias == rhs.qkv_bias && lhs.final_bias == rhs.final_bias && lhs.add_zero_attn == rhs.add_zero_attn && - lhs.apply_rotary_embedding == rhs.apply_rotary_embedding && + lhs.rotary_embedding_meta.apply_rotary_embedding == + rhs.rotary_embedding_meta.apply_rotary_embedding && + lhs.rotary_embedding_meta.rope_theta == + rhs.rotary_embedding_meta.rope_theta && + lhs.rotary_embedding_meta.rope_type == + rhs.rotary_embedding_meta.rope_type && + lhs.rotary_embedding_meta.factor == rhs.rotary_embedding_meta.factor && + lhs.rotary_embedding_meta.low_freq_factor == + rhs.rotary_embedding_meta.low_freq_factor && + lhs.rotary_embedding_meta.high_freq_factor == + rhs.rotary_embedding_meta.high_freq_factor && + lhs.rotary_embedding_meta.original_max_position_embeddings == + rhs.rotary_embedding_meta.original_max_position_embeddings && lhs.scaling_query == rhs.scaling_query && lhs.scaling_factor == rhs.scaling_factor && lhs.qk_prod_scaling == rhs.qk_prod_scaling && @@ -1049,10 +778,8 @@ IncMultiHeadSelfAttentionParams IncMultiHeadSelfAttention::get_params() const { params.kdim = this->kProjSize; params.vdim = this->vProjSize; params.dropout = this->dropout; - params.qkv_bias = this->qkv_bias; - params.final_bias = this->final_bias; params.add_zero_attn = this->add_zero_attn; - params.apply_rotary_embedding = this->apply_rotary_embedding; + params.rotary_embedding_meta = this->rotary_embedding_meta; params.scaling_query = this->scaling_query; params.scaling_factor = this->scaling_factor; params.qk_prod_scaling = this->qk_prod_scaling; @@ -1081,10 +808,15 @@ size_t hash::operator()( hash_combine(key, params.kdim); hash_combine(key, params.vdim); hash_combine(key, params.dropout); - hash_combine(key, params.qkv_bias); - hash_combine(key, params.final_bias); hash_combine(key, params.add_zero_attn); - hash_combine(key, params.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.rope_theta); + hash_combine(key, params.rotary_embedding_meta.rope_type); + hash_combine(key, params.rotary_embedding_meta.factor); + hash_combine(key, params.rotary_embedding_meta.low_freq_factor); + hash_combine(key, params.rotary_embedding_meta.high_freq_factor); + hash_combine(key, + params.rotary_embedding_meta.original_max_position_embeddings); hash_combine(key, params.scaling_query); hash_combine(key, params.scaling_factor); hash_combine(key, params.qk_prod_scaling); diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index 826fea4347..a4604a11a2 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -19,6 +19,7 @@ #include "flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh" #include "flexflow/utils/hip_helper.h" #include "hip/hip_complex.h" +#include #include namespace FlexFlow { @@ -52,6 +53,339 @@ __device__ __forceinline__ T #endif } +template +__global__ void store_kv_cache(DT const *devQKVProjArray, + DT *kCache_ptr, + DT *vCache_ptr, + BatchConfig::PerTokenInfo const *tokenInfos, + int num_tokens, + int max_seq_len, + int hidden_size) { + CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + int token_idx = i / hidden_size; + int offset = i % hidden_size; + + size_t val_idx = + token_idx * QKV_WEIGHT_NUM * hidden_size + hidden_size + offset; + + DT kVal = devQKVProjArray[val_idx]; + DT vVal = devQKVProjArray[val_idx + hidden_size]; + int const req_id = tokenInfos[token_idx].request_index; + int const tok_id = tokenInfos[token_idx].abs_depth_in_request; + + // key cache + kCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + + offset] = kVal; + vCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + + offset] = vVal; + } +} + +template +__global__ void store_query_cache(DT const *devQKVProjArray, + DT *qCache_ptr, + int num_tokens, + int hidden_size) { + CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + int token_idx = i / hidden_size; + int offset = i % hidden_size; + + size_t val_idx = token_idx * QKV_WEIGHT_NUM * hidden_size + offset; + + DT qVal = devQKVProjArray[val_idx]; + + // query cache + qCache_ptr[i] = qVal; + } +} + +template +__global__ void fill_entries_above_diagonal(DT *matrix, + size_t num_rows, + size_t num_cols, + size_t num_q_heads, + size_t entries_above_diagonal, + DT value) { + CUDA_KERNEL_LOOP(i, entries_above_diagonal * num_q_heads) { + size_t head_idx = i / entries_above_diagonal; + size_t entry_idx = i % entries_above_diagonal; + size_t y = (-1 + sqrt(8 * (float)entry_idx + 1)) / 2; + size_t x = entry_idx - y * (y + 1) / 2; + y += (num_cols - num_rows) + 1; + matrix[head_idx * num_rows * num_cols + num_cols * y + x] = value; + } +} + +template +void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, + BatchConfig const *bc, + int shard_id, + hipStream_t stream) { + checkCUDA(hipblasSetStream(m->handle.blas, stream)); + checkCUDNN(miopenSetStream(m->handle.dnn, stream)); + hipblasDatatype_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); + miopenDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]); + assert(data_type_size(m->output_type[0]) == sizeof(DT)); + hipblasDatatype_t compute_type = cublas_data_type; + + int num_tokens = bc->num_active_tokens(); + int tokens_previous_requests = 0; + int q_block_size = m->qProjSize; + int kt_block_size = m->kProjSize; + int kt_req_block_size = + kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + int vt_block_size = m->vProjSize; + int vt_req_block_size = + vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + assert(m->qProjSize == m->kProjSize); + + for (int i = 0; i < bc->max_requests_per_batch(); i++) { + if (bc->request_completed[i] || + (!bc->requestsInfo[i].prompt_phase && !bc->requestsInfo[i].peft_bwd)) { + continue; + } + int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; + int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + + bc->requestsInfo[i].num_tokens_in_batch; + int max_peft_tokens = bc->requestsInfo[i].max_sequence_length; + // Copy query to m->query_activation_buffer if we need to compute + // PEFT backward + if (bc->requestsInfo[i].peft_bwd) { + size_t activation_size_needed = + sizeof(DT) * max_peft_tokens * m->num_q_heads * m->qProjSize; + if (activation_size_needed > m->allocated_peft_buffer_size1) { + MemoryAllocator *allocator = m->handle.peft_activation_allocator; + m->query_activation_buffer = + allocator->allocate_instance_untyped(activation_size_needed); + m->allocated_peft_buffer_size1 = activation_size_needed; + } + int parallelism = m->hidden_size * num_tokens; + hipLaunchKernelGGL(HIP_KERNEL_NAME(store_query_cache), + GET_BLOCKS(parallelism), + min(CUDA_NUM_THREADS, parallelism), + 0, + stream, + static_cast
(m->devQKVProjArray), + static_cast
(m->query_activation_buffer), + num_tokens, + m->hidden_size); + } + // Step 1: compute query-key product QK.T/sqrt(d_k) + { + // Scale by sqrt(d_k) as per the original attention paper + DT alpha = 1.0f, beta = 0.0f; + if (*m->qk_prod_scaling) { + alpha = static_cast
(1.0f / sqrt(m->kProjSize)); + } + // after transpositions + int m_ = num_new_tokens; + int n = total_tokens; + int k = m->qProjSize; + // before transpositions + int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, + ldc = m_; + // N.B. strides are applied before transpose operations + int strideA = q_block_size; + int strideB = kt_block_size; + int strideC = num_new_tokens * total_tokens; + + // matrix A: devQKVProjArray + // matrix A's layout: [qProjSize, num_heads, 3, num_new_tokens] + // To get query projection, skip over Q entries from previous requests + DT const *A = static_cast
(m->devQKVProjArray) + + bc->requestsInfo[i].first_token_offset_in_batch * + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM; + // matrix B: key cache + // matrix B's layout: [kProjSize * num_heads, total_tokens] + // To get B, skip over K entries from previous requests (all heads + + // padding) + DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; + // matrix C: qk_prods + // matrix C's layout: [num_new_tokens, total_tokens, num_heads] + // To get C, skip over QK.T products from previous requests + DT *C = static_cast
(m->qk_prods); + checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, + HIPBLAS_OP_T, + HIPBLAS_OP_N, + m_, + n, + k, + &alpha, + A, + cublas_data_type, + lda, + strideA, + B, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + HIPBLAS_GEMM_DEFAULT)); + } + // Step 2: Add alibi position bias to qk production + // matrix C: qk_prods + // matrix C's layout: [num_new_tokens, total_tokens, num_heads] + // To get C, skip over QK.T products from previous requests + DT *C = static_cast
(m->qk_prods); + if (*m->position_bias) { + size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; + hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_position_bias_qkprd), + GET_BLOCKS(parallelism), + min((size_t)CUDA_NUM_THREADS, parallelism), + 0, + stream, + C, + num_new_tokens, + total_tokens, + m->num_q_heads, + m->global_num_q_heads, + shard_id); + } + + // Step 3: Apply causal mask. Fill all elements above diagonal in qk prods + // with -inf to force causal attention. + assert(num_new_tokens <= total_tokens); + size_t entries_above_diagonal = num_new_tokens * (num_new_tokens - 1) / 2; + if (entries_above_diagonal > 0) { + size_t parallelism = m->num_q_heads * entries_above_diagonal; + hipLaunchKernelGGL(HIP_KERNEL_NAME(fill_entries_above_diagonal), + GET_BLOCKS(parallelism), + min((size_t)CUDA_NUM_THREADS, parallelism), + 0, + stream, + C, + num_new_tokens, + total_tokens, + m->num_q_heads, + entries_above_diagonal, + static_cast
(-INFINITY)); + } + + // Step 4: Compute Softmax(QK.T/sqrt(d_k)) + { + // Before modifying the parameters below, make sure to read the following + // description of the HIPDNN_TENSOR_NCHW tensor layout, from + // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#hipdnnTensorFormat_t: + // This tensor format specifies that the data is laid out in the following + // order: batch size, feature maps, rows, columns. The strides are + // implicitly defined in such a way that the data are contiguous in memory + // with no padding between images, feature maps, rows, and columns; the + // columns are the inner dimension and the images are the outermost + // dimension. + int n_param = m->num_q_heads; + int c_param = total_tokens; + int h_param = 1; + int w_param = num_new_tokens; + checkCUDNN(miopenSet4dTensorDescriptor( + m->qk_tensor, cudnn_data_type, n_param, c_param, h_param, w_param)); + float softmax_alpha = 1.0f, softmax_beta = 0.0f; + DT *C_softmax = static_cast
(m->qk_prods_softmax); + // The softmax operation below is executed according to the + // MIOPEN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The + // softmax operation is computed per spatial location (H,W) per image (N) + // across dimension C. + checkCUDNN(miopenSoftmaxForward_V2(m->handle.dnn, + &softmax_alpha, + m->qk_tensor, + C, + &softmax_beta, + m->qk_tensor, + C_softmax, + MIOPEN_SOFTMAX_ACCURATE, + MIOPEN_SOFTMAX_MODE_CHANNEL)); + } + // Copy C_softmax to m->softmax_activation_buffer if we need to compute + // PEFT backward + if (bc->requestsInfo[i].peft_bwd) { + DT *C_softmax = static_cast
(m->qk_prods_softmax); + size_t activation_size_needed = + sizeof(DT) * max_peft_tokens * max_peft_tokens * m->num_q_heads; + if (activation_size_needed > m->allocated_peft_buffer_size2) { + MemoryAllocator *allocator = m->handle.peft_activation_allocator; + m->softmax_activation_buffer = + allocator->allocate_instance_untyped(activation_size_needed); + m->allocated_peft_buffer_size2 = activation_size_needed; + } + checkCUDA(hipMemcpyAsync(m->softmax_activation_buffer, + C_softmax, + sizeof(DT) * total_tokens * num_new_tokens * + m->num_q_heads, + hipMemcpyDeviceToDevice, + stream)); + } + // Step 5: Matmul softmax(QK.T/sqrt(d_k)) by V. Implemented as V @ + // softmax(QK.T/sqrt(d_k)).T + { + DT alpha = 1.0f, beta = 0.0f; + // after transpositions + int m_ = m->vProjSize; + int n = num_new_tokens; + int k = total_tokens; + // before transpositions + int lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; + // N.B. strides are applied before transpose operations + int strideA = vt_block_size; + int strideB = num_new_tokens * total_tokens; + int strideC = m->vProjSize; + // matrix A: value cache + // matrix A's layout: [vProjSize, num_heads, total_tokens] + // To get A, skip over V.T entries from previous requests (all heads + + // padding) + DT *A = static_cast
(m->valueCache) + i * vt_req_block_size; + // matrix B: qk_prods_softmax + // matrix B's layout: [num_new_tokens, total_tokens, num_heads] + // To get B, skip over softmax(QK.T/sqrt(d_k)) entries from previous + // requests (all heads) + DT *B = static_cast
(m->qk_prods_softmax); + // matrix C: attn heads + // matrix C's layout: [vProjSize, num_heads, num_new_tokens] + // To get C, skip over softmax(QK.T/sqrt(d_k))V products from previous + // requests + // store the result attn heads, also skip the genration tokens + DT *C = static_cast
(m->attn_heads) + + (bc->requestsInfo[i].first_token_offset_in_batch) * + m->num_q_heads * m->vProjSize; + checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, + HIPBLAS_OP_N, + HIPBLAS_OP_T, + m_, + n, + k, + &alpha, + A, + cublas_data_type, + lda, + strideA, + B, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + HIPBLAS_GEMM_DEFAULT)); + } + tokens_previous_requests += num_new_tokens; + } + if (tokens_previous_requests != (num_tokens - bc->num_generation_tokens)) { + bc->print(); + printf("tokens_previous_requests: %i\n", tokens_previous_requests); + printf("num_tokens: %i\n", num_tokens); + printf("bc->num_generation_tokens: %i\n", bc->num_generation_tokens); + } + assert(tokens_previous_requests == (num_tokens - bc->num_generation_tokens)); +} + // gridDim = num_heads // blockDim = num_tokens/num_request * head_size // QKV tensor layout: |QKV| * num_new_tokens. |Q=K=V=head_size * num_heads| @@ -334,63 +668,6 @@ __global__ void apply_position_bias_qkprd(DT *input_ptr, } } -template -__global__ void apply_proj_bias_w(DT *input_ptr, - DT const *bias_ptr, - int num_tokens, - int qkv_weight_size, - int oProjSize) { - CUDA_KERNEL_LOOP(i, num_tokens * oProjSize) { - int bias_idx = qkv_weight_size + i % oProjSize; - input_ptr[i] += bias_ptr[bias_idx]; - } -} - -template -__global__ void apply_proj_bias_qkv(DT *input_ptr, - DT const *bias_ptr, - int shard_id, - int num_tokens, - int qProjSize, - int kProjSize, - int vProjSize, - int global_num_q_heads, - int num_q_heads, - bool scaling_query, - float scaling_factor, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size * QKV_WEIGHT_NUM) { - // for simplicity, assume q, k, v is in same shape - // 0->q, 1->k, 2->v - // int qkv_index = i / (num_tokens * qProjSize) % 3; - - int token_idx = i / (hidden_size * QKV_WEIGHT_NUM); - size_t in_token_idx = i - token_idx * hidden_size * QKV_WEIGHT_NUM; - - int qkv_index = in_token_idx / hidden_size; - - int proj_size = qkv_index == 0 ? qProjSize : kProjSize; - - int head_idx = - (in_token_idx - qkv_index * num_q_heads * proj_size) / proj_size; - int global_head_idx = head_idx + shard_id * num_q_heads; - - size_t pre_length = - qkv_index == 0 - ? 0 - : (qkv_index == 1 ? qProjSize * global_num_q_heads - : qProjSize * global_num_q_heads * KV_WEIGHT_NUM); - - size_t bias_idx = pre_length + global_head_idx * proj_size + i % proj_size; - - input_ptr[i] += bias_ptr[bias_idx]; - - if (scaling_query && qkv_index == 0) { - input_ptr[i] *= scaling_factor; - } - } -} - template __global__ void scaling_query_kernel(DT *input_ptr, int qProjSize, @@ -405,60 +682,17 @@ __global__ void scaling_query_kernel(DT *input_ptr, } } -template -__global__ void - apply_rotary_embedding_native(DT *input_ptr, - hipFloatComplex *complex_input, - BatchConfig::PerTokenInfo const *tokenInfos, - int qProjSize, - int kProjSize, - int num_q_heads, - int num_tokens, - int num_kv_heads, - int q_block_size, - int k_block_size, - int q_array_size) { - CUDA_KERNEL_LOOP( - i, - num_tokens * (qProjSize * num_q_heads + kProjSize * num_kv_heads) / 2) { - // create complex number - bool q_tensor = i < (q_array_size / 2); - int proj_size = q_tensor ? qProjSize : kProjSize; - int real_i = q_tensor ? i : i - q_array_size / 2; - - int head_idx = real_i / (num_tokens * proj_size / 2); - int idx = real_i % (num_tokens * proj_size / 2); - int real_part_index = idx * 2 + - head_idx * (q_tensor ? q_block_size : k_block_size) + - (q_tensor ? 0 : q_array_size); - - int complex_part_index = real_part_index + 1; - - complex_input[i] = {input_ptr[real_part_index], - input_ptr[complex_part_index]}; - - int token_idx = - (real_i - head_idx * (num_tokens * proj_size / 2)) / (proj_size / 2); - size_t pos = tokenInfos[token_idx].abs_depth_in_request; - - // float before_real = complex_input[i].x, before_complex = - // complex_input[i].y; - - int pos_i = real_i % (proj_size / 2); - float freq = pos * (1.0 / pow(10000.0, (float)2 * pos_i / proj_size)); - hipFloatComplex complex_pos = {cos(freq), sin(freq)}; - - complex_input[i] = hipCmulf(complex_input[i], complex_pos); - input_ptr[real_part_index] = complex_input[i].x; - input_ptr[complex_part_index] = complex_input[i].y; - } -} - template __global__ void apply_rotary_embedding_hf(DT *input_ptr, hipFloatComplex *complex_input, BatchConfig::PerTokenInfo const *tokenInfos, + float rope_theta, + bool llama3_rope, + float factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, int qProjSize, int kProjSize, int num_tokens, @@ -493,7 +727,29 @@ __global__ void // float before_real = complex_input[i].x, before_complex = int pos_i = real_i % (proj_size / 2); - float freq = pos * (1.0 / pow(10000.0, (float)2 * pos_i / proj_size)); + + float freq = + pos * (1.0 / pow(rope_theta, (float)2 * pos_i / proj_size)); // θ_i + + if (llama3_rope) { + float pi = HIP_PI_F; + float wavelen = 2 * pi / freq; + float low_freq_wavelen = + original_max_position_embeddings / low_freq_factor; + float high_freq_wavelen = + original_max_position_embeddings / high_freq_factor; + if (wavelen < high_freq_wavelen) { + } else if (wavelen > low_freq_wavelen) { + freq = freq / factor; + } else { + assert(low_freq_wavelen != high_freq_wavelen); + float smooth = + (original_max_position_embeddings / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor); + freq = ((1 - smooth) * freq / factor + smooth * freq); + } + } + hipFloatComplex complex_pos = {cos(freq), sin(freq)}; complex_input[i] = hipCmulf(complex_input[i], complex_pos); @@ -507,6 +763,12 @@ __global__ void apply_rotary_embedding_bwd(DT *input_ptr, hipFloatComplex *complex_input, BatchConfig::PerTokenInfo const *tokenInfos, + float rope_theta, + bool llama3_rope, + float factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, int proj_size, int num_tokens, int hidden_size) { @@ -533,7 +795,28 @@ __global__ void size_t pos = tokenInfos[token_idx].abs_depth_in_request; - float freq = pos * (1.0 / pow(10000.0, (float)2 * idx / proj_size)); + float freq = + pos * (1.0 / pow(rope_theta, (float)2 * idx / proj_size)); // θ_i + + if (llama3_rope) { + float pi = HIP_PI_F; + float wavelen = 2 * pi / freq; + float low_freq_wavelen = + original_max_position_embeddings / low_freq_factor; + float high_freq_wavelen = + original_max_position_embeddings / high_freq_factor; + if (wavelen < high_freq_wavelen) { + } else if (wavelen > low_freq_wavelen) { + freq = freq / factor; + } else { + assert(low_freq_wavelen != high_freq_wavelen); + float smooth = + (original_max_position_embeddings / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor); + freq = ((1 - smooth) * freq / factor + smooth * freq); + } + } + hipFloatComplex complex_pos = {cos(freq), sin(freq)}; complex_input[i] = hipCmulf(complex_input[i], complex_pos); @@ -542,172 +825,59 @@ __global__ void } } -template -__global__ void fill_entries_above_diagonal(DT *matrix, - size_t num_rows, - size_t num_cols, - size_t num_q_heads, - size_t entries_above_diagonal, - DT value) { - CUDA_KERNEL_LOOP(i, entries_above_diagonal * num_q_heads) { - size_t head_idx = i / entries_above_diagonal; - size_t entry_idx = i % entries_above_diagonal; - size_t y = (-1 + sqrt(8 * (float)entry_idx + 1)) / 2; - size_t x = entry_idx - y * (y + 1) / 2; - y += (num_cols - num_rows) + 1; - matrix[head_idx * num_rows * num_cols + num_cols * y + x] = value; - } -} - template void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, int shard_id, - DT const *input_ptr, - DT const *weight_ptr, DT *output_ptr, - DT const *bias_ptr, hipStream_t stream) { checkCUDA(hipblasSetStream(m->handle.blas, stream)); checkCUDNN(miopenSetStream(m->handle.dnn, stream)); assert(m->qSize == m->vSize && m->qSize == m->kSize); - hipblasDatatype_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); - hipblasDatatype_t compute_type = cublas_data_type; - // #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) - // hipblasDatatype_t compute_type = cublas_data_type; - // #else - // // For best performance, set the default cublas compute type to - // // CUBLAS_COMPUTE_16F for half precision and to - // // CUBLAS_COMPUTE_32F_FAST_16F for full precision - // hipblasDatatype_t compute_type = CUBLAS_COMPUTE_16F; - // if (m->output_type[0] == DT_FLOAT) { - // compute_type = CUBLAS_COMPUTE_32F_FAST_16F; - // } - // #endif - - // Step 1: Compute QKV projections - { - DT alpha = 1.0f, beta = 0.0f; - // after transpositions - int m_q = m->qProjSize * m->num_q_heads; - int m_k = m->kProjSize * m->num_q_heads; - int m_v = m->vProjSize * m->num_q_heads; - assert(m_q == m_k && m_k == m_v); // keep things simple for now - int n = bc->num_active_infr_tokens(); - int k = m->qSize; - int m_ = m_q * QKV_WEIGHT_NUM; - // before transpositions - int lda = k, ldb = k, ldc = m_; - // matrix A: QKV weights - // matrix A's layout: [qSize (hidden_dim), qProjSize, num_heads, 3] - // matrix B: input - // matrix B's layout: [qSize (hidden_dim), num_new_tokens] - // matrix C: devQKVProjArray - // matrix B's layout: [qProjSize, num_heads, 3, num_new_tokens] - checkCUDA(hipblasGemmEx(m->handle.blas, - HIPBLAS_OP_T, - HIPBLAS_OP_N, - m_, - n, - k, - &alpha, - weight_ptr, - cublas_data_type, - lda, - input_ptr, - cublas_data_type, - ldb, - &beta, - output_ptr, - cublas_data_type, - ldc, - compute_type, - HIPBLAS_GEMM_DEFAULT)); - } int num_tokens = bc->num_active_tokens(); int parallelism = m->kProjSize * num_tokens * m->num_q_heads; size_t q_array_size = m->qProjSize * num_tokens * m->num_q_heads; - // Step 2: apply bias for QKV, or scale the query - if (*m->qkv_bias) { - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_proj_bias_qkv), - GET_BLOCKS(parallelism), - min(CUDA_NUM_THREADS, parallelism), - 0, - stream, - output_ptr, - bias_ptr, - shard_id, - num_tokens, - m->qProjSize, - m->kProjSize, - m->vProjSize, - m->global_num_q_heads, - m->num_q_heads, - *m->scaling_query, - m->scaling_factor, - m->hidden_size); - } else if (m->scaling_query) { + if (m->scaling_query) { hipLaunchKernelGGL(HIP_KERNEL_NAME(scaling_query_kernel), GET_BLOCKS(parallelism), min(CUDA_NUM_THREADS, parallelism), 0, stream, output_ptr, + m->qProjSize, num_tokens, m->num_q_heads, - m->qProjSize, m->scaling_factor, m->hidden_size); } // Step 3: apply rotary embedding if needed - if (*m->apply_rotary_embedding) { + if (m->rotary_embedding_meta->apply_rotary_embedding) { /*q&k*/ parallelism = num_tokens * m->hidden_size; - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_rotary_embedding_hf), - GET_BLOCKS(parallelism), - min(CUDA_NUM_THREADS, parallelism), - 0, - stream, - output_ptr, - m->complex_input, - m->token_infos, - m->qProjSize, - m->kProjSize, - num_tokens, - q_array_size, - m->hidden_size); - } -} - -template -__global__ void store_kv_cache(DT const *devQKVProjArray, - DT *kCache_ptr, - DT *vCache_ptr, - BatchConfig::PerTokenInfo const *tokenInfos, - int num_tokens, - int max_seq_len, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { - int token_idx = i / hidden_size; - int offset = i % hidden_size; - - size_t val_idx = - token_idx * QKV_WEIGHT_NUM * hidden_size + hidden_size + offset; - - DT kVal = devQKVProjArray[val_idx]; - DT vVal = devQKVProjArray[val_idx + hidden_size]; - int const req_id = tokenInfos[token_idx].request_index; - int const tok_id = tokenInfos[token_idx].abs_depth_in_request; - - // key cache - kCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + - offset] = kVal; - vCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + - offset] = vVal; + hipLaunchKernelGGL( + HIP_KERNEL_NAME(apply_rotary_embedding_hf), + GET_BLOCKS(parallelism), + min(CUDA_NUM_THREADS, parallelism), + 0, + stream, + output_ptr, + m->complex_input, + m->token_infos, + m->rotary_embedding_meta->rope_theta, + (m->rotary_embedding_meta->rope_type == "llama3"), + m->rotary_embedding_meta->factor, + m->rotary_embedding_meta->low_freq_factor, + m->rotary_embedding_meta->high_freq_factor, + m->rotary_embedding_meta->original_max_position_embeddings, + m->qProjSize, + m->kProjSize, + num_tokens, + q_array_size, + m->hidden_size); } } @@ -723,91 +893,13 @@ void update_kv_cache_kernel(IncMultiHeadSelfAttentionMeta const *m, min(CUDA_NUM_THREADS, parallelism), 0, stream, - static_cast
(m->devQKVProjArray), - static_cast
(m->keyCache), - static_cast
(m->valueCache), - m->token_infos, - num_tokens, - BatchConfig::max_sequence_length(), - m->hidden_size); - } -} - -template -void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - int shard_id, - DT *output_ptr, - DT const *weight_ptr, - DT const *bias_ptr, - int num_tokens, - hipStream_t stream) { - hipblasDatatype_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); - miopenDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]); - assert(data_type_size(m->output_type[0]) == sizeof(DT)); -#if CUDA_VERSION >= 11000 - // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance - hipblasDatatype_t compute_type = HIPBLAS_R_16F; -#else - hipblasDatatype_t compute_type = cublas_data_type; -#endif - // Project to output, save result directly on output tensor - { - DT alpha = 1.0f, beta = 0.0f; - // after transpositions - int m_ = m->oProjSize; - int k = m->vProjSize * m->num_q_heads; - int n = num_tokens; - // before transpositions - int lda = k, ldb = k, ldc = m_; - // matrix A: output projection weight - // matrix A's layout: [vProjSize * num_heads, oProjSize] - DT const *A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads + - m->kProjSize * m->num_q_heads + - m->vProjSize * m->num_q_heads); - // matrix B: attn heads - // matrix B's layout: [vProjSize * num_heads, num_new_tokens] - DT const *B = static_cast
(m->attn_heads); - // matrix B: output - // matrix B's layout: [oProjSize, num_new_tokens] - DT *C = static_cast
(output_ptr); - - checkCUDA(hipblasGemmEx(m->handle.blas, - HIPBLAS_OP_T, - HIPBLAS_OP_N, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - B, - cublas_data_type, - ldb, - &beta, - C, - cublas_data_type, - ldc, - compute_type, - HIPBLAS_GEMM_DEFAULT)); - } - // Add final output bias - if (*m->final_bias && shard_id == 0) { - int parallelism = m->oProjSize * num_tokens; - int qkv_weight_size = m->qProjSize * m->global_num_q_heads + - m->kProjSize * m->global_num_q_heads + - m->vProjSize * m->global_num_q_heads; - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_proj_bias_w), - GET_BLOCKS(parallelism), - min(CUDA_NUM_THREADS, parallelism), - 0, - stream, - output_ptr, - bias_ptr, + static_cast
(m->devQKVProjArray), + static_cast
(m->keyCache), + static_cast
(m->valueCache), + m->token_infos, num_tokens, - qkv_weight_size, - m->oProjSize); + BatchConfig::max_sequence_length(), + m->hidden_size); } } @@ -856,93 +948,43 @@ void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m, } } -template -void pre_build_weight_kernel(IncMultiHeadSelfAttentionMeta const *m, - GenericTensorAccessorR const weight, - DataType data_type, - hipStream_t stream) { - // additional processing for weight uploading - // Note that we update weight_ptr and bias_ptr when uploading weight and - // bias - if (m->quantization_type != DT_NONE) { - // copy weight_ptr to quantized_weight_ptr, do compression and store in - // m->weight_ptr - checkCUDA(hipMemcpyAsync(m->quantized_weight_ptr, - weight.get_byte_ptr(), - m->quantized_weightSize, - hipMemcpyHostToDevice, - stream)); - - if (m->quantization_type == DT_INT4) { - int parallelism = m->qProjSize * m->qSize * m->num_q_heads / 2; - hipLaunchKernelGGL(HIP_KERNEL_NAME(decompress_int4_attention_weights), - GET_BLOCKS(parallelism), - min(CUDA_NUM_THREADS, parallelism), - 0, - stream, - m->quantized_weight_ptr, - static_cast
(m->weight_ptr), - m->qProjSize, - m->qSize, - m->num_q_heads); - } else { - assert(m->quantization_type == DT_INT8); - int parallelism = m->qProjSize * m->qSize * m->num_q_heads; - hipLaunchKernelGGL(HIP_KERNEL_NAME(decompress_int8_attention_weights), - GET_BLOCKS(parallelism), - min(CUDA_NUM_THREADS, parallelism), - 0, - stream, - m->quantized_weight_ptr, - static_cast
(m->weight_ptr), - m->qProjSize, - m->qSize, - m->num_q_heads); - } - } else { - if (data_type == DT_FLOAT) { - checkCUDA(hipMemcpyAsync(m->weight_ptr, - weight.get_float_ptr(), - m->weightSize, - hipMemcpyHostToDevice, - stream)); - } else if (data_type == DT_HALF) { - checkCUDA(hipMemcpyAsync(m->weight_ptr, - weight.get_half_ptr(), - m->weightSize, - hipMemcpyHostToDevice, - stream)); - } else { - assert(false); - } +std::string get_fwd_dbg_folder(IncMultiHeadSelfAttentionMeta const *m, + int shard_id) { + std::string op_name_without_uid = + IncMultiHeadSelfAttention::get_op_name_without_uid(m); + fs::path dst_filepath = get_dst_folder("fwd", m->decoding_step, shard_id); + if (m->layer_guid.model_id > 0) { + assert(false && "Model ID > 0 not supported yet"); } + std::string layername = "layers." + + std::to_string(m->layer_guid.transformer_layer_id) + + "." + op_name_without_uid; + dst_filepath /= layername; + return dst_filepath.string(); } template void inference_kernel(IncMultiHeadSelfAttentionMeta *m, BatchConfig const *bc, int shard_id, - DT const *input_ptr, - DT const *weight_ptr, + DT const *qkv_ptr, DT *output_ptr, - DT const *bias_ptr, hipStream_t stream) { - if (m->offload && m->biasSize > 0) { - checkCUDA(hipMemcpyAsync( - m->bias_ptr, bias_ptr, m->biasSize, hipMemcpyHostToDevice, stream)); - bias_ptr = static_cast
(m->bias_ptr); - } + // phase 0: copy calculated qkv into devQKVProjArray + // [qProjSize, num_heads, 3, num_new_tokens] + size_t qkv_proj_size = + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); - // phase 1: Implement kernel to compute KQV for input tokens - compute_qkv_kernel(m, - bc, - shard_id, - input_ptr, - weight_ptr, - static_cast
(m->devQKVProjArray), - bias_ptr, - stream); + hipMemcpyAsync(m->devQKVProjArray, + qkv_ptr, + qkv_proj_size * sizeof(DT), + hipMemcpyDeviceToDevice, + stream); + + // phase 1: Implement kernel to apply rotary embedding and scaling + compute_qkv_kernel( + m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); update_kv_cache_kernel
(m, bc, stream); if (bc->num_generation_tokens > 0) { @@ -953,14 +995,16 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, if (bc->num_tokens > bc->num_generation_tokens) { // phase 4: Compute attention score for prompt tokens; - compute_attention_kernel_prompt( - m, bc, shard_id, bias_ptr, weight_ptr, stream); + compute_attention_kernel_prompt
(m, bc, shard_id, stream); } // compute output production and bias together for all tokens int num_tokens = bc->num_active_tokens(); - compute_o_prod_bias( - m, bc, shard_id, output_ptr, weight_ptr, bias_ptr, num_tokens, stream); + hipMemcpyAsync(output_ptr, + m->attn_heads, + m->oProjSize * num_tokens * sizeof(DT), + hipMemcpyDeviceToDevice, + stream); } std::string get_peft_dbg_folder(IncMultiHeadSelfAttentionMeta const *m, @@ -978,14 +1022,75 @@ std::string get_peft_dbg_folder(IncMultiHeadSelfAttentionMeta const *m, return dst_filepath.string(); } +__global__ void transposeAdd_half_kernel( + half *out, half const *in, int width, int height, half alpha, half beta) { + int t_id = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + for (int i = t_id; i < width * height; i += num_threads) { + int row = i / width; + int col = i % width; + out[col * height + row] = + alpha * in[row * width + col] + beta * out[col * height + row]; + } +} + +__global__ void transposeAdd_float_kernel(float *out, + float const *in, + int width, + int height, + float alpha, + float beta) { + int t_id = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + for (int i = t_id; i < width * height; i += num_threads) { + int row = i / width; + int col = i % width; + out[col * height + row] = + alpha * in[row * width + col] + beta * out[col * height + row]; + } +} + +template +void transposeAdd(DT *out, + const DT *in, + int width, + int height, + float alpha, + float beta, + hipStream_t stream) { + assert(false && "Unsupported data type"); +} + +template <> +void transposeAdd(float *out, + float const *in, + int width, + int height, + float alpha, + float beta, + hipStream_t stream) { + transposeAdd_float_kernel<<<4, 1024, 0, stream>>>( + out, in, width, height, alpha, beta); +} + +template <> +void transposeAdd(half *out, + half const *in, + int width, + int height, + float alpha, + float beta, + hipStream_t stream) { + transposeAdd_half_kernel<<<4, 1024, 0, stream>>>( + out, in, width, height, __float2half(alpha), __float2half(beta)); +} + template void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, int shard_id, DT *input_grad_ptr, - DT const *weight_ptr, DT const *output_grad_ptr, - DT const *bias_ptr, hipStream_t stream) { assert(!m->offload); checkCUDA(hipblasSetStream(m->handle.blas, stream)); @@ -994,17 +1099,6 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, miopenDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]); assert(data_type_size(m->output_type[0]) == sizeof(DT)); hipblasDatatype_t compute_type = cublas_data_type; - // #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) - // hipblasDatatype_t compute_type = cublas_data_type; - // #else - // // For best performance, set the default cublas compute type to - // // CUBLAS_COMPUTE_16F for half precision and to - // // CUBLAS_COMPUTE_32F_FAST_16F for full precision - // hipblasDatatype_t compute_type = CUBLAS_COMPUTE_16F; - // if (m->output_type[0] == DT_FLOAT) { - // compute_type = CUBLAS_COMPUTE_32F_FAST_16F; - // } - // #endif for (int i = 0; i < bc->max_requests_per_batch(); i++) { if (bc->request_completed[i]) { @@ -1026,47 +1120,18 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, int vt_req_block_size = vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); assert(m->qProjSize == m->kProjSize && m->kProjSize == m->vProjSize); - // Step 1: compute gradients before final projection + // Step 1: copy gradient before final projection into workspace { int m_ = m->vProjSize * m->num_q_heads; int n_ = num_tokens; - int k_ = m->oProjSize; - int lda = m_; - int ldb = k_; - int ldc = m_; - float alpha = 1.0f, beta = 0.0f; - // matrix A: output projection weight - // matrix A's layout: [vProjSize * num_heads, oProjSize] - DT const *A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads + - m->kProjSize * m->num_q_heads + - m->vProjSize * m->num_q_heads); - // matrix B: output gradients - // matrix B's layout: [oProjSize, num_new_tokens] - DT const *B = - output_grad_ptr + - bc->requestsInfo[i].first_token_offset_in_batch * m->oProjSize; - // matrix C: attn_heads gradients - // matrix C's layout: [vProjSize * num_heads, num_new_tokens] DT *C = static_cast
(m->handle.workSpace); - checkCUDA(hipblasGemmEx(m->handle.blas, - HIPBLAS_OP_N, - HIPBLAS_OP_N, - m_, - n_, - k_, - &alpha, - A, - cublas_data_type, - lda, - B, - cublas_data_type, - ldb, - &beta, - C, - cublas_data_type, - ldc, - compute_type, - HIPBLAS_GEMM_DEFAULT)); + hipMemcpyAsync(C, + output_grad_ptr + + bc->requestsInfo[i].first_token_offset_in_batch * + m->oProjSize, + m_ * n_ * sizeof(DT), + hipMemcpyDeviceToDevice, + stream); if (m->inference_debugging) { // save result to file for checking std::string filename = @@ -1331,264 +1396,15 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, int lda = num_tokens; // num_new_tokens int ldb = m->qProjSize * m->num_q_heads; int ldc = num_tokens; - int strideA = num_tokens * num_tokens; - int strideB = m->qProjSize; - int strideC = num_tokens * m->qProjSize; - checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, - HIPBLAS_OP_N, - HIPBLAS_OP_T, - m_, - n_, - k_, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - HIPBLAS_GEMM_DEFAULT)); - if (m->inference_debugging) { - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".devQKVPRojArray_pre"; - save_tensor(C, - num_tokens * m->qProjSize * m->num_q_heads * 3, - filename.c_str()); - } - } - - // Step 7: perform rotary position embeddings (RoPE) bwd - { - if (*m->apply_rotary_embedding) { - assert(m->hidden_size == m->qProjSize * m->num_q_heads); - assert(m->qProjSize == m->kProjSize); - /*q&k*/ - int parallelism = num_tokens * m->hidden_size; - DT *A = static_cast
(m->devQKVProjArray); - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_rotary_embedding_bwd), - GET_BLOCKS(parallelism), - min(CUDA_NUM_THREADS, parallelism), - 0, - stream, - A, - m->complex_input, - m->token_infos, - m->qProjSize, - num_tokens, - m->hidden_size); - DT *C = static_cast
(m->devQKVProjArray); - if (m->inference_debugging) { - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".devQKVPRojArray"; - save_tensor(C, - num_tokens * m->qProjSize * m->num_q_heads * 3, - filename.c_str()); - } - } - - // matrix C: gradients for key (saved as part of m->devQKVProjArray) - // matrix C's layout: [num_tokens, qProjsize * num_heads, 3] - DT *C = - static_cast
(m->devQKVProjArray) + - num_tokens * - (m->qProjSize * - m->num_q_heads); // skip over regions reserved for Q gradients - if (m->inference_debugging) { - std::string filename = get_peft_dbg_folder(m, shard_id) + ".devkproj"; - save_tensor( - C, num_tokens * (m->qProjSize * m->num_q_heads), filename.c_str()); - } - } - - // Step 8: compute gradients w.r.t. input - { - float alpha = 1.0f, beta = 0.0f; - if (!m->reset_input_grads[0]) { - beta = 1.0f; - } - // matrix A: QKV projection weights - // matrix A's layout: [qSize, qProjSize * num_q_heads, 3] - DT const *A = weight_ptr; - // matrix B: gradients w.r.t. QKV (concatenated in devQKVArray) - // matrix B's layout: [num_tokens, qProjsize * num_heads, 3] - DT const *B = static_cast
(m->devQKVProjArray); - // matrix C: gradients w.r.t. input - // matrix C's layout: [m->qSize, num_tokens] - DT *C = input_grad_ptr + - bc->requestsInfo[i].first_token_offset_in_batch * m->qSize; - int m_ = m->qSize; - int n_ = num_tokens; - int k_ = m->num_q_heads * (m->qProjSize + m->kProjSize + m->vProjSize); - int lda = m_; - int ldb = n_; - int ldc = m_; - checkCUDA(hipblasGemmEx(m->handle.blas, - HIPBLAS_OP_N, - HIPBLAS_OP_T, - m_, - n_, - k_, - &alpha, - A, - cublas_data_type, - lda, - B, - cublas_data_type, - ldb, - &beta, - C, - cublas_data_type, - ldc, - compute_type, - HIPBLAS_GEMM_DEFAULT)); - if (m->inference_debugging) { - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".self_attn.input_gradient_0"; - save_tensor(C, num_tokens * m->qSize, filename.c_str()); - } - } - } -} - -} // namespace IncMultiHeadAttention -} // namespace Kernels - -using namespace Kernels::IncMultiHeadAttention; - -template -__global__ void store_query_cache(DT const *devQKVProjArray, - DT *qCache_ptr, - int num_tokens, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { - int token_idx = i / hidden_size; - int offset = i % hidden_size; - - size_t val_idx = token_idx * QKV_WEIGHT_NUM * hidden_size + offset; - - DT qVal = devQKVProjArray[val_idx]; - - // query cache - qCache_ptr[i] = qVal; - } -} - -template -void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, - BatchConfig const *bc, - int shard_id, - DT const *bias_ptr, - DT const *weight_ptr, - hipStream_t stream) { - checkCUDA(hipblasSetStream(m->handle.blas, stream)); - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); - hipblasDatatype_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); - miopenDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]); - assert(data_type_size(m->output_type[0]) == sizeof(DT)); - hipblasDatatype_t compute_type = cublas_data_type; - // #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) - // hipblasDatatype_t compute_type = cublas_data_type; - // #else - // // For best performance, set the default cublas compute type to - // // CUBLAS_COMPUTE_16F for half precision and to - // // CUBLAS_COMPUTE_32F_FAST_16F for full precision - // hipblasDatatype_t compute_type = CUBLAS_COMPUTE_16F; - // if (m->output_type[0] == DT_FLOAT) { - // compute_type = CUBLAS_COMPUTE_32F_FAST_16F; - // } - // #endif - // int num_requests = bc->num_active_requests(); - int num_tokens = bc->num_active_tokens(); - int tokens_previous_requests = 0; - int q_block_size = m->qProjSize; - int kt_block_size = m->kProjSize; - int kt_req_block_size = - kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); - int vt_block_size = m->vProjSize; - int vt_req_block_size = - vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); - assert(m->qProjSize == m->kProjSize); - - for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i] || - (!bc->requestsInfo[i].prompt_phase && !bc->requestsInfo[i].peft_bwd)) { - continue; - } - int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + - bc->requestsInfo[i].num_tokens_in_batch; - int max_peft_tokens = bc->requestsInfo[i].max_sequence_length; - // Copy query to m->query_activation_buffer if we need to compute - // PEFT backward - if (bc->requestsInfo[i].peft_bwd) { - size_t activation_size_needed = - sizeof(DT) * max_peft_tokens * m->num_q_heads * m->qProjSize; - if (activation_size_needed > m->allocated_peft_buffer_size1) { - MemoryAllocator *allocator = m->handle.peft_activation_allocator; - m->query_activation_buffer = - allocator->allocate_instance_untyped(activation_size_needed); - m->allocated_peft_buffer_size1 = activation_size_needed; - } - int parallelism = m->hidden_size * num_tokens; - hipLaunchKernelGGL(HIP_KERNEL_NAME(store_query_cache), - GET_BLOCKS(parallelism), - min(CUDA_NUM_THREADS, parallelism), - 0, - stream, - static_cast
(m->devQKVProjArray), - static_cast
(m->query_activation_buffer), - num_tokens, - m->hidden_size); - } - // Step 1: compute query-key product QK.T/sqrt(d_k) - { - // Scale by sqrt(d_k) as per the original attention paper - DT alpha = 1.0f, beta = 0.0f; - if (*m->qk_prod_scaling) { - alpha = static_cast
(1.0f / sqrt(m->kProjSize)); - } - // after transpositions - int m_ = num_new_tokens; - int n = total_tokens; - int k = m->qProjSize; - // before transpositions - int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, - ldc = m_; - // N.B. strides are applied before transpose operations - int strideA = q_block_size; - int strideB = kt_block_size; - int strideC = num_new_tokens * total_tokens; - - // matrix A: devQKVProjArray - // matrix A's layout: [qProjSize, num_heads, 3, num_new_tokens] - // To get query projection, skip over Q entries from previous requests - DT const *A = static_cast
(m->devQKVProjArray) + - bc->requestsInfo[i].first_token_offset_in_batch * - m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM; - // matrix B: key cache - // matrix B's layout: [kProjSize * num_heads, total_tokens] - // To get B, skip over K entries from previous requests (all heads + - // padding) - DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; - // matrix C: qk_prods - // matrix C's layout: [num_new_tokens, total_tokens, num_heads] - // To get C, skip over QK.T products from previous requests - DT *C = static_cast
(m->qk_prods); + int strideA = num_tokens * num_tokens; + int strideB = m->qProjSize; + int strideC = num_tokens * m->qProjSize; checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, - HIPBLAS_OP_T, HIPBLAS_OP_N, + HIPBLAS_OP_T, m_, - n, - k, + n_, + k_, &alpha, A, cublas_data_type, @@ -1606,177 +1422,111 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, m->num_q_heads, compute_type, HIPBLAS_GEMM_DEFAULT)); - } - // Step 2: Add alibi position bias to qk production - // matrix C: qk_prods - // matrix C's layout: [num_new_tokens, total_tokens, num_heads] - // To get C, skip over QK.T products from previous requests - DT *C = static_cast
(m->qk_prods); - if (*m->position_bias) { - size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_position_bias_qkprd), - GET_BLOCKS(parallelism), - min((size_t)CUDA_NUM_THREADS, parallelism), - 0, - stream, - C, - num_new_tokens, - total_tokens, - m->num_q_heads, - m->global_num_q_heads, - shard_id); - } - - // Step 3: Apply causal mask. Fill all elements above diagonal in qk prods - // with -inf to force causal attention. - assert(num_new_tokens <= total_tokens); - size_t entries_above_diagonal = num_new_tokens * (num_new_tokens - 1) / 2; - if (entries_above_diagonal > 0) { - size_t parallelism = m->num_q_heads * entries_above_diagonal; - hipLaunchKernelGGL(HIP_KERNEL_NAME(fill_entries_above_diagonal), - GET_BLOCKS(parallelism), - min((size_t)CUDA_NUM_THREADS, parallelism), - 0, - stream, - C, - num_new_tokens, - total_tokens, - m->num_q_heads, - entries_above_diagonal, - static_cast
(-INFINITY)); + if (m->inference_debugging) { + std::string filename = + get_peft_dbg_folder(m, shard_id) + ".devQKVPRojArray_pre"; + save_tensor(C, + num_tokens * m->qProjSize * m->num_q_heads * 3, + filename.c_str()); + } } - // Step 4: Compute Softmax(QK.T/sqrt(d_k)) + // Step 7: perform rotary position embeddings (RoPE) bwd { - // Before modifying the parameters below, make sure to read the following - // description of the HIPDNN_TENSOR_NCHW tensor layout, from - // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#hipdnnTensorFormat_t: - // This tensor format specifies that the data is laid out in the following - // order: batch size, feature maps, rows, columns. The strides are - // implicitly defined in such a way that the data are contiguous in memory - // with no padding between images, feature maps, rows, and columns; the - // columns are the inner dimension and the images are the outermost - // dimension. - int n_param = m->num_q_heads; - int c_param = total_tokens; - int h_param = 1; - int w_param = num_new_tokens; - checkCUDNN(miopenSet4dTensorDescriptor( - m->qk_tensor, cudnn_data_type, n_param, c_param, h_param, w_param)); - float softmax_alpha = 1.0f, softmax_beta = 0.0f; - DT *C_softmax = static_cast
(m->qk_prods_softmax); - // The softmax operation below is executed according to the - // MIOPEN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The - // softmax operation is computed per spatial location (H,W) per image (N) - // across dimension C. - checkCUDNN(miopenSoftmaxForward_V2(m->handle.dnn, - &softmax_alpha, - m->qk_tensor, - C, - &softmax_beta, - m->qk_tensor, - C_softmax, - MIOPEN_SOFTMAX_ACCURATE, - MIOPEN_SOFTMAX_MODE_CHANNEL)); - } - // Copy C_softmax to m->softmax_activation_buffer if we need to compute - // PEFT backward - if (bc->requestsInfo[i].peft_bwd) { - DT *C_softmax = static_cast
(m->qk_prods_softmax); - size_t activation_size_needed = - sizeof(DT) * max_peft_tokens * max_peft_tokens * m->num_q_heads; - if (activation_size_needed > m->allocated_peft_buffer_size2) { - MemoryAllocator *allocator = m->handle.peft_activation_allocator; - m->softmax_activation_buffer = - allocator->allocate_instance_untyped(activation_size_needed); - m->allocated_peft_buffer_size2 = activation_size_needed; + if (m->rotary_embedding_meta->apply_rotary_embedding) { + assert(m->hidden_size == m->qProjSize * m->num_q_heads); + assert(m->qProjSize == m->kProjSize); + /*q&k*/ + int parallelism = num_tokens * m->hidden_size; + DT *A = static_cast
(m->devQKVProjArray); + hipLaunchKernelGGL( + HIP_KERNEL_NAME(apply_rotary_embedding_bwd), + GET_BLOCKS(parallelism), + min(CUDA_NUM_THREADS, parallelism), + 0, + stream, + A, + m->complex_input, + m->token_infos, + m->rotary_embedding_meta->rope_theta, + (m->rotary_embedding_meta->rope_type == "llama3"), + m->rotary_embedding_meta->factor, + m->rotary_embedding_meta->low_freq_factor, + m->rotary_embedding_meta->high_freq_factor, + m->rotary_embedding_meta->original_max_position_embeddings, + m->qProjSize, + num_tokens, + m->hidden_size); + DT *C = static_cast
(m->devQKVProjArray); + if (m->inference_debugging) { + std::string filename = + get_peft_dbg_folder(m, shard_id) + ".devQKVPRojArray"; + save_tensor(C, + num_tokens * m->qProjSize * m->num_q_heads * 3, + filename.c_str()); + } + } + + // matrix C: gradients for key (saved as part of m->devQKVProjArray) + // matrix C's layout: [num_tokens, qProjsize * num_heads, 3] + DT *C = + static_cast
(m->devQKVProjArray) + + num_tokens * + (m->qProjSize * + m->num_q_heads); // skip over regions reserved for Q gradients + if (m->inference_debugging) { + std::string filename = get_peft_dbg_folder(m, shard_id) + ".devkproj"; + save_tensor( + C, num_tokens * (m->qProjSize * m->num_q_heads), filename.c_str()); } - checkCUDA(hipMemcpyAsync(m->softmax_activation_buffer, - C_softmax, - sizeof(DT) * total_tokens * num_new_tokens * - m->num_q_heads, - hipMemcpyDeviceToDevice, - stream)); } - // Step 5: Matmul softmax(QK.T/sqrt(d_k)) by V. Implemented as V @ - // softmax(QK.T/sqrt(d_k)).T + + // Step 8: compute gradients w.r.t. input { - DT alpha = 1.0f, beta = 0.0f; - // after transpositions - int m_ = m->vProjSize; - int n = num_new_tokens; - int k = total_tokens; - // before transpositions - int lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; - // N.B. strides are applied before transpose operations - int strideA = vt_block_size; - int strideB = num_new_tokens * total_tokens; - int strideC = m->vProjSize; - // matrix A: value cache - // matrix A's layout: [vProjSize, num_heads, total_tokens] - // To get A, skip over V.T entries from previous requests (all heads + - // padding) - DT *A = static_cast
(m->valueCache) + i * vt_req_block_size; - // matrix B: qk_prods_softmax - // matrix B's layout: [num_new_tokens, total_tokens, num_heads] - // To get B, skip over softmax(QK.T/sqrt(d_k)) entries from previous - // requests (all heads) - DT *B = static_cast
(m->qk_prods_softmax); - // matrix C: attn heads - // matrix C's layout: [vProjSize, num_heads, num_new_tokens] - // To get C, skip over softmax(QK.T/sqrt(d_k))V products from previous - // requests - // store the result attn heads, also skip the genration tokens - DT *C = static_cast
(m->attn_heads) + - (bc->requestsInfo[i].first_token_offset_in_batch) * - m->num_q_heads * m->vProjSize; - checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, - HIPBLAS_OP_N, - HIPBLAS_OP_T, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - HIPBLAS_GEMM_DEFAULT)); + float alpha = 1.0f, beta = 0.0f; + if (!m->reset_input_grads[0]) { + beta = 1.0f; + } + // matrix B: gradients w.r.t. QKV (concatenated in devQKVArray) + // matrix B's layout: [num_tokens, qProjsize * num_heads, 3] + DT const *B = static_cast
(m->devQKVProjArray); + // matrix C: gradients w.r.t. input + // matrix C's layout: [m->qSize, num_tokens] + DT *C = input_grad_ptr + + bc->requestsInfo[i].first_token_offset_in_batch * m->qSize; + // int m_ = m->qSize; + int n_ = num_tokens; + int k_ = m->num_q_heads * (m->qProjSize + m->kProjSize + m->vProjSize); + + // The original version uses existing result and attention's projection to + // do further calculation in a way different than the usual dense layer, + // they are off by a transpose. So an explicit transpose is needed here. + // The add here is just for gradient accumulation. + transposeAdd(C, B, n_, k_, alpha, beta, stream); + + if (m->inference_debugging) { + std::string filename = + get_peft_dbg_folder(m, shard_id) + ".self_attn.input_gradient_0"; + save_tensor(C, num_tokens * m->qSize, filename.c_str()); + } } - tokens_previous_requests += num_new_tokens; - } - if (tokens_previous_requests != (num_tokens - bc->num_generation_tokens)) { - bc->print(); - printf("tokens_previous_requests: %i\n", tokens_previous_requests); - printf("num_tokens: %i\n", num_tokens); - printf("bc->num_generation_tokens: %i\n", bc->num_generation_tokens); } - assert(tokens_previous_requests == (num_tokens - bc->num_generation_tokens)); } +} // namespace IncMultiHeadAttention +} // namespace Kernels + +using namespace Kernels::IncMultiHeadAttention; + /*static*/ void IncMultiHeadSelfAttention::inference_kernel_wrapper( IncMultiHeadSelfAttentionMeta *m, BatchConfig const *bc, int shard_id, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &weight, - GenericTensorAccessorW const &output, - GenericTensorAccessorR const &bias) { + GenericTensorAccessorW const &output) { hipStream_t stream; checkCUDA(get_legion_stream(&stream)); - bool use_bias = *m->qkv_bias || *m->final_bias; hipEvent_t t_start, t_end; if (m->profiling) { @@ -1785,43 +1535,14 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( checkCUDA(hipEventRecord(t_start, stream)); } - // assert(input.data_type == weight.data_type); assert(input.data_type == output.data_type); - if (use_bias) { - assert(input.data_type == bias.data_type); - } if (input.data_type == DT_HALF) { - if (m->offload) { - pre_build_weight_kernel(m, weight, input.data_type, stream); - } - half const *bias_ptr = - use_bias ? bias.get_half_ptr() : static_cast(nullptr); Kernels::IncMultiHeadAttention::inference_kernel( - m, - bc, - shard_id, - input.get_half_ptr(), - m->offload ? static_cast(m->weight_ptr) : weight.get_half_ptr(), - output.get_half_ptr(), - bias_ptr, - stream); + m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream); } else if (input.data_type == DT_FLOAT) { - if (m->offload) { - pre_build_weight_kernel(m, weight, input.data_type, stream); - } - float const *bias_ptr = - use_bias ? bias.get_float_ptr() : static_cast(nullptr); Kernels::IncMultiHeadAttention::inference_kernel( - m, - bc, - shard_id, - input.get_float_ptr(), - m->offload ? static_cast(m->weight_ptr) - : weight.get_float_ptr(), - output.get_float_ptr(), - bias_ptr, - stream); + m, bc, shard_id, input.get_float_ptr(), output.get_float_ptr(), stream); } else { assert(false && "Unspported data type"); } @@ -1843,12 +1564,9 @@ void IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( BatchConfig const *bc, int shard_id, GenericTensorAccessorW const &input_grad, - GenericTensorAccessorR const &weight, - GenericTensorAccessorR const &output_grad, - GenericTensorAccessorR const &bias) { + GenericTensorAccessorR const &output_grad) { hipStream_t stream; checkCUDA(get_legion_stream(&stream)); - bool use_bias = *m->qkv_bias || *m->final_bias; hipEvent_t t_start, t_end; if (m->profiling) { @@ -1857,35 +1575,23 @@ void IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( checkCUDA(hipEventRecord(t_start, stream)); } - // assert(input.data_type == weight.data_type); assert(input_grad.data_type == output_grad.data_type); - if (use_bias) { - assert(input_grad.data_type == bias.data_type); - } if (input_grad.data_type == DT_HALF) { assert(!m->offload); - half const *bias_ptr = - use_bias ? bias.get_half_ptr() : static_cast(nullptr); Kernels::IncMultiHeadAttention::peft_bwd_kernel(m, bc, shard_id, input_grad.get_half_ptr(), - weight.get_half_ptr(), output_grad.get_half_ptr(), - bias_ptr, stream); } else if (input_grad.data_type == DT_FLOAT) { assert(!m->offload); - float const *bias_ptr = - use_bias ? bias.get_float_ptr() : static_cast(nullptr); Kernels::IncMultiHeadAttention::peft_bwd_kernel(m, bc, shard_id, input_grad.get_float_ptr(), - weight.get_float_ptr(), output_grad.get_float_ptr(), - bias_ptr, stream); } else { assert(false && "Unspported data type"); @@ -1904,7 +1610,6 @@ void IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( FFHandler handler, IncMultiHeadSelfAttention const *attn, - GenericTensorAccessorR const &weight, MemoryAllocator &gpu_mem_allocator, int num_samples, int _num_q_heads, @@ -1919,14 +1624,11 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( attn->kProjSize, attn->vProjSize, attn->oProjSize, - attn->apply_rotary_embedding, - attn->qkv_bias, + attn->rotary_embedding_meta, attn->scaling_query, attn->qk_prod_scaling, attn->position_bias, - attn->final_bias, attn->scaling_factor, - weight, gpu_mem_allocator, num_samples, attn->num_q_heads, @@ -1947,14 +1649,11 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( int _kProjSize, int _vProjSize, int _oProjSize, - bool _apply_rotary_embedding, - bool _qkv_bias, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, bool _qk_prod_scaling, bool _position_bias, - bool _final_bias, float _scaling_factor, - GenericTensorAccessorR const &weight, MemoryAllocator &gpu_mem_allocator, int num_samples, int _global_num_q_heads, @@ -1963,7 +1662,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( int _num_kv_heads, DataType _quantization_type, bool _offload) - : OpMeta(handler, attn), weight_ptr(nullptr), bias_ptr(nullptr) { + : OpMeta(handler, attn) { hipStream_t stream; checkCUDA(get_legion_stream(&stream)); checkCUDNN(miopenSetStream(handler.dnn, stream)); @@ -1989,29 +1688,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( num_kv_heads = _num_kv_heads; hidden_size = num_q_heads * qProjSize; - weightSize = - ((qSize * qProjSize + oProjSize * (vProjSize > 0 ? vProjSize : vSize)) * - num_q_heads + - (kSize * kProjSize + vSize * vProjSize) * num_q_heads) * - size_of_dt; - if (quantization_type != DT_NONE) { - quantized_weightSize = get_quantization_to_byte_size( - attn->data_type, quantization_type, weightSize); - } - // biasSize = _bias ? oProjSize * size_of_dt * 4 : 0; - - int qkv_bias_size = - qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; - int final_bias_size = oProjSize; - biasSize = - (_qkv_bias ? qkv_bias_size : 0) + (final_bias ? final_bias_size : 0); - - // has_load_weights = (bool *)calloc(1, sizeof(bool)); - //*has_load_weights = false; - apply_rotary_embedding = (bool *)calloc(1, sizeof(bool)); - *apply_rotary_embedding = _apply_rotary_embedding; - qkv_bias = (bool *)calloc(1, sizeof(bool)); - *qkv_bias = _qkv_bias; + rotary_embedding_meta = + (RotaryEmbeddingMeta *)calloc(1, sizeof(RotaryEmbeddingMeta)); + *rotary_embedding_meta = _rotary_embedding_meta; scaling_query = (bool *)calloc(1, sizeof(bool)); *scaling_query = _scaling_query; scaling_factor = _scaling_factor; @@ -2019,14 +1698,6 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( *qk_prod_scaling = _qk_prod_scaling; position_bias = (bool *)calloc(1, sizeof(bool)); *position_bias = _position_bias; - final_bias = (bool *)calloc(1, sizeof(bool)); - *final_bias = _final_bias; - - // allocate weight and bias in the reserve space for cpu offloading - if (offload) { - weight_ptr = gpu_mem_allocator.allocate_reserved_untyped(weightSize); - bias_ptr = gpu_mem_allocator.allocate_reserved_untyped(biasSize); - } // allocate memory for the seqArray and reserve space { @@ -2092,9 +1763,6 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( ? key_cache_size + value_cache_size + qkv_max_proj_size : key_cache_size + value_cache_size); - if (quantization_type != DT_NONE) { - totalSharedSize += quantized_weightSize; - } assert(gpu_mem_allocator.reserved_total_size - gpu_mem_allocator.reserved_allocated_size >= totalSharedSize); @@ -2125,29 +1793,15 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( handler.batch_config_metadata->requestsInfo); if (offload) { - // token_infos = - // gpu_mem_allocator.allocate_reserved( - // tokeninfo_size); - // offset += sizeof(BatchConfig::PerTokenInfo) * tokeninfo_size; qk_prods = gpu_mem_allocator.allocate_reserved_untyped(qk_prod_size * size_of_dt); - // offset += qk_prod_size * size_of_dt; qk_prods_softmax = gpu_mem_allocator.allocate_reserved_untyped( qk_prod_size * size_of_dt); - // offset += qk_prod_size * size_of_dt; attn_heads = gpu_mem_allocator.allocate_reserved_untyped(attn_heads_size * size_of_dt); - // offset += attn_heads_size * size_of_dt; complex_input = gpu_mem_allocator.allocate_reserved(complex_size); - // offset += complex_size * sizeof(hipFloatComplex); - // request_infos = - // gpu_mem_allocator.allocate_reserved( - // requestinfo_size); } else { - // token_infos = - // gpu_mem_allocator.allocate_instance( - // tokeninfo_size); qk_prods = gpu_mem_allocator.allocate_instance_untyped(qk_prod_size * size_of_dt); qk_prods_softmax = gpu_mem_allocator.allocate_instance_untyped( @@ -2156,16 +1810,11 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( size_of_dt); complex_input = gpu_mem_allocator.allocate_instance(complex_size); - // request_infos = - // gpu_mem_allocator.allocate_instance( - // requestinfo_size); } // allocate more size for quantization data if (quantization_type != DT_NONE) { assert(offload); - quantized_weight_ptr = - gpu_mem_allocator.allocate_reserved(quantized_weightSize); } if (!offload) { assert(gpu_mem_allocator.reserved_total_size == @@ -2183,49 +1832,32 @@ IncMultiHeadSelfAttentionMeta::~IncMultiHeadSelfAttentionMeta(void) { } } -template void Kernels::IncMultiHeadAttention::pre_build_weight_kernel( - IncMultiHeadSelfAttentionMeta const *m, - GenericTensorAccessorR const weight, - DataType data_type, - hipStream_t stream); +template void + Kernels::IncMultiHeadAttention::compute_attention_kernel_generation( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + float *output_ptr, + hipStream_t stream); -template void Kernels::IncMultiHeadAttention::pre_build_weight_kernel( - IncMultiHeadSelfAttentionMeta const *m, - GenericTensorAccessorR const weight, - DataType data_type, - hipStream_t stream); +template void + Kernels::IncMultiHeadAttention::compute_attention_kernel_generation( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + half *output_ptr, + hipStream_t stream); -template void Kernels::IncMultiHeadAttention::compute_o_prod_bias( +template void Kernels::IncMultiHeadAttention::compute_qkv_kernel( IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, int shard_id, float *output_ptr, - float const *weight_ptr, - float const *bias_ptr, - int num_tokens, hipStream_t stream); -template void Kernels::IncMultiHeadAttention::compute_o_prod_bias( +template void Kernels::IncMultiHeadAttention::compute_qkv_kernel( IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, int shard_id, half *output_ptr, - half const *weight_ptr, - half const *bias_ptr, - int num_tokens, hipStream_t stream); -template void - Kernels::IncMultiHeadAttention::compute_attention_kernel_generation( - IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - float *output_ptr, - hipStream_t stream); - -template void - Kernels::IncMultiHeadAttention::compute_attention_kernel_generation( - IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - half *output_ptr, - hipStream_t stream); }; // namespace FlexFlow diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index b278611b60..2802dd41b6 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -19,6 +19,7 @@ #include "flexflow/ops/kernels/inc_multihead_self_attention_kernels.h" #include "flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh" #include "flexflow/utils/cuda_helper.h" +#include namespace FlexFlow { @@ -31,1075 +32,162 @@ using Legion::Memory; namespace Kernels { namespace IncMultiHeadAttention { -// gridDim = num_heads -// blockDim = num_tokens/num_request * head_size -// QKV tensor layout: |QKV| * num_new_tokens. |Q=K=V=head_size * num_heads| -// one thread process one head_size -template -__global__ void compute_attention_kernel_generation_kernel( - DT const *query, - DT const *key_cache, - DT const *value_cache, - DT *output_ptr, - float const scale, - int max_seq_length, - int per_head_size, - int hidden_size, - BatchConfig::PerRequestInfo *request_infos) { - - // q, k - using Q_vec = typename VEC_K::Type; - using K_vec = typename VEC_K::Type; - using V_vec = typename VEC_V
::Type; - using Out_sum = typename Vec_fp32_::Type; - - constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; - - // eg. if head_size = 128, thread_per_key = 4, with float32 precision - // then K_VEC_SIZE = 1, QK_VEC_SIZE = 4 - // K_ELTS_PER_THREAD = 128 / 4 = 32 - // K_VECS_PER_THREAD = 32 / 1 = 32 - constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(DT); - // constexpr int QK_VEC_SIZE = 16 / sizeof(DT); - // // constexpr int QK_VEC_SIZE = sizeof(Qk_vec_k) / sizeof(DT); - constexpr int K_ELTS_PER_THREAD = Dh / THREADS_PER_KEY; - constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; - // constexpr int QK_ELTS_IN_16B = 16 / sizeof(DT); - - // thread id - int const tidx = threadIdx.x; - // head id - int const head_idx = blockIdx.x; - // request idx - int const request_idx = blockIdx.y; - - int const batch_config_request_id = - request_infos[request_idx].batch_config_request_id; - - int const first_step = 0; +template +__global__ void store_kv_cache(DT const *devQKVProjArray, + DT *kCache_ptr, + DT *vCache_ptr, + BatchConfig::PerTokenInfo const *tokenInfos, + int num_tokens, + int max_seq_len, + int hidden_size) { + CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + int token_idx = i / hidden_size; + int offset = i % hidden_size; - int const tlength = - request_infos[batch_config_request_id].first_token_depth_in_request + - request_infos[batch_config_request_id].num_tokens_in_batch; + size_t val_idx = + token_idx * QKV_WEIGHT_NUM * hidden_size + hidden_size + offset; - // shared memory objects - extern __shared__ char smem_[]; + DT kVal = devQKVProjArray[val_idx]; + DT vVal = devQKVProjArray[val_idx + hidden_size]; + int const req_id = tokenInfos[token_idx].request_index; + int const tok_id = tokenInfos[token_idx].abs_depth_in_request; - float *qk_smem = reinterpret_cast(smem_); - float *out_smem = reinterpret_cast(smem_); + // key cache + kCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + + offset] = kVal; + vCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + + offset] = vVal; + } +} - float qk_max = -FLT_MAX; +template +__global__ void store_query_cache(DT const *devQKVProjArray, + DT *qCache_ptr, + int num_tokens, + int hidden_size) { + CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + int token_idx = i / hidden_size; + int offset = i % hidden_size; - // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum - __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + size_t val_idx = token_idx * QKV_WEIGHT_NUM * hidden_size + offset; - const DT *q_ptr = query + request_idx * hidden_size * QKV_WEIGHT_NUM + - head_idx * per_head_size; - __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; - // DT const *q_ptr = - // query + request_idx * Dh * QKV_WEIGHT_NUM + head_idx * per_head_size; + DT qVal = devQKVProjArray[val_idx]; - // q tensor in this thread - // if THREADS_PER_KEY is 4, first thread load 0, 4, 8, 12..., total - // K_VECS_PER_THREAD elements - // QK_vec_k: 32->1, 64->2, 128->4... head_size - // K_vec_k: 4->1, 2->2, 1->4 threads_per_key + // query cache + qCache_ptr[i] = qVal; + } +} - // the start offset of the element eg. (0, 1, 2, 3) * K_VEC_SIZE - int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; - int ki_o = tidx % THREADS_PER_KEY; - // the first key's offset for this thread - // ko = 0, 0, 0, 0, 1, 1, 1, 1, .... - int ko = tidx / THREADS_PER_KEY; - // load q tensor - Q_vec q_vec[K_VECS_PER_THREAD]; -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - q_vecs[ki_o][ii] = *reinterpret_cast( - q_ptr + ki + ii * THREADS_PER_KEY * K_VEC_SIZE); +template +__global__ void fill_entries_above_diagonal(DT *matrix, + size_t num_rows, + size_t num_cols, + size_t num_q_heads, + size_t entries_above_diagonal, + DT value) { + CUDA_KERNEL_LOOP(i, entries_above_diagonal * num_q_heads) { + size_t head_idx = i / entries_above_diagonal; + size_t entry_idx = i % entries_above_diagonal; + size_t y = (-1 + sqrt(8 * (float)entry_idx + 1)) / 2; + size_t x = entry_idx - y * (y + 1) / 2; + y += (num_cols - num_rows) + 1; + matrix[head_idx * num_rows * num_cols + num_cols * y + x] = value; } - __syncthreads(); - // first iter = 128 / 4 = 32 - // K_VECS_PER_THREAD = 32 - // K_PER_ITER how many keys in this loop - // The number of timesteps loaded per iteration. - constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; - // // The number of keys per warp. - constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; +} - DT const *k_cache_batch = - key_cache + batch_config_request_id * max_seq_length * hidden_size + ki; +template +void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, + BatchConfig const *bc, + int shard_id, + cudaStream_t stream) { + checkCUDA(cublasSetStream(m->handle.blas, stream)); + checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); + cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); + cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]); + assert(data_type_size(m->output_type[0]) == sizeof(DT)); + cudaDataType_t compute_type = cublas_data_type; - int ti_end = - div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; - // get k, perform qk proj + int num_tokens = bc->num_active_tokens(); + int tokens_previous_requests = 0; + int q_block_size = m->qProjSize; + int kt_block_size = m->kProjSize; + int kt_req_block_size = + kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + int vt_block_size = m->vProjSize; + int vt_req_block_size = + vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + assert(m->qProjSize == m->kProjSize); - for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { - K_vec k[K_VECS_PER_THREAD]; - int const ti_circ = ti % max_seq_length; -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * THREADS_PER_KEY * K_VEC_SIZE; - if (ti < tlength) { - k[ii] = *reinterpret_cast(k_cache_batch + - ti_circ * hidden_size + - head_idx * per_head_size + jj); - } - // Compute dot product. - // This includes a reduction across the threads in the same thread group. + for (int i = 0; i < bc->max_requests_per_batch(); i++) { + if (bc->request_completed[i] || + (!bc->requestsInfo[i].prompt_phase && !bc->requestsInfo[i].peft_bwd)) { + continue; } - float qk = scale * Qk_dot::dot(q_vecs[ki_o], k); - // // todo add positional embedding to the qk production - // // Store the product to shared memory. There's one qk value per - // timestep. - // // Update the max. - if (ti < tlength && tidx % THREADS_PER_KEY == 0) { - // todo add alobi here - bool const mask = ti_circ >= tlength; - if (mask) { - assert(false); + int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; + int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + + bc->requestsInfo[i].num_tokens_in_batch; + int max_peft_tokens = bc->requestsInfo[i].max_sequence_length; + // Copy query to m->query_activation_buffer if we need to compute + // PEFT backward + if (bc->requestsInfo[i].peft_bwd) { + size_t activation_size_needed = + sizeof(DT) * max_peft_tokens * m->num_q_heads * m->qProjSize; + if (activation_size_needed > m->allocated_peft_buffer_size1) { + MemoryAllocator *allocator = m->handle.peft_activation_allocator; + m->query_activation_buffer = + allocator->allocate_instance_untyped(activation_size_needed); + m->allocated_peft_buffer_size1 = activation_size_needed; } - qk_max = mask ? qk_max : fmaxf(qk_max, qk); - qk_smem[ti - first_step] = mask ? 0.f : qk; + int parallelism = m->hidden_size * num_tokens; + store_query_cache<<>>( + static_cast
(m->devQKVProjArray), + static_cast
(m->query_activation_buffer), + num_tokens, + m->hidden_size); } - } - - __syncthreads(); - -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Decompose the thread index into warp and lane. - int const warp = tidx / WARP_SIZE; - int const lane = tidx % WARP_SIZE; + // Step 1: compute query-key product QK.T/sqrt(d_k) + { + // Scale by sqrt(d_k) as per the original attention paper + DT alpha = 1.0f, beta = 0.0f; + if (*m->qk_prod_scaling) { + alpha = static_cast
(1.0f / sqrt(m->kProjSize)); + } + // after transpositions + int m_ = num_new_tokens; + int n = total_tokens; + int k = m->qProjSize; + // before transpositions + int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, + ldc = m_; + // N.B. strides are applied before transpose operations + int strideA = q_block_size; + int strideB = kt_block_size; + int strideC = num_new_tokens * total_tokens; - // The warp leader writes the max to shared memory. - if (lane == 0) { - red_smem[warp] = qk_max; - } - - // Make sure the products are in shared memory. - __syncthreads(); - - // The warps finalize the reduction. - qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Broadcast to all the threads in the warp. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - - float exp_sum = 0.f; - for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { - float logit = __expf(qk_smem[ti - first_step] - qk_max); - exp_sum += logit; - qk_smem[ti - first_step] = logit; - } - - // Compute the sum. - exp_sum = block_sum(&red_smem[WARPS_PER_BLOCK], exp_sum); - - // softmax - float inv_sum = __fdividef(1.f, exp_sum + 1.e-6); - for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { - qk_smem[ti - first_step] *= inv_sum; - } - - __syncthreads(); - // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { - // printf("softmax %.10f\n", qk_smem[0]); - // } - - // value projection - constexpr int V_VEC_SIZE = 16 / sizeof(DT); - // A vector of V elements for the current timestep. - // using V_vec_k = typename V_vec_k_::Type; - // using V_vec_acum = typename V_vec_acum_fp32_::Type; - - // The value computed by this thread. - int vo = tidx / THREADS_PER_VALUE; - // The hidden dimensions computed by this particular thread. - int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; - constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; - - Out_sum out; - zero(out); - - // The base pointer for the value in the cache buffer. - DT const *v_cache_batch = - value_cache + batch_config_request_id * max_seq_length * hidden_size + vi; - - if (Dh == Dh_MAX || vi < Dh) { - for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { - // Load the values from the cache. - int const ti_circ = ti % max_seq_length; - - V_vec v = *reinterpret_cast( - v_cache_batch + ti_circ * hidden_size + head_idx * per_head_size); - float logit = qk_smem[ti - first_step]; - out = FlexFlow::fma(logit, cast_to_float(v), out); - } - } - - // // Make sure we can start writing to shared memory. - __syncthreads(); - - // Run the final reduction amongst the different groups computing different - // partial outputs. - if (Dh == Dh_MAX || vi < Dh) { -#pragma unroll - for (int active_groups = V_PER_ITER; active_groups >= 2; - active_groups /= 2) { - - // The midpoint in the number of active groups. - int midpoint = active_groups / 2; - - // The upper part of active threads store to shared memory. - if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { - *reinterpret_cast(out_smem + (vo - midpoint) * Dh + vi) = - out; - } - __syncthreads(); - - // The bottom warps update their values. - if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { - out = add(*reinterpret_cast(out_smem + vo * Dh + vi), - out); - } - __syncthreads(); - } - } - - // Output the final values. - if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { - convert_from_float( - *reinterpret_cast(output_ptr + request_idx * hidden_size + - head_idx * per_head_size + vi), - out); - } -} - -// only used by MPT model. https://arxiv.org/abs/2108.12409 -template -__global__ void apply_position_bias_qkprd(DT *input_ptr, - int num_tokens, - int num_total_tokens, - int num_heads, - int global_num_q_heads, - int shard_id) { - CUDA_KERNEL_LOOP(i, num_tokens * num_total_tokens * num_heads) { - // get head_idx, - int head_idx = i / (num_tokens * num_total_tokens) + (num_heads * shard_id); - int position_idx = (i / num_tokens) % num_total_tokens; - position_idx = position_idx + 1 - num_total_tokens; - // 8 is alibi_bias_max in - // https://huggingface.co/mosaicml/mpt-30b/blob/main/config.json - float base = (float)(head_idx + 1) * 8 / global_num_q_heads; - float slopes = 1.0 / pow(2, base); - // if(i == 0){ - // printf("see position: %d, %f, %f, %f\n", position_idx, base, slopes, - // position_idx * slopes); - // } - input_ptr[i] += static_cast
(position_idx * slopes); - } -} - -template -__global__ void apply_proj_bias_w(DT *input_ptr, - DT const *bias_ptr, - int num_tokens, - int qkv_weight_size, - int oProjSize) { - CUDA_KERNEL_LOOP(i, num_tokens * oProjSize) { - int bias_idx = qkv_weight_size + i % oProjSize; - input_ptr[i] += bias_ptr[bias_idx]; - } -} - -template -__global__ void apply_proj_bias_qkv(DT *input_ptr, - DT const *bias_ptr, - int shard_id, - int num_tokens, - int qProjSize, - int kProjSize, - int vProjSize, - int global_num_q_heads, - int num_q_heads, - bool scaling_query, - float scaling_factor, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size * QKV_WEIGHT_NUM) { - // for simplicity, assume q, k, v is in same shape - // 0->q, 1->k, 2->v - // int qkv_index = i / (num_tokens * qProjSize) % 3; - - int token_idx = i / (hidden_size * QKV_WEIGHT_NUM); - size_t in_token_idx = i - token_idx * hidden_size * QKV_WEIGHT_NUM; - - int qkv_index = in_token_idx / hidden_size; - - int proj_size = qkv_index == 0 ? qProjSize : kProjSize; - - int head_idx = - (in_token_idx - qkv_index * num_q_heads * proj_size) / proj_size; - int global_head_idx = head_idx + shard_id * num_q_heads; - - size_t pre_length = - qkv_index == 0 - ? 0 - : (qkv_index == 1 ? qProjSize * global_num_q_heads - : qProjSize * global_num_q_heads * KV_WEIGHT_NUM); - - size_t bias_idx = pre_length + global_head_idx * proj_size + i % proj_size; - - input_ptr[i] += bias_ptr[bias_idx]; - - if (scaling_query && qkv_index == 0) { - input_ptr[i] *= scaling_factor; - } - } -} - -template -__global__ void scaling_query_kernel(DT *input_ptr, - int qProjSize, - int num_tokens, - int num_q_heads, - float scaling_factor, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { - int token_idx = i / hidden_size; - input_ptr[i % hidden_size + token_idx * hidden_size * QKV_WEIGHT_NUM] *= - scaling_factor; - } -} - -template -__global__ void - apply_rotary_embedding_native(DT *input_ptr, - cuFloatComplex *complex_input, - BatchConfig::PerTokenInfo const *tokenInfos, - int qProjSize, - int kProjSize, - int num_q_heads, - int num_tokens, - int num_kv_heads, - int q_block_size, - int k_block_size, - int q_array_size) { - CUDA_KERNEL_LOOP( - i, - num_tokens * (qProjSize * num_q_heads + kProjSize * num_kv_heads) / 2) { - // create complex number - bool q_tensor = i < (q_array_size / 2); - int proj_size = q_tensor ? qProjSize : kProjSize; - int real_i = q_tensor ? i : i - q_array_size / 2; - - int head_idx = real_i / (num_tokens * proj_size / 2); - int idx = real_i % (num_tokens * proj_size / 2); - int real_part_index = idx * 2 + - head_idx * (q_tensor ? q_block_size : k_block_size) + - (q_tensor ? 0 : q_array_size); - - int complex_part_index = real_part_index + 1; - - complex_input[i] = {input_ptr[real_part_index], - input_ptr[complex_part_index]}; - - int token_idx = - (real_i - head_idx * (num_tokens * proj_size / 2)) / (proj_size / 2); - size_t pos = tokenInfos[token_idx].abs_depth_in_request; - - // float before_real = complex_input[i].x, before_complex = - // complex_input[i].y; - - int pos_i = real_i % (proj_size / 2); - float freq = pos * (1.0 / pow(10000.0, (float)2 * pos_i / proj_size)); - cuFloatComplex complex_pos = {cos(freq), sin(freq)}; - - complex_input[i] = cuCmulf(complex_input[i], complex_pos); - input_ptr[real_part_index] = complex_input[i].x; - input_ptr[complex_part_index] = complex_input[i].y; - } -} - -template -__global__ void - apply_rotary_embedding_hf(DT *input_ptr, - cuFloatComplex *complex_input, - BatchConfig::PerTokenInfo const *tokenInfos, - int qProjSize, - int kProjSize, - int num_tokens, - size_t q_array_size, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { - // create complex number - bool q_tensor = i < (q_array_size / 2); - int proj_size = q_tensor ? qProjSize : kProjSize; - int real_i = q_tensor ? i : i - q_array_size / 2; - - int token_idx = real_i / (hidden_size / 2); - int idx = real_i % (proj_size / 2); - int head_idx = (real_i - (token_idx * (hidden_size / 2))) / (proj_size / 2); - - int real_part_index = idx + head_idx * proj_size + - token_idx * hidden_size * QKV_WEIGHT_NUM + - hidden_size * (q_tensor ? 0 : 1); - int complex_part_index = real_part_index + (proj_size / 2); - - complex_input[i] = {input_ptr[real_part_index], - input_ptr[complex_part_index]}; - - // get the freq_cis: shape 1 * (qProjSize/2) = 1 * 64 - // apply a Cartesian coordinate transformation - // multiple with input & /copy back to q/k - - // get position of token - - // size_t pos = id_map[token_idx].token_position; - size_t pos = tokenInfos[token_idx].abs_depth_in_request; - - // float before_real = complex_input[i].x, before_complex = - int pos_i = real_i % (proj_size / 2); - float freq = pos * (1.0 / pow(10000.0, (float)2 * pos_i / proj_size)); - cuFloatComplex complex_pos = {cos(freq), sin(freq)}; - - complex_input[i] = cuCmulf(complex_input[i], complex_pos); - input_ptr[real_part_index] = complex_input[i].x; - input_ptr[complex_part_index] = complex_input[i].y; - } -} - -template -__global__ void - apply_rotary_embedding_bwd(DT *input_ptr, - cuFloatComplex *complex_input, - BatchConfig::PerTokenInfo const *tokenInfos, - int proj_size, - int num_tokens, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { - // compute indexes to visit first half proj_size of each of q/k tensor. - // devQKVProj has shape [num_tokens, qProjSize, num_heads, 3] in peft_bwd - bool q_tensor = i < (num_tokens * hidden_size / 2); - int real_i = q_tensor ? i : i - num_tokens * hidden_size / 2; - assert(hidden_size % proj_size == 0); - int num_heads = hidden_size / proj_size; - - int token_idx = real_i % num_tokens; - int idx = (real_i / num_tokens) % (proj_size / 2); - int head_idx = real_i / (num_tokens * proj_size / 2); - assert(head_idx < num_heads); - - int complex_part_index = (q_tensor ? 0 : 1) * num_tokens * hidden_size + - head_idx * num_tokens * proj_size + - idx * num_tokens + token_idx; - int real_part_index = complex_part_index + (proj_size / 2) * num_tokens; - - complex_input[i] = {input_ptr[real_part_index], - input_ptr[complex_part_index]}; - - size_t pos = tokenInfos[token_idx].abs_depth_in_request; - - float freq = pos * (1.0 / pow(10000.0, (float)2 * idx / proj_size)); - cuFloatComplex complex_pos = {cos(freq), sin(freq)}; - - complex_input[i] = cuCmulf(complex_input[i], complex_pos); - input_ptr[real_part_index] = complex_input[i].x; - input_ptr[complex_part_index] = complex_input[i].y; - } -} - -template -__global__ void fill_entries_above_diagonal(DT *matrix, - size_t num_rows, - size_t num_cols, - size_t num_q_heads, - size_t entries_above_diagonal, - DT value) { - CUDA_KERNEL_LOOP(i, entries_above_diagonal * num_q_heads) { - size_t head_idx = i / entries_above_diagonal; - size_t entry_idx = i % entries_above_diagonal; - size_t y = (-1 + sqrt(8 * (float)entry_idx + 1)) / 2; - size_t x = entry_idx - y * (y + 1) / 2; - y += (num_cols - num_rows) + 1; - matrix[head_idx * num_rows * num_cols + num_cols * y + x] = value; - } -} - -template -void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - int shard_id, - DT const *input_ptr, - DT const *weight_ptr, - DT *output_ptr, - DT const *bias_ptr, - cudaStream_t stream) { - - checkCUDA(cublasSetStream(m->handle.blas, stream)); - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - assert(m->qSize == m->vSize && m->qSize == m->kSize); - cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); - cudaDataType_t compute_type = cublas_data_type; - // #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) - // cudaDataType_t compute_type = cublas_data_type; - // #else - // // For best performance, set the default cublas compute type to - // // CUBLAS_COMPUTE_16F for half precision and to - // // CUBLAS_COMPUTE_32F_FAST_16F for full precision - // cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; - // if (m->output_type[0] == DT_FLOAT) { - // compute_type = CUBLAS_COMPUTE_32F_FAST_16F; - // } - // #endif - - // Step 1: Compute QKV projections - { - DT alpha = 1.0f, beta = 0.0f; - // after transpositions - int m_q = m->qProjSize * m->num_q_heads; - int m_k = m->kProjSize * m->num_q_heads; - int m_v = m->vProjSize * m->num_q_heads; - assert(m_q == m_k && m_k == m_v); // keep things simple for now - int n = bc->num_active_infr_tokens(); - int k = m->qSize; - int m_ = m_q * QKV_WEIGHT_NUM; - // before transpositions - int lda = k, ldb = k, ldc = m_; - // matrix A: QKV weights - // matrix A's layout: [qSize (hidden_dim), qProjSize, num_heads, 3] - // matrix B: input - // matrix B's layout: [qSize (hidden_dim), num_new_tokens] - // matrix C: devQKVProjArray - // matrix B's layout: [qProjSize, num_heads, 3, num_new_tokens] - checkCUDA(cublasGemmEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - m_, - n, - k, - &alpha, - weight_ptr, - cublas_data_type, - lda, - input_ptr, - cublas_data_type, - ldb, - &beta, - output_ptr, - cublas_data_type, - ldc, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - } - - int num_tokens = bc->num_active_tokens(); - int parallelism = m->kProjSize * num_tokens * m->num_q_heads; - size_t q_array_size = m->qProjSize * num_tokens * m->num_q_heads; - - // Step 2: apply bias for QKV, or scale the query - if (*m->qkv_bias) { - apply_proj_bias_qkv<<>>(output_ptr, - bias_ptr, - shard_id, - num_tokens, - m->qProjSize, - m->kProjSize, - m->vProjSize, - m->global_num_q_heads, - m->num_q_heads, - *m->scaling_query, - m->scaling_factor, - m->hidden_size); - } else if (m->scaling_query) { - scaling_query_kernel<<>>(output_ptr, - num_tokens, - m->num_q_heads, - m->qProjSize, - m->scaling_factor, - m->hidden_size); - } - - // Step 3: apply rotary embedding if needed - if (*m->apply_rotary_embedding) { - /*q&k*/ - parallelism = num_tokens * m->hidden_size; - apply_rotary_embedding_hf<<>>(output_ptr, - m->complex_input, - m->token_infos, - m->qProjSize, - m->kProjSize, - num_tokens, - q_array_size, - m->hidden_size); - } -} - -template -void update_kv_cache_kernel(IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - cudaStream_t stream) { - int num_tokens = bc->num_active_infr_tokens(); - if (num_tokens > 0) { - int parallelism = m->hidden_size * num_tokens; - store_kv_cache<<>>(static_cast
(m->devQKVProjArray), - static_cast
(m->keyCache), - static_cast
(m->valueCache), - m->token_infos, - num_tokens, - BatchConfig::max_sequence_length(), - m->hidden_size); - } -} - -template -void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - int shard_id, - DT *output_ptr, - DT const *weight_ptr, - DT const *bias_ptr, - int num_tokens, - cudaStream_t stream) { - cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); - cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]); - assert(data_type_size(m->output_type[0]) == sizeof(DT)); -#if CUDA_VERSION >= 11000 - // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance - cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; -#else - cudaDataType_t compute_type = cublas_data_type; -#endif - // Project to output, save result directly on output tensor - { - DT alpha = 1.0f, beta = 0.0f; - // after transpositions - int m_ = m->oProjSize; - int k = m->vProjSize * m->num_q_heads; - int n = num_tokens; - // before transpositions - int lda = k, ldb = k, ldc = m_; - // matrix A: output projection weight - // matrix A's layout: [vProjSize * num_heads, oProjSize] - DT const *A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads + - m->kProjSize * m->num_q_heads + - m->vProjSize * m->num_q_heads); - // matrix B: attn heads - // matrix B's layout: [vProjSize * num_heads, num_new_tokens] - DT const *B = static_cast
(m->attn_heads); - // matrix B: output - // matrix B's layout: [oProjSize, num_new_tokens] - DT *C = static_cast
(output_ptr); - - checkCUDA(cublasGemmEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - B, - cublas_data_type, - ldb, - &beta, - C, - cublas_data_type, - ldc, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - } - // Add final output bias - if (*m->final_bias && shard_id == 0) { - int parallelism = m->oProjSize * num_tokens; - int qkv_weight_size = m->qProjSize * m->global_num_q_heads + - m->kProjSize * m->global_num_q_heads + - m->vProjSize * m->global_num_q_heads; - apply_proj_bias_w<<>>( - output_ptr, bias_ptr, num_tokens, qkv_weight_size, m->oProjSize); - } -} - -#define LAUNCH_ATTENTION_SCORE_KERNEL( \ - DT, Dh, Dh_MAX, THDS_PER_KEY, THREADS_PER_VALUE, THDS_PER_BLOCK, stream) \ - smem_sz = smem_size_in_bytes
(m->qProjSize, \ - BatchConfig::max_sequence_length(), \ - THREADS_PER_VALUE, \ - THDS_PER_BLOCK); \ - compute_attention_kernel_generation_kernel \ - <<>>( \ - static_cast
(m->devQKVProjArray), \ - static_cast
(m->keyCache), \ - static_cast
(m->valueCache), \ - output_ptr, \ - scale, \ - BatchConfig::max_sequence_length(), \ - m->qProjSize, \ - m->hidden_size, \ - m->request_infos) - -template -void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - DT *output_ptr, - cudaStream_t stream) { - dim3 grid(m->num_q_heads, bc->num_generation_tokens); - int const per_head_size = m->qProjSize; - float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f; - size_t smem_sz; - if (per_head_size == 64) { - constexpr int THREADS_PER_VALUE_64 = threads_per_value_t::value; - LAUNCH_ATTENTION_SCORE_KERNEL( - DT, 64, 64, 4, THREADS_PER_VALUE_64, 128, stream); - } else if (per_head_size == 128) { - constexpr int THREADS_PER_VALUE_128 = threads_per_value_t::value; - LAUNCH_ATTENTION_SCORE_KERNEL( - DT, 128, 128, 4, THREADS_PER_VALUE_128, 128, stream); - } else { - assert(false && "a unsupported head size"); - } -} - -template -void pre_build_weight_kernel(IncMultiHeadSelfAttentionMeta const *m, - GenericTensorAccessorR const weight, - DataType data_type, - cudaStream_t stream) { - // additional processing for weight uploading - // Note that we update weight_ptr and bias_ptr when uploading weight and - // bias - if (m->quantization_type != DT_NONE) { - // copy weight_ptr to quantized_weight_ptr, do compression and store in - // m->weight_ptr - cudaMemcpyAsync(m->quantized_weight_ptr, - weight.get_byte_ptr(), - m->quantized_weightSize, - cudaMemcpyHostToDevice, - stream); - - if (m->quantization_type == DT_INT4) { - int parallelism = m->qProjSize * m->qSize * m->num_q_heads / 2; - decompress_int4_attention_weights<<>>( - m->quantized_weight_ptr, - static_cast
(m->weight_ptr), - m->qProjSize, - m->qSize, - m->num_q_heads); - } else { - assert(m->quantization_type == DT_INT8); - int parallelism = m->qProjSize * m->qSize * m->num_q_heads; - decompress_int8_attention_weights<<>>( - m->quantized_weight_ptr, - static_cast
(m->weight_ptr), - m->qProjSize, - m->qSize, - m->num_q_heads); - } - } else { - if (data_type == DT_FLOAT) { - cudaMemcpyAsync(m->weight_ptr, - weight.get_float_ptr(), - m->weightSize, - cudaMemcpyHostToDevice, - stream); - } else if (data_type == DT_HALF) { - cudaMemcpyAsync(m->weight_ptr, - weight.get_half_ptr(), - m->weightSize, - cudaMemcpyHostToDevice, - stream); - } else { - assert(false); - } - } -} - -template -void inference_kernel(IncMultiHeadSelfAttentionMeta *m, - BatchConfig const *bc, - int shard_id, - DT const *input_ptr, - DT const *weight_ptr, - DT *output_ptr, - DT const *bias_ptr, - cudaStream_t stream) { - - if (m->offload && m->biasSize > 0) { - cudaMemcpyAsync( - m->bias_ptr, bias_ptr, m->biasSize, cudaMemcpyHostToDevice, stream); - bias_ptr = static_cast
(m->bias_ptr); - } - - // phase 1: Implement kernel to compute KQV for input tokens - compute_qkv_kernel(m, - bc, - shard_id, - input_ptr, - weight_ptr, - static_cast
(m->devQKVProjArray), - bias_ptr, - stream); - update_kv_cache_kernel
(m, bc, stream); - - if (bc->num_generation_tokens > 0) { - // phase 3: Compute attention score for generation tokens - compute_attention_kernel_generation
( - m, bc, static_cast
(m->attn_heads), stream); - } - - if (bc->num_tokens > bc->num_generation_tokens) { - // phase 4: Compute attention score for prompt tokens; - compute_attention_kernel_prompt( - m, bc, shard_id, bias_ptr, weight_ptr, stream); - } - - // compute output production and bias together for all tokens - int num_tokens = bc->num_active_tokens(); - compute_o_prod_bias( - m, bc, shard_id, output_ptr, weight_ptr, bias_ptr, num_tokens, stream); -} - -std::string get_peft_dbg_folder(IncMultiHeadSelfAttentionMeta const *m, - int shard_id) { - std::string op_name_without_uid = - IncMultiHeadSelfAttention::get_op_name_without_uid(m); - fs::path dst_filepath = get_dst_folder("bwd", m->bwd_step, shard_id); - if (m->layer_guid.model_id > 0) { - assert(false && "Model ID > 0 not supported yet"); - } - std::string layername = "layers." + - std::to_string(m->layer_guid.transformer_layer_id) + - "." + op_name_without_uid; - dst_filepath /= layername; - return dst_filepath.string(); -} - -template -void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - int shard_id, - DT *input_grad_ptr, - DT const *weight_ptr, - DT const *output_grad_ptr, - DT const *bias_ptr, - cudaStream_t stream) { - assert(!m->offload); - checkCUDA(cublasSetStream(m->handle.blas, stream)); - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); - cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]); - assert(data_type_size(m->output_type[0]) == sizeof(DT)); - cudaDataType_t compute_type = cublas_data_type; - // #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) - // cudaDataType_t compute_type = cublas_data_type; - // #else - // // For best performance, set the default cublas compute type to - // // CUBLAS_COMPUTE_16F for half precision and to - // // CUBLAS_COMPUTE_32F_FAST_16F for full precision - // cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; - // if (m->output_type[0] == DT_FLOAT) { - // compute_type = CUBLAS_COMPUTE_32F_FAST_16F; - // } - // #endif - - for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i]) { - continue; - } - if (!bc->requestsInfo[i].peft_bwd) { - continue; - } - int num_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int num_total_tokens = bc->requestsInfo[i].first_token_depth_in_request + - bc->requestsInfo[i].num_tokens_in_batch; - // Currently assume we are calculating gradients for all tokens - // of a request - assert(num_tokens == num_total_tokens); - int kt_block_size = m->kProjSize; - int kt_req_block_size = - kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); - int vt_block_size = m->vProjSize; - int vt_req_block_size = - vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); - assert(m->qProjSize == m->kProjSize && m->kProjSize == m->vProjSize); - // Step 1: compute gradients before final projection - { - int m_ = m->vProjSize * m->num_q_heads; - int n_ = num_tokens; - int k_ = m->oProjSize; - int lda = m_; - int ldb = k_; - int ldc = m_; - float alpha = 1.0f, beta = 0.0f; - // matrix A: output projection weight - // matrix A's layout: [vProjSize * num_heads, oProjSize] - DT const *A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads + - m->kProjSize * m->num_q_heads + - m->vProjSize * m->num_q_heads); - // matrix B: output gradients - // matrix B's layout: [oProjSize, num_new_tokens] - DT const *B = - output_grad_ptr + - bc->requestsInfo[i].first_token_offset_in_batch * m->oProjSize; - // matrix C: attn_heads gradients - // matrix C's layout: [vProjSize * num_heads, num_new_tokens] - DT *C = static_cast
(m->handle.workSpace); - checkCUDA(cublasGemmEx(m->handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_N, - m_, - n_, - k_, - &alpha, - A, - cublas_data_type, - lda, - B, - cublas_data_type, - ldb, - &beta, - C, - cublas_data_type, - ldc, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - if (m->inference_debugging) { - // save result to file for checking - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".o_proj.input_gradient_0"; - save_tensor(C, m_ * n_, filename.c_str()); - } - } - // Step 2: compute gradients w.r.t. value - { - float alpha = 1.0f, beta = 0.0f; - // matrix A: qk_prods_softmax - // matrix A's layout: [num_new_tokens, total_tokens, num_heads] - DT const *A = static_cast
(m->qk_prods_softmax); - // matrix B: attn_heads gradients - // matrix B's layout: [vProjSize * num_heads, num_new_tokens] - DT const *B = static_cast
(m->handle.workSpace); - // matrix C: gradients for value (saved as part of m->devQKVProjArray) - // matrix C's layout: [num_tokens, qProjsize * num_heads, 3] - DT *C = static_cast
(m->devQKVProjArray) + - 2 * num_tokens * - (m->qProjSize * m->num_q_heads); // skip over regions reserved - // for Q and K gradients - // after transpositions - int m_ = num_tokens; // total_tokens - int n_ = m->vProjSize; // num_new_tokens - int k_ = num_tokens; // num_new_tokens - // before transpositions - int lda = num_tokens; // num_new_tokens - int ldb = m->vProjSize * m->num_q_heads; - int ldc = num_tokens; // total_tokens - // N.B. strides are applied before transpose operations - int strideA = num_tokens * num_tokens; // num_new_tokens * total_tokens - int strideB = m->vProjSize; - int strideC = num_tokens * m->vProjSize; - checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_T, - m_, - n_, - k_, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - // save result to file for checking - if (m->inference_debugging) { - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".v_proj.input_gradient_0"; - save_tensor(C, m_ * n_ * m->num_q_heads, filename.c_str()); - std::string filename2 = - get_peft_dbg_folder(m, shard_id) + ".qk_prods.softmax"; - save_tensor(A, m_ * k_ * m->num_q_heads, filename2.c_str()); - } - } - // Step 3: compute gradients w.r.t. the qk_prods_softmax tensor - { - float alpha = 1.0f, beta = 0.0f; - // matrix A: attn_heads gradients - // matrix A's layout: [vProjSize * num_heads, num_new_tokens] - DT const *A = static_cast
(m->handle.workSpace); - // matrix B: value cache - // matrix B's layout: [vProjSize * num_heads, max_num_tokens, num_req] - DT const *B = static_cast
(m->valueCache) + i * vt_req_block_size; - // matrix C: qk_prods_softmax gradients + // matrix A: devQKVProjArray + // matrix A's layout: [qProjSize, num_heads, 3, num_new_tokens] + // To get query projection, skip over Q entries from previous requests + DT const *A = static_cast
(m->devQKVProjArray) + + bc->requestsInfo[i].first_token_offset_in_batch * + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM; + // matrix B: key cache + // matrix B's layout: [kProjSize * num_heads, total_tokens] + // To get B, skip over K entries from previous requests (all heads + + // padding) + DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; + // matrix C: qk_prods // matrix C's layout: [num_new_tokens, total_tokens, num_heads] - DT *C = static_cast
(m->qk_prods_softmax); - // after transposition & striding - int m_ = num_tokens; // num_new_tokens - int n_ = num_tokens; - int k_ = m->vProjSize; - // before transposition and striding - int lda = m->vProjSize * m->num_q_heads; - int ldb = m->vProjSize * m->num_q_heads; - int ldc = num_tokens; // num_new_tokens - int strideA = m->vProjSize; - int strideB = m->vProjSize; - int strideC = num_tokens * num_tokens; // num_new_tokens * total_tokens - + // To get C, skip over QK.T products from previous requests + DT *C = static_cast
(m->qk_prods); checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, CUBLAS_OP_T, CUBLAS_OP_N, m_, - n_, - k_, + n, + k, &alpha, A, cublas_data_type, @@ -1117,23 +205,57 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, m->num_q_heads, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - if (m->inference_debugging) { - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".qk_prods.softmax_grad"; - save_tensor( - C, num_tokens * num_tokens * m->num_q_heads, filename.c_str()); - std::string filename2 = get_peft_dbg_folder(m, shard_id) + ".vcache"; - save_tensor( - B, m->vProjSize * m->num_q_heads * num_tokens, filename2.c_str()); - } } - // Step 4: softmax backpropagation + // Step 2: Add alibi position bias to qk production + // matrix C: qk_prods + // matrix C's layout: [num_new_tokens, total_tokens, num_heads] + // To get C, skip over QK.T products from previous requests + DT *C = static_cast
(m->qk_prods); + if (*m->position_bias) { + size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; + apply_position_bias_qkprd<<>>(C, + num_new_tokens, + total_tokens, + m->num_q_heads, + m->global_num_q_heads, + shard_id); + } + + // Step 3: Apply causal mask. Fill all elements above diagonal in qk prods + // with -inf to force causal attention. + assert(num_new_tokens <= total_tokens); + size_t entries_above_diagonal = num_new_tokens * (num_new_tokens - 1) / 2; + if (entries_above_diagonal > 0) { + size_t parallelism = m->num_q_heads * entries_above_diagonal; + fill_entries_above_diagonal<<>>(C, + num_new_tokens, + total_tokens, + m->num_q_heads, + entries_above_diagonal, + static_cast
(-INFINITY)); + } + + // Step 4: Compute Softmax(QK.T/sqrt(d_k)) { - float alpha = 1.0f, beta = 0.0f; + // Before modifying the parameters below, make sure to read the following + // description of the CUDNN_TENSOR_NCHW tensor layout, from + // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t: + // This tensor format specifies that the data is laid out in the following + // order: batch size, feature maps, rows, columns. The strides are + // implicitly defined in such a way that the data are contiguous in memory + // with no padding between images, feature maps, rows, and columns; the + // columns are the inner dimension and the images are the outermost + // dimension. int n_param = m->num_q_heads; - int c_param = num_tokens; + int c_param = total_tokens; int h_param = 1; - int w_param = num_tokens; + int w_param = num_new_tokens; checkCUDNN(cudnnSetTensor4dDescriptor(m->qk_tensor, CUDNN_TENSOR_NCHW, cudnn_data_type, @@ -1141,85 +263,79 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, c_param, h_param, w_param)); - checkCUDNN(cudnnSoftmaxBackward(m->handle.dnn, - CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_CHANNEL, - &alpha, - m->qk_tensor, - m->softmax_activation_buffer, - m->qk_tensor, - m->qk_prods_softmax, - &beta, - m->qk_tensor, - m->qk_prods)); - - if (m->inference_debugging) { - DT *C = static_cast
(m->qk_prods); - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".qk_prods.softmax_grad_in"; - save_tensor( - C, num_tokens * num_tokens * m->num_q_heads, filename.c_str()); - } - - // TODO: fill all elements above diagonal to force causal attention - size_t entries_above_diagonal = num_tokens * (num_tokens - 1) / 2; - if (entries_above_diagonal > 0) { - size_t parallelism = m->num_q_heads * entries_above_diagonal; - fill_entries_above_diagonal<<>>(static_cast
(m->qk_prods), - num_tokens, - num_tokens, - m->num_q_heads, - entries_above_diagonal, - DT(0.0f)); - } - if (m->inference_debugging) { - DT *C = static_cast
(m->qk_prods); - std::string filename = get_peft_dbg_folder(m, shard_id) + - ".qk_prods.softmax_grad_in.masked"; - save_tensor( - C, num_tokens * num_tokens * m->num_q_heads, filename.c_str()); + float softmax_alpha = 1.0f, softmax_beta = 0.0f; + DT *C_softmax = static_cast
(m->qk_prods_softmax); + // The softmax operation below is executed according to the + // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The + // softmax operation is computed per spatial location (H,W) per image (N) + // across dimension C. + checkCUDNN(cudnnSoftmaxForward(m->handle.dnn, + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_CHANNEL, + &softmax_alpha, + m->qk_tensor, + C, + &softmax_beta, + m->qk_tensor, + C_softmax)); + } + // Copy C_softmax to m->softmax_activation_buffer if we need to compute + // PEFT backward + if (bc->requestsInfo[i].peft_bwd) { + DT *C_softmax = static_cast
(m->qk_prods_softmax); + size_t activation_size_needed = + sizeof(DT) * max_peft_tokens * max_peft_tokens * m->num_q_heads; + if (activation_size_needed > m->allocated_peft_buffer_size2) { + MemoryAllocator *allocator = m->handle.peft_activation_allocator; + m->softmax_activation_buffer = + allocator->allocate_instance_untyped(activation_size_needed); + m->allocated_peft_buffer_size2 = activation_size_needed; } + checkCUDA(cudaMemcpyAsync(m->softmax_activation_buffer, + C_softmax, + sizeof(DT) * total_tokens * num_new_tokens * + m->num_q_heads, + cudaMemcpyDeviceToDevice, + stream)); } - // Step 5: compute gradients w.r.t. key + // Step 5: Matmul softmax(QK.T/sqrt(d_k)) by V. Implemented as V @ + // softmax(QK.T/sqrt(d_k)).T { - float alpha = 1.0f, beta = 0.0f; - if (*m->qk_prod_scaling) { - alpha = 1.0f / sqrt(m->kProjSize); - } - // matrix A: gradients w.r.t. qk_prods - // matrix A's layout: [num_new_tokens, num_tokens, num_heads] - DT const *A = static_cast
(m->qk_prods); - // matrix B: query activation (in query_activation_buffer) - // matrix B's layout: [m->qProjSize * num_heads, num_new_tokens] - DT const *B = static_cast
(m->query_activation_buffer); - // matrix C: gradients for key (saved as part of m->devQKVProjArray) - // matrix C's layout: [num_tokens, qProjsize * num_heads, 3] - DT *C = - static_cast
(m->devQKVProjArray) + - num_tokens * - (m->qProjSize * - m->num_q_heads); // skip over regions reserved for Q gradients - // after transposition & striding - int m_ = num_tokens; - int n_ = m->kProjSize; - int k_ = num_tokens; // num_new_tokens - // before transposition and striding - int lda = num_tokens; // num_new_tokens - int ldb = m->kProjSize * m->num_q_heads; - int ldc = num_tokens; - int strideA = num_tokens * num_tokens; - int strideB = m->kProjSize; - int strideC = num_tokens * m->kProjSize; + DT alpha = 1.0f, beta = 0.0f; + // after transpositions + int m_ = m->vProjSize; + int n = num_new_tokens; + int k = total_tokens; + // before transpositions + int lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; + // N.B. strides are applied before transpose operations + int strideA = vt_block_size; + int strideB = num_new_tokens * total_tokens; + int strideC = m->vProjSize; + // matrix A: value cache + // matrix A's layout: [vProjSize, num_heads, total_tokens] + // To get A, skip over V.T entries from previous requests (all heads + + // padding) + DT *A = static_cast
(m->valueCache) + i * vt_req_block_size; + // matrix B: qk_prods_softmax + // matrix B's layout: [num_new_tokens, total_tokens, num_heads] + // To get B, skip over softmax(QK.T/sqrt(d_k)) entries from previous + // requests (all heads) + DT *B = static_cast
(m->qk_prods_softmax); + // matrix C: attn heads + // matrix C's layout: [vProjSize, num_heads, num_new_tokens] + // To get C, skip over softmax(QK.T/sqrt(d_k))V products from previous + // requests + // store the result attn heads, also skip the genration tokens + DT *C = static_cast
(m->attn_heads) + + (bc->requestsInfo[i].first_token_offset_in_batch) * + m->num_q_heads * m->vProjSize; checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_T, + CUBLAS_OP_N, CUBLAS_OP_T, m_, - n_, - k_, + n, + k, &alpha, A, cublas_data_type, @@ -1237,323 +353,797 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, m->num_q_heads, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - if (m->inference_debugging) { - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".query_activation"; - save_tensor( - B, m->qProjSize * m->num_q_heads * num_tokens, filename.c_str()); - std::string filename2 = - get_peft_dbg_folder(m, shard_id) + ".devkproj_pre"; - save_tensor( - C, num_tokens * (m->qProjSize * m->num_q_heads), filename2.c_str()); + } + tokens_previous_requests += num_new_tokens; + } + if (tokens_previous_requests != (num_tokens - bc->num_generation_tokens)) { + bc->print(); + printf("tokens_previous_requests: %i\n", tokens_previous_requests); + printf("num_tokens: %i\n", num_tokens); + printf("bc->num_generation_tokens: %i\n", bc->num_generation_tokens); + } + assert(tokens_previous_requests == (num_tokens - bc->num_generation_tokens)); +} + +// gridDim = num_heads +// blockDim = num_tokens/num_request * head_size +// QKV tensor layout: |QKV| * num_new_tokens. |Q=K=V=head_size * num_heads| +// one thread process one head_size +template +__global__ void compute_attention_kernel_generation_kernel( + DT const *query, + DT const *key_cache, + DT const *value_cache, + DT *output_ptr, + float const scale, + int max_seq_length, + int per_head_size, + int hidden_size, + BatchConfig::PerRequestInfo *request_infos) { + + // q, k + using Q_vec = typename VEC_K::Type; + using K_vec = typename VEC_K::Type; + using V_vec = typename VEC_V
::Type; + using Out_sum = typename Vec_fp32_::Type; + + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + // eg. if head_size = 128, thread_per_key = 4, with float32 precision + // then K_VEC_SIZE = 1, QK_VEC_SIZE = 4 + // K_ELTS_PER_THREAD = 128 / 4 = 32 + // K_VECS_PER_THREAD = 32 / 1 = 32 + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(DT); + // constexpr int QK_VEC_SIZE = 16 / sizeof(DT); + // // constexpr int QK_VEC_SIZE = sizeof(Qk_vec_k) / sizeof(DT); + constexpr int K_ELTS_PER_THREAD = Dh / THREADS_PER_KEY; + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + // constexpr int QK_ELTS_IN_16B = 16 / sizeof(DT); + + // thread id + int const tidx = threadIdx.x; + // head id + int const head_idx = blockIdx.x; + // request idx + int const request_idx = blockIdx.y; + + int const batch_config_request_id = + request_infos[request_idx].batch_config_request_id; + + int const first_step = 0; + + int const tlength = + request_infos[batch_config_request_id].first_token_depth_in_request + + request_infos[batch_config_request_id].num_tokens_in_batch; + + // shared memory objects + extern __shared__ char smem_[]; + + float *qk_smem = reinterpret_cast(smem_); + float *out_smem = reinterpret_cast(smem_); + + float qk_max = -FLT_MAX; + + // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + + const DT *q_ptr = query + request_idx * hidden_size * QKV_WEIGHT_NUM + + head_idx * per_head_size; + __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; + // DT const *q_ptr = + // query + request_idx * Dh * QKV_WEIGHT_NUM + head_idx * per_head_size; + + // q tensor in this thread + // if THREADS_PER_KEY is 4, first thread load 0, 4, 8, 12..., total + // K_VECS_PER_THREAD elements + // QK_vec_k: 32->1, 64->2, 128->4... head_size + // K_vec_k: 4->1, 2->2, 1->4 threads_per_key + + // the start offset of the element eg. (0, 1, 2, 3) * K_VEC_SIZE + int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; + int ki_o = tidx % THREADS_PER_KEY; + // the first key's offset for this thread + // ko = 0, 0, 0, 0, 1, 1, 1, 1, .... + int ko = tidx / THREADS_PER_KEY; + // load q tensor + Q_vec q_vec[K_VECS_PER_THREAD]; +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + q_vecs[ki_o][ii] = *reinterpret_cast( + q_ptr + ki + ii * THREADS_PER_KEY * K_VEC_SIZE); + } + __syncthreads(); + // first iter = 128 / 4 = 32 + // K_VECS_PER_THREAD = 32 + // K_PER_ITER how many keys in this loop + // The number of timesteps loaded per iteration. + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + // // The number of keys per warp. + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + DT const *k_cache_batch = + key_cache + batch_config_request_id * max_seq_length * hidden_size + ki; + + int ti_end = + div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; + // get k, perform qk proj + + for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { + K_vec k[K_VECS_PER_THREAD]; + int const ti_circ = ti % max_seq_length; +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * THREADS_PER_KEY * K_VEC_SIZE; + if (ti < tlength) { + k[ii] = *reinterpret_cast(k_cache_batch + + ti_circ * hidden_size + + head_idx * per_head_size + jj); + } + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + } + float qk = scale * Qk_dot::dot(q_vecs[ki_o], k); + // // todo add positional embedding to the qk production + // // Store the product to shared memory. There's one qk value per + // timestep. + // // Update the max. + if (ti < tlength && tidx % THREADS_PER_KEY == 0) { + // todo add alobi here + bool const mask = ti_circ >= tlength; + if (mask) { + assert(false); } + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + qk_smem[ti - first_step] = mask ? 0.f : qk; + } + } + + __syncthreads(); + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Decompose the thread index into warp and lane. + int const warp = tidx / WARP_SIZE; + int const lane = tidx % WARP_SIZE; + + // The warp leader writes the max to shared memory. + if (lane == 0) { + red_smem[warp] = qk_max; + } + + // Make sure the products are in shared memory. + __syncthreads(); + + // The warps finalize the reduction. + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Broadcast to all the threads in the warp. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + float exp_sum = 0.f; + for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { + float logit = __expf(qk_smem[ti - first_step] - qk_max); + exp_sum += logit; + qk_smem[ti - first_step] = logit; + } + + // Compute the sum. + exp_sum = block_sum(&red_smem[WARPS_PER_BLOCK], exp_sum); + + // softmax + float inv_sum = __fdividef(1.f, exp_sum + 1.e-6); + for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { + qk_smem[ti - first_step] *= inv_sum; + } + + __syncthreads(); + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + // printf("softmax %.10f\n", qk_smem[0]); + // } + + // value projection + constexpr int V_VEC_SIZE = 16 / sizeof(DT); + // A vector of V elements for the current timestep. + // using V_vec_k = typename V_vec_k_::Type; + // using V_vec_acum = typename V_vec_acum_fp32_::Type; + + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + + Out_sum out; + zero(out); + + // The base pointer for the value in the cache buffer. + DT const *v_cache_batch = + value_cache + batch_config_request_id * max_seq_length * hidden_size + vi; + + if (Dh == Dh_MAX || vi < Dh) { + for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { + // Load the values from the cache. + int const ti_circ = ti % max_seq_length; + + V_vec v = *reinterpret_cast( + v_cache_batch + ti_circ * hidden_size + head_idx * per_head_size); + float logit = qk_smem[ti - first_step]; + out = FlexFlow::fma(logit, cast_to_float(v), out); } - // Step 6: compute gradients w.r.t query - { - float alpha = 1.0f, beta = 0.0f; - if (*m->qk_prod_scaling) { - alpha = 1.0f / sqrt(m->kProjSize); + } + + // // Make sure we can start writing to shared memory. + __syncthreads(); + + // Run the final reduction amongst the different groups computing different + // partial outputs. + if (Dh == Dh_MAX || vi < Dh) { +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; + active_groups /= 2) { + + // The midpoint in the number of active groups. + int midpoint = active_groups / 2; + + // The upper part of active threads store to shared memory. + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { + *reinterpret_cast(out_smem + (vo - midpoint) * Dh + vi) = + out; } - // matrix A: gradients w.r.t. qk_prods - // matrix A's layout: [num_new_tokens, num_tokens, num_heads] - DT const *A = static_cast
(m->qk_prods); - // matrix B: key cache - // matrix B's layout: [vProjSize * num_heads, max_num_tokens, num_req] - DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; - // matrix C: gradients for query (saved as part of m->devQKVProjArray) - // matrix C's layout: [num_tokens, qProjsize * num_heads, 3] - DT *C = static_cast
(m->devQKVProjArray); - // after transposition & striding - int m_ = num_tokens; // num_new_tokens - int n_ = m->qProjSize; - int k_ = num_tokens; - // before transposition and striding - int lda = num_tokens; // num_new_tokens - int ldb = m->qProjSize * m->num_q_heads; - int ldc = num_tokens; - int strideA = num_tokens * num_tokens; - int strideB = m->qProjSize; - int strideC = num_tokens * m->qProjSize; - checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_T, - m_, - n_, - k_, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - if (m->inference_debugging) { - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".devQKVPRojArray_pre"; - save_tensor(C, - num_tokens * m->qProjSize * m->num_q_heads * 3, - filename.c_str()); + __syncthreads(); + + // The bottom warps update their values. + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = add(*reinterpret_cast(out_smem + vo * Dh + vi), + out); } + __syncthreads(); } + } - // Step 7: perform rotary position embeddings (RoPE) bwd - { - if (*m->apply_rotary_embedding) { - assert(m->hidden_size == m->qProjSize * m->num_q_heads); - assert(m->qProjSize == m->kProjSize); - /*q&k*/ - int parallelism = num_tokens * m->hidden_size; - DT *A = static_cast
(m->devQKVProjArray); - apply_rotary_embedding_bwd<<>>(A, - m->complex_input, - m->token_infos, - m->qProjSize, - num_tokens, - m->hidden_size); - DT *C = static_cast
(m->devQKVProjArray); - if (m->inference_debugging) { - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".devQKVPRojArray"; - save_tensor(C, - num_tokens * m->qProjSize * m->num_q_heads * 3, - filename.c_str()); - } - } + // Output the final values. + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { + convert_from_float( + *reinterpret_cast(output_ptr + request_idx * hidden_size + + head_idx * per_head_size + vi), + out); + } +} - // matrix C: gradients for key (saved as part of m->devQKVProjArray) - // matrix C's layout: [num_tokens, qProjsize * num_heads, 3] - DT *C = - static_cast
(m->devQKVProjArray) + - num_tokens * - (m->qProjSize * - m->num_q_heads); // skip over regions reserved for Q gradients - if (m->inference_debugging) { - std::string filename = get_peft_dbg_folder(m, shard_id) + ".devkproj"; - save_tensor( - C, num_tokens * (m->qProjSize * m->num_q_heads), filename.c_str()); +// only used by MPT model. https://arxiv.org/abs/2108.12409 +template +__global__ void apply_position_bias_qkprd(DT *input_ptr, + int num_tokens, + int num_total_tokens, + int num_heads, + int global_num_q_heads, + int shard_id) { + CUDA_KERNEL_LOOP(i, num_tokens * num_total_tokens * num_heads) { + // get head_idx, + int head_idx = i / (num_tokens * num_total_tokens) + (num_heads * shard_id); + int position_idx = (i / num_tokens) % num_total_tokens; + position_idx = position_idx + 1 - num_total_tokens; + // 8 is alibi_bias_max in + // https://huggingface.co/mosaicml/mpt-30b/blob/main/config.json + float base = (float)(head_idx + 1) * 8 / global_num_q_heads; + float slopes = 1.0 / pow(2, base); + // if(i == 0){ + // printf("see position: %d, %f, %f, %f\n", position_idx, base, slopes, + // position_idx * slopes); + // } + input_ptr[i] += static_cast
(position_idx * slopes); + } +} + +template +__global__ void scaling_query_kernel(DT *input_ptr, + int qProjSize, + int num_tokens, + int num_q_heads, + float scaling_factor, + int hidden_size) { + CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + int token_idx = i / hidden_size; + input_ptr[i % hidden_size + token_idx * hidden_size * QKV_WEIGHT_NUM] *= + scaling_factor; + } +} + +template +__global__ void + apply_rotary_embedding_hf(DT *input_ptr, + cuFloatComplex *complex_input, + BatchConfig::PerTokenInfo const *tokenInfos, + float rope_theta, + bool llama3_rope, + float factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, + int qProjSize, + int kProjSize, + int num_tokens, + size_t q_array_size, + int hidden_size) { + CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + // create complex number + bool q_tensor = i < (q_array_size / 2); + int proj_size = q_tensor ? qProjSize : kProjSize; + int real_i = q_tensor ? i : i - q_array_size / 2; + + int token_idx = real_i / (hidden_size / 2); + int idx = real_i % (proj_size / 2); + int head_idx = (real_i - (token_idx * (hidden_size / 2))) / (proj_size / 2); + + int real_part_index = idx + head_idx * proj_size + + token_idx * hidden_size * QKV_WEIGHT_NUM + + hidden_size * (q_tensor ? 0 : 1); + int complex_part_index = real_part_index + (proj_size / 2); + + complex_input[i] = {input_ptr[real_part_index], + input_ptr[complex_part_index]}; + + // get the freq_cis: shape 1 * (qProjSize/2) = 1 * 64 + // apply a Cartesian coordinate transformation + // multiple with input & /copy back to q/k + + // get position of token + + // size_t pos = id_map[token_idx].token_position; + size_t pos = tokenInfos[token_idx].abs_depth_in_request; + + // float before_real = complex_input[i].x, before_complex = + int pos_i = real_i % (proj_size / 2); + + float freq = + pos * (1.0 / pow(rope_theta, (float)2 * pos_i / proj_size)); // θ_i + + if (llama3_rope) { + float pi = CUDART_PI_F; + float wavelen = 2 * pi / freq; + float low_freq_wavelen = + original_max_position_embeddings / low_freq_factor; + float high_freq_wavelen = + original_max_position_embeddings / high_freq_factor; + if (wavelen < high_freq_wavelen) { + } else if (wavelen > low_freq_wavelen) { + freq = freq / factor; + } else { + assert(low_freq_wavelen != high_freq_wavelen); + float smooth = + (original_max_position_embeddings / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor); + freq = ((1 - smooth) * freq / factor + smooth * freq); } } - // Step 8: compute gradients w.r.t. input - { - float alpha = 1.0f, beta = 0.0f; - if (!m->reset_input_grads[0]) { - beta = 1.0f; - } - // matrix A: QKV projection weights - // matrix A's layout: [qSize, qProjSize * num_q_heads, 3] - DT const *A = weight_ptr; - // matrix B: gradients w.r.t. QKV (concatenated in devQKVArray) - // matrix B's layout: [num_tokens, qProjsize * num_heads, 3] - DT const *B = static_cast
(m->devQKVProjArray); - // matrix C: gradients w.r.t. input - // matrix C's layout: [m->qSize, num_tokens] - DT *C = input_grad_ptr + - bc->requestsInfo[i].first_token_offset_in_batch * m->qSize; - int m_ = m->qSize; - int n_ = num_tokens; - int k_ = m->num_q_heads * (m->qProjSize + m->kProjSize + m->vProjSize); - int lda = m_; - int ldb = n_; - int ldc = m_; - checkCUDA(cublasGemmEx(m->handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_T, - m_, - n_, - k_, - &alpha, - A, - cublas_data_type, - lda, - B, - cublas_data_type, - ldb, - &beta, - C, - cublas_data_type, - ldc, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - if (m->inference_debugging) { - std::string filename = - get_peft_dbg_folder(m, shard_id) + ".self_attn.input_gradient_0"; - save_tensor(C, num_tokens * m->qSize, filename.c_str()); + cuFloatComplex complex_pos = {cos(freq), sin(freq)}; + + complex_input[i] = cuCmulf(complex_input[i], complex_pos); + input_ptr[real_part_index] = complex_input[i].x; + input_ptr[complex_part_index] = complex_input[i].y; + } +} + +template +__global__ void + apply_rotary_embedding_bwd(DT *input_ptr, + cuFloatComplex *complex_input, + BatchConfig::PerTokenInfo const *tokenInfos, + float rope_theta, + bool llama3_rope, + float factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, + int proj_size, + int num_tokens, + int hidden_size) { + CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + // compute indexes to visit first half proj_size of each of q/k tensor. + // devQKVProj has shape [num_tokens, qProjSize, num_heads, 3] in peft_bwd + bool q_tensor = i < (num_tokens * hidden_size / 2); + int real_i = q_tensor ? i : i - num_tokens * hidden_size / 2; + assert(hidden_size % proj_size == 0); + int num_heads = hidden_size / proj_size; + + int token_idx = real_i % num_tokens; + int idx = (real_i / num_tokens) % (proj_size / 2); + int head_idx = real_i / (num_tokens * proj_size / 2); + assert(head_idx < num_heads); + + int complex_part_index = (q_tensor ? 0 : 1) * num_tokens * hidden_size + + head_idx * num_tokens * proj_size + + idx * num_tokens + token_idx; + int real_part_index = complex_part_index + (proj_size / 2) * num_tokens; + + complex_input[i] = {input_ptr[real_part_index], + input_ptr[complex_part_index]}; + + size_t pos = tokenInfos[token_idx].abs_depth_in_request; + + float freq = + pos * (1.0 / pow(rope_theta, (float)2 * idx / proj_size)); // θ_i + + if (llama3_rope) { + float pi = CUDART_PI_F; + float wavelen = 2 * pi / freq; + float low_freq_wavelen = + original_max_position_embeddings / low_freq_factor; + float high_freq_wavelen = + original_max_position_embeddings / high_freq_factor; + if (wavelen < high_freq_wavelen) { + } else if (wavelen > low_freq_wavelen) { + freq = freq / factor; + } else { + assert(low_freq_wavelen != high_freq_wavelen); + float smooth = + (original_max_position_embeddings / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor); + freq = ((1 - smooth) * freq / factor + smooth * freq); } } + + cuFloatComplex complex_pos = {cos(freq), sin(freq)}; + + complex_input[i] = cuCmulf(complex_input[i], complex_pos); + input_ptr[real_part_index] = complex_input[i].x; + input_ptr[complex_part_index] = complex_input[i].y; } } -} // namespace IncMultiHeadAttention -} // namespace Kernels +template +void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + int shard_id, + DT *output_ptr, + cudaStream_t stream) { -using namespace Kernels::IncMultiHeadAttention; + checkCUDA(cublasSetStream(m->handle.blas, stream)); + checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); + assert(m->qSize == m->vSize && m->qSize == m->kSize); + + int num_tokens = bc->num_active_tokens(); + int parallelism = m->kProjSize * num_tokens * m->num_q_heads; + size_t q_array_size = m->qProjSize * num_tokens * m->num_q_heads; + + if (m->scaling_query) { + scaling_query_kernel<<>>(output_ptr, + m->qProjSize, + num_tokens, + m->num_q_heads, + m->scaling_factor, + m->hidden_size); + } + + // Step 3: apply rotary embedding if needed + if (m->rotary_embedding_meta->apply_rotary_embedding) { + /*q&k*/ + parallelism = num_tokens * m->hidden_size; + apply_rotary_embedding_hf<<>>( + output_ptr, + m->complex_input, + m->token_infos, + m->rotary_embedding_meta->rope_theta, + (m->rotary_embedding_meta->rope_type == "llama3"), + m->rotary_embedding_meta->factor, + m->rotary_embedding_meta->low_freq_factor, + m->rotary_embedding_meta->high_freq_factor, + m->rotary_embedding_meta->original_max_position_embeddings, + m->qProjSize, + m->kProjSize, + num_tokens, + q_array_size, + m->hidden_size); + } +} template -__global__ void store_kv_cache(DT const *devQKVProjArray, - DT *kCache_ptr, - DT *vCache_ptr, - BatchConfig::PerTokenInfo const *tokenInfos, - int num_tokens, - int max_seq_len, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { - int token_idx = i / hidden_size; - int offset = i % hidden_size; +void update_kv_cache_kernel(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + cudaStream_t stream) { + int num_tokens = bc->num_active_infr_tokens(); + if (num_tokens > 0) { + int parallelism = m->hidden_size * num_tokens; + store_kv_cache<<>>(static_cast
(m->devQKVProjArray), + static_cast
(m->keyCache), + static_cast
(m->valueCache), + m->token_infos, + num_tokens, + BatchConfig::max_sequence_length(), + m->hidden_size); + } +} - size_t val_idx = - token_idx * QKV_WEIGHT_NUM * hidden_size + hidden_size + offset; +#define LAUNCH_ATTENTION_SCORE_KERNEL( \ + DT, Dh, Dh_MAX, THDS_PER_KEY, THREADS_PER_VALUE, THDS_PER_BLOCK, stream) \ + smem_sz = smem_size_in_bytes
(m->qProjSize, \ + BatchConfig::max_sequence_length(), \ + THREADS_PER_VALUE, \ + THDS_PER_BLOCK); \ + compute_attention_kernel_generation_kernel \ + <<>>( \ + static_cast
(m->devQKVProjArray), \ + static_cast
(m->keyCache), \ + static_cast
(m->valueCache), \ + output_ptr, \ + scale, \ + BatchConfig::max_sequence_length(), \ + m->qProjSize, \ + m->hidden_size, \ + m->request_infos) - DT kVal = devQKVProjArray[val_idx]; - DT vVal = devQKVProjArray[val_idx + hidden_size]; - int const req_id = tokenInfos[token_idx].request_index; - int const tok_id = tokenInfos[token_idx].abs_depth_in_request; +template +void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + DT *output_ptr, + cudaStream_t stream) { + dim3 grid(m->num_q_heads, bc->num_generation_tokens); + int const per_head_size = m->qProjSize; + float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f; + size_t smem_sz; + if (per_head_size == 64) { + constexpr int THREADS_PER_VALUE_64 = threads_per_value_t::value; + LAUNCH_ATTENTION_SCORE_KERNEL( + DT, 64, 64, 4, THREADS_PER_VALUE_64, 128, stream); + } else if (per_head_size == 128) { + constexpr int THREADS_PER_VALUE_128 = threads_per_value_t::value; + LAUNCH_ATTENTION_SCORE_KERNEL( + DT, 128, 128, 4, THREADS_PER_VALUE_128, 128, stream); + } else { + assert(false && "a unsupported head size"); + } +} - // key cache - kCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + - offset] = kVal; - vCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + - offset] = vVal; +std::string get_fwd_dbg_folder(IncMultiHeadSelfAttentionMeta const *m, + int shard_id) { + std::string op_name_without_uid = + IncMultiHeadSelfAttention::get_op_name_without_uid(m); + fs::path dst_filepath = get_dst_folder("fwd", m->decoding_step, shard_id); + if (m->layer_guid.model_id > 0) { + assert(false && "Model ID > 0 not supported yet"); } + std::string layername = "layers." + + std::to_string(m->layer_guid.transformer_layer_id) + + "." + op_name_without_uid; + dst_filepath /= layername; + return dst_filepath.string(); } template -__global__ void store_query_cache(DT const *devQKVProjArray, - DT *qCache_ptr, - int num_tokens, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { - int token_idx = i / hidden_size; - int offset = i % hidden_size; +void inference_kernel(IncMultiHeadSelfAttentionMeta *m, + BatchConfig const *bc, + int shard_id, + DT const *qkv_ptr, + DT *output_ptr, + cudaStream_t stream) { - size_t val_idx = token_idx * QKV_WEIGHT_NUM * hidden_size + offset; + // phase 0: copy calculated qkv into devQKVProjArray + // [qProjSize, num_heads, 3, num_new_tokens] + size_t qkv_proj_size = + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); - DT qVal = devQKVProjArray[val_idx]; + cudaMemcpyAsync(m->devQKVProjArray, + qkv_ptr, + qkv_proj_size * sizeof(DT), + cudaMemcpyDeviceToDevice, + stream); - // query cache - qCache_ptr[i] = qVal; + // phase 1: Implement kernel to apply rotary embedding and scaling + compute_qkv_kernel( + m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); + update_kv_cache_kernel
(m, bc, stream); + + if (bc->num_generation_tokens > 0) { + // phase 3: Compute attention score for generation tokens + compute_attention_kernel_generation
( + m, bc, static_cast
(m->attn_heads), stream); + } + + if (bc->num_tokens > bc->num_generation_tokens) { + // phase 4: Compute attention score for prompt tokens; + compute_attention_kernel_prompt
(m, bc, shard_id, stream); + } + + int num_tokens = bc->num_active_tokens(); + cudaMemcpyAsync(output_ptr, + m->attn_heads, + m->oProjSize * num_tokens * sizeof(DT), + cudaMemcpyDeviceToDevice, + stream); +} + +std::string get_peft_dbg_folder(IncMultiHeadSelfAttentionMeta const *m, + int shard_id) { + std::string op_name_without_uid = + IncMultiHeadSelfAttention::get_op_name_without_uid(m); + fs::path dst_filepath = get_dst_folder("bwd", m->bwd_step, shard_id); + if (m->layer_guid.model_id > 0) { + assert(false && "Model ID > 0 not supported yet"); + } + std::string layername = "layers." + + std::to_string(m->layer_guid.transformer_layer_id) + + "." + op_name_without_uid; + dst_filepath /= layername; + return dst_filepath.string(); +} + +__global__ void transposeAdd_half_kernel( + half *out, half const *in, int width, int height, half alpha, half beta) { + int t_id = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + for (int i = t_id; i < width * height; i += num_threads) { + int row = i / width; + int col = i % width; + out[col * height + row] = + alpha * in[row * width + col] + beta * out[col * height + row]; + } +} + +__global__ void transposeAdd_float_kernel(float *out, + float const *in, + int width, + int height, + float alpha, + float beta) { + int t_id = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + for (int i = t_id; i < width * height; i += num_threads) { + int row = i / width; + int col = i % width; + out[col * height + row] = + alpha * in[row * width + col] + beta * out[col * height + row]; } } template -void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, - BatchConfig const *bc, - int shard_id, - DT const *bias_ptr, - DT const *weight_ptr, - cudaStream_t stream) { +void transposeAdd(DT *out, + const DT *in, + int width, + int height, + float alpha, + float beta, + cudaStream_t stream) { + assert(false && "Unsupported data type"); +} + +template <> +void transposeAdd(float *out, + float const *in, + int width, + int height, + float alpha, + float beta, + cudaStream_t stream) { + transposeAdd_float_kernel<<<4, 1024, 0, stream>>>( + out, in, width, height, alpha, beta); +} + +template <> +void transposeAdd(half *out, + half const *in, + int width, + int height, + float alpha, + float beta, + cudaStream_t stream) { + transposeAdd_half_kernel<<<4, 1024, 0, stream>>>( + out, in, width, height, __float2half(alpha), __float2half(beta)); +} + +template +void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + int shard_id, + DT *input_grad_ptr, + DT const *output_grad_ptr, + cudaStream_t stream) { + assert(!m->offload); checkCUDA(cublasSetStream(m->handle.blas, stream)); checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]); assert(data_type_size(m->output_type[0]) == sizeof(DT)); cudaDataType_t compute_type = cublas_data_type; - // #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) - // cudaDataType_t compute_type = cublas_data_type; - // #else - // // For best performance, set the default cublas compute type to - // // CUBLAS_COMPUTE_16F for half precision and to - // // CUBLAS_COMPUTE_32F_FAST_16F for full precision - // cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; - // if (m->output_type[0] == DT_FLOAT) { - // compute_type = CUBLAS_COMPUTE_32F_FAST_16F; - // } - // #endif - // int num_requests = bc->num_active_requests(); - int num_tokens = bc->num_active_tokens(); - int tokens_previous_requests = 0; - int q_block_size = m->qProjSize; - int kt_block_size = m->kProjSize; - int kt_req_block_size = - kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); - int vt_block_size = m->vProjSize; - int vt_req_block_size = - vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); - assert(m->qProjSize == m->kProjSize); - for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i] || - (!bc->requestsInfo[i].prompt_phase && !bc->requestsInfo[i].peft_bwd)) { - continue; - } - int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + - bc->requestsInfo[i].num_tokens_in_batch; - int max_peft_tokens = bc->requestsInfo[i].max_sequence_length; - // Copy query to m->query_activation_buffer if we need to compute - // PEFT backward - if (bc->requestsInfo[i].peft_bwd) { - size_t activation_size_needed = - sizeof(DT) * max_peft_tokens * m->num_q_heads * m->qProjSize; - if (activation_size_needed > m->allocated_peft_buffer_size1) { - MemoryAllocator *allocator = m->handle.peft_activation_allocator; - m->query_activation_buffer = - allocator->allocate_instance_untyped(activation_size_needed); - m->allocated_peft_buffer_size1 = activation_size_needed; - } - int parallelism = m->hidden_size * num_tokens; - store_query_cache<<>>( - static_cast
(m->devQKVProjArray), - static_cast
(m->query_activation_buffer), - num_tokens, - m->hidden_size); + for (int i = 0; i < bc->max_requests_per_batch(); i++) { + if (bc->request_completed[i]) { + continue; } - // Step 1: compute query-key product QK.T/sqrt(d_k) + if (!bc->requestsInfo[i].peft_bwd) { + continue; + } + int num_tokens = bc->requestsInfo[i].num_tokens_in_batch; + int num_total_tokens = bc->requestsInfo[i].first_token_depth_in_request + + bc->requestsInfo[i].num_tokens_in_batch; + // Currently assume we are calculating gradients for all tokens + // of a request + assert(num_tokens == num_total_tokens); + int kt_block_size = m->kProjSize; + int kt_req_block_size = + kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + int vt_block_size = m->vProjSize; + int vt_req_block_size = + vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + assert(m->qProjSize == m->kProjSize && m->kProjSize == m->vProjSize); + // Step 1: copy gradient before final projection into workspace { - // Scale by sqrt(d_k) as per the original attention paper - DT alpha = 1.0f, beta = 0.0f; - if (*m->qk_prod_scaling) { - alpha = static_cast
(1.0f / sqrt(m->kProjSize)); + int m_ = m->vProjSize * m->num_q_heads; + int n_ = num_tokens; + DT *C = static_cast
(m->handle.workSpace); + cudaMemcpyAsync(C, + output_grad_ptr + + bc->requestsInfo[i].first_token_offset_in_batch * + m->oProjSize, + m_ * n_ * sizeof(DT), + cudaMemcpyDeviceToDevice, + stream); + if (m->inference_debugging) { + // save result to file for checking + std::string filename = + get_peft_dbg_folder(m, shard_id) + ".o_proj.input_gradient_0"; + save_tensor(C, m_ * n_, filename.c_str()); } + } + // Step 2: compute gradients w.r.t. value + { + float alpha = 1.0f, beta = 0.0f; + // matrix A: qk_prods_softmax + // matrix A's layout: [num_new_tokens, total_tokens, num_heads] + DT const *A = static_cast
(m->qk_prods_softmax); + // matrix B: attn_heads gradients + // matrix B's layout: [vProjSize * num_heads, num_new_tokens] + DT const *B = static_cast
(m->handle.workSpace); + // matrix C: gradients for value (saved as part of m->devQKVProjArray) + // matrix C's layout: [num_tokens, qProjsize * num_heads, 3] + DT *C = static_cast
(m->devQKVProjArray) + + 2 * num_tokens * + (m->qProjSize * m->num_q_heads); // skip over regions reserved + // for Q and K gradients // after transpositions - int m_ = num_new_tokens; - int n = total_tokens; - int k = m->qProjSize; + int m_ = num_tokens; // total_tokens + int n_ = m->vProjSize; // num_new_tokens + int k_ = num_tokens; // num_new_tokens // before transpositions - int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, - ldc = m_; + int lda = num_tokens; // num_new_tokens + int ldb = m->vProjSize * m->num_q_heads; + int ldc = num_tokens; // total_tokens // N.B. strides are applied before transpose operations - int strideA = q_block_size; - int strideB = kt_block_size; - int strideC = num_new_tokens * total_tokens; - - // matrix A: devQKVProjArray - // matrix A's layout: [qProjSize, num_heads, 3, num_new_tokens] - // To get query projection, skip over Q entries from previous requests - DT const *A = static_cast
(m->devQKVProjArray) + - bc->requestsInfo[i].first_token_offset_in_batch * - m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM; - // matrix B: key cache - // matrix B's layout: [kProjSize * num_heads, total_tokens] - // To get B, skip over K entries from previous requests (all heads + - // padding) - DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; - // matrix C: qk_prods - // matrix C's layout: [num_new_tokens, total_tokens, num_heads] - // To get C, skip over QK.T products from previous requests - DT *C = static_cast
(m->qk_prods); + int strideA = num_tokens * num_tokens; // num_new_tokens * total_tokens + int strideB = m->vProjSize; + int strideC = num_tokens * m->vProjSize; checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, CUBLAS_OP_T, - CUBLAS_OP_N, + CUBLAS_OP_T, m_, - n, - k, + n_, + k_, &alpha, A, cublas_data_type, @@ -1571,57 +1161,80 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, m->num_q_heads, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + // save result to file for checking + if (m->inference_debugging) { + std::string filename = + get_peft_dbg_folder(m, shard_id) + ".v_proj.input_gradient_0"; + save_tensor(C, m_ * n_ * m->num_q_heads, filename.c_str()); + std::string filename2 = + get_peft_dbg_folder(m, shard_id) + ".qk_prods.softmax"; + save_tensor(A, m_ * k_ * m->num_q_heads, filename2.c_str()); + } } - // Step 2: Add alibi position bias to qk production - // matrix C: qk_prods - // matrix C's layout: [num_new_tokens, total_tokens, num_heads] - // To get C, skip over QK.T products from previous requests - DT *C = static_cast
(m->qk_prods); - if (*m->position_bias) { - size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; - apply_position_bias_qkprd<<>>(C, - num_new_tokens, - total_tokens, - m->num_q_heads, - m->global_num_q_heads, - shard_id); - } + // Step 3: compute gradients w.r.t. the qk_prods_softmax tensor + { + float alpha = 1.0f, beta = 0.0f; + // matrix A: attn_heads gradients + // matrix A's layout: [vProjSize * num_heads, num_new_tokens] + DT const *A = static_cast
(m->handle.workSpace); + // matrix B: value cache + // matrix B's layout: [vProjSize * num_heads, max_num_tokens, num_req] + DT const *B = static_cast
(m->valueCache) + i * vt_req_block_size; + // matrix C: qk_prods_softmax gradients + // matrix C's layout: [num_new_tokens, total_tokens, num_heads] + DT *C = static_cast
(m->qk_prods_softmax); + // after transposition & striding + int m_ = num_tokens; // num_new_tokens + int n_ = num_tokens; + int k_ = m->vProjSize; + // before transposition and striding + int lda = m->vProjSize * m->num_q_heads; + int ldb = m->vProjSize * m->num_q_heads; + int ldc = num_tokens; // num_new_tokens + int strideA = m->vProjSize; + int strideB = m->vProjSize; + int strideC = num_tokens * num_tokens; // num_new_tokens * total_tokens - // Step 3: Apply causal mask. Fill all elements above diagonal in qk prods - // with -inf to force causal attention. - assert(num_new_tokens <= total_tokens); - size_t entries_above_diagonal = num_new_tokens * (num_new_tokens - 1) / 2; - if (entries_above_diagonal > 0) { - size_t parallelism = m->num_q_heads * entries_above_diagonal; - fill_entries_above_diagonal<<>>(C, - num_new_tokens, - total_tokens, - m->num_q_heads, - entries_above_diagonal, - static_cast
(-INFINITY)); + checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, + CUBLAS_OP_T, + CUBLAS_OP_N, + m_, + n_, + k_, + &alpha, + A, + cublas_data_type, + lda, + strideA, + B, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + if (m->inference_debugging) { + std::string filename = + get_peft_dbg_folder(m, shard_id) + ".qk_prods.softmax_grad"; + save_tensor( + C, num_tokens * num_tokens * m->num_q_heads, filename.c_str()); + std::string filename2 = get_peft_dbg_folder(m, shard_id) + ".vcache"; + save_tensor( + B, m->vProjSize * m->num_q_heads * num_tokens, filename2.c_str()); + } } - - // Step 4: Compute Softmax(QK.T/sqrt(d_k)) + // Step 4: softmax backpropagation { - // Before modifying the parameters below, make sure to read the following - // description of the CUDNN_TENSOR_NCHW tensor layout, from - // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t: - // This tensor format specifies that the data is laid out in the following - // order: batch size, feature maps, rows, columns. The strides are - // implicitly defined in such a way that the data are contiguous in memory - // with no padding between images, feature maps, rows, and columns; the - // columns are the inner dimension and the images are the outermost - // dimension. + float alpha = 1.0f, beta = 0.0f; int n_param = m->num_q_heads; - int c_param = total_tokens; + int c_param = num_tokens; int h_param = 1; - int w_param = num_new_tokens; + int w_param = num_tokens; checkCUDNN(cudnnSetTensor4dDescriptor(m->qk_tensor, CUDNN_TENSOR_NCHW, cudnn_data_type, @@ -1629,79 +1242,145 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, c_param, h_param, w_param)); - float softmax_alpha = 1.0f, softmax_beta = 0.0f; - DT *C_softmax = static_cast
(m->qk_prods_softmax); - // The softmax operation below is executed according to the - // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The - // softmax operation is computed per spatial location (H,W) per image (N) - // across dimension C. - checkCUDNN(cudnnSoftmaxForward(m->handle.dnn, - CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_CHANNEL, - &softmax_alpha, - m->qk_tensor, - C, - &softmax_beta, - m->qk_tensor, - C_softmax)); + checkCUDNN(cudnnSoftmaxBackward(m->handle.dnn, + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_CHANNEL, + &alpha, + m->qk_tensor, + m->softmax_activation_buffer, + m->qk_tensor, + m->qk_prods_softmax, + &beta, + m->qk_tensor, + m->qk_prods)); + + if (m->inference_debugging) { + DT *C = static_cast
(m->qk_prods); + std::string filename = + get_peft_dbg_folder(m, shard_id) + ".qk_prods.softmax_grad_in"; + save_tensor( + C, num_tokens * num_tokens * m->num_q_heads, filename.c_str()); + } + + // TODO: fill all elements above diagonal to force causal attention + size_t entries_above_diagonal = num_tokens * (num_tokens - 1) / 2; + if (entries_above_diagonal > 0) { + size_t parallelism = m->num_q_heads * entries_above_diagonal; + fill_entries_above_diagonal<<>>(static_cast
(m->qk_prods), + num_tokens, + num_tokens, + m->num_q_heads, + entries_above_diagonal, + DT(0.0f)); + } + if (m->inference_debugging) { + DT *C = static_cast
(m->qk_prods); + std::string filename = get_peft_dbg_folder(m, shard_id) + + ".qk_prods.softmax_grad_in.masked"; + save_tensor( + C, num_tokens * num_tokens * m->num_q_heads, filename.c_str()); + } } - // Copy C_softmax to m->softmax_activation_buffer if we need to compute - // PEFT backward - if (bc->requestsInfo[i].peft_bwd) { - DT *C_softmax = static_cast
(m->qk_prods_softmax); - size_t activation_size_needed = - sizeof(DT) * max_peft_tokens * max_peft_tokens * m->num_q_heads; - if (activation_size_needed > m->allocated_peft_buffer_size2) { - MemoryAllocator *allocator = m->handle.peft_activation_allocator; - m->softmax_activation_buffer = - allocator->allocate_instance_untyped(activation_size_needed); - m->allocated_peft_buffer_size2 = activation_size_needed; + // Step 5: compute gradients w.r.t. key + { + float alpha = 1.0f, beta = 0.0f; + if (*m->qk_prod_scaling) { + alpha = 1.0f / sqrt(m->kProjSize); + } + // matrix A: gradients w.r.t. qk_prods + // matrix A's layout: [num_new_tokens, num_tokens, num_heads] + DT const *A = static_cast
(m->qk_prods); + // matrix B: query activation (in query_activation_buffer) + // matrix B's layout: [m->qProjSize * num_heads, num_new_tokens] + DT const *B = static_cast
(m->query_activation_buffer); + // matrix C: gradients for key (saved as part of m->devQKVProjArray) + // matrix C's layout: [num_tokens, qProjsize * num_heads, 3] + DT *C = + static_cast
(m->devQKVProjArray) + + num_tokens * + (m->qProjSize * + m->num_q_heads); // skip over regions reserved for Q gradients + // after transposition & striding + int m_ = num_tokens; + int n_ = m->kProjSize; + int k_ = num_tokens; // num_new_tokens + // before transposition and striding + int lda = num_tokens; // num_new_tokens + int ldb = m->kProjSize * m->num_q_heads; + int ldc = num_tokens; + int strideA = num_tokens * num_tokens; + int strideB = m->kProjSize; + int strideC = num_tokens * m->kProjSize; + checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, + CUBLAS_OP_T, + CUBLAS_OP_T, + m_, + n_, + k_, + &alpha, + A, + cublas_data_type, + lda, + strideA, + B, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + if (m->inference_debugging) { + std::string filename = + get_peft_dbg_folder(m, shard_id) + ".query_activation"; + save_tensor( + B, m->qProjSize * m->num_q_heads * num_tokens, filename.c_str()); + std::string filename2 = + get_peft_dbg_folder(m, shard_id) + ".devkproj_pre"; + save_tensor( + C, num_tokens * (m->qProjSize * m->num_q_heads), filename2.c_str()); } - checkCUDA(cudaMemcpyAsync(m->softmax_activation_buffer, - C_softmax, - sizeof(DT) * total_tokens * num_new_tokens * - m->num_q_heads, - cudaMemcpyDeviceToDevice, - stream)); } - // Step 5: Matmul softmax(QK.T/sqrt(d_k)) by V. Implemented as V @ - // softmax(QK.T/sqrt(d_k)).T + // Step 6: compute gradients w.r.t query { - DT alpha = 1.0f, beta = 0.0f; - // after transpositions - int m_ = m->vProjSize; - int n = num_new_tokens; - int k = total_tokens; - // before transpositions - int lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; - // N.B. strides are applied before transpose operations - int strideA = vt_block_size; - int strideB = num_new_tokens * total_tokens; - int strideC = m->vProjSize; - // matrix A: value cache - // matrix A's layout: [vProjSize, num_heads, total_tokens] - // To get A, skip over V.T entries from previous requests (all heads + - // padding) - DT *A = static_cast
(m->valueCache) + i * vt_req_block_size; - // matrix B: qk_prods_softmax - // matrix B's layout: [num_new_tokens, total_tokens, num_heads] - // To get B, skip over softmax(QK.T/sqrt(d_k)) entries from previous - // requests (all heads) - DT *B = static_cast
(m->qk_prods_softmax); - // matrix C: attn heads - // matrix C's layout: [vProjSize, num_heads, num_new_tokens] - // To get C, skip over softmax(QK.T/sqrt(d_k))V products from previous - // requests - // store the result attn heads, also skip the genration tokens - DT *C = static_cast
(m->attn_heads) + - (bc->requestsInfo[i].first_token_offset_in_batch) * - m->num_q_heads * m->vProjSize; + float alpha = 1.0f, beta = 0.0f; + if (*m->qk_prod_scaling) { + alpha = 1.0f / sqrt(m->kProjSize); + } + // matrix A: gradients w.r.t. qk_prods + // matrix A's layout: [num_new_tokens, num_tokens, num_heads] + DT const *A = static_cast
(m->qk_prods); + // matrix B: key cache + // matrix B's layout: [vProjSize * num_heads, max_num_tokens, num_req] + DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; + // matrix C: gradients for query (saved as part of m->devQKVProjArray) + // matrix C's layout: [num_tokens, qProjsize * num_heads, 3] + DT *C = static_cast
(m->devQKVProjArray); + // after transposition & striding + int m_ = num_tokens; // num_new_tokens + int n_ = m->qProjSize; + int k_ = num_tokens; + // before transposition and striding + int lda = num_tokens; // num_new_tokens + int ldb = m->qProjSize * m->num_q_heads; + int ldc = num_tokens; + int strideA = num_tokens * num_tokens; + int strideB = m->qProjSize; + int strideC = num_tokens * m->qProjSize; checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, CUBLAS_OP_N, CUBLAS_OP_T, m_, - n, - k, + n_, + k_, &alpha, A, cublas_data_type, @@ -1719,30 +1398,109 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, m->num_q_heads, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + if (m->inference_debugging) { + std::string filename = + get_peft_dbg_folder(m, shard_id) + ".devQKVPRojArray_pre"; + save_tensor(C, + num_tokens * m->qProjSize * m->num_q_heads * 3, + filename.c_str()); + } + } + + // Step 7: perform rotary position embeddings (RoPE) bwd + { + if (m->rotary_embedding_meta->apply_rotary_embedding) { + assert(m->hidden_size == m->qProjSize * m->num_q_heads); + assert(m->qProjSize == m->kProjSize); + /*q&k*/ + int parallelism = num_tokens * m->hidden_size; + DT *A = static_cast
(m->devQKVProjArray); + apply_rotary_embedding_bwd<<>>( + A, + m->complex_input, + m->token_infos, + m->rotary_embedding_meta->rope_theta, + (m->rotary_embedding_meta->rope_type == "llama3"), + m->rotary_embedding_meta->factor, + m->rotary_embedding_meta->low_freq_factor, + m->rotary_embedding_meta->high_freq_factor, + m->rotary_embedding_meta->original_max_position_embeddings, + m->qProjSize, + num_tokens, + m->hidden_size); + DT *C = static_cast
(m->devQKVProjArray); + if (m->inference_debugging) { + std::string filename = + get_peft_dbg_folder(m, shard_id) + ".devQKVPRojArray"; + save_tensor(C, + num_tokens * m->qProjSize * m->num_q_heads * 3, + filename.c_str()); + } + } + + // matrix C: gradients for key (saved as part of m->devQKVProjArray) + // matrix C's layout: [num_tokens, qProjsize * num_heads, 3] + DT *C = + static_cast
(m->devQKVProjArray) + + num_tokens * + (m->qProjSize * + m->num_q_heads); // skip over regions reserved for Q gradients + if (m->inference_debugging) { + std::string filename = get_peft_dbg_folder(m, shard_id) + ".devkproj"; + save_tensor( + C, num_tokens * (m->qProjSize * m->num_q_heads), filename.c_str()); + } + } + + // Step 8: compute gradients w.r.t. input + { + float alpha = 1.0f, beta = 0.0f; + if (!m->reset_input_grads[0]) { + beta = 1.0f; + } + // matrix B: gradients w.r.t. QKV (concatenated in devQKVArray) + // matrix B's layout: [num_tokens, qProjsize * num_heads, 3] + DT const *B = static_cast
(m->devQKVProjArray); + // matrix C: gradients w.r.t. input + // matrix C's layout: [m->qSize, num_tokens] + DT *C = input_grad_ptr + + bc->requestsInfo[i].first_token_offset_in_batch * m->qSize; + // int m_ = m->qSize; + int n_ = num_tokens; + int k_ = m->num_q_heads * (m->qProjSize + m->kProjSize + m->vProjSize); + + // The original version uses existing result and attention's projection to + // do further calculation in a way different than the usual dense layer, + // they are off by a transpose. So an explicit transpose is needed here. + // The add here is just for gradient accumulation. + transposeAdd(C, B, n_, k_, alpha, beta, stream); + + if (m->inference_debugging) { + std::string filename = + get_peft_dbg_folder(m, shard_id) + ".self_attn.input_gradient_0"; + save_tensor(C, num_tokens * m->qSize, filename.c_str()); + } } - tokens_previous_requests += num_new_tokens; - } - if (tokens_previous_requests != (num_tokens - bc->num_generation_tokens)) { - bc->print(); - printf("tokens_previous_requests: %i\n", tokens_previous_requests); - printf("num_tokens: %i\n", num_tokens); - printf("bc->num_generation_tokens: %i\n", bc->num_generation_tokens); } - assert(tokens_previous_requests == (num_tokens - bc->num_generation_tokens)); } +} // namespace IncMultiHeadAttention +} // namespace Kernels + +using namespace Kernels::IncMultiHeadAttention; + /*static*/ void IncMultiHeadSelfAttention::inference_kernel_wrapper( IncMultiHeadSelfAttentionMeta *m, BatchConfig const *bc, int shard_id, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &weight, - GenericTensorAccessorW const &output, - GenericTensorAccessorR const &bias) { + GenericTensorAccessorW const &output) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - bool use_bias = *m->qkv_bias || *m->final_bias; cudaEvent_t t_start, t_end; if (m->profiling) { @@ -1751,43 +1509,14 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( cudaEventRecord(t_start, stream); } - // assert(input.data_type == weight.data_type); assert(input.data_type == output.data_type); - if (use_bias) { - assert(input.data_type == bias.data_type); - } if (input.data_type == DT_HALF) { - if (m->offload) { - pre_build_weight_kernel(m, weight, input.data_type, stream); - } - half const *bias_ptr = - use_bias ? bias.get_half_ptr() : static_cast(nullptr); Kernels::IncMultiHeadAttention::inference_kernel( - m, - bc, - shard_id, - input.get_half_ptr(), - m->offload ? static_cast(m->weight_ptr) : weight.get_half_ptr(), - output.get_half_ptr(), - bias_ptr, - stream); + m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream); } else if (input.data_type == DT_FLOAT) { - if (m->offload) { - pre_build_weight_kernel(m, weight, input.data_type, stream); - } - float const *bias_ptr = - use_bias ? bias.get_float_ptr() : static_cast(nullptr); Kernels::IncMultiHeadAttention::inference_kernel( - m, - bc, - shard_id, - input.get_float_ptr(), - m->offload ? static_cast(m->weight_ptr) - : weight.get_float_ptr(), - output.get_float_ptr(), - bias_ptr, - stream); + m, bc, shard_id, input.get_float_ptr(), output.get_float_ptr(), stream); } else { assert(false && "Unspported data type"); } @@ -1809,12 +1538,9 @@ void IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( BatchConfig const *bc, int shard_id, GenericTensorAccessorW const &input_grad, - GenericTensorAccessorR const &weight, - GenericTensorAccessorR const &output_grad, - GenericTensorAccessorR const &bias) { + GenericTensorAccessorR const &output_grad) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - bool use_bias = *m->qkv_bias || *m->final_bias; cudaEvent_t t_start, t_end; if (m->profiling) { @@ -1823,35 +1549,23 @@ void IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( cudaEventRecord(t_start, stream); } - // assert(input.data_type == weight.data_type); assert(input_grad.data_type == output_grad.data_type); - if (use_bias) { - assert(input_grad.data_type == bias.data_type); - } if (input_grad.data_type == DT_HALF) { assert(!m->offload); - half const *bias_ptr = - use_bias ? bias.get_half_ptr() : static_cast(nullptr); Kernels::IncMultiHeadAttention::peft_bwd_kernel(m, bc, shard_id, input_grad.get_half_ptr(), - weight.get_half_ptr(), output_grad.get_half_ptr(), - bias_ptr, stream); } else if (input_grad.data_type == DT_FLOAT) { assert(!m->offload); - float const *bias_ptr = - use_bias ? bias.get_float_ptr() : static_cast(nullptr); Kernels::IncMultiHeadAttention::peft_bwd_kernel(m, bc, shard_id, input_grad.get_float_ptr(), - weight.get_float_ptr(), output_grad.get_float_ptr(), - bias_ptr, stream); } else { assert(false && "Unspported data type"); @@ -1870,7 +1584,6 @@ void IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( FFHandler handler, IncMultiHeadSelfAttention const *attn, - GenericTensorAccessorR const &weight, MemoryAllocator &gpu_mem_allocator, int num_samples, int _num_q_heads, @@ -1885,14 +1598,11 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( attn->kProjSize, attn->vProjSize, attn->oProjSize, - attn->apply_rotary_embedding, - attn->qkv_bias, + attn->rotary_embedding_meta, attn->scaling_query, attn->qk_prod_scaling, attn->position_bias, - attn->final_bias, attn->scaling_factor, - weight, gpu_mem_allocator, num_samples, attn->num_q_heads, @@ -1913,14 +1623,11 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( int _kProjSize, int _vProjSize, int _oProjSize, - bool _apply_rotary_embedding, - bool _qkv_bias, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, bool _qk_prod_scaling, bool _position_bias, - bool _final_bias, float _scaling_factor, - GenericTensorAccessorR const &weight, MemoryAllocator &gpu_mem_allocator, int num_samples, int _global_num_q_heads, @@ -1929,7 +1636,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( int _num_kv_heads, DataType _quantization_type, bool _offload) - : OpMeta(handler, attn), weight_ptr(nullptr), bias_ptr(nullptr) { + : OpMeta(handler, attn) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); checkCUDNN(cudnnSetStream(handler.dnn, stream)); @@ -1955,29 +1662,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( num_kv_heads = _num_kv_heads; hidden_size = num_q_heads * qProjSize; - weightSize = - ((qSize * qProjSize + oProjSize * (vProjSize > 0 ? vProjSize : vSize)) * - num_q_heads + - (kSize * kProjSize + vSize * vProjSize) * num_q_heads) * - size_of_dt; - if (quantization_type != DT_NONE) { - quantized_weightSize = get_quantization_to_byte_size( - attn->data_type, quantization_type, weightSize); - } - // biasSize = _bias ? oProjSize * size_of_dt * 4 : 0; - - int qkv_bias_size = - qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; - int final_bias_size = oProjSize; - biasSize = - (_qkv_bias ? qkv_bias_size : 0) + (final_bias ? final_bias_size : 0); - - // has_load_weights = (bool *)calloc(1, sizeof(bool)); - //*has_load_weights = false; - apply_rotary_embedding = (bool *)calloc(1, sizeof(bool)); - *apply_rotary_embedding = _apply_rotary_embedding; - qkv_bias = (bool *)calloc(1, sizeof(bool)); - *qkv_bias = _qkv_bias; + rotary_embedding_meta = + (RotaryEmbeddingMeta *)calloc(1, sizeof(RotaryEmbeddingMeta)); + *rotary_embedding_meta = _rotary_embedding_meta; scaling_query = (bool *)calloc(1, sizeof(bool)); *scaling_query = _scaling_query; scaling_factor = _scaling_factor; @@ -1985,14 +1672,6 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( *qk_prod_scaling = _qk_prod_scaling; position_bias = (bool *)calloc(1, sizeof(bool)); *position_bias = _position_bias; - final_bias = (bool *)calloc(1, sizeof(bool)); - *final_bias = _final_bias; - - // allocate weight and bias in the reserve space for cpu offloading - if (offload) { - weight_ptr = gpu_mem_allocator.allocate_reserved_untyped(weightSize); - bias_ptr = gpu_mem_allocator.allocate_reserved_untyped(biasSize); - } // allocate memory for the seqArray and reserve space { @@ -2058,9 +1737,6 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( ? key_cache_size + value_cache_size + qkv_max_proj_size : key_cache_size + value_cache_size); - if (quantization_type != DT_NONE) { - totalSharedSize += quantized_weightSize; - } assert(gpu_mem_allocator.reserved_total_size - gpu_mem_allocator.reserved_allocated_size >= totalSharedSize); @@ -2091,29 +1767,15 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( handler.batch_config_metadata->requestsInfo); if (offload) { - // token_infos = - // gpu_mem_allocator.allocate_reserved( - // tokeninfo_size); - // offset += sizeof(BatchConfig::PerTokenInfo) * tokeninfo_size; qk_prods = gpu_mem_allocator.allocate_reserved_untyped(qk_prod_size * size_of_dt); - // offset += qk_prod_size * size_of_dt; qk_prods_softmax = gpu_mem_allocator.allocate_reserved_untyped( qk_prod_size * size_of_dt); - // offset += qk_prod_size * size_of_dt; attn_heads = gpu_mem_allocator.allocate_reserved_untyped(attn_heads_size * size_of_dt); - // offset += attn_heads_size * size_of_dt; complex_input = gpu_mem_allocator.allocate_reserved(complex_size); - // offset += complex_size * sizeof(cuFloatComplex); - // request_infos = - // gpu_mem_allocator.allocate_reserved( - // requestinfo_size); } else { - // token_infos = - // gpu_mem_allocator.allocate_instance( - // tokeninfo_size); qk_prods = gpu_mem_allocator.allocate_instance_untyped(qk_prod_size * size_of_dt); qk_prods_softmax = gpu_mem_allocator.allocate_instance_untyped( @@ -2122,16 +1784,11 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( size_of_dt); complex_input = gpu_mem_allocator.allocate_instance(complex_size); - // request_infos = - // gpu_mem_allocator.allocate_instance( - // requestinfo_size); } // allocate more size for quantization data if (quantization_type != DT_NONE) { assert(offload); - quantized_weight_ptr = - gpu_mem_allocator.allocate_reserved(quantized_weightSize); } if (!offload) { assert(gpu_mem_allocator.reserved_total_size == @@ -2149,49 +1806,32 @@ IncMultiHeadSelfAttentionMeta::~IncMultiHeadSelfAttentionMeta(void) { } } -template void Kernels::IncMultiHeadAttention::pre_build_weight_kernel( - IncMultiHeadSelfAttentionMeta const *m, - GenericTensorAccessorR const weight, - DataType data_type, - cudaStream_t stream); +template void + Kernels::IncMultiHeadAttention::compute_attention_kernel_generation( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + float *output_ptr, + cudaStream_t stream); -template void Kernels::IncMultiHeadAttention::pre_build_weight_kernel( - IncMultiHeadSelfAttentionMeta const *m, - GenericTensorAccessorR const weight, - DataType data_type, - cudaStream_t stream); +template void + Kernels::IncMultiHeadAttention::compute_attention_kernel_generation( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + half *output_ptr, + cudaStream_t stream); -template void Kernels::IncMultiHeadAttention::compute_o_prod_bias( +template void Kernels::IncMultiHeadAttention::compute_qkv_kernel( IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, int shard_id, float *output_ptr, - float const *weight_ptr, - float const *bias_ptr, - int num_tokens, cudaStream_t stream); -template void Kernels::IncMultiHeadAttention::compute_o_prod_bias( +template void Kernels::IncMultiHeadAttention::compute_qkv_kernel( IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, int shard_id, half *output_ptr, - half const *weight_ptr, - half const *bias_ptr, - int num_tokens, cudaStream_t stream); -template void - Kernels::IncMultiHeadAttention::compute_attention_kernel_generation( - IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - float *output_ptr, - cudaStream_t stream); - -template void - Kernels::IncMultiHeadAttention::compute_attention_kernel_generation( - IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - half *output_ptr, - cudaStream_t stream); }; // namespace FlexFlow diff --git a/src/ops/kernels/linear_kernels.cu b/src/ops/kernels/linear_kernels.cu index d4f930db6c..3835d258e0 100644 --- a/src/ops/kernels/linear_kernels.cu +++ b/src/ops/kernels/linear_kernels.cu @@ -511,6 +511,7 @@ void forward_kernel(LinearMeta const *m, out_dim, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + // use_bias = True if (bias_ptr != NULL) { // fuse bias and relu diff --git a/src/ops/linear.cc b/src/ops/linear.cc index 20ad762b62..09170d3c28 100644 --- a/src/ops/linear.cc +++ b/src/ops/linear.cc @@ -668,11 +668,11 @@ void Linear::inference_task(Task const *task, } Linear::save_inference_tensors_to_file( m, shard_id, bc, {input}, weights_accessors, {output}); - printf("\tin=[%i,%i].T @ w=[%i,%i] -> out=[%i,%i]\n", - in_dim, - bc->num_tokens, + printf("\tw=[%i,%i].T @ in=[%i,%i] -> out=[%i,%i]\n", in_dim, out_dim, + in_dim, + bc->num_tokens, out_dim, bc->num_tokens); } diff --git a/src/ops/residual_layer_norm.cc b/src/ops/residual_layer_norm.cc index 2a30d12d6d..ce4150f9d6 100644 --- a/src/ops/residual_layer_norm.cc +++ b/src/ops/residual_layer_norm.cc @@ -988,9 +988,20 @@ void ResidualLayerNorm::inference_task( return; } - assert(regions.size() == - 3 + m->use_two_residuals + - (m->elementwise_affine ? (m->use_bias ? 2 : 1) : 0)); + int expected_num_regions = 4; // input, residual1, added_output, output + if (m->use_two_residuals) { + expected_num_regions++; // residual2 + } + if (m->inplace_residual) { + expected_num_regions--; // added_output = input + } + if (m->elementwise_affine) { + expected_num_regions += 1; // gamma + if (m->use_bias) { + expected_num_regions += 1; // beta + } + } + assert(regions.size() == expected_num_regions); int region_idx = 0, task_region_idx = 0; GenericTensorAccessorR input = diff --git a/src/ops/spec_inc_multihead_self_attention.cc b/src/ops/spec_inc_multihead_self_attention.cc index 52da51fb26..aa74ecc6f5 100644 --- a/src/ops/spec_inc_multihead_self_attention.cc +++ b/src/ops/spec_inc_multihead_self_attention.cc @@ -52,24 +52,22 @@ bool SpecIncMultiHeadSelfAttentionParams::is_valid( return is_valid; } -Tensor - FFModel::spec_inc_multihead_self_attention(Tensor const input, - int embed_dim, - int num_heads, - int kdim, - int vdim, - float dropout, - bool qkv_bias, - bool final_bias, - bool add_zero_attn, - DataType data_type, - Initializer *kernel_initializer, - bool apply_rotary_embedding, - bool scaling_query, - float scaling_factor, - bool qk_prod_scaling, - bool position_bias, - char const *name) { +Tensor FFModel::spec_inc_multihead_self_attention( + Tensor const input, + int embed_dim, + int num_heads, + int kdim, + int vdim, + float dropout, + bool add_zero_attn, + DataType data_type, + Initializer *kernel_initializer, + RotaryEmbeddingMeta rotary_embedding_meta, + bool scaling_query, + float scaling_factor, + bool qk_prod_scaling, + bool position_bias, + char const *name) { return spec_inc_multiquery_self_attention(input, embed_dim, num_heads, @@ -77,12 +75,10 @@ Tensor kdim, vdim, dropout, - qkv_bias, - final_bias, add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -90,30 +86,27 @@ Tensor name); } -Tensor - FFModel::spec_inc_multiquery_self_attention(Tensor const input, - int embed_dim, - int num_q_heads, - int num_kv_heads, - int kdim, - int vdim, - float dropout, - bool qkv_bias, - bool final_bias, - bool add_zero_attn, - DataType data_type, - Initializer *kernel_initializer, - bool apply_rotary_embedding, - bool scaling_query, - float scaling_factor, - bool qk_prod_scaling, - bool position_bias, - char const *name) { +Tensor FFModel::spec_inc_multiquery_self_attention( + Tensor const input, + int embed_dim, + int num_q_heads, + int num_kv_heads, + int kdim, + int vdim, + float dropout, + bool add_zero_attn, + DataType data_type, + Initializer *kernel_initializer, + RotaryEmbeddingMeta rotary_embedding_meta, + bool scaling_query, + float scaling_factor, + bool qk_prod_scaling, + bool position_bias, + char const *name) { if (data_type == DT_NONE) { data_type = input->data_type; } Layer *li = nullptr; - int weight_num = (qkv_bias || final_bias) ? 2 : 1; if (data_type != input->data_type) { Tensor casted_input = cast(input, data_type, "type cast for IncMHA"); li = new Layer(this, @@ -121,7 +114,7 @@ Tensor data_type, name, 1 /*inputs*/, - weight_num /*weights*/, + 0 /*weights*/, 1 /*outputs*/, casted_input); } else { @@ -130,7 +123,7 @@ Tensor data_type, name, 1 /*inputs*/, - weight_num /*weights*/, + 0 /*weights*/, 1 /*outputs*/, input); } @@ -144,51 +137,26 @@ Tensor li->outputs[0] = create_tensor_legion_ordering( numdims, dims, data_type, li, 0, true /*create_grad*/); } - // Compute weight size - int qProjSize = kdim, kProjSize = kdim, vProjSize = kdim, - oProjSize = embed_dim; - int qSize = input->dims[0], kSize = input->dims[0], vSize = input->dims[0]; - int qParas = qProjSize * qSize; - int kParas = kProjSize * kSize; - int vParas = vProjSize * vSize; - int oParas = oProjSize * (vProjSize > 0 ? vProjSize : vSize); - int weight_size = qParas * num_q_heads + kParas * num_q_heads + - vParas * num_q_heads + oParas * num_q_heads; - { - int dims[1] = {weight_size}; - li->weights[0] = create_weight_legion_ordering(1, - dims, - data_type, - li, - true /*create_grad*/, - kernel_initializer, - CHOSEN_SYNC_TYPE); - } - if (qkv_bias || final_bias) { - // q, k, v, o - int qkv_bias_size = - qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; - int dims[1] = {(qkv_bias ? qkv_bias_size : 0) + - (final_bias ? oProjSize : 0)}; - li->weights[1] = create_weight_legion_ordering(1, - dims, - data_type, - li, - true /*create_grad*/, - kernel_initializer, - CHOSEN_SYNC_TYPE); - } + li->data_type = data_type; li->add_int_property("embed_dim", embed_dim); li->add_int_property("num_q_heads", num_q_heads); li->add_int_property("num_kv_heads", num_kv_heads); li->add_int_property("kdim", kdim); li->add_int_property("vdim", vdim); - li->add_int_property("qkv_bias", qkv_bias); - li->add_int_property("final_bias", final_bias); li->add_int_property("add_zero_attn", add_zero_attn); li->add_float_property("dropout", dropout); - li->add_int_property("apply_rotary_embedding", apply_rotary_embedding); + li->add_int_property("apply_rotary_embedding", + rotary_embedding_meta.apply_rotary_embedding); + li->add_float_property("rope_theta", rotary_embedding_meta.rope_theta); + li->add_string_property("rope_type", rotary_embedding_meta.rope_type); + li->add_float_property("factor", rotary_embedding_meta.factor); + li->add_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + li->add_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + li->add_int_property("original_max_position_embeddings", + rotary_embedding_meta.original_max_position_embeddings); li->add_int_property("scaling_query", scaling_query); li->add_float_property("scaling_factor", scaling_factor); li->add_int_property("qk_prod_scaling", qk_prod_scaling); @@ -216,14 +184,20 @@ Op *SpecIncMultiHeadSelfAttention::create_operator_from_layer( int vdim = value; float dropout; layer->get_float_property("dropout", dropout); - layer->get_int_property("qkv_bias", value); - bool qkv_bias = (bool)value; - layer->get_int_property("final_bias", value); - bool final_bias = (bool)value; layer->get_int_property("add_zero_attn", value); bool add_zero_attn = (bool)value; + RotaryEmbeddingMeta rotary_embedding_meta; layer->get_int_property("apply_rotary_embedding", value); - bool apply_rotary_embedding = (bool)value; + rotary_embedding_meta.apply_rotary_embedding = (bool)value; + layer->get_float_property("rope_theta", rotary_embedding_meta.rope_theta); + layer->get_string_property("rope_type", rotary_embedding_meta.rope_type); + layer->get_float_property("factor", rotary_embedding_meta.factor); + layer->get_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + layer->get_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + layer->get_int_property("original_max_position_embeddings", value); + rotary_embedding_meta.original_max_position_embeddings = (int)value; layer->get_int_property("scaling_query", value); bool scaling_query = (bool)value; float scaling_factor; @@ -242,15 +216,12 @@ Op *SpecIncMultiHeadSelfAttention::create_operator_from_layer( kdim, vdim, dropout, - qkv_bias, - final_bias, add_zero_attn, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, position_bias, - false /*allocate_weights*/, layer->name); } @@ -264,29 +235,24 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( int _kdim, int _vdim, float _dropout, - bool _qkv_bias, - bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, bool _position_bias, - bool allocate_weights, char const *name) - // Initializer* _bias_initializer) : Op(model, OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION, _input->data_type, name, 1 /*inputs*/, - (_qkv_bias || _final_bias ? 2 : 1) /*weights*/, + 0, 1 /*outputs*/, _input), num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), - qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), qSize(_input->dims[0].size), kSize(_input->dims[0].size), vSize(_input->dims[0].size), qProjSize(_kdim), kProjSize(_kdim), vProjSize(_vdim), oProjSize(_embed_dim), @@ -305,99 +271,44 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( dims[0].size = _embed_dim; // Currently require no parallelism along this dim assert(dims[0].degree == 1); - if (allocate_weights) { - // Create weight tensor - int num_dims = inputs[0]->num_dims; - // Compute weight size - int qParas = this->qProjSize * this->qSize; - int kParas = this->kProjSize * this->kSize; - int vParas = this->vProjSize * this->vSize; - int oParas = - this->oProjSize * (this->vProjSize > 0 ? this->vProjSize : this->vSize); - ParallelDim dims[2]; - dims[0] = inputs[0]->dims[num_dims - 2]; - dims[0].size = dims[0].degree; - dims[1] = inputs[0]->dims[num_dims - 1]; - dims[1].size = this->num_q_heads * (qParas + oParas) + - this->num_q_heads * (kParas + vParas); - dims[1].is_replica_dim = false; - int seed = std::rand(); - Initializer *initializer = new GlorotUniform(seed); - weights[0] = model.create_parallel_weight<2>(dims, - this->data_type, - NULL /*owner_op*/, - true /*create_grad*/, - initializer, - CHOSEN_SYNC_TYPE); - if (qkv_bias || final_bias) { - ParallelTensorShape bias_shape = _input->get_shape(); - int qkv_bias_size = - qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; - bias_shape.dims[0].size = - (qkv_bias ? qkv_bias_size : 0) + (final_bias ? oProjSize : 0); - bias_shape.dims[1].size = bias_shape.dims[2].size = 1; - weights[1] = - model.create_parallel_weight_legion_ordering(bias_shape.num_dims, - bias_shape.dims, - this->data_type, - nullptr /*owner_op*/, - true /*create_grad*/, - initializer, - CHOSEN_SYNC_TYPE); - } - } outputs[0] = model.create_parallel_tensor_legion_ordering( _input->num_dims, dims, this->data_type, this); - /* for (int i = 0; i < numdim; i++) { */ - /* register_output_input_parallel_dims(outputs[0], i, inputs[0], i); */ - /* } */ - /* // Check correctness */ - /* assert(check_output_input_weight_parallel_dims()); */ } SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( FFModel &model, ParallelTensor const _input, - ParallelTensor const _weight, int _embed_dim, int _num_q_heads, int _num_kv_heads, int _kdim, int _vdim, float _dropout, - bool _qkv_bias, - bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, bool _position_bias, - bool allocate_weights, char const *name) - // Initializer* _bias_initializer) : Op(model, OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION, _input->data_type, name, 1 /*inputs*/, - (_qkv_bias || _final_bias ? 2 : 1) /*weights*/, + 0 /*weights*/, 1 /*outputs*/, - _input, - _weight), + _input), num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), - qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), qSize(_input->dims[0].size), kSize(_input->dims[0].size), vSize(_input->dims[0].size), qProjSize(_kdim), kProjSize(_kdim), vProjSize(_vdim), oProjSize(_embed_dim), qoSeqLength(_input->dims[1].size), kvSeqLength(_input->dims[1].size), scaling_query(_scaling_query), scaling_factor(_scaling_factor), - qk_prod_scaling(_qk_prod_scaling), position_bias(_position_bias) -// bias_initializer(_bias_initializer) -{ + qk_prod_scaling(_qk_prod_scaling), position_bias(_position_bias) { numOutputs = 1; int numdim = _input->num_dims; ParallelDim dims[MAX_TENSOR_DIM]; @@ -407,66 +318,15 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( dims[0].size = _embed_dim; // Currently require no parallelism along this dim assert(dims[0].degree == 1); - if (allocate_weights) { - // Create weight tensor - int num_dims = inputs[0]->num_dims; - // Compute weight size - int qParas = this->qProjSize * this->qSize; - int kParas = this->kProjSize * this->kSize; - int vParas = this->vProjSize * this->vSize; - int oParas = - this->oProjSize * (this->vProjSize > 0 ? this->vProjSize : this->vSize); - ParallelDim dims[2]; - dims[0] = inputs[0]->dims[num_dims - 2]; - dims[0].size = dims[0].degree; - dims[1] = inputs[0]->dims[num_dims - 1]; - dims[1].size = this->num_q_heads * (qParas + oParas) + - this->num_q_heads * (kParas + vParas); - dims[1].is_replica_dim = false; - // dims[2].size = qParas + kParas + vParas + oParas; - int seed = std::rand(); - Initializer *initializer = new GlorotUniform(seed); - weights[0] = model.create_parallel_weight<2>(dims, - this->data_type, - NULL /*owner_op*/, - true /*create_grad*/, - initializer, - CHOSEN_SYNC_TYPE); - if (qkv_bias || final_bias) { - ParallelTensorShape bias_shape = _input->get_shape(); - int qkv_bias_size = - qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; - bias_shape.dims[0].size = - (qkv_bias ? qkv_bias_size : 0) + (final_bias ? oProjSize : 0); - bias_shape.dims[1].size = bias_shape.dims[2].size = 1; - weights[1] = - model.create_parallel_weight_legion_ordering(bias_shape.num_dims, - bias_shape.dims, - this->data_type, - nullptr /*owner_op*/, - true /*create_grad*/, - initializer, - CHOSEN_SYNC_TYPE); - } - } outputs[0] = model.create_parallel_tensor_legion_ordering( _input->num_dims, dims, this->data_type, this); - - /* for (int i = 0; i < numdim; i++) { */ - /* register_output_input_parallel_dims(outputs[0], i, inputs[0], i); */ - /* } */ - /* register_output_weight_parallel_dims(outputs[0], numdim-1, _weight, 1); */ - /* register_output_weight_parallel_dims(outputs[0], numdim-2, _weight, 2); */ - // Check correctness - /* assert(check_output_input_weight_parallel_dims()); */ } SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( FFModel &model, SpecIncMultiHeadSelfAttention const &other, - ParallelTensor const input, - bool allocate_weights) + ParallelTensor const input) : SpecIncMultiHeadSelfAttention(model, other.layer_guid, input, @@ -476,22 +336,18 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( other.qProjSize, other.vProjSize, other.dropout, - other.qkv_bias, - other.final_bias, other.add_zero_attn, - other.apply_rotary_embedding, + other.rotary_embedding_meta, other.scaling_query, other.scaling_factor, other.qk_prod_scaling, other.position_bias, - allocate_weights, other.name) {} SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( FFModel &model, SpecIncMultiHeadSelfAttentionParams const ¶ms, ParallelTensor const &input, - bool allocate_weights, char const *name) : SpecIncMultiHeadSelfAttention(model, params.layer_guid, @@ -502,15 +358,12 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( params.kdim, params.vdim, params.dropout, - params.qkv_bias, - params.final_bias, params.add_zero_attn, - params.apply_rotary_embedding, + params.rotary_embedding_meta, params.scaling_query, params.scaling_factor, params.qk_prod_scaling, params.position_bias, - allocate_weights, params.name) {} void SpecIncMultiHeadSelfAttention::init_inference( @@ -541,18 +394,12 @@ void SpecIncMultiHeadSelfAttention::init_inference( EXCLUSIVE, batch_inputs[0]->region)); launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(weights[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region)); - launcher.add_field(1, FID_DATA); launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, batch_outputs[0]->region)); - launcher.add_field(2, FID_DATA); + launcher.add_field(1, FID_DATA); FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); set_opmeta_from_futuremap_inference(ff, fm, batch_outputs[0]); @@ -580,18 +427,12 @@ void SpecIncMultiHeadSelfAttention::init(FFModel const &ff) { EXCLUSIVE, inputs[0]->region)); launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(weights[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region)); - launcher.add_field(1, FID_DATA); launcher.add_region_requirement(RegionRequirement(outputs[0]->part, 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, outputs[0]->region)); - launcher.add_field(2, FID_DATA); + launcher.add_field(1, FID_DATA); FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); set_opmeta_from_futuremap(ff, fm); @@ -599,8 +440,7 @@ void SpecIncMultiHeadSelfAttention::init(FFModel const &ff) { /* regions[0](I): input - regions[1](I): weight - regions[2](O): output + regions[1](O): output */ OpMeta *SpecIncMultiHeadSelfAttention::init_task( Task const *task, @@ -618,17 +458,10 @@ OpMeta *SpecIncMultiHeadSelfAttention::init_task( FID_DATA, ctx, runtime); - GenericTensorAccessorR weight = - helperGetGenericTensorAccessorRO(attn->weights[0]->data_type, - regions[1], - task->regions[1], - FID_DATA, - ctx, - runtime); GenericTensorAccessorW output = helperGetGenericTensorAccessorWO(attn->outputs[0]->data_type, - regions[2], - task->regions[2], + regions[1], + task->regions[1], FID_DATA, ctx, runtime); @@ -643,14 +476,8 @@ OpMeta *SpecIncMultiHeadSelfAttention::init_task( Memory gpu_mem = get_proc_mem(Machine::get_machine(), task->target_proc); MemoryAllocator gpu_mem_allocator(gpu_mem); // We don't do offloading for SSMs (small speculative models) - SpecIncMultiHeadSelfAttentionMeta *m = - new SpecIncMultiHeadSelfAttentionMeta(handle, - attn, - weight, - gpu_mem_allocator, - num_samples, - num_q_heads, - num_kv_heads); + SpecIncMultiHeadSelfAttentionMeta *m = new SpecIncMultiHeadSelfAttentionMeta( + handle, attn, gpu_mem_allocator, num_samples, num_q_heads, num_kv_heads); // assert that we didn't over allocate memory assert(gpu_mem_allocator.instance_allocated_size == gpu_mem_allocator.instance_total_size); @@ -658,8 +485,6 @@ OpMeta *SpecIncMultiHeadSelfAttention::init_task( m->inference_debugging = attn->inference_debugging; std::strcpy(m->op_name, attn->name); m->layer_guid = attn->layer_guid; - assert(weight.domain.get_volume() * data_type_size(weight.data_type) == - m->weightSize); return m; } @@ -697,12 +522,6 @@ FutureMap SpecIncMultiHeadSelfAttention::inference( EXCLUSIVE, batch_inputs[0]->region)); launcher.add_field(idx++, FID_DATA); - launcher.add_region_requirement(RegionRequirement(weights[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region)); - launcher.add_field(idx++, FID_DATA); launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, 0 /*projection id*/, WRITE_ONLY, @@ -710,21 +529,12 @@ FutureMap SpecIncMultiHeadSelfAttention::inference( batch_outputs[0]->region)); launcher.add_field(idx++, FID_DATA); - if (qkv_bias || final_bias) { - launcher.add_region_requirement(RegionRequirement(weights[1]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[1]->region)); - launcher.add_field(idx++, FID_DATA); - } return runtime->execute_index_space(ctx, launcher); } /* regions[0](I): input - regions[3](I): weight - regions[4](O): output + regions[1](O): output */ void SpecIncMultiHeadSelfAttention::inference_task( Task const *task, @@ -741,51 +551,29 @@ void SpecIncMultiHeadSelfAttention::inference_task( SpecIncMultiHeadSelfAttentionMeta *m = *((SpecIncMultiHeadSelfAttentionMeta **)task->local_args); - assert(((*m->qkv_bias || *m->final_bias) ? regions.size() == 4 - : regions.size() == 3)); + assert(regions.size() == 2); GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); - GenericTensorAccessorR weight = helperGetGenericTensorAccessorRO( - m->weight_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( - m->output_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); - GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - biases = helperGetGenericTensorAccessorRO(m->weight_type[1], - regions[3], - task->regions[3], - FID_DATA, - ctx, - runtime); - Domain bias_domain = runtime->get_index_space_domain( - ctx, task->regions[3].region.get_index_space()); - assert(bias_domain.get_dim() == 4); - } + m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); + Domain input_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); - Domain weight_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); + ctx, task->regions[1].region.get_index_space()); assert(input_domain.get_dim() == 4); - assert(weight_domain.get_dim() == 2); assert(output_domain.get_dim() == 4); assert(task->index_point.get_dim() == 1); SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( - m, &bc, task->index_point.point_data[0], input, weight, output, biases); + m, &bc, task->index_point.point_data[0], input, output); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; - std::vector weights_accessors; - weights_accessors.push_back(weight); - if (*m->qkv_bias || *m->final_bias) { - weights_accessors.push_back(biases); - } SpecIncMultiHeadSelfAttention::save_inference_tensors_to_file( - m, shard_id, &bc, {input}, weights_accessors, {output}); + m, shard_id, &bc, {input}, {}, {output}); } } @@ -809,8 +597,7 @@ Op *SpecIncMultiHeadSelfAttention::materialize(FFModel &ff, ParallelTensor inputs[], int num_inputs) const { SpecIncMultiHeadSelfAttentionParams params = get_params(); - return new SpecIncMultiHeadSelfAttention( - ff, params, inputs[0], true, this->name); + return new SpecIncMultiHeadSelfAttention(ff, params, inputs[0], this->name); } bool SpecIncMultiHeadSelfAttention::measure_operator_cost( @@ -823,9 +610,20 @@ bool operator==(SpecIncMultiHeadSelfAttentionParams const &lhs, return lhs.layer_guid == rhs.layer_guid && lhs.embed_dim == rhs.embed_dim && lhs.num_q_heads == rhs.num_q_heads && lhs.kdim == rhs.kdim && lhs.vdim == rhs.vdim && lhs.dropout == rhs.dropout && - lhs.qkv_bias == rhs.qkv_bias && lhs.final_bias == rhs.final_bias && lhs.add_zero_attn == rhs.add_zero_attn && - lhs.apply_rotary_embedding == rhs.apply_rotary_embedding && + lhs.rotary_embedding_meta.apply_rotary_embedding == + rhs.rotary_embedding_meta.apply_rotary_embedding && + lhs.rotary_embedding_meta.rope_theta == + rhs.rotary_embedding_meta.rope_theta && + lhs.rotary_embedding_meta.rope_type == + rhs.rotary_embedding_meta.rope_type && + lhs.rotary_embedding_meta.factor == rhs.rotary_embedding_meta.factor && + lhs.rotary_embedding_meta.low_freq_factor == + rhs.rotary_embedding_meta.low_freq_factor && + lhs.rotary_embedding_meta.high_freq_factor == + rhs.rotary_embedding_meta.high_freq_factor && + lhs.rotary_embedding_meta.original_max_position_embeddings == + rhs.rotary_embedding_meta.original_max_position_embeddings && lhs.scaling_query == rhs.scaling_query && lhs.scaling_factor == rhs.scaling_factor && lhs.qk_prod_scaling == rhs.qk_prod_scaling && @@ -842,10 +640,8 @@ SpecIncMultiHeadSelfAttentionParams params.kdim = this->kProjSize; params.vdim = this->vProjSize; params.dropout = this->dropout; - params.qkv_bias = this->qkv_bias; - params.final_bias = this->final_bias; params.add_zero_attn = this->add_zero_attn; - params.apply_rotary_embedding = this->apply_rotary_embedding; + params.rotary_embedding_meta = this->rotary_embedding_meta; params.scaling_query = this->scaling_query; params.scaling_factor = this->scaling_factor; params.qk_prod_scaling = this->qk_prod_scaling; @@ -870,10 +666,15 @@ size_t hash::operator()( hash_combine(key, params.kdim); hash_combine(key, params.vdim); hash_combine(key, params.dropout); - hash_combine(key, params.qkv_bias); - hash_combine(key, params.final_bias); hash_combine(key, params.add_zero_attn); - hash_combine(key, params.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.rope_theta); + hash_combine(key, params.rotary_embedding_meta.rope_type); + hash_combine(key, params.rotary_embedding_meta.factor); + hash_combine(key, params.rotary_embedding_meta.low_freq_factor); + hash_combine(key, params.rotary_embedding_meta.high_freq_factor); + hash_combine(key, + params.rotary_embedding_meta.original_max_position_embeddings); hash_combine(key, params.scaling_query); hash_combine(key, params.scaling_factor); hash_combine(key, params.qk_prod_scaling); diff --git a/src/ops/spec_inc_multihead_self_attention.cpp b/src/ops/spec_inc_multihead_self_attention.cpp index aebd5e8892..b2f4e35d5e 100644 --- a/src/ops/spec_inc_multihead_self_attention.cpp +++ b/src/ops/spec_inc_multihead_self_attention.cpp @@ -16,6 +16,7 @@ #include "flexflow/ops/spec_inc_multihead_self_attention.h" #include "flexflow/ffconst_utils.h" #include "flexflow/ops/kernels/inc_multihead_self_attention_kernels.h" +#include "flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh" #include "flexflow/utils/hip_helper.h" #include #include @@ -26,13 +27,310 @@ namespace FlexFlow { using Legion::coord_t; using Legion::Memory; +#define WARP_SIZE 32 + using namespace Kernels::IncMultiHeadAttention; namespace Kernels { -namespace SpecIncMultiHeadAttention { +namespace SpecIncMultiHeadSelfAttention { + +template +__device__ __forceinline__ T + WARP_SHFL(unsigned mask, T var, int srcLane, int width = warpSize) { +#ifndef __HIP_PLATFORM_HCC__ + return __shfl_sync(mask, var, srcLane, width); +#else + return __shfl(var, srcLane, width); +#endif +} + +template +__device__ __forceinline__ T + WARP_SHFL_XOR(unsigned mask, T var, int laneMask, int width = warpSize) { +#ifndef __HIP_PLATFORM_HCC__ + return __shfl_xor_sync(mask, var, laneMask, width); +#else + return __shfl_xor(var, laneMask, width); +#endif +} + +template +__global__ void compute_spec_inc_attention_kernel_generation_kernel( + DT const *query, + DT const *key_cache, + DT const *value_cache, + DT *output_ptr, + float const scale, + int const max_seq_length, + int per_head_size, + int hidden_size, + BatchConfig::PerRequestInfo *request_infos, + BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos, + BatchConfig::BitMask *causalMask, + bool *request_completed) { + + // q, k + using Q_vec = typename VEC_K::Type; + using K_vec = typename VEC_K::Type; + using V_vec = typename VEC_V
::Type; + using Out_sum = typename Vec_fp32_::Type; + + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(DT); + constexpr int K_ELTS_PER_THREAD = Dh / THREADS_PER_KEY; + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + // constexpr int QK_ELTS_IN_16B = 16 / sizeof(DT); + + // thread id + int const tidx = threadIdx.x; + // head id + int const head_idx = blockIdx.x; + // nth request idx + int const request_idx = blockIdx.y; + + // request id in batch config + int const batch_config_request_id = + request_infos[request_idx].batch_config_request_id; + + // request_idx = re + + BatchConfig::BitMask bitmask = causalMask[batch_config_request_id]; + + int const first_step = 0; + + // int const tlength = + // request_infos[batch_config_request_id].first_token_depth_in_request + + // request_infos[batch_config_request_id].num_tokens_in_batch; + + int const totalCacheSize = + bitmask.non_tree_cache_size + bitmask.tree_size + bitmask.prompt_size - 1; + + int first_token_idx = 0; + for (int r = 0; r < batch_config_request_id; r++) { + first_token_idx += request_completed[r] ? 0 : causalMask[r].this_layer_size; + } + + int const tree_branch_num = + beam_request_infos[batch_config_request_id].sub_request_num; + + // shared memory objects + extern __shared__ char smem_[]; + + float *qk_smem = reinterpret_cast(smem_); + float *out_smem = reinterpret_cast(smem_); + + float qk_max = -FLT_MAX; + + // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + + const DT *q_ptr = query + first_token_idx * hidden_size * QKV_WEIGHT_NUM + + head_idx * per_head_size; + __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; + + // the start offset of the element eg. (0, 1, 2, 3) * K_VEC_SIZE + int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; + int ki_o = tidx % THREADS_PER_KEY; + // the first key's offset for this thread + // ko = 0, 0, 0, 0, 1, 1, 1, 1, .... + int ko = tidx / THREADS_PER_KEY; + // load q tensor + Q_vec q_vec[K_VECS_PER_THREAD]; + + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + // The number of keys per warp. + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + DT const *k_cache_batch = + key_cache + batch_config_request_id * max_seq_length * hidden_size + ki; + + int ti_end = + div_up(totalCacheSize - first_step, K_PER_WARP) * K_PER_WARP + first_step; + + for (int qi = 0; qi < tree_branch_num; qi += 1) { +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + q_vecs[ki_o][ii] = *reinterpret_cast( + q_ptr + (hidden_size * QKV_WEIGHT_NUM * qi) + ki + + ii * THREADS_PER_KEY * K_VEC_SIZE); + } + + int const query_token = + bitmask.prompt_size + bitmask.tree_size - 1 - tree_branch_num + qi; + + __syncthreads(); + for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { + K_vec k[K_VECS_PER_THREAD]; + int const ti_circ = ti % max_seq_length; + + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * THREADS_PER_KEY * K_VEC_SIZE; + if (ti < totalCacheSize) { + + k[ii] = *reinterpret_cast( + k_cache_batch + ti_circ * hidden_size + head_idx * per_head_size + + jj); + } + } + float qk = scale * Qk_dot::dot(q_vecs[ki_o], k); + + if (ti < totalCacheSize && tidx % THREADS_PER_KEY == 0) { + // todo add alobi here + // bool const mask = ti_circ >= totalCacheSize; + bool const mask = (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << query_token)))); + + // if (head_idx == 0 && ti == 0 && request_idx == 15 && !mask) { + // printf("spec inc attn qkqkqk request id %d, %.10f, %d\n", + // batch_config_request_id, + // ti, + // qk, + // qi); + // } + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + qk_smem[ti - first_step] = mask ? 0.f : qk; + } + } + + __syncthreads(); + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, WARP_SHFL_XOR(uint32_t(-1), qk_max, mask)); + } + + // Decompose the thread index into warp and lane. + int const warp = tidx / WARP_SIZE; + int const lane = tidx % WARP_SIZE; + + // The warp leader writes the max to shared memory. + if (lane == 0) { + red_smem[warp] = qk_max; + } + + // Make sure the products are in shared memory. + __syncthreads(); + + // The warps finalize the reduction. + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, WARP_SHFL_XOR(uint32_t(-1), qk_max, mask)); + } + + // Broadcast to all the threads in the warp. + qk_max = WARP_SHFL(uint32_t(-1), qk_max, 0); + + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + // printf("spec inc attn first token qk_max %.10f\n", qk_max); + // } + + float exp_sum = 0.f; + for (int ti = first_step + tidx; ti < totalCacheSize; + ti += THREADS_PER_BLOCK) { + bool const mask = (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << query_token)))); + float logit = mask ? 0.0f : __expf(qk_smem[ti - first_step] - qk_max); + exp_sum += logit; + qk_smem[ti - first_step] = mask ? 0.0f : logit; + } + + // Compute the sum. + exp_sum = block_sum(&red_smem[WARPS_PER_BLOCK], exp_sum); + + // softmax + float inv_sum = __fdividef(1.f, exp_sum + 1.e-6); + for (int ti = first_step + tidx; ti < totalCacheSize; + ti += THREADS_PER_BLOCK) { + qk_smem[ti - first_step] *= inv_sum; + } + + __syncthreads(); + + // value projection + constexpr int V_VEC_SIZE = 16 / sizeof(DT); + // A vector of V elements for the current timestep. + // using V_vec_k = typename V_vec_k_::Type; + // using V_vec_acum = typename V_vec_acum_fp32_::Type; + + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + + Out_sum out; + zero(out); + + // The base pointer for the value in the cache buffer. + DT const *v_cache_batch = + value_cache + batch_config_request_id * max_seq_length * hidden_size + + vi; + + if (Dh == Dh_MAX || vi < Dh) { + for (int ti = first_step + vo; ti < totalCacheSize; ti += V_PER_ITER) { + // Load the values from the cache. + int const ti_circ = ti % max_seq_length; + V_vec v = *reinterpret_cast( + v_cache_batch + ti_circ * hidden_size + head_idx * per_head_size); + + bool const mask = (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << query_token)))); + float logit = mask ? 0.0f : qk_smem[ti - first_step]; + out = FlexFlow::fma(logit, cast_to_float(v), out); + } + } + + // // Make sure we can start writing to shared memory. + __syncthreads(); + + // Run the final reduction amongst the different groups computing different + // partial outputs. + if (Dh == Dh_MAX || vi < Dh) { +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; + active_groups /= 2) { + + // The midpoint in the number of active groups. + int midpoint = active_groups / 2; + + // The upper part of active threads store to shared memory. + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { + *reinterpret_cast(out_smem + (vo - midpoint) * Dh + vi) = + out; + } + __syncthreads(); + + // The bottom warps update their values. + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = add(*reinterpret_cast(out_smem + vo * Dh + vi), + out); + } + __syncthreads(); + } + } + + // Output the final values. + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { + convert_from_float(*reinterpret_cast( + output_ptr + (first_token_idx + qi) * hidden_size + + head_idx * per_head_size + vi), + out); + } + } +} template -__global__ void spec_store_kv_cache( +__global__ void spec_inc_store_kv_cache( DT const *devQKVProjArray, DT *kCache_ptr, DT *vCache_ptr, @@ -40,16 +338,16 @@ __global__ void spec_store_kv_cache( BatchConfig::PerRequestInfo *requestInfo, BeamSearchBatchConfig::BeamSearchPerTokenInfo *beamTokenInfos, BeamSearchBatchConfig::BeamSearchPerRequestInfo *beamRequestInfos, + BatchConfig::BitMask *causalMask, int qProjSize, int kProjSize, int vProjSize, int num_tokens, int max_seq_len, - int max_beam_width, bool is_root, int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size * 2) { - int token_idx = i / (hidden_size * KV_WEIGHT_NUM); + CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + int token_idx = i / (hidden_size); int offset = i % hidden_size; size_t val_idx = @@ -58,82 +356,25 @@ __global__ void spec_store_kv_cache( DT kVal = devQKVProjArray[val_idx]; DT vVal = devQKVProjArray[val_idx + hidden_size]; - // above no need to be changed - // int const req_id = id_map[token_idx].request_index; - // int const tok_id = id_map[token_idx].token_position; - // int const sub_req_id = id_map[token_idx].sub_request_index; - // int const parent_id = id_map[token_idx].parent_id; - // int const beam_depth = id_map[token_idx].beam_depth; - // int const beam_width = id_map[token_idx].beam_width; - int const req_id = tokenInfos[token_idx].request_index; - int const tok_id = tokenInfos[token_idx].abs_depth_in_request; - int const sub_req_id = beamTokenInfos[token_idx].sub_request_index; - int const parent_id = beamRequestInfos[req_id].parent_id[sub_req_id]; - int const beam_depth = beamRequestInfos[req_id].current_depth; - int const beam_width = beamRequestInfos[req_id].beam_size; - - // new token - kCache_ptr[(req_id * max_beam_width + sub_req_id) * - (hidden_size * max_seq_len) + - tok_id * hidden_size + offset] = kVal; - vCache_ptr[(req_id * max_beam_width + sub_req_id) * - (hidden_size * max_seq_len) + - tok_id * hidden_size + offset] = vVal; - - // replica in the root iteration - if (beam_depth == 1) { - for (int i = 1; i < beam_width; i++) { - kCache_ptr[(req_id * max_beam_width + i) * (hidden_size * max_seq_len) + - tok_id * hidden_size + offset] = kVal; - vCache_ptr[(req_id * max_beam_width + i) * (hidden_size * max_seq_len) + - tok_id * hidden_size + offset] = vVal; - } - } + // int const tok_id = tokenInfos[token_idx].abs_depth_in_request; - // naive cache stealing - if (sub_req_id != parent_id) { - if (offset == 0 && tok_id == 0) { - printf("cache stealing!, depth %d req_id %d sub_req_id %d, parentid " - "%d, tok_id %d\n", - beam_depth, - req_id, - sub_req_id, - parent_id, - tok_id); - } + int const request_token_offset = + requestInfo[req_id].first_token_offset_in_batch; - for (int depth = 0; depth < beam_depth; depth++) { - int steal_token_idx = tok_id - beam_depth + depth; - int steal_from_idx = (req_id * max_beam_width + parent_id) * - (hidden_size * max_seq_len) + - steal_token_idx * hidden_size + offset; - int steal_to_idx = (req_id * max_beam_width + sub_req_id) * - (hidden_size * max_seq_len) + - steal_token_idx * hidden_size + offset; - kCache_ptr[steal_to_idx] = kCache_ptr[steal_from_idx]; - vCache_ptr[steal_to_idx] = vCache_ptr[steal_from_idx]; - - // if(data_idx == 0 && head_idx == 0 && k_cache && req_id == 1){ - // printf("cache stealing kernel!, steal_token_idx %d\n", - // steal_token_idx); - // } - } - } + BatchConfig::BitMask bitmask = causalMask[req_id]; - // parallel cache stealing not yet implemented - // logic shld be - // launch spec_store_kv_cache with parallelism * current depth - // from the i here, get depth index - // if depth index not the current one, check if we need to steal - // steal if needed - - // cache stealing theory - // identify which sub request does this token come from - // for initial token, 0 - // for other, may 0,0,1/ 0,1,2/ 1,1,1 to get which cache to be reuse and - // which to be delete copy beam_size bunch of blocks when sub_req_id == - // parent_id : like 0 -> 0, 1->1, 2->2, do nothing, just append the new k/v + // if prompt token -> token id + // if tree token: + + int const cache_idx = bitmask.prompt_size + bitmask.non_tree_cache_size + + bitmask.tree_size - 1 - bitmask.this_layer_size + + token_idx - request_token_offset; + + kCache_ptr[req_id * (hidden_size * max_seq_len) + (cache_idx)*hidden_size + + offset] = kVal; + vCache_ptr[req_id * (hidden_size * max_seq_len) + (cache_idx)*hidden_size + + offset] = vVal; } } @@ -143,11 +384,9 @@ void update_kv_cache_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, hipStream_t stream) { int num_tokens = bc->num_active_infr_tokens(); int curr_depth = bc->beamRequestsInfo[0].current_depth; - // printf("curr depth: %d\n", curr_depth); - // assert(curr_depth < 3); if (num_tokens > 0) { int parallelism = m->hidden_size * KV_WEIGHT_NUM * num_tokens; - hipLaunchKernelGGL(HIP_KERNEL_NAME(spec_store_kv_cache
), + hipLaunchKernelGGL(HIP_KERNEL_NAME(spec_inc_store_kv_cache
), GET_BLOCKS(parallelism), min(CUDA_NUM_THREADS, parallelism), 0, @@ -159,17 +398,71 @@ void update_kv_cache_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, m->request_infos, m->beam_token_infos, m->beam_request_infos, + m->causalMask, m->qProjSize, m->kProjSize, m->vProjSize, num_tokens, - BatchConfig::max_sequence_length(), - BeamSearchBatchConfig::MAX_BEAM_WIDTH, + BatchConfig::max_sequence_length() + + BatchConfig::max_spec_tree_token_num(), /*root*/ curr_depth == 0, m->hidden_size); } } +#define LAUNCH_SPEC_INC_ATTENTION_SCORE_KERNEL( \ + DT, Dh, Dh_MAX, THDS_PER_KEY, THREADS_PER_VALUE, THDS_PER_BLOCK, stream) \ + smem_sz = smem_size_in_bytes
(m->qProjSize, \ + BatchConfig::max_sequence_length() + \ + BatchConfig::max_spec_tree_token_num(), \ + THREADS_PER_VALUE, \ + THDS_PER_BLOCK); \ + compute_spec_inc_attention_kernel_generation_kernel \ + <<>>( \ + static_cast
(m->devQKVProjArray), \ + static_cast
(m->keyCache), \ + static_cast
(m->valueCache), \ + output_ptr, \ + scale, \ + BatchConfig::max_sequence_length() + \ + BatchConfig::max_spec_tree_token_num(), \ + m->qProjSize, \ + m->hidden_size, \ + m->request_infos, \ + m->beam_request_infos, \ + m->causalMask, \ + m->request_completed) + +template +void compute_spec_inc_attention_kernel_generation( + SpecIncMultiHeadSelfAttentionMeta const *m, + BeamSearchBatchConfig const *bc, + DT *output_ptr, + hipStream_t stream) { + // one block == one head per request + // how many generation requests + dim3 grid(m->num_q_heads, bc->get_speculative_request_num()); + int const per_head_size = m->qProjSize; + float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f; + size_t smem_sz; + if (per_head_size == 64) { + constexpr int THREADS_PER_VALUE_64 = threads_per_value_t::value; + LAUNCH_SPEC_INC_ATTENTION_SCORE_KERNEL( + DT, 64, 64, 4, THREADS_PER_VALUE_64, 128, stream); + } else if (per_head_size == 128) { + constexpr int THREADS_PER_VALUE_128 = threads_per_value_t::value; + LAUNCH_SPEC_INC_ATTENTION_SCORE_KERNEL( + DT, 128, 128, 4, THREADS_PER_VALUE_128, 128, stream); + } else { + assert(false && "a unsupported head size"); + } +} + template __global__ void spec_fill_entries_above_diagonal(DT *matrix, size_t new_tokens, @@ -188,331 +481,268 @@ __global__ void spec_fill_entries_above_diagonal(DT *matrix, } template -void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, - BeamSearchBatchConfig const *bc, - int shard_id, - DT *output_ptr, - DT const *bias_ptr, - DT const *weight_ptr, - hipStream_t stream) { +void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, + BeamSearchBatchConfig const *bc, + int shard_id, + DT *output_ptr, + hipStream_t stream) { checkCUDA(hipblasSetStream(m->handle.blas, stream)); checkCUDNN(miopenSetStream(m->handle.dnn, stream)); hipblasDatatype_t hipblas_data_type = ff_to_cuda_datatype(m->output_type[0]); miopenDataType_t miopen_data_type = ff_to_cudnn_datatype(m->output_type[0]); assert(data_type_size(m->output_type[0]) == sizeof(DT)); hipblasDatatype_t compute_type = hipblas_data_type; - // #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) - // hipblasDatatype_t compute_type = hipblas_data_type; - // #else - // // TODO: currently use the hipblas_data_type - // // cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; - // hipblasDatatype_t compute_type = hipblas_data_type; - // #endif - // int num_requests = bc->num_active_requests(); - int num_tokens = bc->num_active_infr_tokens(); + + int num_tokens = bc->num_active_tokens(); int tokens_previous_requests = 0; int tokens_prev_requests_squares = 0; - // int qkv_block_size = - // (m->qProjSize + m->kProjSize + m->vProjSize) * num_tokens; int q_block_size = m->qProjSize; + int kt_block_size = m->kProjSize; - int kt_req_block_size = - kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + int kt_req_block_size = kt_block_size * m->num_q_heads * + (BatchConfig::max_sequence_length() + + BatchConfig::max_spec_tree_token_num()); int vt_block_size = m->vProjSize; - int vt_req_block_size = - vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + int vt_req_block_size = vt_block_size * m->num_q_heads * + (BatchConfig::max_sequence_length() + + BatchConfig::max_spec_tree_token_num()); assert(m->qProjSize == m->kProjSize); for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i]) { + if (bc->request_completed[i] || (!bc->requestsInfo[i].prompt_phase) || + (bc->requestsInfo[i].num_tokens_in_batch == 0)) { + continue; + } else if (tokens_previous_requests < bc->num_generation_tokens) { + tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; continue; } - for (int sub_req_id = 0; sub_req_id < bc->sub_requests[i]; sub_req_id++) { - - // int num_new_tokens = bc->num_processing_tokens[i]; - // int total_tokens = bc->token_last_available_idx[i] + 1; - - int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + - bc->requestsInfo[i].num_tokens_in_batch; - // Compute (QK^T/sqrt(d_k)) - int m_ = num_new_tokens; - int n = total_tokens; - int k = m->qProjSize; - int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, - ldc = m_; - int strideA = q_block_size; - int strideB = kt_block_size; - int strideC = num_new_tokens * total_tokens; - - // a flag of using this scaling alpha - DT alpha = 1.0f, beta = 0.0f; - if (*m->qk_prod_scaling) { - alpha = static_cast
(1.0f / sqrt(m->kProjSize)); - } - // To get A, skip over Q entries from previous requests (same head) - DT const *A = static_cast
(m->devQKVProjArray) + - tokens_previous_requests * m->qProjSize * m->num_q_heads * - QKV_WEIGHT_NUM; - // To get B, skip over K entries from previous requests (all heads + - // padding) - DT const *B = static_cast
(m->keyCache) + - (i * bc->MAX_BEAM_WIDTH + sub_req_id) * kt_req_block_size; - - // if (i == 0 && sub_req_id == 0 && - // bc->beam_slots.at(0).current_depth == 1) { - // int offset = (float *)B - m->keyCache; - // printf("key cache offset %d\n", kt_req_block_size); - // } - // To get C, skip over QK^T products from previous requests - DT *C = static_cast
(m->qk_prods) + - m->num_q_heads * tokens_prev_requests_squares; - - checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, - HIPBLAS_OP_T, - HIPBLAS_OP_N, - m_, - n, - k, - &alpha, - A, - hipblas_data_type, - lda, - strideA, - B, - hipblas_data_type, - ldb, - strideB, - &beta, - C, - hipblas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - HIPBLAS_GEMM_DEFAULT)); - - if (*m->position_bias) { - size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_position_bias_qkprd
), - GET_BLOCKS(parallelism), - min((size_t)CUDA_NUM_THREADS, parallelism), - 0, - stream, - C, - num_new_tokens, - total_tokens, - m->num_q_heads, - m->global_num_q_heads, - shard_id); - } - // Fill all elements above diagonal in qk prods with -inf to force - // causal attention. - assert(num_new_tokens <= total_tokens); - if (num_new_tokens > 1) { - size_t parallelism = m->num_q_heads * num_new_tokens * total_tokens; - hipLaunchKernelGGL( - HIP_KERNEL_NAME(spec_fill_entries_above_diagonal
), - GET_BLOCKS(parallelism), - min((size_t)CUDA_NUM_THREADS, parallelism), - 0, - stream, - C, - num_new_tokens, - total_tokens, - m->num_q_heads, - static_cast
(-INFINITY)); - } - // Compute Softmax(QK^T/sqrt(d_k)) - // Before modifying the parameters below, make sure to read the following - // description of the CUDNN_TENSOR_NCHW tensor layout, from - // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t: - // This tensor format specifies that the data is laid out in the following - // order: batch size, feature maps, rows, columns. The strides are - // implicitly defined in such a way that the data are contiguous in memory - // with no padding between images, feature maps, rows, and columns; the - // columns are the inner dimension and the images are the outermost - // dimension. - int n_param = m->num_q_heads; - int c_param = total_tokens; - int h_param = 1; - int w_param = num_new_tokens; - checkCUDNN(miopenSet4dTensorDescriptor( - m->qk_tensor, miopen_data_type, n_param, c_param, h_param, w_param)); - float softmax_alpha = 1.0f, softmax_beta = 0.0f; - DT *C_softmax = static_cast
(m->qk_prods_softmax) + - m->num_q_heads * tokens_prev_requests_squares; - // The softmax operation below is executed according to the - // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The - // softmax operation is computed per spatial location (H,W) per image (N) - // across dimension C. - checkCUDNN(miopenSoftmaxForward_V2(m->handle.dnn, - &softmax_alpha, - m->qk_tensor, - C, - &softmax_beta, - m->qk_tensor, - C_softmax, - MIOPEN_SOFTMAX_ACCURATE, - MIOPEN_SOFTMAX_MODE_CHANNEL)); - // Matmul softmax(QK^T/sqrt(d_k)) by V - alpha = 1.0f, beta = 0.0f; - m_ = num_new_tokens; - n = m->vProjSize; - k = total_tokens; - lda = m_, ldb = n * m->num_q_heads, ldc = m_; - strideA = num_new_tokens * total_tokens; - strideB = vt_block_size; - strideC = num_new_tokens * m->vProjSize; - // To get A, skip over softmax(QK^T/sqrt(d_k)) entries from previous - // requests (all heads) - A = C_softmax; - // To get B, skip over V^T entries from previous requests (all heads + - // padding) - B = static_cast
(m->valueCache) + - (i * bc->MAX_BEAM_WIDTH + sub_req_id) * vt_req_block_size; - // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous - // requests - C = static_cast
(m->attn_heads) + - tokens_previous_requests * m->num_q_heads * m->vProjSize; - - checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, - HIPBLAS_OP_N, - HIPBLAS_OP_T, - m_, - n, - k, - &alpha, - A, - hipblas_data_type, - lda, - strideA, - B, - hipblas_data_type, - ldb, - strideB, - &beta, - C, - hipblas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - HIPBLAS_GEMM_DEFAULT)); - - // Project to output, save result directly on output tensor - alpha = 1.0f, beta = 0.0f; - m_ = m->oProjSize; - k = m->vProjSize * m->num_q_heads; - n = num_new_tokens; - lda = k, ldb = n, ldc = m_; - A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads + - m->kProjSize * m->num_q_heads + - m->vProjSize * m->num_q_heads); - B = C; - C = static_cast
(output_ptr) + - tokens_previous_requests * m->oProjSize; - - checkCUDA(hipblasGemmEx(m->handle.blas, - HIPBLAS_OP_T, - HIPBLAS_OP_T, - m_, - n, - k, - &alpha, - A, - hipblas_data_type, - lda, - B, - hipblas_data_type, - ldb, - &beta, - C, - hipblas_data_type, - ldc, - compute_type, - HIPBLAS_GEMM_DEFAULT)); - tokens_previous_requests += num_new_tokens; - tokens_prev_requests_squares += num_new_tokens * total_tokens; + // all requests in prompt phase should only have one sub requests; + assert(bc->sub_requests[i] == 1); + // int num_new_tokens = bc->num_processing_tokens[i]; + // int total_tokens = bc->token_last_available_idx[i] + 1; + + int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; + int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + + bc->requestsInfo[i].num_tokens_in_batch; + + if (num_new_tokens <= 0) { + continue; } - } - if (*m->final_bias && shard_id == 0) { - int parallelism = m->oProjSize * num_tokens; - int qkv_weight_size = m->qProjSize * m->global_num_q_heads + - m->kProjSize * m->global_num_q_heads + - m->vProjSize * m->global_num_q_heads; - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_proj_bias_w
), - GET_BLOCKS(parallelism), - min(CUDA_NUM_THREADS, parallelism), - 0, - stream, - output_ptr, - bias_ptr, - num_tokens, - qkv_weight_size, - m->oProjSize); + + // Compute (QK^T/sqrt(d_k)) + int m_ = num_new_tokens; + int n = total_tokens; + int k = m->qProjSize; + int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, + ldc = m_; + int strideA = q_block_size; + int strideB = kt_block_size; + int strideC = num_new_tokens * total_tokens; + + // a flag of using this scaling alpha + DT alpha = 1.0f, beta = 0.0f; + if (*m->qk_prod_scaling) { + alpha = static_cast
(1.0f / sqrt(m->kProjSize)); + } + // To get A, skip over Q entries from previous requests (same head) + DT const *A = static_cast
(m->devQKVProjArray) + + bc->requestsInfo[i].first_token_offset_in_batch * + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM; + DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; + DT *C = static_cast
(m->qk_prods); + + checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, + HIPBLAS_OP_T, + HIPBLAS_OP_N, + m_, + n, + k, + &alpha, + A, + hipblas_data_type, + lda, + strideA, + B, + hipblas_data_type, + ldb, + strideB, + &beta, + C, + hipblas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + HIPBLAS_GEMM_DEFAULT)); + + if (*m->position_bias) { + size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; + hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_position_bias_qkprd
), + GET_BLOCKS(parallelism), + min((size_t)CUDA_NUM_THREADS, parallelism), + 0, + stream, + C, + num_new_tokens, + total_tokens, + m->num_q_heads, + m->global_num_q_heads, + shard_id); + } + // Fill all elements above diagonal in qk prods with -inf to force + // causal attention. + assert(num_new_tokens <= total_tokens); + if (num_new_tokens > 1) { + size_t parallelism = m->num_q_heads * num_new_tokens * total_tokens; + hipLaunchKernelGGL(HIP_KERNEL_NAME(spec_fill_entries_above_diagonal
), + GET_BLOCKS(parallelism), + min((size_t)CUDA_NUM_THREADS, parallelism), + 0, + stream, + C, + num_new_tokens, + total_tokens, + m->num_q_heads, + static_cast
(-INFINITY)); + } + // Compute Softmax(QK^T/sqrt(d_k)) + // Before modifying the parameters below, make sure to read the following + // description of the CUDNN_TENSOR_NCHW tensor layout, from + // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t: + // This tensor format specifies that the data is laid out in the following + // order: batch size, feature maps, rows, columns. The strides are + // implicitly defined in such a way that the data are contiguous in memory + // with no padding between images, feature maps, rows, and columns; the + // columns are the inner dimension and the images are the outermost + // dimension. + int n_param = m->num_q_heads; + int c_param = total_tokens; + int h_param = 1; + int w_param = num_new_tokens; + checkCUDNN(miopenSet4dTensorDescriptor( + m->qk_tensor, miopen_data_type, n_param, c_param, h_param, w_param)); + float softmax_alpha = 1.0f, softmax_beta = 0.0f; + DT *C_softmax = static_cast
(m->qk_prods_softmax) + + m->num_q_heads * tokens_prev_requests_squares; + // The softmax operation below is executed according to the + // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The + // softmax operation is computed per spatial location (H,W) per image (N) + // across dimension C. + checkCUDNN(miopenSoftmaxForward_V2(m->handle.dnn, + &softmax_alpha, + m->qk_tensor, + C, + &softmax_beta, + m->qk_tensor, + C_softmax, + MIOPEN_SOFTMAX_ACCURATE, + MIOPEN_SOFTMAX_MODE_CHANNEL)); + // Matmul softmax(QK^T/sqrt(d_k)) by V + alpha = 1.0f, beta = 0.0f; + m_ = m->vProjSize; + n = num_new_tokens; + k = total_tokens; + lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; + strideA = vt_block_size; + strideB = num_new_tokens * total_tokens; + strideC = m->vProjSize; + // To get A, skip over V^T entries from previous requests (all heads + + // padding) + A = static_cast
(m->valueCache) + i * vt_req_block_size; + // To get B, skip over softmax(QK^T/sqrt(d_k)) entries from previous + // requests (all heads) + B = C_softmax; + // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous + // requests + + int token_offset = bc->requestsInfo[i].first_token_offset_in_batch; + + C = static_cast
(m->attn_heads) + + (token_offset)*m->num_q_heads * m->vProjSize; + checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, + HIPBLAS_OP_N, + HIPBLAS_OP_T, + m_, + n, + k, + &alpha, + A, + hipblas_data_type, + lda, + strideA, + B, + hipblas_data_type, + ldb, + strideB, + &beta, + C, + hipblas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + HIPBLAS_GEMM_DEFAULT)); + + tokens_previous_requests += num_new_tokens; + tokens_prev_requests_squares += num_new_tokens * total_tokens; } - assert(tokens_previous_requests == num_tokens); + if (tokens_previous_requests != (num_tokens - bc->num_generation_tokens)) { + bc->print(); + printf("tokens_previous_requests: %i\n", tokens_previous_requests); + printf("num_tokens: %i\n", num_tokens); + printf("bc->num_generation_tokens: %i\n", bc->num_generation_tokens); + } + assert(tokens_previous_requests == (num_tokens - bc->num_generation_tokens)); } template void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, BeamSearchBatchConfig const *bc, int shard_id, - DT const *input_ptr, - DT const *weight_ptr, + DT const *qkv_ptr, DT *output_ptr, - DT const *bias_ptr, hipStream_t stream) { - // here because we need postion info in infernece 1 - int max_tokens_per_batch = BatchConfig::max_tokens_per_batch(); - checkCUDA( - hipMemcpyAsync(m->token_infos, - &(bc->tokensInfo), - max_tokens_per_batch * sizeof(BatchConfig::PerTokenInfo), - hipMemcpyHostToDevice, - stream)); - checkCUDA(hipMemcpyAsync(m->request_infos, - &(bc->requestsInfo), - bc->max_requests_per_batch() * - sizeof(BatchConfig::PerRequestInfo), - hipMemcpyHostToDevice, - stream)); - checkCUDA( - hipMemcpyAsync(m->beam_token_infos, - &(bc->beamTokenInfo), - max_tokens_per_batch * bc->MAX_BEAM_WIDTH * - sizeof(BeamSearchBatchConfig::BeamSearchPerTokenInfo), - hipMemcpyHostToDevice, - stream)); - checkCUDA(hipMemcpyAsync( - m->beam_request_infos, - &(bc->beamRequestsInfo), - bc->max_requests_per_batch() * - sizeof(BeamSearchBatchConfig::BeamSearchPerRequestInfo), - hipMemcpyHostToDevice, - stream)); + + // phase 0: copy calculated qkv into devQKVProjArray + // [qProjSize, num_heads, 3, num_new_tokens] + size_t qkv_proj_size = + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); + + hipMemcpyAsync(m->devQKVProjArray, + qkv_ptr, + qkv_proj_size * + sizeof(DT), // is this right, do we need layers etc here + hipMemcpyDeviceToDevice, + stream); // phase 1: Implement kernel to compute KQV for input tokens - compute_qkv_kernel(m, - bc, - shard_id, - input_ptr, - weight_ptr, - static_cast
(m->devQKVProjArray), - bias_ptr, - stream); + // TODO WARNING: this is commented out only because we are fixing the inc_attn + // first + compute_qkv_kernel( + m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); // phase 2: Update key/val cache update_kv_cache_kernel
(m, bc, stream); - + if (bc->num_generation_tokens > 0) { + compute_spec_inc_attention_kernel_generation
( + m, bc, static_cast
(m->attn_heads), stream); + } // phase 3: Compute attention score // 3 kernels for pahse 3: matmul1 - softmax - matmal2 - compute_attention_kernel( - m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); + if (bc->num_tokens > bc->num_generation_tokens) { + compute_attention_kernel_prompt(m, bc, shard_id, output_ptr, stream); + } + + int num_tokens = bc->num_active_tokens(); + + hipMemcpyAsync(output_ptr, + m->attn_heads, + m->oProjSize * num_tokens * sizeof(DT), + hipMemcpyDeviceToDevice, + stream); } -} // namespace SpecIncMultiHeadAttention +} // namespace SpecIncMultiHeadSelfAttention } // namespace Kernels /*static*/ @@ -521,12 +751,9 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( BeamSearchBatchConfig const *bc, int shard_id, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &weight, - GenericTensorAccessorW const &output, - GenericTensorAccessorR const &bias) { + GenericTensorAccessorW const &output) { hipStream_t stream; checkCUDA(get_legion_stream(&stream)); - bool use_bias = *m->qkv_bias || *m->final_bias; hipEvent_t t_start, t_end; if (m->profiling) { @@ -535,34 +762,14 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( checkCUDA(hipEventRecord(t_start, stream)); } - assert(input.data_type == weight.data_type); assert(input.data_type == output.data_type); - if (use_bias) { - assert(input.data_type == bias.data_type); - } if (input.data_type == DT_HALF) { - half const *bias_ptr = - use_bias ? bias.get_half_ptr() : static_cast(nullptr); - Kernels::SpecIncMultiHeadAttention::inference_kernel(m, - bc, - shard_id, - input.get_half_ptr(), - weight.get_half_ptr(), - output.get_half_ptr(), - bias_ptr, - stream); + Kernels::SpecIncMultiHeadSelfAttention::inference_kernel( + m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream); } else if (input.data_type == DT_FLOAT) { - float const *bias_ptr = - use_bias ? bias.get_float_ptr() : static_cast(nullptr); - Kernels::SpecIncMultiHeadAttention::inference_kernel(m, - bc, - shard_id, - input.get_float_ptr(), - weight.get_float_ptr(), - output.get_float_ptr(), - bias_ptr, - stream); + Kernels::SpecIncMultiHeadSelfAttention::inference_kernel( + m, bc, shard_id, input.get_float_ptr(), output.get_float_ptr(), stream); } else { assert(false && "Unspported data type"); } @@ -581,7 +788,6 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( FFHandler handler, SpecIncMultiHeadSelfAttention const *attn, - GenericTensorAccessorR const &weight, MemoryAllocator &gpu_mem_allocator, int num_samples, int _num_q_heads, @@ -596,14 +802,11 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( attn->kProjSize, attn->vProjSize, attn->oProjSize, - attn->apply_rotary_embedding, - attn->qkv_bias, + attn->rotary_embedding_meta, attn->scaling_query, attn->qk_prod_scaling, attn->position_bias, - attn->final_bias, attn->scaling_factor, - weight, gpu_mem_allocator, num_samples, attn->num_q_heads, @@ -618,43 +821,16 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( // allocate memory for the seqArray and reserve space { - int max_tokens_per_batch = BatchConfig::max_tokens_per_batch(); - size_t beam_tokeninfo_size = - max_tokens_per_batch * BeamSearchBatchConfig::MAX_BEAM_WIDTH; - size_t requestinfo_size = BeamSearchBatchConfig::max_requests_per_batch(); - size_t beam_requestinfo_size = - BeamSearchBatchConfig::max_requests_per_batch(); - size_t total_size = - requestinfo_size * sizeof(BatchConfig::PerRequestInfo) + - beam_tokeninfo_size * - sizeof(BeamSearchBatchConfig::BeamSearchPerTokenInfo) + - beam_requestinfo_size * - sizeof(BeamSearchBatchConfig:: - BeamSearchPerRequestInfo); // more components will - // be added here later - - // We always directly allocate memory for small speculative models - gpu_mem_allocator.create_legion_instance(beam_search_reserve_inst, - total_size); beam_token_infos = - gpu_mem_allocator - .allocate_instance( - beam_tokeninfo_size); - // offset += beam_tokeninfo_size * - // sizeof(BeamSearchBatchConfig::BeamSearchPerTokenInfo); - request_infos = - gpu_mem_allocator.allocate_instance( - requestinfo_size); - // offset += requestinfo_size * sizeof(BatchConfig::PerRequestInfo); + static_cast( + handler.batch_config_metadata->beamTokenInfo); beam_request_infos = - gpu_mem_allocator - .allocate_instance( - beam_requestinfo_size); - // offset += beam_requestinfo_size * - // sizeof(BeamSearchBatchConfig::BeamSearchPerRequestInfo); - // assert(offset == total_size); - assert(gpu_mem_allocator.instance_total_size == - gpu_mem_allocator.instance_allocated_size); + static_cast( + handler.batch_config_metadata->beamRequestsInfo); + causalMask = static_cast( + handler.batch_config_metadata->causalMask); + request_completed = + static_cast(handler.batch_config_metadata->request_completed); } checkCUDA(hipStreamSynchronize(stream)); diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 4688a8233c..d8a2008388 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -463,8 +463,6 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, BeamSearchBatchConfig const *bc, int shard_id, DT *output_ptr, - DT const *bias_ptr, - DT const *weight_ptr, cudaStream_t stream) { checkCUDA(cublasSetStream(m->handle.blas, stream)); checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); @@ -472,23 +470,10 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]); assert(data_type_size(m->output_type[0]) == sizeof(DT)); cudaDataType_t compute_type = cublas_data_type; - // #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) - // cudaDataType_t compute_type = cublas_data_type; - // #else - // // For best performance, set the default cublas compute type to - // // CUBLAS_COMPUTE_16F for half precision and to - // // CUBLAS_COMPUTE_32F_FAST_16F for full precision - // cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; - // if (m->output_type[0] == DT_FLOAT) { - // compute_type = CUBLAS_COMPUTE_32F_FAST_16F; - // } - // #endif - // int num_requests = bc->num_active_requests(); + int num_tokens = bc->num_active_tokens(); int tokens_previous_requests = 0; int tokens_prev_requests_squares = 0; - // int qkv_block_size = - // (m->qProjSize + m->kProjSize + m->vProjSize) * num_tokens; int q_block_size = m->qProjSize; int kt_block_size = m->kProjSize; @@ -568,8 +553,7 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, m->num_q_heads, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - // print_tensor((float*)C, 32, "C"); - // add alibi position bias to qk production + // add alibi position bias to qk production if (*m->position_bias) { size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; @@ -698,21 +682,26 @@ template void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, BeamSearchBatchConfig const *bc, int shard_id, - DT const *input_ptr, - DT const *weight_ptr, + DT const *qkv_ptr, DT *output_ptr, - DT const *bias_ptr, cudaStream_t stream) { - // phase 1: Implement kernel to compute KQV for input tokens - compute_qkv_kernel(m, - bc, - shard_id, - input_ptr, - weight_ptr, - static_cast
(m->devQKVProjArray), - bias_ptr, - stream); + // phase 0: copy calculated qkv into devQKVProjArray + // [qProjSize, num_heads, 3, num_new_tokens] + size_t qkv_proj_size = + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); + + cudaMemcpyAsync(m->devQKVProjArray, + qkv_ptr, + qkv_proj_size * + sizeof(DT), // is this right, do we need layers etc here + cudaMemcpyDeviceToDevice, + stream); + // phase 1: Implement kernel to compute KQV for input tokens + // TODO WARNING: this is commented out only because we are fixing the inc_attn + // first + compute_qkv_kernel( + m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); // phase 2: Update key/val cache update_kv_cache_kernel
(m, bc, stream); if (bc->num_generation_tokens > 0) { @@ -722,14 +711,16 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, // phase 3: Compute attention score // 3 kernels for pahse 3: matmul1 - softmax - matmal2 if (bc->num_tokens > bc->num_generation_tokens) { - compute_attention_kernel_prompt( - m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); + compute_attention_kernel_prompt(m, bc, shard_id, output_ptr, stream); } - // compute output production and bias together for all tokens + int num_tokens = bc->num_active_tokens(); - compute_o_prod_bias( - m, bc, shard_id, output_ptr, weight_ptr, bias_ptr, num_tokens, stream); + cudaMemcpyAsync(output_ptr, + m->attn_heads, + m->oProjSize * num_tokens * sizeof(DT), + cudaMemcpyDeviceToDevice, + stream); } } // namespace SpecIncMultiHeadSelfAttention @@ -741,12 +732,9 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( BeamSearchBatchConfig const *bc, int shard_id, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &weight, - GenericTensorAccessorW const &output, - GenericTensorAccessorR const &bias) { + GenericTensorAccessorW const &output) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - bool use_bias = *m->qkv_bias || *m->final_bias; cudaEvent_t t_start, t_end; if (m->profiling) { @@ -755,36 +743,14 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( cudaEventRecord(t_start, stream); } - assert(input.data_type == weight.data_type); assert(input.data_type == output.data_type); - if (use_bias) { - assert(input.data_type == bias.data_type); - } if (input.data_type == DT_HALF) { - half const *bias_ptr = - use_bias ? bias.get_half_ptr() : static_cast(nullptr); Kernels::SpecIncMultiHeadSelfAttention::inference_kernel( - m, - bc, - shard_id, - input.get_half_ptr(), - weight.get_half_ptr(), - output.get_half_ptr(), - bias_ptr, - stream); + m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream); } else if (input.data_type == DT_FLOAT) { - float const *bias_ptr = - use_bias ? bias.get_float_ptr() : static_cast(nullptr); Kernels::SpecIncMultiHeadSelfAttention::inference_kernel( - m, - bc, - shard_id, - input.get_float_ptr(), - weight.get_float_ptr(), - output.get_float_ptr(), - bias_ptr, - stream); + m, bc, shard_id, input.get_float_ptr(), output.get_float_ptr(), stream); } else { assert(false && "Unspported data type"); } @@ -797,16 +763,12 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( cudaEventDestroy(t_start); cudaEventDestroy(t_end); printf("SpecIncMultiHeadSelfAttention forward time = %.2fms\n", elapsed); - // print_tensor<3, float>(acc_query.ptr, acc_query.rect, - // "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr, - // acc_output.rect, "[Attention:forward:output]"); } } SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( FFHandler handler, SpecIncMultiHeadSelfAttention const *attn, - GenericTensorAccessorR const &weight, MemoryAllocator &gpu_mem_allocator, int num_samples, int _num_q_heads, @@ -821,14 +783,11 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( attn->kProjSize, attn->vProjSize, attn->oProjSize, - attn->apply_rotary_embedding, - attn->qkv_bias, + attn->rotary_embedding_meta, attn->scaling_query, attn->qk_prod_scaling, attn->position_bias, - attn->final_bias, attn->scaling_factor, - weight, gpu_mem_allocator, num_samples, attn->num_q_heads, diff --git a/src/ops/tree_inc_multihead_self_attention.cc b/src/ops/tree_inc_multihead_self_attention.cc index 132a48be40..ae0795ac1e 100644 --- a/src/ops/tree_inc_multihead_self_attention.cc +++ b/src/ops/tree_inc_multihead_self_attention.cc @@ -61,12 +61,10 @@ Tensor FFModel::inc_multihead_self_attention_verify( int kdim, int vdim, float dropout, - bool qkv_bias, - bool final_bias, bool add_zero_attn, DataType data_type, Initializer *kernel_initializer, - bool apply_rotary_embedding, + RotaryEmbeddingMeta rotary_embedding_meta, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -79,12 +77,10 @@ Tensor FFModel::inc_multihead_self_attention_verify( kdim, vdim, dropout, - qkv_bias, - final_bias, add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -100,12 +96,10 @@ Tensor FFModel::inc_multiquery_self_attention_verify( int kdim, int vdim, float dropout, - bool qkv_bias, - bool final_bias, bool add_zero_attn, DataType data_type, Initializer *kernel_initializer, - bool apply_rotary_embedding, + RotaryEmbeddingMeta rotary_embedding_meta, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -117,7 +111,6 @@ Tensor FFModel::inc_multiquery_self_attention_verify( DataType quantization_type = cpu_offload ? config.quantization_type : DT_NONE; bool offload = cpu_offload; Layer *li = nullptr; - int weight_num = (qkv_bias || final_bias) ? 2 : 1; if (data_type != input->data_type) { Tensor casted_input = cast(input, data_type, "type cast for IncMHA"); li = new Layer(this, @@ -125,7 +118,7 @@ Tensor FFModel::inc_multiquery_self_attention_verify( data_type, name, 1 /*inputs*/, - weight_num /*weights*/, + 0, 1 /*outputs*/, casted_input); } else { @@ -134,7 +127,7 @@ Tensor FFModel::inc_multiquery_self_attention_verify( data_type, name, 1 /*inputs*/, - weight_num /*weights*/, + 0, 1 /*outputs*/, input); } @@ -148,62 +141,28 @@ Tensor FFModel::inc_multiquery_self_attention_verify( li->outputs[0] = create_tensor_legion_ordering( numdims, dims, data_type, li, 0, true /*create_grad*/); } - // Compute weight size - int qProjSize = kdim, kProjSize = kdim, vProjSize = kdim, - oProjSize = embed_dim; - int qSize = input->dims[0], kSize = input->dims[0], vSize = input->dims[0]; - int qParas = qProjSize * qSize; - int kParas = kProjSize * kSize; - int vParas = vProjSize * vSize; - int oParas = oProjSize * (vProjSize > 0 ? vProjSize : vSize); - int one_head_size = qParas + kParas + vParas + oParas; - int weight_size = qParas * num_q_heads + kParas * num_q_heads + - vParas * num_q_heads + oParas * num_q_heads; - { - // compress the weight size if quantization. - if (quantization_type != DT_NONE) { - one_head_size = get_quantization_to_byte_size( - data_type, quantization_type, one_head_size); - } - int dims[1] = {weight_size}; - li->weights[0] = create_weight_legion_ordering( - 1, - dims, - quantization_type == DT_NONE ? data_type : quantization_type, - li, - true /*create_grad*/, - kernel_initializer, - CHOSEN_SYNC_TYPE); - } - if (qkv_bias || final_bias) { - // q, k, v, o - int qkv_bias_size = - qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; - int dims[1] = {(qkv_bias ? qkv_bias_size : 0) + - (final_bias ? oProjSize : 0)}; - li->weights[1] = create_weight_legion_ordering(1, - dims, - data_type, - li, - true /*create_grad*/, - kernel_initializer, - CHOSEN_SYNC_TYPE); - } li->data_type = data_type; li->add_int_property("embed_dim", embed_dim); li->add_int_property("num_q_heads", num_q_heads); li->add_int_property("num_kv_heads", num_kv_heads); li->add_int_property("kdim", kdim); li->add_int_property("vdim", vdim); - li->add_int_property("qkv_bias", qkv_bias); - li->add_int_property("final_bias", final_bias); li->add_int_property("add_zero_attn", add_zero_attn); li->add_float_property("dropout", dropout); - li->add_int_property("apply_rotary_embedding", apply_rotary_embedding); + li->add_int_property("apply_rotary_embedding", + rotary_embedding_meta.apply_rotary_embedding); + li->add_float_property("rope_theta", rotary_embedding_meta.rope_theta); + li->add_string_property("rope_type", rotary_embedding_meta.rope_type); + li->add_float_property("factor", rotary_embedding_meta.factor); + li->add_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + li->add_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + li->add_int_property("original_max_position_embeddings", + rotary_embedding_meta.original_max_position_embeddings); li->add_int_property("scaling_query", scaling_query); li->add_float_property("scaling_factor", scaling_factor); - li->add_int_property("qk_prod_scaling", qk_prod_scaling); li->add_int_property("position_bias", position_bias); li->add_int_property("quantization_type", quantization_type); li->add_int_property("offload", offload); @@ -230,15 +189,20 @@ Op *TreeIncMultiHeadSelfAttention::create_operator_from_layer( int vdim = value; float dropout; layer->get_float_property("dropout", dropout); - layer->get_int_property("qkv_bias", value); - bool qkv_bias = (bool)value; - layer->get_int_property("final_bias", value); - bool final_bias = (bool)value; layer->get_int_property("add_zero_attn", value); bool add_zero_attn = (bool)value; + RotaryEmbeddingMeta rotary_embedding_meta; layer->get_int_property("apply_rotary_embedding", value); - bool apply_rotary_embedding = (bool)value; - layer->get_int_property("scaling_query", value); + rotary_embedding_meta.apply_rotary_embedding = (bool)value; + layer->get_float_property("rope_theta", rotary_embedding_meta.rope_theta); + layer->get_string_property("rope_type", rotary_embedding_meta.rope_type); + layer->get_float_property("factor", rotary_embedding_meta.factor); + layer->get_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + layer->get_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + layer->get_int_property("original_max_position_embeddings", value); + rotary_embedding_meta.original_max_position_embeddings = (int)value; bool scaling_query = (bool)value; float scaling_factor; layer->get_float_property("scaling_factor", scaling_factor); @@ -261,15 +225,12 @@ Op *TreeIncMultiHeadSelfAttention::create_operator_from_layer( kdim, vdim, dropout, - qkv_bias, - final_bias, add_zero_attn, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, position_bias, - false /*allocate_weights*/, quantization_type, offload, tensor_parallelism_degree, @@ -286,32 +247,27 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( int _kdim, int _vdim, float _dropout, - bool _qkv_bias, - bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, bool _position_bias, - bool allocate_weights, DataType _quantization_type, bool _offload, int _tensor_parallelism_degree, char const *name) - // Initializer* _bias_initializer) : Op(model, OP_TREE_INC_MULTIHEAD_SELF_ATTENTION, _input->data_type, name, 1 /*inputs*/, - (_qkv_bias || _final_bias ? 2 : 1) /*weights*/, + 0, 1 /*outputs*/, _input), num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), - qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), qSize(_input->dims[0].size), kSize(_input->dims[0].size), vSize(_input->dims[0].size), qProjSize(_kdim), kProjSize(_kdim), vProjSize(_vdim), oProjSize(_embed_dim), @@ -330,63 +286,12 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( dims[i] = _input->dims[i]; } dims[0].size = _embed_dim; - // Currently require no parallelism along this dim - assert(dims[0].degree == 1); - if (allocate_weights) { - // Create weight tensor - int num_dims = inputs[0]->num_dims; - // Compute weight size - int qParas = this->qProjSize * this->qSize; - int kParas = this->kProjSize * this->kSize; - int vParas = this->vProjSize * this->vSize; - int oParas = - this->oProjSize * (this->vProjSize > 0 ? this->vProjSize : this->vSize); - ParallelDim dims[2]; - dims[0] = inputs[0]->dims[num_dims - 2]; - dims[0].size = dims[0].degree; - dims[1] = inputs[0]->dims[num_dims - 1]; - dims[1].size = this->num_q_heads * (qParas + oParas) + - this->num_q_heads * (kParas + vParas); - dims[1].is_replica_dim = false; - // dims[2].size = qParas + kParas + vParas + oParas; - if (quantization_type != DT_NONE) { - dims[1].size = get_quantization_to_byte_size( - data_type, quantization_type, dims[1].size); - } - // dims[2].degree = 1; - // dims[2].parallel_idx = -1; - int seed = std::rand(); - Initializer *initializer = new GlorotUniform(seed); - weights[0] = model.create_parallel_weight<2>( - dims, - quantization_type == DT_NONE ? this->data_type : quantization_type, - NULL /*owner_op*/, - true /*create_grad*/, - initializer, - CHOSEN_SYNC_TYPE); - if (qkv_bias || final_bias) { - ParallelTensorShape bias_shape = _input->get_shape(); - int qkv_bias_size = - qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; - bias_shape.dims[0].size = - (qkv_bias ? qkv_bias_size : 0) + (final_bias ? oProjSize : 0); - bias_shape.dims[1].size = bias_shape.dims[2].size = 1; - weights[1] = - model.create_parallel_weight_legion_ordering(bias_shape.num_dims, - bias_shape.dims, - this->data_type, - nullptr /*owner_op*/, - true /*create_grad*/, - initializer, - CHOSEN_SYNC_TYPE); - } - } + // No longer require no parallelism along this dim + // assert(dims[0].degree == 1); outputs[0] = model.create_parallel_tensor_legion_ordering( _input->num_dims, dims, this->data_type, this); - /* for (int i = 0; i < numdim; i++) { */ - /* register_output_input_parallel_dims(outputs[0], i, inputs[0], i); */ - /* } */ + /* // Check correctness */ /* assert(check_output_input_weight_parallel_dims()); */ } @@ -394,40 +299,33 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( FFModel &model, const ParallelTensor _input, - const ParallelTensor _weight, int _embed_dim, int _num_q_heads, int _num_kv_heads, int _kdim, int _vdim, float _dropout, - bool _qkv_bias, - bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, bool _position_bias, - bool allocate_weights, DataType _quantization_type, bool _offload, int _tensor_parallelism_degree, char const *name) - // Initializer* _bias_initializer) : Op(model, OP_TREE_INC_MULTIHEAD_SELF_ATTENTION, _input->data_type, name, 1 /*inputs*/, - (_qkv_bias || _final_bias ? 2 : 1) /*weights*/, + 0, 1 /*outputs*/, - _input, - _weight), + _input), num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), - qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), qSize(_input->dims[0].size), kSize(_input->dims[0].size), vSize(_input->dims[0].size), qProjSize(_kdim), kProjSize(_kdim), vProjSize(_vdim), oProjSize(_embed_dim), @@ -435,9 +333,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( scaling_query(_scaling_query), scaling_factor(_scaling_factor), qk_prod_scaling(_qk_prod_scaling), position_bias(_position_bias), quantization_type(_quantization_type), offload(_offload), - tensor_parallelism_degree(_tensor_parallelism_degree) -// bias_initializer(_bias_initializer) -{ + tensor_parallelism_degree(_tensor_parallelism_degree) { numOutputs = 1; int numdim = _input->num_dims; ParallelDim dims[MAX_TENSOR_DIM]; @@ -445,64 +341,13 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( dims[i] = _input->dims[i]; } dims[0].size = _embed_dim; - // Currently require no parallelism along this dim + // Currently require no parallelism along this dim, is this aligned with the + // previous removal of assert? assert(dims[0].degree == 1); - if (allocate_weights) { - // Create weight tensor - int num_dims = inputs[0]->num_dims; - // Compute weight size - int qParas = this->qProjSize * this->qSize; - int kParas = this->kProjSize * this->kSize; - int vParas = this->vProjSize * this->vSize; - int oParas = - this->oProjSize * (this->vProjSize > 0 ? this->vProjSize : this->vSize); - ParallelDim dims[2]; - dims[0] = inputs[0]->dims[num_dims - 2]; - dims[0].size = dims[0].degree; - dims[1] = inputs[0]->dims[num_dims - 1]; - dims[1].size = this->num_q_heads * (qParas + oParas) + - this->num_q_heads * (kParas + vParas); - dims[1].is_replica_dim = false; - // dims[2].size = qParas + kParas + vParas + oParas; - if (quantization_type != DT_NONE) { - dims[1].size = get_quantization_to_byte_size( - data_type, quantization_type, dims[1].size); - } - int seed = std::rand(); - Initializer *initializer = new GlorotUniform(seed); - weights[0] = model.create_parallel_weight<2>( - dims, - quantization_type == DT_NONE ? this->data_type : quantization_type, - NULL /*owner_op*/, - true /*create_grad*/, - initializer, - CHOSEN_SYNC_TYPE); - if (qkv_bias || final_bias) { - ParallelTensorShape bias_shape = _input->get_shape(); - int qkv_bias_size = - qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; - bias_shape.dims[0].size = - (qkv_bias ? qkv_bias_size : 0) + (final_bias ? oProjSize : 0); - bias_shape.dims[1].size = bias_shape.dims[2].size = 1; - weights[1] = - model.create_parallel_weight_legion_ordering(bias_shape.num_dims, - bias_shape.dims, - this->data_type, - nullptr /*owner_op*/, - true /*create_grad*/, - initializer, - CHOSEN_SYNC_TYPE); - } - } outputs[0] = model.create_parallel_tensor_legion_ordering( _input->num_dims, dims, this->data_type, this); - /* for (int i = 0; i < numdim; i++) { */ - /* register_output_input_parallel_dims(outputs[0], i, inputs[0], i); */ - /* } */ - /* register_output_weight_parallel_dims(outputs[0], numdim-1, _weight, 1); */ - /* register_output_weight_parallel_dims(outputs[0], numdim-2, _weight, 2); */ // Check correctness /* assert(check_output_input_weight_parallel_dims()); */ } @@ -510,8 +355,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( FFModel &model, TreeIncMultiHeadSelfAttention const &other, - const ParallelTensor input, - bool allocate_weights) + const ParallelTensor input) : TreeIncMultiHeadSelfAttention(model, other.layer_guid, input, @@ -521,15 +365,12 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( other.qProjSize, other.vProjSize, other.dropout, - other.qkv_bias, - other.final_bias, other.add_zero_attn, - other.apply_rotary_embedding, + other.rotary_embedding_meta, other.scaling_query, other.scaling_factor, other.qk_prod_scaling, other.position_bias, - allocate_weights, other.quantization_type, other.offload, other.tensor_parallelism_degree, @@ -539,7 +380,6 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( FFModel &model, TreeIncMultiHeadSelfAttentionParams const ¶ms, ParallelTensor const &input, - bool allocate_weights, char const *name) : TreeIncMultiHeadSelfAttention(model, params.layer_guid, @@ -550,15 +390,12 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( params.kdim, params.vdim, params.dropout, - params.qkv_bias, - params.final_bias, params.add_zero_attn, - params.apply_rotary_embedding, + params.rotary_embedding_meta, params.scaling_query, params.scaling_factor, params.qk_prod_scaling, params.position_bias, - allocate_weights, params.quantization_type, params.offload, params.tensor_parallelism_degree, @@ -592,20 +429,12 @@ void TreeIncMultiHeadSelfAttention::init_inference( EXCLUSIVE, batch_inputs[0]->region)); launcher.add_field(0, FID_DATA); - launcher.add_region_requirement( - RegionRequirement(weights[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region, - ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0)); - launcher.add_field(1, FID_DATA); launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, batch_outputs[0]->region)); - launcher.add_field(2, FID_DATA); + launcher.add_field(1, FID_DATA); FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); set_opmeta_from_futuremap_inference(ff, fm, batch_outputs[0]); @@ -633,18 +462,12 @@ void TreeIncMultiHeadSelfAttention::init(FFModel const &ff) { EXCLUSIVE, inputs[0]->region)); launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(weights[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region)); - launcher.add_field(1, FID_DATA); launcher.add_region_requirement(RegionRequirement(outputs[0]->part, 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, outputs[0]->region)); - launcher.add_field(2, FID_DATA); + launcher.add_field(1, FID_DATA); FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); set_opmeta_from_futuremap(ff, fm); @@ -652,8 +475,7 @@ void TreeIncMultiHeadSelfAttention::init(FFModel const &ff) { /* regions[0](I): input - regions[1](I): weight - regions[2](O): output + regions[1](O): output */ OpMeta *TreeIncMultiHeadSelfAttention::init_task( Task const *task, @@ -671,17 +493,10 @@ OpMeta *TreeIncMultiHeadSelfAttention::init_task( FID_DATA, ctx, runtime); - GenericTensorAccessorR weight = - helperGetGenericTensorAccessorRO(attn->weights[0]->data_type, - regions[1], - task->regions[1], - FID_DATA, - ctx, - runtime); GenericTensorAccessorW output = helperGetGenericTensorAccessorWO(attn->outputs[0]->data_type, - regions[2], - task->regions[2], + regions[1], + task->regions[1], FID_DATA, ctx, runtime); @@ -689,14 +504,12 @@ OpMeta *TreeIncMultiHeadSelfAttention::init_task( int num_samples = input.domain.hi()[2] - input.domain.lo()[2] + 1; assert(attn->qoSeqLength == input.domain.hi()[1] - input.domain.lo()[1] + 1); assert(attn->kvSeqLength == input.domain.hi()[1] - input.domain.lo()[1] + 1); - // int num_q_heads = weight.domain.hi()[1] - weight.domain.lo()[1] + 1; + int num_q_heads = attn->num_q_heads / attn->tensor_parallelism_degree; int num_kv_heads = attn->num_kv_heads / attn->tensor_parallelism_degree + (attn->num_kv_heads % attn->tensor_parallelism_degree != 0); - assert(attn->oProjSize == output.domain.hi()[0] - output.domain.lo()[0] + 1); - Memory gpu_mem = get_proc_mem(Machine::get_machine(), task->target_proc); MemoryAllocator gpu_mem_allocator(gpu_mem); if (attn->offload) { @@ -705,14 +518,8 @@ OpMeta *TreeIncMultiHeadSelfAttention::init_task( gpu_mem_allocator.register_reserved_work_space( handle.offload_reserve_space, handle.offload_reserve_space_size); } - TreeIncMultiHeadSelfAttentionMeta *m = - new TreeIncMultiHeadSelfAttentionMeta(handle, - attn, - weight, - gpu_mem_allocator, - num_samples, - num_q_heads, - num_kv_heads); + TreeIncMultiHeadSelfAttentionMeta *m = new TreeIncMultiHeadSelfAttentionMeta( + handle, attn, gpu_mem_allocator, num_samples, num_q_heads, num_kv_heads); if (!attn->offload) { // assert that we didn't over allocate memory assert(gpu_mem_allocator.reserved_allocated_size == @@ -723,10 +530,6 @@ OpMeta *TreeIncMultiHeadSelfAttention::init_task( std::strcpy(m->op_name, attn->name); m->layer_guid = attn->layer_guid; - if (attn->quantization_type == DT_NONE) { - assert(weight.domain.get_volume() * data_type_size(weight.data_type) == - m->weightSize); - } return m; } @@ -764,37 +567,18 @@ FutureMap TreeIncMultiHeadSelfAttention::inference( EXCLUSIVE, batch_inputs[0]->region)); launcher.add_field(idx++, FID_DATA); - launcher.add_region_requirement( - RegionRequirement(weights[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region, - ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0)); - launcher.add_field(idx++, FID_DATA); launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, batch_outputs[0]->region)); launcher.add_field(idx++, FID_DATA); - if (qkv_bias || final_bias) { - launcher.add_region_requirement( - RegionRequirement(weights[1]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[1]->region, - ff.cpu_offload ? MAP_TO_ZC_MEMORY : 0)); - launcher.add_field(idx++, FID_DATA); - } return runtime->execute_index_space(ctx, launcher); } /* regions[0](I): input - regions[3](I): weight - regions[4](O): output + regions[1](O): output */ void TreeIncMultiHeadSelfAttention::inference_task( Task const *task, @@ -815,37 +599,19 @@ void TreeIncMultiHeadSelfAttention::inference_task( TreeIncMultiHeadSelfAttentionMeta *m = *((TreeIncMultiHeadSelfAttentionMeta **)task->local_args); - assert(((*m->qkv_bias || *m->final_bias) ? regions.size() == 4 - : regions.size() == 3)); + assert(regions.size() == 2); GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); - GenericTensorAccessorR weight = helperGetGenericTensorAccessorRO( - m->weight_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( - m->output_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); - GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - biases = helperGetGenericTensorAccessorRO(m->weight_type[1], - regions[3], - task->regions[3], - FID_DATA, - ctx, - runtime); - Domain bias_domain = runtime->get_index_space_domain( - ctx, task->regions[3].region.get_index_space()); - assert(bias_domain.get_dim() == 4); - } + m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); Domain input_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); - Domain weight_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); + ctx, task->regions[1].region.get_index_space()); assert(input_domain.get_dim() == 4); - assert(weight_domain.get_dim() == 2); assert(output_domain.get_dim() == 4); /* print_tensor(input.get_float_ptr(), @@ -855,18 +621,13 @@ void TreeIncMultiHeadSelfAttention::inference_task( assert(task->index_point.get_dim() == 1); TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( - m, &bc, task->index_point.point_data[0], input, weight, output, biases); + m, &bc, task->index_point.point_data[0], input, output); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; - std::vector weights_accessors; - weights_accessors.push_back(weight); - if (*m->qkv_bias || *m->final_bias) { - weights_accessors.push_back(biases); - } TreeIncMultiHeadSelfAttention::save_inference_tensors_to_file( - m, shard_id, &bc, {input}, weights_accessors, {output}); + m, shard_id, &bc, {input}, {}, {output}); } } @@ -896,9 +657,20 @@ bool operator==(TreeIncMultiHeadSelfAttentionParams const &lhs, return lhs.layer_guid == rhs.layer_guid && lhs.embed_dim == rhs.embed_dim && lhs.num_q_heads == rhs.num_q_heads && lhs.kdim == rhs.kdim && lhs.vdim == rhs.vdim && lhs.dropout == rhs.dropout && - lhs.qkv_bias == rhs.qkv_bias && lhs.final_bias == rhs.final_bias && lhs.add_zero_attn == rhs.add_zero_attn && - lhs.apply_rotary_embedding == rhs.apply_rotary_embedding && + lhs.rotary_embedding_meta.apply_rotary_embedding == + rhs.rotary_embedding_meta.apply_rotary_embedding && + lhs.rotary_embedding_meta.rope_theta == + rhs.rotary_embedding_meta.rope_theta && + lhs.rotary_embedding_meta.rope_type == + rhs.rotary_embedding_meta.rope_type && + lhs.rotary_embedding_meta.factor == rhs.rotary_embedding_meta.factor && + lhs.rotary_embedding_meta.low_freq_factor == + rhs.rotary_embedding_meta.low_freq_factor && + lhs.rotary_embedding_meta.high_freq_factor == + rhs.rotary_embedding_meta.high_freq_factor && + lhs.rotary_embedding_meta.original_max_position_embeddings == + rhs.rotary_embedding_meta.original_max_position_embeddings && lhs.scaling_query == rhs.scaling_query && lhs.scaling_factor == rhs.scaling_factor && lhs.qk_prod_scaling == rhs.qk_prod_scaling && @@ -915,10 +687,8 @@ TreeIncMultiHeadSelfAttentionParams params.kdim = this->kProjSize; params.vdim = this->vProjSize; params.dropout = this->dropout; - params.qkv_bias = this->qkv_bias; - params.final_bias = this->final_bias; params.add_zero_attn = this->add_zero_attn; - params.apply_rotary_embedding = this->apply_rotary_embedding; + params.rotary_embedding_meta = this->rotary_embedding_meta; params.scaling_query = this->scaling_query; params.scaling_factor = this->scaling_factor; params.qk_prod_scaling = this->qk_prod_scaling; @@ -943,10 +713,15 @@ size_t hash::operator()( hash_combine(key, params.kdim); hash_combine(key, params.vdim); hash_combine(key, params.dropout); - hash_combine(key, params.qkv_bias); - hash_combine(key, params.final_bias); hash_combine(key, params.add_zero_attn); - hash_combine(key, params.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.rope_theta); + hash_combine(key, params.rotary_embedding_meta.rope_type); + hash_combine(key, params.rotary_embedding_meta.factor); + hash_combine(key, params.rotary_embedding_meta.low_freq_factor); + hash_combine(key, params.rotary_embedding_meta.high_freq_factor); + hash_combine(key, + params.rotary_embedding_meta.original_max_position_embeddings); hash_combine(key, params.scaling_query); hash_combine(key, params.scaling_factor); hash_combine(key, params.qk_prod_scaling); diff --git a/src/ops/tree_inc_multihead_self_attention.cpp b/src/ops/tree_inc_multihead_self_attention.cpp index 890d32bc87..50e2311ca8 100644 --- a/src/ops/tree_inc_multihead_self_attention.cpp +++ b/src/ops/tree_inc_multihead_self_attention.cpp @@ -17,7 +17,6 @@ #include "flexflow/ffconst_utils.h" #include "flexflow/ops/kernels/inc_multihead_self_attention_kernels.h" #include "flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh" -#include "flexflow/ops/tree_inc_multihead_self_attention.h" #include "flexflow/utils/hip_helper.h" #include #include @@ -519,300 +518,6 @@ __global__ void tree_fill_entries_above_diagonal(DT *matrix, } } -template -void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, - TreeVerifyBatchConfig const *bc, - int shard_id, - DT *output_ptr, - DT const *bias_ptr, - DT const *weight_ptr, - hipStream_t stream) { - checkCUDA(hipblasSetStream(m->handle.blas, stream)); - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); - hipblasDatatype_t hipblas_data_type = ff_to_cuda_datatype(m->output_type[0]); - miopenDataType_t miopen_data_type = ff_to_cudnn_datatype(m->output_type[0]); - assert(data_type_size(m->output_type[0]) == sizeof(DT)); - hipblasDatatype_t compute_type = hipblas_data_type; - // #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) - // hipblasDatatype_t compute_type = hipblas_data_type; - // #else - // // TODO: currently use the hipblas_data_type - // // cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; - // hipblasDatatype_t compute_type = hipblas_data_type; - // #endif - // int num_requests = bc->num_active_requests(); - int processed_tokens_in_batch = 0; - // int qkv_block_size = - // (m->qProjSize + m->kProjSize + m->vProjSize) * bc->num_active_tokens(); - int q_block_size = m->qProjSize; - int kt_block_size = m->kProjSize; - int kt_req_block_size = - kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length() + - BatchConfig::max_spec_tree_token_num(); - int vt_block_size = m->vProjSize; - int vt_req_block_size = - vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length() + - BatchConfig::max_spec_tree_token_num(); - assert(m->qProjSize == m->kProjSize); - - for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i]) { - continue; - } - assert(processed_tokens_in_batch == - bc->requestsInfo[i].first_token_offset_in_batch); - int last_token_idx_of_the_request = - processed_tokens_in_batch + bc->requestsInfo[i].num_tokens_in_batch - 1; - while (processed_tokens_in_batch <= last_token_idx_of_the_request) { - int num_new_tokens = 1; - int j = processed_tokens_in_batch; - while ((j + 1 <= last_token_idx_of_the_request) && - (bc->tokensInfo[j].abs_depth_in_request + 1 == - bc->tokensInfo[j + 1].abs_depth_in_request)) { - j++; - num_new_tokens++; - } - - int total_tokens_in_request = bc->tokensInfo[j].abs_depth_in_request + 1; - assert(num_new_tokens >= 1 && total_tokens_in_request >= num_new_tokens); - { - // update K-V cache - int parallelism = m->hidden_size * KV_WEIGHT_NUM * num_new_tokens; - hipLaunchKernelGGL( - HIP_KERNEL_NAME(update_tree_branch_kv_cache
), - GET_BLOCKS(parallelism), - min(CUDA_NUM_THREADS, parallelism), - 0, - stream, - static_cast
(m->devQKVProjArray), - static_cast
(m->keyCache), - static_cast
(m->valueCache), - m->token_infos, - m->qProjSize, - m->kProjSize, - m->vProjSize, - num_new_tokens, // num_tokens_in_branch - processed_tokens_in_batch, // num_processed_tokens_in_batch - m->num_active_infr_tokens, // total_tokens_in_batch - BatchConfig::max_sequence_length(), - m->hidden_size); - } - - // bc->token_last_available_idx[i] + 1; - // Compute (QK^T/sqrt(d_k)) - int m_ = num_new_tokens; - int n = total_tokens_in_request; - int k = m->qProjSize; - int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, - ldc = m_; - int strideA = q_block_size; - int strideB = kt_block_size; - int strideC = num_new_tokens * total_tokens_in_request; - - // a flag of using this scaling alpha - DT alpha = 1.0f, beta = 0.0f; - if (*m->qk_prod_scaling) { - alpha = static_cast
(1.0f / sqrt(m->kProjSize)); - } - // To get A, skip over Q entries from previous requests (same head) - DT const *A = static_cast
(m->devQKVProjArray) + - processed_tokens_in_batch * m->qProjSize * m->num_q_heads * - QKV_WEIGHT_NUM; - // To get B, skip over K entries from previous requests (all heads + - // padding) - DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; - // To get C, skip over QK^T products from previous requests - DT *C = static_cast
(m->qk_prods); - - checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, - HIPBLAS_OP_T, - HIPBLAS_OP_N, - m_, - n, - k, - &alpha, - A, - hipblas_data_type, - lda, - strideA, - B, - hipblas_data_type, - ldb, - strideB, - &beta, - C, - hipblas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - HIPBLAS_GEMM_DEFAULT)); - - if (*m->position_bias) { - size_t parallelism = - m->num_q_heads * total_tokens_in_request * num_new_tokens; - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_position_bias_qkprd
), - GET_BLOCKS(parallelism), - min((size_t)CUDA_NUM_THREADS, parallelism), - 0, - stream, - C, - num_new_tokens, - total_tokens_in_request, - m->num_q_heads, - m->global_num_q_heads, - shard_id); - } - - // Fill all elements above diagonal in qk prods with -inf to force - // causal attention. - assert(num_new_tokens <= total_tokens_in_request); - if (num_new_tokens > 1) { - size_t parallelism = - m->num_q_heads * num_new_tokens * total_tokens_in_request; - hipLaunchKernelGGL( - HIP_KERNEL_NAME(tree_fill_entries_above_diagonal
), - GET_BLOCKS(parallelism), - min((size_t)CUDA_NUM_THREADS, parallelism), - 0, - stream, - C, - num_new_tokens, - total_tokens_in_request, - m->num_q_heads, - static_cast
(-INFINITY)); - } - // Compute Softmax(QK^T/sqrt(d_k)) - // Before modifying the parameters below, make sure to read the following - // description of the CUDNN_TENSOR_NCHW tensor layout, from - // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t: - // This tensor format specifies that the data is laid out in the following - // order: batch size, feature maps, rows, columns. The strides are - // implicitly defined in such a way that the data are contiguous in memory - // with no padding between images, feature maps, rows, and columns; the - // columns are the inner dimension and the images are the outermost - // dimension. - int n_param = m->num_q_heads; - int c_param = total_tokens_in_request; - int h_param = 1; - int w_param = num_new_tokens; - checkCUDNN(miopenSet4dTensorDescriptor( - m->qk_tensor, miopen_data_type, n_param, c_param, h_param, w_param)); - float softmax_alpha = 1.0f, softmax_beta = 0.0f; - DT *C_softmax = static_cast
(m->qk_prods_softmax); - // The softmax operation below is executed according to the - // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The - // softmax operation is computed per spatial location (H,W) per image (N) - // across dimension C. - checkCUDNN(miopenSoftmaxForward_V2(m->handle.dnn, - &softmax_alpha, - m->qk_tensor, - C, - &softmax_beta, - m->qk_tensor, - C_softmax, - MIOPEN_SOFTMAX_ACCURATE, - MIOPEN_SOFTMAX_MODE_CHANNEL)); - // Matmul softmax(QK^T/sqrt(d_k)) by V - alpha = 1.0f, beta = 0.0f; - m_ = m->vProjSize; - n = num_new_tokens; - k = total_tokens_in_request; - lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; - strideA = vt_block_size; - strideB = num_new_tokens * total_tokens_in_request; - strideC = m->vProjSize; - // To get A, skip over V^T entries from previous requests (all heads + - // padding) - A = static_cast
(m->valueCache) + i * vt_req_block_size; - // To get B, skip over softmax(QK^T/sqrt(d_k)) entries from previous - // requests (all heads) - B = C_softmax; - // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous - // requests - C = static_cast
(m->attn_heads) + - processed_tokens_in_batch * m->num_q_heads * m->vProjSize; - checkCUDA(hipblasGemmStridedBatchedEx(m->handle.blas, - HIPBLAS_OP_N, - HIPBLAS_OP_T, - m_, - n, - k, - &alpha, - A, - hipblas_data_type, - lda, - strideA, - B, - hipblas_data_type, - ldb, - strideB, - &beta, - C, - hipblas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - HIPBLAS_GEMM_DEFAULT)); - processed_tokens_in_batch += num_new_tokens; - } - // Before moving to the next request - // check that we have finished all tokens of the request - assert(last_token_idx_of_the_request + 1 == processed_tokens_in_batch); - } - // Project to output, save result directly on output tensor - DT alpha = 1.0f, beta = 0.0f; - int m_ = m->oProjSize; - int k = m->vProjSize * m->num_q_heads; - int n = processed_tokens_in_batch; - int lda = k, ldb = k, ldc = m_; - DT const *A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads + - m->kProjSize * m->num_q_heads + - m->vProjSize * m->num_q_heads); - DT const *B = static_cast
(m->attn_heads); - DT *C = static_cast
(output_ptr); - - checkCUDA(hipblasGemmEx(m->handle.blas, - HIPBLAS_OP_T, - HIPBLAS_OP_T, - m_, - n, - k, - &alpha, - A, - hipblas_data_type, - lda, - B, - hipblas_data_type, - ldb, - &beta, - C, - hipblas_data_type, - ldc, - compute_type, - HIPBLAS_GEMM_DEFAULT)); - - if (*m->final_bias && shard_id == 0) { - int parallelism = m->oProjSize * processed_tokens_in_batch; - int qkv_weight_size = m->qProjSize * m->global_num_q_heads + - m->kProjSize * m->global_num_q_heads + - m->vProjSize * m->global_num_q_heads; - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_proj_bias_w
), - GET_BLOCKS(parallelism), - min(CUDA_NUM_THREADS, parallelism), - 0, - stream, - output_ptr, - bias_ptr, - processed_tokens_in_batch, - qkv_weight_size, - m->oProjSize); - } - - assert(processed_tokens_in_batch == bc->num_active_infr_tokens()); -} - #define LAUNCH_TREE_VERIFY_ATTENTION_SCORE_KERNEL( \ DT, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ smem_size_in_bytes_tree
(m->qProjSize, \ @@ -895,27 +600,10 @@ template void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, TreeVerifyBatchConfig const *bc, int shard_id, - DT const *input_ptr, - DT const *weight_ptr, + DT const *qkv_ptr, DT *output_ptr, - DT const *bias_ptr, hipStream_t stream) { - // additional processing for weight uploading - if (m->handle.offload_reserve_space != nullptr) { - // Note that we update weight_ptr and bias_ptr when uploading weight and - // bias - checkCUDA(hipMemcpyAsync(m->weight_ptr, - weight_ptr, - m->weightSize, - hipMemcpyHostToDevice, - stream)); - weight_ptr = static_cast
(m->weight_ptr); - if (m->biasSize > 0) { - checkCUDA(hipMemcpyAsync( - m->bias_ptr, bias_ptr, m->biasSize, hipMemcpyHostToDevice, stream)); - bias_ptr = static_cast
(m->bias_ptr); - } - } + // copy committed tokens info to GPU for the commit_tokens kernel // Note that m->num_active_infr_tokens stores the number of active // tokens in the previous batch, which is needed for committing @@ -929,39 +617,36 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, // tokens for the current batch m->num_active_infr_tokens = bc->num_active_infr_tokens(); - // here because we need postion info in infernece 1 - if (m->offload && m->biasSize > 0) { - checkCUDA(hipMemcpyAsync( - m->bias_ptr, bias_ptr, m->biasSize, hipMemcpyHostToDevice, stream)); - bias_ptr = static_cast
(m->bias_ptr); - } + // phase 0: copy calculated qkv into devQKVProjArray + // [qProjSize, num_heads, 3, num_new_tokens] + size_t qkv_proj_size = + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); + + hipMemcpyAsync(m->devQKVProjArray, + qkv_ptr, + qkv_proj_size * + sizeof(DT), // is this right, do we need layers etc here + hipMemcpyDeviceToDevice, + stream); + // phase 1: Implement kernel to compute KQV for input tokens - compute_qkv_kernel(m, - bc, - shard_id, - input_ptr, - weight_ptr, - static_cast
(m->devQKVProjArray), - bias_ptr, - stream); + // TODO WARNING: this is commented out only because we are fixing the inc_attn + // first + compute_qkv_kernel( + m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); // phase 2: No need to update key/val cache - // IncMultiHeadSelfAttention::update_kv_cache_kernel( - // m, bc, stream); - // use the new kernel compute_attention_kernel_fused
( m, bc, static_cast
(m->attn_heads), stream); int processed_tokens_in_batch = bc->num_active_tokens(); - compute_o_prod_bias(m, - bc, - shard_id, - output_ptr, - weight_ptr, - bias_ptr, - processed_tokens_in_batch, - stream); + int num_tokens = bc->num_active_tokens(); + hipMemcpyAsync(output_ptr, + m->attn_heads, + m->oProjSize * num_tokens * sizeof(DT), + hipMemcpyDeviceToDevice, + stream); } } // namespace TreeIncMultiHeadAttention @@ -973,12 +658,9 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( TreeVerifyBatchConfig const *bc, int shard_id, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &weight, - GenericTensorAccessorW const &output, - GenericTensorAccessorR const &bias) { + GenericTensorAccessorW const &output) { hipStream_t stream; checkCUDA(get_legion_stream(&stream)); - bool use_bias = *m->qkv_bias || *m->final_bias; hipEvent_t t_start, t_end; if (m->profiling) { @@ -987,44 +669,14 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( checkCUDA(hipEventRecord(t_start, stream)); } - // assert(input.data_type == weight.data_type); assert(input.data_type == output.data_type); - if (use_bias) { - assert(input.data_type == bias.data_type); - } if (input.data_type == DT_HALF) { - if (m->offload) { - pre_build_weight_kernel(m, weight, input.data_type, stream); - } - - half const *bias_ptr = - use_bias ? bias.get_half_ptr() : static_cast(nullptr); Kernels::TreeIncMultiHeadAttention::inference_kernel( - m, - bc, - shard_id, - input.get_half_ptr(), - m->offload ? static_cast(m->weight_ptr) : weight.get_half_ptr(), - output.get_half_ptr(), - bias_ptr, - stream); + m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream); } else if (input.data_type == DT_FLOAT) { - if (m->offload) { - pre_build_weight_kernel(m, weight, input.data_type, stream); - } - float const *bias_ptr = - use_bias ? bias.get_float_ptr() : static_cast(nullptr); Kernels::TreeIncMultiHeadAttention::inference_kernel( - m, - bc, - shard_id, - input.get_float_ptr(), - m->offload ? static_cast(m->weight_ptr) - : weight.get_float_ptr(), - output.get_float_ptr(), - bias_ptr, - stream); + m, bc, shard_id, input.get_float_ptr(), output.get_float_ptr(), stream); } else { assert(false && "Unspported data type"); } @@ -1037,16 +689,12 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( checkCUDA(hipEventDestroy(t_start)); checkCUDA(hipEventDestroy(t_end)); printf("TreeIncMultiHeadSelfAttention forward time = %.2fms\n", elapsed); - // print_tensor<3, float>(acc_query.ptr, acc_query.rect, - // "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr, - // acc_output.rect, "[Attention:forward:output]"); } } TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( FFHandler handler, TreeIncMultiHeadSelfAttention const *attn, - GenericTensorAccessorR const &weight, MemoryAllocator &gpu_mem_allocator, int num_samples, int _num_q_heads, @@ -1061,14 +709,11 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( attn->kProjSize, attn->vProjSize, attn->oProjSize, - attn->apply_rotary_embedding, - attn->qkv_bias, + attn->rotary_embedding_meta, attn->scaling_query, attn->qk_prod_scaling, attn->position_bias, - attn->final_bias, attn->scaling_factor, - weight, gpu_mem_allocator, num_samples, attn->num_q_heads, diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 86c53d7ea1..8c643b1964 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -494,303 +494,6 @@ __global__ void tree_fill_entries_above_diagonal(DT *matrix, } } -template -void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, - TreeVerifyBatchConfig const *bc, - int shard_id, - DT *output_ptr, - DT const *bias_ptr, - DT const *weight_ptr, - cudaStream_t stream) { - checkCUDA(cublasSetStream(m->handle.blas, stream)); - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); - cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]); - assert(data_type_size(m->output_type[0]) == sizeof(DT)); - cudaDataType_t compute_type = cublas_data_type; - // #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) - // cudaDataType_t compute_type = cublas_data_type; - // #else - // // For best performance, set the default cublas compute type to - // // CUBLAS_COMPUTE_16F for half precision and to - // // CUBLAS_COMPUTE_32F_FAST_16F for full precision - // cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; - // if (m->output_type[0] == DT_FLOAT) { - // compute_type = CUBLAS_COMPUTE_32F_FAST_16F; - // } - // #endif - // int num_requests = bc->num_active_requests(); - int processed_tokens_in_batch = 0; - // int qkv_block_size = - // (m->qProjSize + m->kProjSize + m->vProjSize) * bc->num_active_tokens(); - int q_block_size = m->qProjSize; - int kt_block_size = m->kProjSize; - int kt_req_block_size = - kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length() + - BatchConfig::max_spec_tree_token_num(); - int vt_block_size = m->vProjSize; - int vt_req_block_size = - vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length() + - BatchConfig::max_spec_tree_token_num(); - assert(m->qProjSize == m->kProjSize); - - for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i]) { - continue; - } - assert(processed_tokens_in_batch == - bc->requestsInfo[i].first_token_offset_in_batch); - int last_token_idx_of_the_request = - processed_tokens_in_batch + bc->requestsInfo[i].num_tokens_in_batch - 1; - while (processed_tokens_in_batch <= last_token_idx_of_the_request) { - int num_new_tokens = 1; - int j = processed_tokens_in_batch; - while ((j + 1 <= last_token_idx_of_the_request) && - (bc->tokensInfo[j].abs_depth_in_request + 1 == - bc->tokensInfo[j + 1].abs_depth_in_request)) { - j++; - num_new_tokens++; - } - - int total_tokens_in_request = bc->tokensInfo[j].abs_depth_in_request + 1; - assert(num_new_tokens >= 1 && total_tokens_in_request >= num_new_tokens); - { - // update K-V cache - int parallelism = m->hidden_size * KV_WEIGHT_NUM * num_new_tokens; - update_tree_branch_kv_cache<<>>( - static_cast
(m->devQKVProjArray), - static_cast
(m->keyCache), - static_cast
(m->valueCache), - m->token_infos, - m->qProjSize, - m->kProjSize, - m->vProjSize, - num_new_tokens, // num_tokens_in_branch - processed_tokens_in_batch, // num_processed_tokens_in_batch - m->num_active_infr_tokens, // total_tokens_in_batch - BatchConfig::max_sequence_length(), - m->hidden_size); - } - - // bc->token_last_available_idx[i] + 1; - // Compute (QK^T/sqrt(d_k)) - int m_ = num_new_tokens; - int n = total_tokens_in_request; - int k = m->qProjSize; - int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, - ldc = m_; - int strideA = q_block_size; - int strideB = kt_block_size; - int strideC = num_new_tokens * total_tokens_in_request; - - // a flag of using this scaling alpha - DT alpha = 1.0f, beta = 0.0f; - if (*m->qk_prod_scaling) { - alpha = static_cast
(1.0f / sqrt(m->kProjSize)); - } - // To get A, skip over Q entries from previous requests (same head) - DT const *A = static_cast
(m->devQKVProjArray) + - processed_tokens_in_batch * m->qProjSize * m->num_q_heads * - QKV_WEIGHT_NUM; - // To get B, skip over K entries from previous requests (all heads + - // padding) - DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; - // To get C, skip over QK^T products from previous requests - DT *C = static_cast
(m->qk_prods); - - checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - // add alibi position bias to qk production - // add alibi position bias to qk production - if (*m->position_bias) { - size_t parallelism = - m->num_q_heads * total_tokens_in_request * num_new_tokens; - apply_position_bias_qkprd<<>>(C, - num_new_tokens, - total_tokens_in_request, - m->num_q_heads, - m->global_num_q_heads, - shard_id); - } - - // Fill all elements above diagonal in qk prods with -inf to force - // causal attention. - assert(num_new_tokens <= total_tokens_in_request); - if (num_new_tokens > 1) { - size_t parallelism = - m->num_q_heads * num_new_tokens * total_tokens_in_request; - tree_fill_entries_above_diagonal<<>>( - C, - num_new_tokens, - total_tokens_in_request, - m->num_q_heads, - static_cast
(-INFINITY)); - } - // Compute Softmax(QK^T/sqrt(d_k)) - // Before modifying the parameters below, make sure to read the following - // description of the CUDNN_TENSOR_NCHW tensor layout, from - // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t: - // This tensor format specifies that the data is laid out in the following - // order: batch size, feature maps, rows, columns. The strides are - // implicitly defined in such a way that the data are contiguous in memory - // with no padding between images, feature maps, rows, and columns; the - // columns are the inner dimension and the images are the outermost - // dimension. - int n_param = m->num_q_heads; - int c_param = total_tokens_in_request; - int h_param = 1; - int w_param = num_new_tokens; - checkCUDNN(cudnnSetTensor4dDescriptor(m->qk_tensor, - CUDNN_TENSOR_NCHW, - cudnn_data_type, - n_param, - c_param, - h_param, - w_param)); - float softmax_alpha = 1.0f, softmax_beta = 0.0f; - DT *C_softmax = static_cast
(m->qk_prods_softmax); - // The softmax operation below is executed according to the - // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The - // softmax operation is computed per spatial location (H,W) per image (N) - // across dimension C. - checkCUDNN(cudnnSoftmaxForward(m->handle.dnn, - CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_CHANNEL, - &softmax_alpha, - m->qk_tensor, - C, - &softmax_beta, - m->qk_tensor, - C_softmax)); - // Matmul softmax(QK^T/sqrt(d_k)) by V - alpha = 1.0f, beta = 0.0f; - m_ = m->vProjSize; - n = num_new_tokens; - k = total_tokens_in_request; - lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; - strideA = vt_block_size; - strideB = num_new_tokens * total_tokens_in_request; - strideC = m->vProjSize; - // To get A, skip over V^T entries from previous requests (all heads + - // padding) - A = static_cast
(m->valueCache) + i * vt_req_block_size; - // To get B, skip over softmax(QK^T/sqrt(d_k)) entries from previous - // requests (all heads) - B = C_softmax; - // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous - // requests - C = static_cast
(m->attn_heads) + - processed_tokens_in_batch * m->num_q_heads * m->vProjSize; - checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_T, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - processed_tokens_in_batch += num_new_tokens; - } - // Before moving to the next request - // check that we have finished all tokens of the request - assert(last_token_idx_of_the_request + 1 == processed_tokens_in_batch); - } - // Project to output, save result directly on output tensor - DT alpha = 1.0f, beta = 0.0f; - int m_ = m->oProjSize; - int k = m->vProjSize * m->num_q_heads; - int n = processed_tokens_in_batch; - int lda = k, ldb = k, ldc = m_; - DT const *A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads + - m->kProjSize * m->num_q_heads + - m->vProjSize * m->num_q_heads); - DT const *B = static_cast
(m->attn_heads); - DT *C = static_cast
(output_ptr); - - checkCUDA(cublasGemmEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - B, - cublas_data_type, - ldb, - &beta, - C, - cublas_data_type, - ldc, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - if (*m->final_bias && shard_id == 0) { - int parallelism = m->oProjSize * processed_tokens_in_batch; - int qkv_weight_size = m->qProjSize * m->global_num_q_heads + - m->kProjSize * m->global_num_q_heads + - m->vProjSize * m->global_num_q_heads; - apply_proj_bias_w<<>>(output_ptr, - bias_ptr, - processed_tokens_in_batch, - qkv_weight_size, - m->oProjSize); - } - - assert(processed_tokens_in_batch == bc->num_active_infr_tokens()); -} - #define LAUNCH_TREE_VERIFY_ATTENTION_SCORE_KERNEL( \ DT, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ smem_size_in_bytes_tree
(m->qProjSize, \ @@ -873,27 +576,9 @@ template void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, TreeVerifyBatchConfig const *bc, int shard_id, - DT const *input_ptr, - DT const *weight_ptr, + DT const *qkv_ptr, DT *output_ptr, - DT const *bias_ptr, cudaStream_t stream) { - // additional processing for weight uploading - if (m->handle.offload_reserve_space != nullptr) { - // Note that we update weight_ptr and bias_ptr when uploading weight and - // bias - cudaMemcpyAsync(m->weight_ptr, - weight_ptr, - m->weightSize, - cudaMemcpyHostToDevice, - stream); - weight_ptr = static_cast
(m->weight_ptr); - if (m->biasSize > 0) { - cudaMemcpyAsync( - m->bias_ptr, bias_ptr, m->biasSize, cudaMemcpyHostToDevice, stream); - bias_ptr = static_cast
(m->bias_ptr); - } - } // copy committed tokens info to GPU for the commit_tokens kernel // Note that m->num_active_infr_tokens stores the number of active @@ -908,39 +593,36 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, // tokens for the current batch m->num_active_infr_tokens = bc->num_active_infr_tokens(); - // here because we need postion info in infernece 1 - if (m->offload && m->biasSize > 0) { - cudaMemcpyAsync( - m->bias_ptr, bias_ptr, m->biasSize, cudaMemcpyHostToDevice, stream); - bias_ptr = static_cast
(m->bias_ptr); - } + // phase 0: copy calculated qkv into devQKVProjArray + // [qProjSize, num_heads, 3, num_new_tokens] + size_t qkv_proj_size = + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); + + cudaMemcpyAsync(m->devQKVProjArray, + qkv_ptr, + qkv_proj_size * + sizeof(DT), // is this right, do we need layers etc here + cudaMemcpyDeviceToDevice, + stream); + // phase 1: Implement kernel to compute KQV for input tokens - compute_qkv_kernel(m, - bc, - shard_id, - input_ptr, - weight_ptr, - static_cast
(m->devQKVProjArray), - bias_ptr, - stream); + // TODO WARNING: this is commented out only because we are fixing the inc_attn + // first + compute_qkv_kernel( + m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); // phase 2: No need to update key/val cache - // IncMultiHeadSelfAttention::update_kv_cache_kernel( - // m, bc, stream); - // use the new kernel compute_attention_kernel_fused
( m, bc, static_cast
(m->attn_heads), stream); int processed_tokens_in_batch = bc->num_active_tokens(); - compute_o_prod_bias(m, - bc, - shard_id, - output_ptr, - weight_ptr, - bias_ptr, - processed_tokens_in_batch, - stream); + int num_tokens = bc->num_active_tokens(); + cudaMemcpyAsync(output_ptr, + m->attn_heads, + m->oProjSize * num_tokens * sizeof(DT), + cudaMemcpyDeviceToDevice, + stream); } } // namespace TreeIncMultiHeadAttention @@ -952,12 +634,9 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( TreeVerifyBatchConfig const *bc, int shard_id, GenericTensorAccessorR const &input, - GenericTensorAccessorR const &weight, - GenericTensorAccessorW const &output, - GenericTensorAccessorR const &bias) { + GenericTensorAccessorW const &output) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - bool use_bias = *m->qkv_bias || *m->final_bias; cudaEvent_t t_start, t_end; if (m->profiling) { @@ -966,44 +645,14 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( cudaEventRecord(t_start, stream); } - // assert(input.data_type == weight.data_type); assert(input.data_type == output.data_type); - if (use_bias) { - assert(input.data_type == bias.data_type); - } if (input.data_type == DT_HALF) { - if (m->offload) { - pre_build_weight_kernel(m, weight, input.data_type, stream); - } - - half const *bias_ptr = - use_bias ? bias.get_half_ptr() : static_cast(nullptr); Kernels::TreeIncMultiHeadAttention::inference_kernel( - m, - bc, - shard_id, - input.get_half_ptr(), - m->offload ? static_cast(m->weight_ptr) : weight.get_half_ptr(), - output.get_half_ptr(), - bias_ptr, - stream); + m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream); } else if (input.data_type == DT_FLOAT) { - if (m->offload) { - pre_build_weight_kernel(m, weight, input.data_type, stream); - } - float const *bias_ptr = - use_bias ? bias.get_float_ptr() : static_cast(nullptr); Kernels::TreeIncMultiHeadAttention::inference_kernel( - m, - bc, - shard_id, - input.get_float_ptr(), - m->offload ? static_cast(m->weight_ptr) - : weight.get_float_ptr(), - output.get_float_ptr(), - bias_ptr, - stream); + m, bc, shard_id, input.get_float_ptr(), output.get_float_ptr(), stream); } else { assert(false && "Unspported data type"); } @@ -1021,7 +670,6 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( FFHandler handler, TreeIncMultiHeadSelfAttention const *attn, - GenericTensorAccessorR const &weight, MemoryAllocator &gpu_mem_allocator, int num_samples, int _num_q_heads, @@ -1036,14 +684,11 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( attn->kProjSize, attn->vProjSize, attn->oProjSize, - attn->apply_rotary_embedding, - attn->qkv_bias, + attn->rotary_embedding_meta, attn->scaling_query, attn->qk_prod_scaling, attn->position_bias, - attn->final_bias, attn->scaling_factor, - weight, gpu_mem_allocator, num_samples, attn->num_q_heads, diff --git a/src/parallel_ops/allreduce.cc b/src/parallel_ops/allreduce.cc index dc43d80133..a4443c4066 100644 --- a/src/parallel_ops/allreduce.cc +++ b/src/parallel_ops/allreduce.cc @@ -73,7 +73,7 @@ AllReduce::AllReduce(FFModel &model, for (int i = 0; i < numdim; i++) { dims[i] = _input->dims[i]; } - assert(dims[allreduce_dim].degree > 1); + // assert(dims[allreduce_dim].degree > 1); // ParallelTensorBase::update_parallel_ids(numdim, dims); outputs[0] = model.create_parallel_tensor_legion_ordering( numdim, dims, _input->data_type, this); diff --git a/src/runtime/file_loader.cc b/src/runtime/file_loader.cc index c373e0da9b..e73893475c 100644 --- a/src/runtime/file_loader.cc +++ b/src/runtime/file_loader.cc @@ -80,51 +80,56 @@ std::string removeGuidOperatorName(std::string const &input) { } template -void load_attention_weights_multi_query(DT *ptr, - std::string layer_name, - std::string weights_folder, - size_t hidden_dim, - int num_heads) { - - std::string qkv_file = layer_name.substr(0, layer_name.find("attention")) + - "attention_query_key_value_weight"; - std::string o_file = layer_name.substr(0, layer_name.find("attention")) + - "attention_dense_weight"; +void load_attention_o_proj_bias_to_dense_v2(DT *ptr, + int num_heads, + int num_kv_heads, + size_t hidden_dim, + size_t qkv_inner_dim, + std::string layer_name, + std::string weights_folder) { + std::string filename = layer_name + ".o_proj.bias"; - // q has n_heads heads, k and v only have one head, o have n_head heads - std::vector weight_filenames = {qkv_file, o_file}; int file_index = 0; - int data_index = 0; - for (auto filename : weight_filenames) { - std::cout << "Loading weight file " << filename << std::endl; - std::string weight_filepath = join_path({weights_folder, filename}); - size_t partial_size = - file_index == 0 ? (hidden_dim + 2 * hidden_dim / num_heads) * hidden_dim - : hidden_dim * hidden_dim; - std::ifstream in(weight_filepath, std::ios::in | std::ios::binary); - // std::cout << "Loading filename: " << weight_filepath << std::endl; - if (!in.good()) { - std::cout << "Could not open file: " << weight_filepath << std::endl; - } - assert(in.good() && "incorrect weight file path"); - std::vector
host_array(partial_size); - size_t loaded_data_size = sizeof(DT) * partial_size; - in.seekg(0, in.end); - in.seekg(0, in.beg); - in.read((char *)host_array.data(), loaded_data_size); - size_t in_get_size = in.gcount(); + // now only opt use this. + // assert(num_heads == num_kv_heads); + int idx = 0; - if (in_get_size != loaded_data_size) { - std::cout << "load data error " << in_get_size << ", " - << loaded_data_size; - assert(false && "data size mismatch"); - } - for (int i = 0; i < partial_size; i++) { - ptr[data_index++] = host_array.at(i); - } - file_index++; + std::cout << "Loading weight file " << filename << std::endl; + std::string weight_filepath = join_path({weights_folder, filename}); + + int n_heads = num_heads; + + int replicate_num = num_heads / num_kv_heads; + + size_t out_partial_size = hidden_dim; + size_t partial_size = out_partial_size; + std::ifstream in(weight_filepath, std::ios::in | std::ios::binary); + assert(in.good() && "incorrect bias file path"); + std::vector
host_array(partial_size); + size_t loaded_data_size = sizeof(DT) * partial_size; + in.seekg(0, in.end); + in.seekg(0, in.beg); + in.read((char *)host_array.data(), loaded_data_size); + size_t in_get_size = in.gcount(); + + if (in_get_size != loaded_data_size) { + printf( + "load bias data error: in_get_size (%lu) != loaded_data_size (%lu)\n", + in_get_size, + loaded_data_size); + assert(false); } + assert(partial_size == host_array.size()); + + size_t data_index = 0; + + for (int i = 0; i < partial_size; i++) { + ptr[i] = host_array.at(data_index); + data_index++; + } + + in.close(); } template @@ -135,44 +140,53 @@ void load_attention_bias_v2(DT *ptr, size_t qkv_inner_dim, bool final_bias, std::string layer_name, - std::string weights_folder) { + std::string weights_folder, + int tp_degree) { std::string q_file = layer_name + ".q_proj.bias"; std::string k_file = layer_name + ".k_proj.bias"; std::string v_file = layer_name + ".v_proj.bias"; std::vector bias_files = {q_file, k_file, v_file}; - if (final_bias) { - std::string o_file = layer_name + ".o_proj.bias"; - bias_files.push_back(o_file); - } - int file_index = 0; - - // now only opt use this. - // assert(num_heads == num_kv_heads); - int idx = 0; + // linear layer weights: [output_size, input_size] + // bias layer weights: [output_size] + // Q,K,V projection weights: [head_dim*num_heads, hidden_size] = [768, 768] + // QKV bias weights: [head_dim*num_heads] = [768], organized as: [head_dim_0, + // head_dim_1, ...] + + // need to rearrange: [[q_heads_shard_0], [k_heads_shard_0], + // [v_heads_shard_0], ..., [q_heads_shard_n], [k_heads_shard_n], + // [v_heads_shard_n]] where n = tp_degree + assert(num_heads % tp_degree == 0); + assert(num_kv_heads % tp_degree == 0); + assert(hidden_dim % num_heads == 0); + assert(qkv_inner_dim == hidden_dim / num_heads); + size_t q_heads_per_shard = num_heads / tp_degree; + size_t kv_heads_per_shard = num_kv_heads / tp_degree; + size_t shard_chunk_size = + (q_heads_per_shard + 2 * kv_heads_per_shard) * qkv_inner_dim; + int file_index = 0; for (auto filename : bias_files) { std::cout << "Loading weight file " << filename << std::endl; std::string weight_filepath = join_path({weights_folder, filename}); int n_heads = file_index == 0 ? num_heads : num_kv_heads; - - int replicate_num = num_heads / num_kv_heads; - - size_t qkv_partial_size = qkv_inner_dim * n_heads; - size_t qkv_replicate_size = qkv_inner_dim * num_heads; - size_t out_partial_size = hidden_dim; - size_t partial_size = - (file_index < 3) ? qkv_partial_size : out_partial_size; + assert(n_heads % tp_degree == 0); + int heads_per_shard = n_heads / tp_degree; + int qkv_prev_heads_cur_shard = + (file_index == 2) ? num_heads + num_kv_heads : file_index * num_heads; + assert(qkv_prev_heads_cur_shard % tp_degree == 0); + qkv_prev_heads_cur_shard /= tp_degree; + + // load into memory first + size_t bias_size = qkv_inner_dim * n_heads; std::ifstream in(weight_filepath, std::ios::in | std::ios::binary); assert(in.good() && "incorrect bias file path"); - std::vector
host_array(partial_size); - size_t loaded_data_size = sizeof(DT) * partial_size; - in.seekg(0, in.end); + std::vector
host_array(bias_size); + size_t loaded_data_size = sizeof(DT) * bias_size; in.seekg(0, in.beg); in.read((char *)host_array.data(), loaded_data_size); size_t in_get_size = in.gcount(); - if (in_get_size != loaded_data_size) { printf( "load bias data error: in_get_size (%lu) != loaded_data_size (%lu)\n", @@ -180,43 +194,37 @@ void load_attention_bias_v2(DT *ptr, loaded_data_size); assert(false); } - assert(partial_size == host_array.size()); - - size_t data_index = 0; - - // q, o - if (file_index == 0 || file_index == 3) { - for (int i = 0; i < partial_size; i++) { - ptr[idx + i] = host_array.at(data_index); - data_index++; - } - } else { - // k, v - for (int i = 0; i < partial_size; i++) { - for (int j = 0; j < replicate_num; j++) { - ptr[idx + j * partial_size + i] = host_array.at(data_index); - } - data_index++; + assert(bias_size == host_array.size()); + + // now copy chunks into ptr + for (int i = 0; i < n_heads; i++) { + int shard_idx = i / heads_per_shard; + for (int j = 0; j < qkv_inner_dim; j++) { + int src_idx = i * qkv_inner_dim + j; + int dst_idx = shard_idx * shard_chunk_size + + qkv_prev_heads_cur_shard * qkv_inner_dim + + (i % heads_per_shard) * qkv_inner_dim + j; + ptr[dst_idx] = host_array.at(src_idx); } } - file_index++; - idx += qkv_replicate_size; - in.close(); } } template -void load_attention_weights_v2(DT *ptr, - int num_heads, - int num_kv_heads, - size_t hidden_dim, - size_t qkv_inner_dim, - std::string layer_name, - std::string weights_folder, - size_t volume, - int tensor_parallelism_degree) { +void load_attention_weights_to_dense_v2(DT *ptr, + int num_heads, + int num_kv_heads, + size_t hidden_dim, + size_t qkv_inner_dim, + std::string layer_name, + std::string weights_folder, + size_t volume, + int tensor_parallelism_degree, + bool load_o_proj) { + // layers_0_attention_wq_weight + // layers_0_self_attn_q_proj_weight std::string q_file = layer_name + ".q_proj.weight"; std::string k_file = layer_name + ".k_proj.weight"; std::string v_file = layer_name + ".v_proj.weight"; @@ -241,64 +249,64 @@ void load_attention_weights_v2(DT *ptr, int replicate_num = num_heads / num_kv_heads; // stride for q, k, v, o - size_t stride_size = (q_size + v_replicate_size + k_replicate_size + o_size) / + size_t stride_size = (q_size + v_replicate_size + k_replicate_size) / tensor_parallelism_degree; - for (auto filename : weight_filenames) { - std::cout << "Loading weight file " << filename << std::endl; - std::string weight_filepath = join_path({weights_folder, filename}); - - int data_index = 0; - size_t partial_size = (file_index == 0 || file_index == 3) - ? one_weight_file_size - : single_proj_size * num_kv_heads; - size_t one_partition_size = - one_weight_file_size / tensor_parallelism_degree; - - std::ifstream in(weight_filepath, std::ios::in | std::ios::binary); - if (!in.good()) { - std::cout << "Could not open file: " << weight_filepath << std::endl; - } - assert(in.good() && "incorrect weight file path"); - std::vector
host_array(partial_size); - size_t loaded_data_size = sizeof(DT) * partial_size; - in.seekg(0, in.end); - in.seekg(0, in.beg); - in.read((char *)host_array.data(), loaded_data_size); - size_t in_get_size = in.gcount(); + if (!load_o_proj) { + for (auto filename : weight_filenames) { + std::cout << "Loading weight file " << filename << std::endl; + std::string weight_filepath = join_path({weights_folder, filename}); + + int data_index = 0; + size_t partial_size = (file_index == 0 || file_index == 3) + ? one_weight_file_size + : single_proj_size * num_kv_heads; + size_t one_partition_size = + one_weight_file_size / tensor_parallelism_degree; + + std::ifstream in(weight_filepath, std::ios::in | std::ios::binary); + if (!in.good()) { + std::cout << "Could not open file: " << weight_filepath << std::endl; + } + assert(in.good() && "incorrect weight file path"); + std::vector
host_array(partial_size); + size_t loaded_data_size = sizeof(DT) * partial_size; + in.seekg(0, in.end); + in.seekg(0, in.beg); + in.read((char *)host_array.data(), loaded_data_size); + size_t in_get_size = in.gcount(); - if (in_get_size != loaded_data_size) { - std::cout << "load attention data error " << in_get_size << ", " - << loaded_data_size << ", " << file_index << ", " - << weight_filepath << "\n"; - assert(false && "data size mismatch"); - } - // wq, wk, wo - if (file_index == 0) { - for (int i = 0; i < tensor_parallelism_degree; i++) { - for (int j = 0; j < one_partition_size; j++) { - ptr[base_index + i * stride_size + j] = host_array.at(data_index++); - } + if (in_get_size != loaded_data_size) { + std::cout << "load attention data error " << in_get_size << ", " + << loaded_data_size << ", " << file_index << ", " + << weight_filepath << "\n"; + assert(false && "data size mismatch"); } - } else { - for (int i = 0; i < num_heads; i++) { - int kv_idx = i / (num_heads / num_kv_heads); - int head_idx = i % (num_heads / tensor_parallelism_degree); - int tp_idx = (i / (num_heads / tensor_parallelism_degree)); - for (int j = 0; j < single_proj_size; j++) { - ptr[base_index + tp_idx * stride_size + single_proj_size * head_idx + - j] = host_array.at(kv_idx * single_proj_size + j); + // wq, wk, wo + if (file_index == 0) { + for (int i = 0; i < tensor_parallelism_degree; i++) { + for (int j = 0; j < one_partition_size; j++) { + ptr[base_index + i * stride_size + j] = host_array.at(data_index++); + } + } + } else { + for (int i = 0; i < num_heads; i++) { + int kv_idx = i / (num_heads / num_kv_heads); + int head_idx = i % (num_heads / tensor_parallelism_degree); + int tp_idx = (i / (num_heads / tensor_parallelism_degree)); + for (int j = 0; j < single_proj_size; j++) { + ptr[base_index + tp_idx * stride_size + + single_proj_size * head_idx + j] = + host_array.at(kv_idx * single_proj_size + j); + } } } + // std::cout << "host array going out of scope, releasing" << endl; + base_index += one_partition_size; + file_index++; } - - // assert(data_index == partial_size); - base_index += one_partition_size; - file_index++; - } - assert(base_index == (q_size + k_replicate_size + v_replicate_size) / - tensor_parallelism_degree); - - { + assert(base_index == (q_size + k_replicate_size + v_replicate_size) / + tensor_parallelism_degree); + } else { std::cout << "Loading weight file " << o_file << std::endl; std::string weight_filepath = join_path({weights_folder, o_file}); @@ -314,6 +322,15 @@ void load_attention_weights_v2(DT *ptr, in.read((char *)host_array.data(), loaded_data_size); size_t in_get_size = in.gcount(); + DT temp; + + for (int i = 0; i < one_weight_file_size; i++) { + temp = host_array.at(i); + } + + // std::cout<<"o_proj loaded into host array, total size: + // "<name)); + bool is_attn_proj = false, is_o_proj = false; + + // dense layers for attention projection is named as + // self_attn.qkv_proj or self_attn.o_proj + // so looking for self_attn. in the name can determine if it is an attention + // projection + if (weight_filename.find("attn.") != std::string::npos || + weight_filename.find("self_attention.") != std::string::npos) { + size_t pos = weight_filename.find(".o_proj"); + if (pos != std::string::npos) { + weight_filename.replace(pos, std::string(".o_proj").length(), ""); + is_o_proj = true; + } else { + pos = weight_filename.find(".qkv_proj"); + if (pos == std::string::npos) { + cout << weight_filename << endl; + } + assert(pos != std::string::npos); + weight_filename.replace(pos, std::string(".qkv_proj").length(), ""); + } + is_attn_proj = true; + } if (ff->config.benchmarking) { std::cout << "Initializing weight " << weight_filename @@ -730,28 +773,51 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff, if (l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION || l->op_type == OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION || l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION) { - if (weight_idx == 0) { - load_attention_weights_v2(data, - num_heads, - num_kv_heads, - hidden_dim, - qkv_inner_dim, - weight_filename, - weights_folder, - volume, - tensor_parallelism_degree); + } else if (is_attn_proj) { + if (is_o_proj) { + if (weight_idx == 0) { + load_attention_weights_to_dense_v2(data, + num_heads, + num_kv_heads, + hidden_dim, + qkv_inner_dim, + weight_filename, + weights_folder, + volume, + tensor_parallelism_degree, + true); + } else { + load_attention_o_proj_bias_to_dense_v2(data, + num_heads, + num_kv_heads, + hidden_dim, + qkv_inner_dim, + weight_filename, + weights_folder); + } } else { - long long value; - l->get_int_property("final_bias", value); - bool final_bias = (bool)value; - load_attention_bias_v2(data, - num_heads, - num_kv_heads, - hidden_dim, - qkv_inner_dim, - final_bias, - weight_filename, - weights_folder); + if (weight_idx == 0) { + load_attention_weights_to_dense_v2(data, + num_heads, + num_kv_heads, + hidden_dim, + qkv_inner_dim, + weight_filename, + weights_folder, + volume, + tensor_parallelism_degree, + false); + } else { + load_attention_bias_v2(data, + num_heads, + num_kv_heads, + hidden_dim, + qkv_inner_dim, + false, // do not load o_proj bias + weight_filename, + weights_folder, + tensor_parallelism_degree); + } } } else if (l->op_type == OP_ADD_BIAS_RESIDUAL_LAYERNORM) { assert(weight_idx >= 0 || weight_idx <= 2); diff --git a/src/runtime/graph.cc b/src/runtime/graph.cc index 1a38782e81..2bc64c1670 100644 --- a/src/runtime/graph.cc +++ b/src/runtime/graph.cc @@ -2331,10 +2331,17 @@ GraphOptimalViewSerialized sez.serialize(attn->qProjSize); sez.serialize(attn->vProjSize); sez.serialize(attn->dropout); - sez.serialize(attn->qkv_bias); - sez.serialize(attn->final_bias); sez.serialize(attn->add_zero_attn); - sez.serialize(attn->apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.rope_theta); + sez.serialize(attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.rope_type.c_str(), + attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.factor); + sez.serialize(attn->rotary_embedding_meta.low_freq_factor); + sez.serialize(attn->rotary_embedding_meta.high_freq_factor); + sez.serialize( + attn->rotary_embedding_meta.original_max_position_embeddings); sez.serialize(attn->scaling_query); sez.serialize(attn->scaling_factor); sez.serialize(attn->qk_prod_scaling); @@ -2358,10 +2365,17 @@ GraphOptimalViewSerialized sez.serialize(attn->qProjSize); sez.serialize(attn->vProjSize); sez.serialize(attn->dropout); - sez.serialize(attn->qkv_bias); - sez.serialize(attn->final_bias); sez.serialize(attn->add_zero_attn); - sez.serialize(attn->apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.rope_theta); + sez.serialize(attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.rope_type.c_str(), + attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.factor); + sez.serialize(attn->rotary_embedding_meta.low_freq_factor); + sez.serialize(attn->rotary_embedding_meta.high_freq_factor); + sez.serialize( + attn->rotary_embedding_meta.original_max_position_embeddings); sez.serialize(attn->scaling_query); sez.serialize(attn->scaling_factor); sez.serialize(attn->qk_prod_scaling); @@ -2382,10 +2396,17 @@ GraphOptimalViewSerialized sez.serialize(attn->qProjSize); sez.serialize(attn->vProjSize); sez.serialize(attn->dropout); - sez.serialize(attn->qkv_bias); - sez.serialize(attn->final_bias); sez.serialize(attn->add_zero_attn); - sez.serialize(attn->apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.rope_theta); + sez.serialize(attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.rope_type.c_str(), + attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.factor); + sez.serialize(attn->rotary_embedding_meta.low_freq_factor); + sez.serialize(attn->rotary_embedding_meta.high_freq_factor); + sez.serialize( + attn->rotary_embedding_meta.original_max_position_embeddings); sez.serialize(attn->scaling_query); sez.serialize(attn->scaling_factor); sez.serialize(attn->qk_prod_scaling); @@ -2817,8 +2838,9 @@ void FFModel::deserialize_graph_optimal_view( int embed_dim, num_q_heads, k_dim, v_dim, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, offload, position_bias; + bool add_zero_attn, scaling_query, qk_prod_scaling, offload, + position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; DataType quantization_type; size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); @@ -2830,10 +2852,18 @@ void FFModel::deserialize_graph_optimal_view( dez.deserialize(k_dim); dez.deserialize(v_dim); dez.deserialize(dropout); - dez.deserialize(qkv_bias); - dez.deserialize(final_bias); dez.deserialize(add_zero_attn); - dez.deserialize(apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.rope_theta); + size_t rope_type_len; + char rope_type[1024] = {0}; + dez.deserialize(rope_type_len); + dez.deserialize(rope_type, rope_type_len); + rotary_embedding_meta.rope_type = std::string(rope_type); + dez.deserialize(rotary_embedding_meta.factor); + dez.deserialize(rotary_embedding_meta.low_freq_factor); + dez.deserialize(rotary_embedding_meta.high_freq_factor); + dez.deserialize(rotary_embedding_meta.original_max_position_embeddings); dez.deserialize(scaling_query); dez.deserialize(scaling_factor); dez.deserialize(qk_prod_scaling); @@ -2853,11 +2883,9 @@ void FFModel::deserialize_graph_optimal_view( params.kdim = k_dim; params.vdim = v_dim; params.dropout = dropout; - params.qkv_bias = qkv_bias; - params.final_bias = final_bias; params.add_zero_attn = add_zero_attn; params.layer_guid = layer_guid; - params.apply_rotary_embedding = apply_rotary_embedding; + params.rotary_embedding_meta = rotary_embedding_meta; params.scaling_query = scaling_query; params.scaling_factor = scaling_factor; params.qk_prod_scaling = qk_prod_scaling; @@ -2874,8 +2902,8 @@ void FFModel::deserialize_graph_optimal_view( assert(num_inputs == 1); int embed_dim, num_q_heads, k_dim, v_dim, num_kv_heads; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, position_bias; + bool add_zero_attn, scaling_query, qk_prod_scaling, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); @@ -2886,10 +2914,18 @@ void FFModel::deserialize_graph_optimal_view( dez.deserialize(k_dim); dez.deserialize(v_dim); dez.deserialize(dropout); - dez.deserialize(qkv_bias); - dez.deserialize(final_bias); dez.deserialize(add_zero_attn); - dez.deserialize(apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.rope_theta); + size_t rope_type_len; + char rope_type[1024] = {0}; + dez.deserialize(rope_type_len); + dez.deserialize(rope_type, rope_type_len); + rotary_embedding_meta.rope_type = std::string(rope_type); + dez.deserialize(rotary_embedding_meta.factor); + dez.deserialize(rotary_embedding_meta.low_freq_factor); + dez.deserialize(rotary_embedding_meta.high_freq_factor); + dez.deserialize(rotary_embedding_meta.original_max_position_embeddings); dez.deserialize(scaling_query); dez.deserialize(scaling_factor); dez.deserialize(qk_prod_scaling); @@ -2906,11 +2942,9 @@ void FFModel::deserialize_graph_optimal_view( params.kdim = k_dim; params.vdim = v_dim; params.dropout = dropout; - params.qkv_bias = qkv_bias; - params.final_bias = final_bias; params.add_zero_attn = add_zero_attn; params.layer_guid = layer_guid; - params.apply_rotary_embedding = apply_rotary_embedding; + params.rotary_embedding_meta = rotary_embedding_meta; params.scaling_query = scaling_query; params.scaling_factor = scaling_factor; params.qk_prod_scaling = qk_prod_scaling; @@ -2926,8 +2960,9 @@ void FFModel::deserialize_graph_optimal_view( int embed_dim, num_q_heads, k_dim, v_dim, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, offload, position_bias; + bool add_zero_attn, scaling_query, qk_prod_scaling, offload, + position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; DataType quantization_type; size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); @@ -2939,10 +2974,18 @@ void FFModel::deserialize_graph_optimal_view( dez.deserialize(k_dim); dez.deserialize(v_dim); dez.deserialize(dropout); - dez.deserialize(qkv_bias); - dez.deserialize(final_bias); dez.deserialize(add_zero_attn); - dez.deserialize(apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.rope_theta); + size_t rope_type_len; + char rope_type[1024] = {0}; + dez.deserialize(rope_type_len); + dez.deserialize(rope_type, rope_type_len); + rotary_embedding_meta.rope_type = std::string(rope_type); + dez.deserialize(rotary_embedding_meta.factor); + dez.deserialize(rotary_embedding_meta.low_freq_factor); + dez.deserialize(rotary_embedding_meta.high_freq_factor); + dez.deserialize(rotary_embedding_meta.original_max_position_embeddings); dez.deserialize(scaling_query); dez.deserialize(scaling_factor); dez.deserialize(qk_prod_scaling); @@ -2962,11 +3005,9 @@ void FFModel::deserialize_graph_optimal_view( params.kdim = k_dim; params.vdim = v_dim; params.dropout = dropout; - params.qkv_bias = qkv_bias; - params.final_bias = final_bias; params.add_zero_attn = add_zero_attn; params.layer_guid = layer_guid; - params.apply_rotary_embedding = apply_rotary_embedding; + params.rotary_embedding_meta = rotary_embedding_meta; params.scaling_query = scaling_query; params.scaling_factor = scaling_factor; params.qk_prod_scaling = qk_prod_scaling; diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index 1b65dfd869..f39ea91f28 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -800,6 +800,7 @@ void FFModel::compile_inference() { false /*must*/, 0 /*mapper_id*/, view.hash() /*MappingTagID*/); + index_launcher.concurrent = true; FutureMap fm = runtime->execute_index_space(ctx, index_launcher); fm.wait_all_results(); int idx = 0; diff --git a/src/runtime/layer.cc b/src/runtime/layer.cc index 8f33f6db87..72e71688c1 100644 --- a/src/runtime/layer.cc +++ b/src/runtime/layer.cc @@ -87,6 +87,11 @@ void Layer::add_int_vector_property(std::string const &key, int_vector_properties[key] = value; } +void Layer::add_string_property(std::string const &key, + std::string const &value) { + string_properties[key] = value; +} + void Layer::add_initializer(std::string const &key, Initializer *initializer) { initializers[key] = initializer; } @@ -125,6 +130,18 @@ bool Layer::get_int_vector_property(std::string const &key, } } +bool Layer::get_string_property(std::string const &key, + std::string &value) const { + auto const &it = string_properties.find(key); + if (it == string_properties.end()) { + assert(false); + return false; + } else { + value = it->second; + return true; + } +} + bool Layer::get_initializer(std::string const &key, Initializer *&initializer) const { auto const &it = initializers.find(key); diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 52f1dd2220..69fe3b598d 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -1156,16 +1156,25 @@ bool Op::check_output_input_weight_same_parallel_is() const { IndexSpace parallel_is = outputs[0]->parallel_is; for (int i = 0; i < numOutputs; i++) { if (outputs[i]->parallel_is != parallel_is) { + std::cout << "outputs[" << i << "] has different parallel_is " + << outputs[i]->parallel_is << " than output[0] " << parallel_is + << std::endl; return false; } } for (int i = 0; i < numInputs; i++) { if (inputs[i]->parallel_is != parallel_is) { + std::cout << "inputs[" << i << "] has different parallel_is " + << inputs[i]->parallel_is << " than output[0] " << parallel_is + << std::endl; return false; } } for (int i = 0; i < numWeights; i++) { if (weights[i]->parallel_is != parallel_is) { + std::cout << "weights[" << i << "] has different parallel_is " + << weights[i]->parallel_is << " than output[0] " << parallel_is + << std::endl; return false; } } @@ -3414,26 +3423,28 @@ bool FFModel::need_to_add_allreduce(int layer_idx) const { auto const &l = layers[layer_idx]; if (config.computationMode == COMP_MODE_INFERENCE && config.tensor_parallelism_degree > 1 && - (l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION || - l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION || - // mlp layer - is_mlp_block(layer_idx) || - // llama mlp layer - (l->op_type == OP_LINEAR && layer_idx >= 2 && - layers[layer_idx - 1]->op_type == OP_GELU && - layers[layer_idx - 2]->op_type == OP_LINEAR) || - // LLAMA without element-wise operator fusion - (l->op_type == OP_LINEAR && layer_idx >= 5 && - layers[layer_idx - 1]->op_type == OP_EW_MUL && - layers[layer_idx - 2]->op_type == OP_EW_MUL && - layers[layer_idx - 3]->op_type == OP_SIGMOID && - layers[layer_idx - 4]->op_type == OP_LINEAR && - layers[layer_idx - 5]->op_type == OP_LINEAR) || - // LLAMA with element-wise operator fusion - (l->op_type == OP_LINEAR && layer_idx >= 3 && - layers[layer_idx - 1]->op_type == OP_SIGMOID_SILU_MULTI && - layers[layer_idx - 2]->op_type == OP_LINEAR && - layers[layer_idx - 3]->op_type == OP_LINEAR))) { + ( + // l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION || + // l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION || + (std::string(l->name).find("attn.o_proj") != std::string::npos) || + // mlp layer + is_mlp_block(layer_idx) || + // llama mlp layer + (l->op_type == OP_LINEAR && layer_idx >= 2 && + layers[layer_idx - 1]->op_type == OP_GELU && + layers[layer_idx - 2]->op_type == OP_LINEAR) || + // LLAMA without element-wise operator fusion + (l->op_type == OP_LINEAR && layer_idx >= 5 && + layers[layer_idx - 1]->op_type == OP_EW_MUL && + layers[layer_idx - 2]->op_type == OP_EW_MUL && + layers[layer_idx - 3]->op_type == OP_SIGMOID && + layers[layer_idx - 4]->op_type == OP_LINEAR && + layers[layer_idx - 5]->op_type == OP_LINEAR) || + // LLAMA with element-wise operator fusion + (l->op_type == OP_LINEAR && layer_idx >= 3 && + layers[layer_idx - 1]->op_type == OP_SIGMOID_SILU_MULTI && + layers[layer_idx - 2]->op_type == OP_LINEAR && + layers[layer_idx - 3]->op_type == OP_LINEAR))) { return true; } return false; diff --git a/src/runtime/operator.cc b/src/runtime/operator.cc index dcac52397a..d5bfcfc48e 100644 --- a/src/runtime/operator.cc +++ b/src/runtime/operator.cc @@ -2,6 +2,7 @@ #include "flexflow/ffconst_utils.h" #include "flexflow/simulator.h" #include +#include #include namespace FlexFlow { @@ -29,7 +30,15 @@ fs::path get_dst_folder(std::string const &subdir, if (before_kernel) { step_substr += "_pre"; } + char cwd[PATH_MAX]; + getcwd(cwd, sizeof(cwd)); + + // char const *ff_cache_path = std::string(std::getenv("FF_DEBUG_PATH")) == + // "." ? + // cwd : std::getenv("FF_DEBUG_PATH"); + char const *ff_cache_path = std::getenv("FF_CACHE_PATH"); + std::string debug_dir_ = ff_cache_path ? std::string(ff_cache_path) + "/debug/flexflow" : std::string("~/.cache/flexflow/debug/flexflow"); @@ -38,6 +47,9 @@ fs::path get_dst_folder(std::string const &subdir, debug_dir_ = p.we_wordv[0]; wordfree(&p); fs::path debug_dir = debug_dir_; + if (!fs::is_directory(debug_dir)) { + printf("invalid debug directory: %s\n", debug_dir.c_str()); + } assert(fs::is_directory(debug_dir)); fs::path dst_folder = debug_dir / subdir / step_substr / ("shard_" + std::to_string(shard_idx)); diff --git a/src/runtime/substitution.cc b/src/runtime/substitution.cc index 9b6510fe5e..0e28c02cdf 100644 --- a/src/runtime/substitution.cc +++ b/src/runtime/substitution.cc @@ -3734,15 +3734,14 @@ bool FFModel::convert_graph_to_operators( case OP_INC_MULTIHEAD_SELF_ATTENTION: { assert(inList.size() == 1); IncMultiHeadSelfAttention *attn = (IncMultiHeadSelfAttention *)node.ptr; - new_op = new IncMultiHeadSelfAttention(*this, *attn, inputs[0], true); + new_op = new IncMultiHeadSelfAttention(*this, *attn, inputs[0]); break; } case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION: { assert(inList.size() == 1); TreeIncMultiHeadSelfAttention *attn = (TreeIncMultiHeadSelfAttention *)node.ptr; - new_op = - new TreeIncMultiHeadSelfAttention(*this, *attn, inputs[0], true); + new_op = new TreeIncMultiHeadSelfAttention(*this, *attn, inputs[0]); break; } case OP_RMS_NORM: { diff --git a/tests/fine_grained_alignment_test.sh b/tests/fine_grained_alignment_test.sh new file mode 100755 index 0000000000..9ad26318f9 --- /dev/null +++ b/tests/fine_grained_alignment_test.sh @@ -0,0 +1,106 @@ +#! /usr/bin/env bash +set -x +set -e + +MODEL_NAME=${MODEL_NAME:-"JackFram/llama-160m"} +MEMORY_PER_GPU=${MEMORY_PER_GPU:-14000} +ZCOPY_MEMORY=${ZCOPY_MEMORY:-40000} +TP_DEGREE=${TP_DEGREE:-2} +PP_DEGREE=${PP_DEGREE:-2} +CACHE_PATH=${FF_CACHE_PATH:-"~/.cache/flexflow"} +NUM_STEPS=${NUM_STEPS:-2} + +cleanup() { + rm -rf "${CACHE_PATH}"/debug ./fine_grained_alignment_config.json ./inference/output/fine_grained_alignment_test_ff.txt ./inference/output/fine_grained_alignment_test_hf.txt +} + +# Cd into directory holding this script +cd "${BASH_SOURCE[0]%/*}/.." + +# Initial cleanup +cleanup + +# Create test prompt file +mkdir -p ./inference/prompt +echo '["Three tips for staying healthy are: "]' > ./inference/prompt/test.json + +# Create output folder +mkdir -p ./inference/output + +# Enable backtrace in case we run into a segfault or assertion failure +export LEGION_BACKTRACE=1 +export FF_DEBG_NO_WEIGHTS=1 +FUSION=true + + +# Check if the Python code executed successfully +if ! PROMPT_LENGTH=$(python -c " +from transformers import AutoTokenizer +import os +tokenizer = AutoTokenizer.from_pretrained(\"$MODEL_NAME\") +tokens = tokenizer.tokenize('Three tips for staying healthy are: ') +print(len(tokens)) +"); +then + echo "Error: Failed to execute Python code" + exit 1 +fi + +MAX_LENGTH=$((PROMPT_LENGTH + NUM_STEPS + 1)) + +python ./tests/inference/huggingface_inference.py \ + --model-name "${MODEL_NAME}" \ + --max-length "${MAX_LENGTH}" \ + --prompt-file ../../inference/prompt/test.json \ + --output-file ../../inference/output/fine_grained_alignment_test_hf.txt \ + --use-full-precision \ + --inference-debugging + +NUM_GPUS=$((TP_DEGREE * PP_DEGREE)) +json_config=$(cat <<-END + { + "num_gpus": ${NUM_GPUS}, + "memory_per_gpu": ${MEMORY_PER_GPU}, + "zero_copy_memory_per_node": ${ZCOPY_MEMORY}, + "num_cpus": 4, + "legion_utility_processors": 4, + "data_parallelism_degree": 1, + "tensor_parallelism_degree": ${TP_DEGREE}, + "pipeline_parallelism_degree": ${PP_DEGREE}, + "inference_debugging": true, + "fusion": ${FUSION}, + "refresh_cache": false, + "llm_model": "${MODEL_NAME}", + "cache_path": "${CACHE_PATH}", + "full_precision": true, + "prompt": "./inference/prompt/test.json", + "max_length": $MAX_LENGTH, + "output_file": "./inference/output/fine_grained_alignment_test_ff.txt" + } +END +) +echo "$json_config" > ./fine_grained_alignment_config.json + +python ./inference/python/incr_decoding.py -config-file ./fine_grained_alignment_config.json + +# # C++ test +# echo "C++ test" +# ./build/inference/incr_decoding/incr_decoding \ +# -ll:gpu 2 -ll:cpu 4 -ll:util 4 \ +# -tensor-parallelism-degree 2 \ +# -ll:fsize 8192 -ll:zsize 12000 \ +# -llm-model $MODEL_NAME \ +# -prompt ./inference/prompt/peft.json \ +# --use-full-precision \ +# --inference-debugging + +# Check alignment +python ./tests/inference/inference_alignment_test.py -m "$MODEL_NAME" -tp "$TP_DEGREE" -n "$NUM_STEPS" + +# Print succeess message +echo "" +echo "Inference alignment tests passed (model ${MODEL_NAME})!" +echo "" + +# Cleanup after the test +cleanup diff --git a/tests/inference/huggingface_inference.py b/tests/inference/huggingface_inference.py index 5e563c9974..fa72bef463 100644 --- a/tests/inference/huggingface_inference.py +++ b/tests/inference/huggingface_inference.py @@ -10,30 +10,9 @@ LlamaTokenizer, GenerationConfig, ) -######################### debugging helper functions ######################### -def pre_forward_hook(module, input): - assert module.name is not None and module.decoding_step is not None - name = module.name.replace("model.", "") - print( - f"Pre-forward hook activated on module: {name}, decoding step: {module.decoding_step}" - ) - print("Pre-Input: ", input[0].shape) - torch.save( - input, f"./hf_tensors/decoding_step_{module.decoding_step}_{name}.input" - ) -def post_forward_hook(module, input, output): - assert module.name is not None and module.decoding_step is not None - name = module.name.replace("model.", "") - print( - f"Post-forward Hook activated for module: {name}, decoding step: {module.decoding_step}" - ) - print("Post-Input/Output: ", input[0].shape, output[0].shape) - torch.save( - output, f"./hf_tensors/decoding_step_{module.decoding_step}_{name}.output" - ) - print("===") - module.decoding_step += 1 -############################################################################## +import sys +sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "peft")) +from hf_utils import * def main(): # Change working dir to folder storing this script @@ -91,26 +70,20 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) generation_config = GenerationConfig.from_pretrained(args.model_name) generation_config.do_sample = args.do_sample + if not args.do_sample: + generation_config.num_beams=1 + generation_config.temperature = None + generation_config.top_p = None ################# debugging ################# if args.inference_debugging: # Print model and configs print(hf_config) print(model) - # Save weights to file - shutil.rmtree("./hf_tensors") - # Check that the output folder exists - os.makedirs("./hf_tensors", exist_ok=True) + make_debug_dirs() + register_inference_hooks(model) # Save weights - for name, params in model.named_parameters(): - torch.save(params, f"./hf_tensors/{name}") - # params.detach().cpu().numpy().tofile(f"./hf_tensors/{name}") - # Register hooks to save per-op hidden states - for name, layer in dict(model.named_modules()).items(): - layer.name = name - layer.decoding_step = 0 - print(f"Adding hooks to layer {layer.name}") - layer.register_forward_pre_hook(pre_forward_hook) - layer.register_forward_hook(post_forward_hook) + save_model_weights(model, target_modules=["lora", "lm_head", "final_layer_norm", "self_attn_layer_norm", "out_proj", "fc1", "fc2"]) + ############################################### # Generate output with open(args.output_file, "w") as f: diff --git a/tests/inference/inference_alignment_test.py b/tests/inference/inference_alignment_test.py new file mode 100644 index 0000000000..6fff4906f7 --- /dev/null +++ b/tests/inference/inference_alignment_test.py @@ -0,0 +1,817 @@ +import numpy as np +import os, torch, argparse, sys +sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "peft")) +from alignment.align_test_utils import * +from transformers import AutoConfig +from tqdm import tqdm + +class AlignmentTest: + def __init__(self, hf_config, tp_degree=1): + raise NotImplementedError() + def check_weights_alignment(self): + raise NotImplementedError() + def check_fwd_pass(self): + raise NotImplementedError() + def check_bwd_pass(self): + raise NotImplementedError() + def check_step(self, step_idx, learning_rate=0.001): + raise NotImplementedError() + +class LllamaAlignmentTest(AlignmentTest): + def __init__(self, hf_config, tp_degree=1): + self.hf_config = hf_config + self.num_layers = self.hf_config.num_hidden_layers + self.hidden_size = self.hf_config.hidden_size + self.intermediate_size = self.hf_config.intermediate_size + self.num_attention_heads = self.hf_config.num_attention_heads + self.num_key_value_heads = self.hf_config.num_key_value_heads + self.projsize = self.hidden_size // self.num_attention_heads + self.tp_degree = tp_degree + + self.num_tokens = None + self.ff_batch_size = None + + + def check_weights_alignment(self): + def convert_hf_filename_to_ff(hf_filename): + if hf_filename == "lm_head.weight": + f_version = f"layers.{self.num_layers-1}.lm_head.weight_0" + elif hf_filename == "norm.weight": + f_version = f"layers.{self.num_layers-1}.norm.weight_0" + else: + f_version = "" + if hf_filename.startswith("layers."): + layernum = hf_filename.split("layers.")[1].split(".")[0] + f_version += f"layers.{layernum}." + f_version += hf_filename.replace(".base_layer", "").replace(".default", "") + # compute weight index, then rename lora if needed if needed + weight_index="0" + if "lora_A" in f_version: + weight_index="A" + elif "lora_B" in f_version: + weight_index="B" + f_version = f_version.replace("lora_A", "lora").replace("lora_B", "lora") + if f_version.endswith(".weight"): + if weight_index == "0": + f_version += f"_{weight_index}" + else: + f_version += f"_{weight_index}.original" + elif f_version.endswith(".gradient"): + prefix = f_version.split(".gradient")[0] + f_version = prefix + f".weight_{weight_index}.gradient" + return f_version + def get_tp_partition_dim(ff_weight_name) -> int: + # MLP layers split the intermediate size dimension + # gate_proj, up_proj: [hidden_size, intermediate_size] + # down_proj: [intermediate_size, hidden_size] + if self.tp_degree == 1: + return -1 + if "lora.weight_B" in ff_weight_name: + return -1 + if "lm_head" in ff_weight_name or "norm" in ff_weight_name: + return 1 + if "gate_proj" in ff_weight_name or "up_proj" in ff_weight_name: + return 1 + elif "down_proj" in ff_weight_name: + return 0 + else: + return -1 + print("-- Weights alignment --") + hf_weights_folder = os.path.join(hf_path, "weights", "step_0") + ff_weights_folder = os.path.join(ff_path, "weights", "step_0", "shard_0") + files_list = os.listdir(hf_weights_folder) + for hf_weight_name in tqdm(sorted(files_list)): + if hf_weight_name.endswith(".weight"): + ff_weight_name = convert_hf_filename_to_ff(hf_weight_name) + # print(hf_weight_name, ff_weight_name) + hf_w_path = os.path.join(hf_weights_folder, hf_weight_name) + ff_w_path = os.path.join(ff_weights_folder, ff_weight_name) + if not os.path.isfile(hf_w_path): + print(f"File '{hf_w_path}' not found") + if not os.path.isfile(ff_w_path): + print(f"File '{ff_w_path}' not found") + assert(os.path.isfile(hf_w_path)) + assert(os.path.isfile(ff_w_path)) + + # 1. get shape of hf weight + hf_weight = torch.load(hf_w_path, map_location='cpu') + hf_weight_shape = hf_weight.shape + ff_partition_dim = get_tp_partition_dim(ff_weight_name) + ff_weight_shape = list(hf_weight_shape)[::-1] + if ff_partition_dim >= 0: + ff_weight_shape[ff_partition_dim] //= self.tp_degree + + # 2. handle flexflow shards in case of tensor parallelism + ff_weights = [load_ff_tensor(ff_w_path.replace("shard_0", f"shard_{tp_idx}"), ff_weight_shape) for tp_idx in range(self.tp_degree)] + if self.tp_degree > 1: + if ff_partition_dim >= 0: + ff_weight = np.concatenate(ff_weights, axis=ff_partition_dim) + else: + assert(are_np_arrays_identical(ff_weights)) + ff_weight = ff_weights[0] + else: + ff_weight = ff_weights[0] + ff_weight = torch.from_numpy(ff_weight).to(hf_weight.dtype) + + # check equivalence + try: + torch.testing.assert_close(ff_weight, hf_weight.T) + except Exception as e: + print(f"Error comparing {ff_w_path} weight to {hf_w_path}:\n{e}\n") + raise e + + def check_fwd_pass(self, step_idx=0): + hf_fwd_folder = os.path.join(hf_path, "fwd", f"step_{step_idx}") + ff_fwd_folder = os.path.join(ff_path, "fwd", f"step_{step_idx}", "shard_0") + + def convert_hf_filename_to_ff(hf_filename): + if hf_filename == "embed_tokens": + f_version = f"layers.0.embed_tokens" + elif hf_filename == "lm_head" or hf_filename == "norm": + f_version = f"layers.{self.num_layers-1}.{hf_filename}" + else: + assert hf_filename.startswith("layers.") + layernum = hf_filename.split("layers.")[1].split(".")[0] + f_version = f"layers.{layernum}." + f_version += hf_filename.replace(".base_layer", "").replace(".default", "") + # right now, attention in flexflow is done with a single operator, so there is a single output file without the projection suffix + f_version = f_version.replace(".q_proj", ".qkv_proj").replace(".k_proj", ".qkv_proj").replace(".v_proj", ".qkv_proj")#.replace(".o_proj", "") + return f_version + + def get_hf_tensor(hf_tensor_name, tensor_comparison_idx): + hf_tensor_filename = f"{hf_tensor_name}.{tensor_comparison_idx.hf_tensor_type}_{tensor_comparison_idx.hf_tensor_idx}" + hf_tensor_path = os.path.join(hf_fwd_folder, hf_tensor_filename) + + if not os.path.isfile(hf_tensor_path): + raise FileNotFoundError(f"File '{hf_tensor_path}' not found") + print("loading hf tensor: ", hf_tensor_filename) + hf_tensor = torch.load(hf_tensor_path, map_location='cpu') + if hf_tensor_name == "embed_tokens": + self.num_tokens = hf_tensor.shape[1] + return hf_tensor + + def get_ff_tensor(ff_tensor_name, tensor_comparison_idx, hf_shape, tp_type=TPType.REPLICATE): + ff_tensor_suffix = f".{tensor_comparison_idx.ff_tensor_type}" if len(tensor_comparison_idx.ff_tensor_type) > 0 else "" + ff_tensor_idx_suffix = f"_{tensor_comparison_idx.ff_tensor_idx}" if tensor_comparison_idx.ff_tensor_idx is not None else "" + ff_tensor_filename = f"{ff_tensor_name}{ff_tensor_suffix}{ff_tensor_idx_suffix}" + ff_tensor_path = os.path.join(ff_fwd_folder, ff_tensor_filename) + if not os.path.isfile(ff_tensor_path): + raise FileNotFoundError(f"File '{ff_tensor_path}' not found") + + print("loading ff tensor: ", ff_tensor_filename) + ff_shape = list(hf_shape)[::-1] + if tp_type == TPType.PARTITION: + ff_shape[0] //= self.tp_degree + + if "layers.0.embed_tokens.input_0" in ff_tensor_path: + # get number of tokens + ff_tensor = np.loadtxt(ff_tensor_path, delimiter=',') + self.ff_batch_size = ff_tensor.shape[0] + + ff_shape = replace_value(ff_shape, self.num_tokens, self.ff_batch_size) + ff_tensors = [load_ff_tensor(ff_tensor_path.replace("shard_0", f"shard_{tp_idx}"), ff_shape) for tp_idx in range(self.tp_degree)] + if self.tp_degree > 1: + # if replicate, check that they are identical + if tp_type == TPType.REPLICATE: + assert(are_np_arrays_identical(ff_tensors)) + ff_tensor = ff_tensors[0] + # if partition, concatenate along the partition dimension + elif tp_type == TPType.PARTITION: + ff_tensor = np.concatenate(ff_tensors, axis=0) + # if to_reduce, sum along the partition dimension + elif tp_type == TPType.TO_REDUCE: + ff_tensor = np.sum(ff_tensors, axis=0) + else: + ff_tensor = ff_tensors[0] + ff_tensor = torch.from_numpy(ff_tensor) + ff_tensor = truncate_dimension(ff_tensor, self.ff_batch_size, self.num_tokens) + return ff_tensor + + def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance=1e-2): + ff_tensor = ff_tensor.to(hf_tensor.dtype) + hf_tensor = hf_tensor.T + if additional_ff_tensor is not None: + additional_ff_tensor = additional_ff_tensor.to(hf_tensor.dtype) + ff_tensor = ff_tensor - additional_ff_tensor + try: + # torch.testing.assert_close(hf_tensor, ff_tensor, rtol=1.3e-6, atol=tolerance) + if not np.allclose(hf_tensor.detach().numpy(), ff_tensor.detach().numpy(), atol=tolerance): + mismatches = np.where(~np.isclose(hf_tensor.detach().numpy(), ff_tensor.detach().numpy(), atol=tolerance))[0] + print(f"Pct mismatch {label}: {100.0*(np.prod(mismatches.shape) / ff_tensor.numel()):.3f}%") + assert(np.prod(mismatches.shape) <= .05 * ff_tensor.numel()) + except Exception as e: + print(f"Error in comparison {label}:\n{e}\n") + print("HF tensor:") + print(hf_tensor.squeeze()) + print(hf_tensor.shape) + print("FF tensor:") + print(ff_tensor.squeeze()) + print(ff_tensor.shape) + raise e + + print(f"-- FWD pass {step_idx}--") + + # Embedding layer + hf_tensor_name = "embed_tokens" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label="Embedding input") + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label="Embedding output") + + # Transformers blocks + for i in range(self.num_layers): + # Input laye norm + hf_tensor_name = f"layers.{i}.input_layernorm" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + if i == 0: + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + else: + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=1) + hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label=f"Input layernorm {i} input") + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label=f"Input layernorm {i} output") + + # Attention QKV projections + hf_q_proj_tensor_name = f"layers.{i}.self_attn.q_proj" + hf_k_proj_tensor_name = f"layers.{i}.self_attn.k_proj" + hf_v_proj_tensor_name = f"layers.{i}.self_attn.v_proj" + ff_qkv_tensor_name = convert_hf_filename_to_ff(hf_q_proj_tensor_name) + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_q_proj_in = get_hf_tensor(hf_q_proj_tensor_name, input_comparison) + hf_k_proj_in = get_hf_tensor(hf_k_proj_tensor_name, input_comparison) + hf_v_proj_in = get_hf_tensor(hf_v_proj_tensor_name, input_comparison) + hf_q_proj_out = get_hf_tensor(hf_q_proj_tensor_name, output_comparison) + hf_k_proj_out = get_hf_tensor(hf_k_proj_tensor_name, output_comparison) + hf_v_proj_out = get_hf_tensor(hf_v_proj_tensor_name, output_comparison) + ff_qkv_tensor_in = get_ff_tensor(ff_qkv_tensor_name, input_comparison, hf_q_proj_in.shape) + torch.testing.assert_close(hf_q_proj_in, hf_k_proj_in) + torch.testing.assert_close(hf_k_proj_in, hf_v_proj_in) + compare(hf_q_proj_in, ff_qkv_tensor_in, label=f"QKV proj {i} input") + ff_qkv_tensor_out = get_ff_tensor( + ff_qkv_tensor_name, + output_comparison, + torch.Size([hf_q_proj_out.shape[0], hf_q_proj_out.shape[1], 3*hf_q_proj_out.shape[2]]), + tp_type=TPType.PARTITION + ) + head_dim = hf_q_proj_out.shape[2] // self.num_attention_heads + heads_per_shard = self.num_attention_heads // self.tp_degree + chunk_size = head_dim * heads_per_shard + # print(ff_qkv_tensor_out.shape) + ff_qproj_out = ff_qkv_tensor_out[:chunk_size, :, :] + ff_kproj_out = ff_qkv_tensor_out[chunk_size:2*chunk_size, :, :] + ff_vproj_out = ff_qkv_tensor_out[2*chunk_size : 3*chunk_size, :, :] + qkv_chunk_size = 3*chunk_size + for tp_idx in range(1, self.tp_degree): + prev_size = tp_idx * qkv_chunk_size + ff_qproj_out_ = ff_qkv_tensor_out[prev_size : prev_size + chunk_size, :, :] + ff_kproj_out_ = ff_qkv_tensor_out[prev_size + chunk_size : prev_size + 2*chunk_size, :, :] + ff_vproj_out_ = ff_qkv_tensor_out[prev_size + 2*chunk_size : prev_size + 3*chunk_size, :, :] + ff_qproj_out = np.concatenate((ff_qproj_out, ff_qproj_out_), axis=0) + ff_kproj_out = np.concatenate((ff_kproj_out, ff_kproj_out_), axis=0) + ff_vproj_out = np.concatenate((ff_vproj_out, ff_vproj_out_), axis=0) + compare_loaded_tensors(hf_q_proj_out.T, ff_qproj_out) + compare_loaded_tensors(hf_k_proj_out.T, ff_kproj_out) + compare_loaded_tensors(hf_v_proj_out.T, ff_vproj_out) + ff_tensor_name = f"layers.{i}.layers.{i}.self_attn" + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0) + ff_attn_tensor_in = get_ff_tensor( + ff_tensor_name, + input_comparison, + torch.Size([hf_q_proj_out.shape[0], hf_q_proj_out.shape[1], 3*hf_q_proj_out.shape[2]]), + tp_type=TPType.PARTITION + ) + assert torch.allclose(ff_qkv_tensor_out, ff_attn_tensor_in) + + # Attention + hf_tensor_name = f"layers.{i}.self_attn.o_proj" + ff_tensor_name = convert_hf_filename_to_ff(f"layers.{i}.self_attn") + # the raw attention result, w/o o_proj. This is the output of senf_attn of FF and the input of o_proj in HF + output_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + # ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE) + # TP for self-attn partitions the attention heads across TP workers + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + print("comparing attention tensor: ", hf_tensor_name, " and ", ff_tensor_name) + compare(hf_tensor, ff_tensor, label=f"Attention {i} output") + + # Post-attention layernorm + hf_tensor_name = f"layers.{i}.post_attention_layernorm" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=1) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label=f"Post-attention layernorm {i} output") + + # W1 (gate_proj) + hf_tensor_name = f"layers.{i}.mlp.gate_proj" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + compare(hf_tensor, ff_tensor, label=f"W1 {i} output") + + # W3 (up_proj) + hf_tensor_name = f"layers.{i}.mlp.up_proj" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + compare(hf_tensor, ff_tensor, label=f"W3 {i} output") + + # W2 (down_proj) + hf_tensor_name = f"layers.{i}.mlp.down_proj" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_down_proj_out = get_hf_tensor(hf_tensor_name, output_comparison) + hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + compare(hf_tensor, ff_tensor, label=f"W2 {i} input") + + hf_down_proj_in = hf_tensor.clone() + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_down_proj_out = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE) + + # Norm + hf_tensor_name = "norm" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=1) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label="Norm output") + + # LM head + hf_tensor_name = "lm_head" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.REPLICATE) + compare(hf_tensor, ff_tensor, label="LM head input") + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + compare(hf_tensor, ff_tensor, label="LM head output") + +class OPTAlignmentTest(AlignmentTest): + def __init__(self, hf_config, tp_degree=1): + self.hf_config = hf_config + self.num_layers = self.hf_config.num_hidden_layers + self.hidden_size = self.hf_config.hidden_size + self.intermediate_size = self.hf_config.ffn_dim + self.num_attention_heads = self.hf_config.num_attention_heads + self.num_key_value_heads = self.num_attention_heads + self.projsize = self.hidden_size // self.num_attention_heads + self.tp_degree = tp_degree + + self.num_tokens = None + self.ff_batch_size = None + + def check_weights_alignment(self): + def convert_hf_filename_to_ff(hf_filename): + if hf_filename == "lm_head.weight" or hf_filename == "final_layer_norm.weight": + f_version = f"layers.{self.num_layers-1}.{hf_filename}_0" + elif hf_filename == "lm_head.bias" or hf_filename == "final_layer_norm.bias": + f_version = f"layers.{self.num_layers-1}.{hf_filename.replace('bias', 'weight')}_1" + elif hf_filename.startswith("layers.") and hf_filename.endswith("self_attn.out_proj.bias"): + layernum = hf_filename.split("layers.")[1].split(".")[0] + f_version = f"layers.{layernum}.layers.{layernum}.add_bias_residual_layer_norm.weight_0" + elif hf_filename.startswith("layers.") and hf_filename.endswith(".final_layer_norm.weight"): + layernum = hf_filename.split("layers.")[1].split(".")[0] + f_version = f"layers.{layernum}.layers.{layernum}.add_bias_residual_layer_norm.weight_1" + elif hf_filename.startswith("layers.") and hf_filename.endswith(".final_layer_norm.bias"): + layernum = hf_filename.split("layers.")[1].split(".")[0] + f_version = f"layers.{layernum}.layers.{layernum}.add_bias_residual_layer_norm.weight_2" + else: + f_version = "" + if hf_filename.startswith("layers."): + layernum = hf_filename.split("layers.")[1].split(".")[0] + f_version += f"layers.{layernum}." + f_version += hf_filename.replace(".base_layer", "").replace(".default", "").replace("out_proj", "o_proj") + # compute weight index, then rename lora if needed if needed + weight_index="0" + if "lora_A" in f_version: + weight_index="A" + elif "lora_B" in f_version: + weight_index="B" + f_version = f_version.replace("lora_A", "lora").replace("lora_B", "lora") + if f_version.endswith(".weight"): + if weight_index == "0": + f_version += f"_{weight_index}" + else: + f_version += f"_{weight_index}.original" + elif f_version.endswith(".gradient"): + prefix = f_version.split(".gradient")[0] + f_version = prefix + f".weight_{weight_index}.gradient" + elif f_version.endswith(".bias"): + f_version = f_version.replace(".bias", ".weight_1") + return f_version + def get_tp_partition_dim(ff_weight_name) -> int: + # MLP layers split the intermediate size dimension + # gate_proj, up_proj: [hidden_size, intermediate_size] + # down_proj: [intermediate_size, hidden_size] + if self.tp_degree == 1: + return -1 + if "lora.weight_B" in ff_weight_name: + return -1 + if "lm_head" in ff_weight_name or "fc1" in ff_weight_name: + return 1 + elif "fc2" in ff_weight_name or "o_proj.weight" in ff_weight_name: + return 0 + else: + return -1 + def get_bias_tp_partition_dim(ff_weight_name) -> int: + if self.tp_degree == 1: + return -1 + elif "lm_head" in ff_weight_name or "fc1" in ff_weight_name: + return 0 + else: + return -1 + print("-- Weights alignment --") + hf_weights_folder = os.path.join(hf_path, "weights", "step_0") + ff_weights_folder = os.path.join(ff_path, "weights", "step_0", "shard_0") + files_list = os.listdir(hf_weights_folder) + for hf_weight_name in tqdm(sorted(files_list)): + if hf_weight_name.endswith(".weight") or hf_weight_name.endswith(".bias"): + ff_weight_name = convert_hf_filename_to_ff(hf_weight_name) + # print(hf_weight_name, ff_weight_name) + hf_w_path = os.path.join(hf_weights_folder, hf_weight_name) + ff_w_path = os.path.join(ff_weights_folder, ff_weight_name) + if not os.path.isfile(hf_w_path): + print(f"File '{hf_w_path}' not found") + if not os.path.isfile(ff_w_path): + print(f"File '{ff_w_path}' not found") + assert(os.path.isfile(hf_w_path)) + assert(os.path.isfile(ff_w_path)) + + # 1. get shape of hf weight + hf_weight = torch.load(hf_w_path, map_location='cpu') + hf_weight_shape = hf_weight.shape + ff_partition_dim = get_tp_partition_dim(ff_weight_name) if hf_weight_name.endswith(".weight") else get_bias_tp_partition_dim(ff_weight_name) + ff_weight_shape = list(hf_weight_shape)[::-1] + # print(ff_partition_dim, ff_weight_name, hf_w_path, ff_weight_shape) + if ff_partition_dim >= 0: + ff_weight_shape[ff_partition_dim] //= self.tp_degree + + # 2. handle flexflow shards in case of tensor parallelism + if hf_weight_name.endswith(".bias") and ff_partition_dim == -1: + # unpartitioned bias (E.g. replicated bias) only lives on shard 0 + ff_weight = load_ff_tensor(ff_w_path, ff_weight_shape) + else: + ff_weights = [load_ff_tensor(ff_w_path.replace("shard_0", f"shard_{tp_idx}"), ff_weight_shape) for tp_idx in range(self.tp_degree)] + if self.tp_degree > 1: + if ff_partition_dim >= 0: + ff_weight = np.concatenate(ff_weights, axis=ff_partition_dim) + else: + assert(are_np_arrays_identical(ff_weights)) + ff_weight = ff_weights[0] + else: + ff_weight = ff_weights[0] + ff_weight = torch.from_numpy(ff_weight).to(hf_weight.dtype) + # print("comparing weight tensor: ", hf_weight_name, " and ", ff_weight_name) + # check equivalence + try: + torch.testing.assert_close(ff_weight, hf_weight.T) + except Exception as e: + print(f"Error comparing {ff_w_path} weight to {hf_w_path}:\n{e}\n") + raise e + + def check_fwd_pass(self, step_idx=0): + hf_fwd_folder = os.path.join(hf_path, "fwd", f"step_{step_idx}") + ff_fwd_folder = os.path.join(ff_path, "fwd", f"step_{step_idx}", "shard_0") + + def convert_hf_filename_to_ff(hf_filename): + if hf_filename == "embed_tokens" or hf_filename == "embed_positions": + f_version = f"layers.0.{hf_filename}" + elif hf_filename == "lm_head" or hf_filename == "final_layer_norm": + f_version = f"layers.{self.num_layers-1}.{hf_filename}" + else: + assert hf_filename.startswith("layers.") + layernum = hf_filename.split("layers.")[1].split(".")[0] + f_version = f"layers.{layernum}." + f_version += hf_filename.replace(".base_layer", "").replace(".default", "") + # right now, attention in flexflow is done with a single operator, so there is a single output file without the projection suffix + f_version = f_version.replace(".q_proj", ".qkv_proj").replace(".k_proj", ".qkv_proj").replace(".v_proj", ".qkv_proj") + return f_version + + def get_hf_tensor(hf_tensor_name, tensor_comparison_idx): + hf_tensor_filename = f"{hf_tensor_name}.{tensor_comparison_idx.hf_tensor_type}_{tensor_comparison_idx.hf_tensor_idx}" + hf_tensor_path = os.path.join(hf_fwd_folder, hf_tensor_filename) + + if not os.path.isfile(hf_tensor_path): + raise FileNotFoundError(f"File '{hf_tensor_path}' not found") + print("loading hf tensor: ", hf_tensor_filename) + hf_tensor = torch.load(hf_tensor_path, map_location='cpu') + if hf_tensor_name == "embed_tokens": + self.num_tokens = hf_tensor.shape[1] + return hf_tensor + + def get_ff_tensor(ff_tensor_name, tensor_comparison_idx, hf_shape, tp_type=TPType.REPLICATE): + ff_tensor_suffix = f".{tensor_comparison_idx.ff_tensor_type}" if len(tensor_comparison_idx.ff_tensor_type) > 0 else "" + ff_tensor_idx_suffix = f"_{tensor_comparison_idx.ff_tensor_idx}" if tensor_comparison_idx.ff_tensor_idx is not None else "" + ff_tensor_filename = f"{ff_tensor_name}{ff_tensor_suffix}{ff_tensor_idx_suffix}" + ff_tensor_path = os.path.join(ff_fwd_folder, ff_tensor_filename) + if not os.path.isfile(ff_tensor_path): + raise FileNotFoundError(f"File '{ff_tensor_path}' not found") + + print("loading ff tensor: ", ff_tensor_filename) + ff_shape = list(hf_shape)[::-1] + if tp_type == TPType.PARTITION: + ff_shape[0] //= self.tp_degree + + if "layers.0.embed_tokens.input_0" in ff_tensor_path: + # get number of tokens + ff_tensor = np.loadtxt(ff_tensor_path, delimiter=',') + self.ff_batch_size = ff_tensor.shape[0] + + ff_shape = replace_value(ff_shape, self.num_tokens, self.ff_batch_size) + ff_tensors = [load_ff_tensor(ff_tensor_path.replace("shard_0", f"shard_{tp_idx}"), ff_shape) for tp_idx in range(self.tp_degree)] + if self.tp_degree > 1: + # if replicate, check that they are identical + if tp_type == TPType.REPLICATE: + assert(are_np_arrays_identical(ff_tensors)) + ff_tensor = ff_tensors[0] + # if partition, concatenate along the partition dimension + elif tp_type == TPType.PARTITION: + ff_tensor = np.concatenate(ff_tensors, axis=0) + # if to_reduce, sum along the partition dimension + elif tp_type == TPType.TO_REDUCE: + ff_tensor = np.sum(ff_tensors, axis=0) + else: + ff_tensor = ff_tensors[0] + ff_tensor = torch.from_numpy(ff_tensor) + ff_tensor = truncate_dimension(ff_tensor, self.ff_batch_size, self.num_tokens) + return ff_tensor + + def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance=1e-2): + ff_tensor = ff_tensor.to(hf_tensor.dtype) + hf_tensor = hf_tensor.T + if additional_ff_tensor is not None: + additional_ff_tensor = additional_ff_tensor.to(hf_tensor.dtype) + ff_tensor = ff_tensor - additional_ff_tensor + try: + # torch.testing.assert_close(hf_tensor, ff_tensor, rtol=1.3e-6, atol=tolerance) + if not np.allclose(hf_tensor.detach().numpy(), ff_tensor.detach().numpy(), atol=tolerance): + mismatches = np.where(~np.isclose(hf_tensor.detach().numpy(), ff_tensor.detach().numpy(), atol=tolerance))[0] + print(f"Pct mismatch {label}: {100.0*(np.prod(mismatches.shape) / ff_tensor.numel()):.3f}%") + assert(np.prod(mismatches.shape) <= .05 * ff_tensor.numel()) + except Exception as e: + print(f"Error in comparison {label}:\n{e}\n") + print("HF tensor:") + print(hf_tensor.squeeze()) + print(hf_tensor.shape) + print("FF tensor:") + print(ff_tensor.squeeze()) + print(ff_tensor.shape) + raise e + + print(f"-- FWD pass {step_idx}--") + + # Embedding layer + hf_tensor_name = "embed_tokens" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label="Embedding input") + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label="Embedding output") + + # Positional embedding layer + hf_tensor_name = "embed_positions" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label="Position Embedding output") + + # Transformers blocks + for i in range(self.num_layers): + # Input layer norm + hf_tensor_name = f"layers.{i}.self_attn_layer_norm" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=1) + hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label=f"Self attention layernorm {i} input") + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label=f"Self attention layernorm {i} output") + + # Attention QKV projections + hf_q_proj_tensor_name = f"layers.{i}.self_attn.q_proj" + hf_k_proj_tensor_name = f"layers.{i}.self_attn.k_proj" + hf_v_proj_tensor_name = f"layers.{i}.self_attn.v_proj" + ff_qkv_tensor_name = convert_hf_filename_to_ff(hf_q_proj_tensor_name) + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_q_proj_in = get_hf_tensor(hf_q_proj_tensor_name, input_comparison) + hf_k_proj_in = get_hf_tensor(hf_k_proj_tensor_name, input_comparison) + hf_v_proj_in = get_hf_tensor(hf_v_proj_tensor_name, input_comparison) + hf_q_proj_out = get_hf_tensor(hf_q_proj_tensor_name, output_comparison) + hf_k_proj_out = get_hf_tensor(hf_k_proj_tensor_name, output_comparison) + hf_v_proj_out = get_hf_tensor(hf_v_proj_tensor_name, output_comparison) + ff_qkv_tensor_in = get_ff_tensor(ff_qkv_tensor_name, input_comparison, hf_q_proj_in.shape) + torch.testing.assert_close(hf_q_proj_in, hf_k_proj_in) + torch.testing.assert_close(hf_k_proj_in, hf_v_proj_in) + compare(hf_q_proj_in, ff_qkv_tensor_in, label=f"QKV proj {i} input") + ff_qkv_tensor_out = get_ff_tensor( + ff_qkv_tensor_name, + output_comparison, + torch.Size([hf_q_proj_out.shape[0], hf_q_proj_out.shape[1], 3*hf_q_proj_out.shape[2]]), + tp_type=TPType.PARTITION + ) + head_dim = hf_q_proj_out.shape[2] // self.num_attention_heads + heads_per_shard = self.num_attention_heads // self.tp_degree + chunk_size = head_dim * heads_per_shard + # print(ff_qkv_tensor_out.shape) + ff_qproj_out = ff_qkv_tensor_out[:chunk_size, :, :] + ff_kproj_out = ff_qkv_tensor_out[chunk_size:2*chunk_size, :, :] + ff_vproj_out = ff_qkv_tensor_out[2*chunk_size : 3*chunk_size, :, :] + qkv_chunk_size = 3*chunk_size + for tp_idx in range(1, self.tp_degree): + prev_size = tp_idx * qkv_chunk_size + ff_qproj_out_ = ff_qkv_tensor_out[prev_size : prev_size + chunk_size, :, :] + ff_kproj_out_ = ff_qkv_tensor_out[prev_size + chunk_size : prev_size + 2*chunk_size, :, :] + ff_vproj_out_ = ff_qkv_tensor_out[prev_size + 2*chunk_size : prev_size + 3*chunk_size, :, :] + ff_qproj_out = np.concatenate((ff_qproj_out, ff_qproj_out_), axis=0) + ff_kproj_out = np.concatenate((ff_kproj_out, ff_kproj_out_), axis=0) + ff_vproj_out = np.concatenate((ff_vproj_out, ff_vproj_out_), axis=0) + compare_loaded_tensors(hf_q_proj_out.T, ff_qproj_out) + compare_loaded_tensors(hf_k_proj_out.T, ff_kproj_out) + compare_loaded_tensors(hf_v_proj_out.T, ff_vproj_out) + ff_tensor_name = f"layers.{i}.layers.{i}.self_attn" + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0) + ff_attn_tensor_in = get_ff_tensor( + ff_tensor_name, + input_comparison, + torch.Size([hf_q_proj_out.shape[0], hf_q_proj_out.shape[1], 3*hf_q_proj_out.shape[2]]), + tp_type=TPType.PARTITION + ) + assert torch.allclose(ff_qkv_tensor_out, ff_attn_tensor_in) + + # Compared scaled qproj + hf_tensor_name = f"layers.{i}.self_attn.scaled_qproj" + input_c = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + output_c = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + scaled_qproj_in = get_hf_tensor(hf_tensor_name, input_c) + scaled_qproj_out = get_hf_tensor(hf_tensor_name, output_c) + assert torch.allclose(scaled_qproj_in, scaled_qproj_out) + ff_tensor_name = f"layers.{i}.layers.{i}.self_attn.scaled_qkv_proj" + scaled_qkv_proj0 = load_ff_tensor(os.path.join(ff_fwd_folder, f"{ff_tensor_name}.output_0"), [64*6,3,9]) + scaled_qkv_proj1 = load_ff_tensor(os.path.join(ff_fwd_folder, f"{ff_tensor_name}.output_0").replace("shard_0", "shard_1"), [64*6,3,9]) + ff_scaled_qkv_proj = np.concatenate([scaled_qkv_proj0, scaled_qkv_proj1], axis=0) + ff_scaled_q_proj = torch.from_numpy(ff_scaled_qkv_proj[:, :1, :]).to(scaled_qproj_out.dtype) + # print("HF scaled qproj:") + # print(scaled_qproj_out.squeeze().T) + # print("FF scaled q proj:") + # print(ff_scaled_q_proj.squeeze()) + # print("HF unscaled qproj:") + # print(hf_q_proj_out.squeeze().T) + # print("FF unscaled qproj:") + # print(torch.from_numpy(ff_qproj_out.squeeze()).to(scaled_qproj_out.dtype)) + # assert torch.allclose(hf_q_proj_out.squeeze().T, ff_scaled_q_proj.squeeze()) + + + + # check that out_proj input, attn_scores out and input are identical on the hf side + hf_tensor_name = f"layers.{i}.self_attn.attn_scores" + input_c = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + output_c = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + attn_scores_in = get_hf_tensor(hf_tensor_name, input_c) + attn_scores_out = get_hf_tensor(hf_tensor_name, output_c) + hf_tensor_name = f"layers.{i}.self_attn.out_proj" + out_proj_in = get_hf_tensor(hf_tensor_name, input_c) + assert torch.allclose(attn_scores_in, attn_scores_out) + assert torch.allclose(attn_scores_in, out_proj_in) + + # Compare out proj input. This should be the output of the attention without any bias involved + hf_tensor_name = f"layers.{i}.self_attn.out_proj" + ff_tensor_name = f"layers.{i}.layers.{i}.self_attn" + output_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + print("comparing attention tensor: ", hf_tensor_name, " and ", ff_tensor_name) + compare(hf_tensor, ff_tensor, label=f"Attention o-proj {i} input") + + hf_tensor_name = f"layers.{i}.self_attn.attn_scores" + ff_tensor_name = f"layers.{i}.layers.{i}.self_attn" + output_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + compare(hf_tensor, ff_tensor, label=f"Attention {i} output") + + # hf_tensor_name = f"layers.{i}.final_layer_norm" + # ff_tensor_name = f"layers.{i}.layers.{i}.add_bias_residual_layer_norm" + # output_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + # hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + # ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.REPLICATE) + # compare(hf_tensor, ff_tensor, label=f"Add Bias Residula LN {i} output 0") + + hf_tensor_name = f"layers.{i}.self_attn.out_proj" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name.replace(".out_proj", ".o_proj")) + # # the raw attention result, w/o o_proj. This is the output of senf_attn of FF and the input of o_proj in HF + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE) + # # TP for self-attn partitions the attention heads across TP workers + # ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + print("comparing attention tensor: ", hf_tensor_name, " and ", ff_tensor_name) + # compare(hf_tensor, ff_tensor, label=f"Attention oproj {i} output") + + # hf_tensor_name = f"layers.{i}.self_attn.out_proj" + # ff_tensor_name = f"layers.{i}.layers.{i}.self_attn" + # output_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + # hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + # ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + # print("comparing attention tensor: ", hf_tensor_name, " and ", ff_tensor_name) + # compare(hf_tensor, ff_tensor, label=f"Attention {i} output") + + + + # # Post-attention layernorm + # hf_tensor_name = f"layers.{i}.add_bias_residual_layer_norm" + # ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + # output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=1) + # hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + # ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape) + # compare(hf_tensor, ff_tensor, label=f"Add bias residual layernorm {i} output") + + # FC1 (+ ReLU) + hf_tensor_name = f"layers.{i}.activation_fn" + ff_tensor_name = convert_hf_filename_to_ff(f"layers.{i}.fc1") + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + compare(hf_tensor, ff_tensor, label=f"FC1 {i} output") + + # FC2 + hf_tensor_name = f"layers.{i}.fc2" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_down_proj_out = get_hf_tensor(hf_tensor_name, output_comparison) + hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + compare(hf_tensor, ff_tensor, label=f"FC2 {i} input") + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE) + # compare(hf_tensor, ff_tensor, label=f"FC2 {i} output") + + hf_down_proj_in = hf_tensor.clone() + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_down_proj_out = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE) + + # Norm + hf_tensor_name = "final_layer_norm" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=1) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label="Final layer norm output") + + # LM head + hf_tensor_name = "lm_head" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.REPLICATE) + compare(hf_tensor, ff_tensor, label="LM head input") + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + compare(hf_tensor, ff_tensor, label="LM head output") + +parser = argparse.ArgumentParser(description='Argument Parser Example') +# Adding arguments +parser.add_argument('-m', '--model-name', type=str, default="goliaro/llama-160m-lora", help='Name of the model') +parser.add_argument('-n', '--num-steps', type=int, default=1, help='Number of decoding steps') +parser.add_argument('-tp', '--tensor-parallelism-degree', type=int, default=1, help='The tensor parallelism degree used when running FlexFlow') + +# Parse the arguments from command line +args = parser.parse_args() + +if __name__ == "__main__": + hf_config = AutoConfig.from_pretrained(args.model_name) + alignment_class = None + if hf_config.architectures[0] == "LlamaForCausalLM": + alignment_class = LllamaAlignmentTest(hf_config, tp_degree=args.tensor_parallelism_degree) + elif hf_config.architectures[0] == "OPTForCausalLM": + alignment_class = OPTAlignmentTest(hf_config, tp_degree=args.tensor_parallelism_degree) + + # alignment_class.check_weights_alignment() + for i in range(args.num_steps): + alignment_class.check_fwd_pass(i) diff --git a/tests/peft/alignment/align_test_utils.py b/tests/peft/alignment/align_test_utils.py index 93727bdc89..3085bbda56 100644 --- a/tests/peft/alignment/align_test_utils.py +++ b/tests/peft/alignment/align_test_utils.py @@ -3,6 +3,8 @@ from typing import List from enum import Enum from dataclasses import dataclass +import warnings + abs_dirname = os.path.dirname(os.path.abspath(__file__)) cache_folder = os.path.expanduser(os.getenv("FF_CACHE_PATH", "~/.cache/flexflow")) @@ -472,7 +474,16 @@ def replace_value(lst, old_value, new_value): if occurrences == 0: raise ValueError(f"Value {old_value} not found in the list.") elif occurrences > 1: - raise ValueError(f"Multiple instances of {old_value} found in the list.") + warnings.warn(f"Multiple instances of {old_value} found in the list.") + occurrence_idx=0 + for i, value in enumerate(lst): + if value == old_value: + occurrence_idx += 1 + if occurrence_idx == 2: + lst[i] = new_value + break + return lst + # raise ValueError(f"Multiple instances of {old_value} found in the list.") else: index = lst.index(old_value) lst[index] = new_value diff --git a/tests/peft/hf_finetune.py b/tests/peft/hf_finetune.py index 16b46cfa81..a2fc5548ab 100644 --- a/tests/peft/hf_finetune.py +++ b/tests/peft/hf_finetune.py @@ -77,7 +77,7 @@ def main(): if args.save_peft_tensors: make_debug_dirs() register_peft_hooks(model) - save_peft_weights(model, target_modules=["lora", "lm_head", "down_proj"]) + save_model_weights(model, target_modules=["lora", "lm_head", "down_proj"]) # Load fine-tuning dataset data = load_dataset("Abirate/english_quotes") diff --git a/tests/peft/hf_utils.py b/tests/peft/hf_utils.py index 9332c803b2..94fb96f029 100644 --- a/tests/peft/hf_utils.py +++ b/tests/peft/hf_utils.py @@ -40,7 +40,7 @@ def get_dst_folder(subdir, step_idx=0): def simplify_name(name): - return name.replace("base_model.model.model.", "").replace("base_model.model.", "") + return name.replace("base_model.model.model.", "").replace("base_model.model.", "").replace("model.layers.", "layers.").replace("model.", "").replace("decoder.", "") def get_optim_type(args): @@ -114,7 +114,7 @@ def peft_backward_hook(module, grad_input, grad_output): module.bwd_step += 1 -def peft_forward_hook(module, input, output): +def fwd_hook(module, input, output): if len(input) == 0 or len(output) == 0: return assert module.name is not None and module.fwd_step is not None @@ -312,11 +312,18 @@ def register_peft_hooks(model): layer.bwd_step = 0 if verbose: print(f"Adding hooks to layer {layer.name}") - layer.register_forward_hook(peft_forward_hook) + layer.register_forward_hook(fwd_hook) layer.register_full_backward_hook(peft_backward_hook) +def register_inference_hooks(model): + for name, layer in dict(model.named_modules()).items(): + layer.name = name + layer.fwd_step = 0 + if verbose: + print(f"Adding hooks to layer {layer.name}") + layer.register_forward_hook(fwd_hook) -def save_peft_weights(model, target_modules=[]): +def save_model_weights(model, target_modules=[]): # Save any weights of interest for name, params in model.named_parameters(): simplified_name = simplify_name(name) diff --git a/tests/peft/peft_alignment_test.py b/tests/peft/peft_alignment_test.py index 266bb64137..cc677cd51a 100644 --- a/tests/peft/peft_alignment_test.py +++ b/tests/peft/peft_alignment_test.py @@ -98,14 +98,14 @@ def get_tp_partition_dim(ff_weight_name) -> int: # 1. get shape of hf weight hf_weight = torch.load(hf_w_path, map_location='cpu') - hf_weigth_shape = hf_weight.shape + hf_weight_shape = hf_weight.shape ff_partition_dim = get_tp_partition_dim(ff_weight_name) - ff_weigth_shape = list(hf_weigth_shape)[::-1] + ff_weight_shape = list(hf_weight_shape)[::-1] if ff_partition_dim >= 0: - ff_weigth_shape[ff_partition_dim] //= self.tp_degree + ff_weight_shape[ff_partition_dim] //= self.tp_degree # 2. handle flexflow shards in case of tensor parallelism - ff_weights = [load_ff_tensor(ff_w_path.replace("shard_0", f"shard_{tp_idx}"), ff_weigth_shape) for tp_idx in range(self.tp_degree)] + ff_weights = [load_ff_tensor(ff_w_path.replace("shard_0", f"shard_{tp_idx}"), ff_weight_shape) for tp_idx in range(self.tp_degree)] if self.tp_degree > 1: if ff_partition_dim >= 0: ff_weight = np.concatenate(ff_weights, axis=ff_partition_dim) @@ -149,6 +149,7 @@ def get_hf_tensor(hf_tensor_name, tensor_comparison_idx): if not os.path.isfile(hf_tensor_path): raise FileNotFoundError(f"File '{hf_tensor_path}' not found") + print("loading hf tensor: ", hf_tensor_filename) hf_tensor = torch.load(hf_tensor_path, map_location='cpu') if hf_tensor_name == "embed_tokens": self.num_tokens = hf_tensor.shape[1] @@ -162,6 +163,7 @@ def get_ff_tensor(ff_tensor_name, tensor_comparison_idx, hf_shape, tp_type=TPTyp if not os.path.isfile(ff_tensor_path): raise FileNotFoundError(f"File '{ff_tensor_path}' not found") + print("loading ff tensor: ", ff_tensor_filename) ff_shape = list(hf_shape)[::-1] if tp_type == TPType.PARTITION: ff_shape[0] //= self.tp_degree @@ -206,8 +208,10 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance print(f"Error in comparison {label}:\n{e}\n") print("HF tensor:") print(hf_tensor.squeeze()) + print(hf_tensor.shape) print("FF tensor:") print(ff_tensor.squeeze()) + print(ff_tensor.shape) raise e print(f"-- FWD pass {step_idx}--") @@ -245,9 +249,13 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance # Attention hf_tensor_name = f"layers.{i}.self_attn.o_proj" ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) - output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + # the raw attention result, w/o o_proj. This is the output of senf_attn of FF and the input of o_proj in HF + output_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) - ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE) + # ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE) + # TP for self-attn partitions the attention heads across TP workers + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + print("comparing attention tensor: ", hf_tensor_name, " and ", ff_tensor_name) compare(hf_tensor, ff_tensor, label=f"Attention {i} output") # Post-attention layernorm @@ -365,6 +373,7 @@ def get_hf_tensor(hf_tensor_name, tensor_comparison_idx): if not os.path.isfile(hf_tensor_path): raise FileNotFoundError(f"File '{hf_tensor_path}' not found") + print("loading hf tensor: ", hf_tensor_filename) hf_tensor = torch.load(hf_tensor_path, map_location='cpu') return hf_tensor @@ -378,6 +387,7 @@ def get_ff_tensor(ff_tensor_name, tensor_comparison_idx, hf_shape, tp_type=TPTyp ff_tensor_path = ff_tensor_path.replace(f"step_{step_idx}", f"step_{step_idx}_pre") if not os.path.isfile(ff_tensor_path): raise FileNotFoundError(f"File '{ff_tensor_path}' not found") + print("loading ff tensor: ", ff_tensor_filename) ff_shape = list(hf_shape)[::-1] if tp_type == TPType.PARTITION: @@ -392,8 +402,10 @@ def get_ff_tensor(ff_tensor_name, tensor_comparison_idx, hf_shape, tp_type=TPTyp tensor_comparison_idx.ff_tensor_type == "output_gradient" or tensor_comparison_idx.ff_tensor_type == "input_gradient" ) - ) + ) and + not ff_tensor_name.endswith(".self_attn.qkv_proj") ) + print(ff_tensor_filename + (" is not truncated" if intermediate_attention_tensor else " is truncated")) if not intermediate_attention_tensor: ff_shape = replace_value(ff_shape, self.num_tokens, self.ff_batch_size) @@ -432,8 +444,10 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance print(f"Error in comparison {label}:\n{e}\n") print("HF tensor:") print(hf_tensor.squeeze()) + print(hf_tensor.shape) print("FF tensor:") print(ff_tensor.squeeze()) + print(ff_tensor.shape) raise e print(f"-- BWD pass {step_idx}--") @@ -533,11 +547,12 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance # Attn O-proj hf_tensor_name = f"layers.{i}.self_attn.o_proj" - ff_tensor_name = f"layers.{i}.layers.{i}.self_attn" + ff_tensor_name = f"layers.{i}.layers.{i}.self_attn.o_proj" + # ff_tensor_name = f"layers.{i}.layers.{i}.self_attn" output_comparison = TensorComparisonIdxs(hf_tensor_type="output_gradient", ff_tensor_type="output_gradient", hf_tensor_idx=0, ff_tensor_idx=0) - hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) - ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.REPLICATE) - compare(hf_tensor, ff_tensor, label=f"Attn O-proj {i} gradient output") + # hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + # ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE) + # compare(hf_tensor, ff_tensor, label=f"Attn O-proj {i} gradient output") ff_tensor_name = f"layers.{i}.layers.{i}.self_attn.o_proj" input_comparison = TensorComparisonIdxs(hf_tensor_type="input_gradient", ff_tensor_type="input_gradient", hf_tensor_idx=0, ff_tensor_idx=0) hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) @@ -579,7 +594,7 @@ def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance # FF Attn input with HF layernorm out hf_tensor_name = f"layers.{i}.input_layernorm" - ff_tensor_name = f"layers.{i}.layers.{i}.self_attn" + ff_tensor_name = f"layers.{i}.layers.{i}.self_attn.qkv_proj" input_comparison = TensorComparisonIdxs(hf_tensor_type="output_gradient", ff_tensor_type="input_gradient", hf_tensor_idx=0, ff_tensor_idx=0) hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE)