diff --git a/mindone/diffusers/models/attention_processor.py b/mindone/diffusers/models/attention_processor.py index 25828ef3fe..3f7b54489b 100644 --- a/mindone/diffusers/models/attention_processor.py +++ b/mindone/diffusers/models/attention_processor.py @@ -14,7 +14,7 @@ from typing import Callable, Optional, Union import mindspore as ms -from mindspore import nn, ops +from mindspore import mint, nn, ops from ..image_processor import IPAdapterMaskProcessor from ..utils import logging @@ -544,6 +544,17 @@ def fuse_projections(self, fuse=True): concatenated_bias = ops.cat([self.to_k.bias, self.to_v.bias]) self.to_kv.bias.set_data(concatenated_bias) + # handle added projections for SD3 and others. + if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): + concatenated_weights = ops.cat([self.add_q_proj.weight, self.add_k_proj.weight, self.add_v_proj.weight]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = nn.Dense(in_features, out_features, has_bias=True, dtype=dtype) + self.to_added_qkv.weight.set_data(concatenated_weights) + concatenated_bias = ops.cat([self.add_q_proj.bias, self.add_k_proj.bias, self.add_v_proj.bias]) + self.to_added_qkv.bias.set_data(concatenated_bias) + self.fused_projections = fuse @@ -882,7 +893,7 @@ def __call__( # `sample` projections. qkv = attn.to_qkv(hidden_states) split_size = qkv.shape[-1] // 3 - query, key, value = ops.split(qkv, split_size, axis=-1) + query, key, value = mint.split(qkv, split_size, dim=-1) # `context` projections. encoder_qkv = attn.to_added_qkv(encoder_hidden_states) @@ -891,7 +902,7 @@ def __call__( encoder_hidden_states_query_proj, encoder_hidden_states_key_proj, encoder_hidden_states_value_proj, - ) = ops.split(encoder_qkv, split_size, axis=-1) + ) = mint.split(encoder_qkv, split_size, dim=-1) # attention query = ops.cat([query, encoder_hidden_states_query_proj], axis=1) @@ -902,8 +913,9 @@ def __call__( key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = ops.bmm(attention_probs, value) + hidden_states = ops.operations.nn_ops.FlashAttentionScore(1, scale_value=attn.scale)( + query.to(ms.float16), key.to(ms.float16), value.to(ms.float16), None, None, None, attention_mask + )[3].to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = hidden_states.to(query.dtype) diff --git a/mindone/diffusers/models/transformers/transformer_sd3.py b/mindone/diffusers/models/transformers/transformer_sd3.py index 2fcec79acb..9e789919ff 100644 --- a/mindone/diffusers/models/transformers/transformer_sd3.py +++ b/mindone/diffusers/models/transformers/transformer_sd3.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...models.attention import JointTransformerBlock -from ...models.attention_processor import AttentionProcessor +from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous from ...utils import logging @@ -166,6 +166,46 @@ def fn_recursive_attn_processor(name: str, module: nn.Cell, processor): for name, module in self.name_cells().items(): fn_recursive_attn_processor(name, module, processor) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for _, module in self.cells_and_names(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedJointAttnProcessor()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value