Skip to content

Commit

Permalink
feat(diffusers): support fused qkv projections for SD3 (#659)
Browse files Browse the repository at this point in the history
  • Loading branch information
townwish4git authored Sep 20, 2024
1 parent 7159716 commit 70f4251
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 6 deletions.
22 changes: 17 additions & 5 deletions mindone/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Callable, Optional, Union

import mindspore as ms
from mindspore import nn, ops
from mindspore import mint, nn, ops

from ..image_processor import IPAdapterMaskProcessor
from ..utils import logging
Expand Down Expand Up @@ -544,6 +544,17 @@ def fuse_projections(self, fuse=True):
concatenated_bias = ops.cat([self.to_k.bias, self.to_v.bias])
self.to_kv.bias.set_data(concatenated_bias)

# handle added projections for SD3 and others.
if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
concatenated_weights = ops.cat([self.add_q_proj.weight, self.add_k_proj.weight, self.add_v_proj.weight])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

self.to_added_qkv = nn.Dense(in_features, out_features, has_bias=True, dtype=dtype)
self.to_added_qkv.weight.set_data(concatenated_weights)
concatenated_bias = ops.cat([self.add_q_proj.bias, self.add_k_proj.bias, self.add_v_proj.bias])
self.to_added_qkv.bias.set_data(concatenated_bias)

self.fused_projections = fuse


Expand Down Expand Up @@ -882,7 +893,7 @@ def __call__(
# `sample` projections.
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = ops.split(qkv, split_size, axis=-1)
query, key, value = mint.split(qkv, split_size, dim=-1)

# `context` projections.
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
Expand All @@ -891,7 +902,7 @@ def __call__(
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
) = ops.split(encoder_qkv, split_size, axis=-1)
) = mint.split(encoder_qkv, split_size, dim=-1)

# attention
query = ops.cat([query, encoder_hidden_states_query_proj], axis=1)
Expand All @@ -902,8 +913,9 @@ def __call__(
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = ops.bmm(attention_probs, value)
hidden_states = ops.operations.nn_ops.FlashAttentionScore(1, scale_value=attn.scale)(
query.to(ms.float16), key.to(ms.float16), value.to(ms.float16), None, None, None, attention_mask
)[3].to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
hidden_states = hidden_states.to(query.dtype)

Expand Down
42 changes: 41 additions & 1 deletion mindone/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...models.attention import JointTransformerBlock
from ...models.attention_processor import AttentionProcessor
from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import logging
Expand Down Expand Up @@ -166,6 +166,46 @@ def fn_recursive_attn_processor(name: str, module: nn.Cell, processor):
for name, module in self.name_cells().items():
fn_recursive_attn_processor(name, module, processor)

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None

for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

self.original_attn_processors = self.attn_processors

for _, module in self.cells_and_names():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)

self.set_attn_processor(FusedJointAttnProcessor())

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)

def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
Expand Down

0 comments on commit 70f4251

Please sign in to comment.