diff --git a/mindone/diffusers/models/attention_processor.py b/mindone/diffusers/models/attention_processor.py
index 25828ef3fe..3f7b54489b 100644
--- a/mindone/diffusers/models/attention_processor.py
+++ b/mindone/diffusers/models/attention_processor.py
@@ -14,7 +14,7 @@
from typing import Callable, Optional, Union
import mindspore as ms
-from mindspore import nn, ops
+from mindspore import mint, nn, ops
from ..image_processor import IPAdapterMaskProcessor
from ..utils import logging
@@ -544,6 +544,17 @@ def fuse_projections(self, fuse=True):
concatenated_bias = ops.cat([self.to_k.bias, self.to_v.bias])
self.to_kv.bias.set_data(concatenated_bias)
+ # handle added projections for SD3 and others.
+ if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
+ concatenated_weights = ops.cat([self.add_q_proj.weight, self.add_k_proj.weight, self.add_v_proj.weight])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_added_qkv = nn.Dense(in_features, out_features, has_bias=True, dtype=dtype)
+ self.to_added_qkv.weight.set_data(concatenated_weights)
+ concatenated_bias = ops.cat([self.add_q_proj.bias, self.add_k_proj.bias, self.add_v_proj.bias])
+ self.to_added_qkv.bias.set_data(concatenated_bias)
+
self.fused_projections = fuse
@@ -882,7 +893,7 @@ def __call__(
# `sample` projections.
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
- query, key, value = ops.split(qkv, split_size, axis=-1)
+ query, key, value = mint.split(qkv, split_size, dim=-1)
# `context` projections.
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
@@ -891,7 +902,7 @@ def __call__(
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
- ) = ops.split(encoder_qkv, split_size, axis=-1)
+ ) = mint.split(encoder_qkv, split_size, dim=-1)
# attention
query = ops.cat([query, encoder_hidden_states_query_proj], axis=1)
@@ -902,8 +913,9 @@ def __call__(
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = ops.bmm(attention_probs, value)
+ hidden_states = ops.operations.nn_ops.FlashAttentionScore(1, scale_value=attn.scale)(
+ query.to(ms.float16), key.to(ms.float16), value.to(ms.float16), None, None, None, attention_mask
+ )[3].to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
hidden_states = hidden_states.to(query.dtype)
diff --git a/mindone/diffusers/models/transformers/transformer_sd3.py b/mindone/diffusers/models/transformers/transformer_sd3.py
index 2fcec79acb..9e789919ff 100644
--- a/mindone/diffusers/models/transformers/transformer_sd3.py
+++ b/mindone/diffusers/models/transformers/transformer_sd3.py
@@ -21,7 +21,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...models.attention import JointTransformerBlock
-from ...models.attention_processor import AttentionProcessor
+from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import logging
@@ -166,6 +166,46 @@ def fn_recursive_attn_processor(name: str, module: nn.Cell, processor):
for name, module in self.name_cells().items():
fn_recursive_attn_processor(name, module, processor)
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for _, module in self.cells_and_names():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedJointAttnProcessor())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value