Skip to content

Commit

Permalink
Fix Mochi Quality Issues (#10033)
Browse files Browse the repository at this point in the history
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update src/diffusers/models/transformers/transformer_mochi.py

Co-authored-by: Aryan <aryan@huggingface.co>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
  • Loading branch information
3 people authored Dec 17, 2024
1 parent e24941b commit 128b96f
Show file tree
Hide file tree
Showing 7 changed files with 337 additions and 159 deletions.
261 changes: 172 additions & 89 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,177 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.processor(self, hidden_states)


class MochiAttention(nn.Module):
def __init__(
self,
query_dim: int,
added_kv_proj_dim: int,
processor: "MochiAttnProcessor2_0",
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_proj_bias: bool = True,
out_dim: Optional[int] = None,
out_context_dim: Optional[int] = None,
out_bias: bool = True,
context_pre_only: bool = False,
eps: float = 1e-5,
):
super().__init__()
from .normalization import MochiRMSNorm

self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim else query_dim
self.context_pre_only = context_pre_only

self.heads = out_dim // dim_head if out_dim is not None else heads

self.norm_q = MochiRMSNorm(dim_head, eps, True)
self.norm_k = MochiRMSNorm(dim_head, eps, True)
self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
self.norm_added_k = MochiRMSNorm(dim_head, eps, True)

self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)

self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)

self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))

if not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)

self.processor = processor

def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**kwargs,
)


class MochiAttnProcessor2_0:
"""Attention processor used in Mochi."""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")

def __call__(
self,
attn: "MochiAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)

encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))

if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)

if image_rotary_emb is not None:

def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float()
x_odd = x[..., 1::2].float()

cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)

return torch.stack([cos, sin], dim=-1).flatten(-2)

query = apply_rotary_emb(query, *image_rotary_emb)
key = apply_rotary_emb(key, *image_rotary_emb)

query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
encoder_query, encoder_key, encoder_value = (
encoder_query.transpose(1, 2),
encoder_key.transpose(1, 2),
encoder_value.transpose(1, 2),
)

sequence_length = query.size(2)
encoder_sequence_length = encoder_query.size(2)
total_length = sequence_length + encoder_sequence_length

batch_size, heads, _, dim = query.shape
attn_outputs = []
for idx in range(batch_size):
mask = attention_mask[idx][None, :]
valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()

valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]

valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)

attn_output = F.scaled_dot_product_attention(
valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False
)
valid_sequence_length = attn_output.size(2)
attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
attn_outputs.append(attn_output)

hidden_states = torch.cat(attn_outputs, dim=0)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)

hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1
)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

return hidden_states, encoder_hidden_states


class AttnProcessor:
r"""
Default processor for performing attention-related computations.
Expand Down Expand Up @@ -3868,94 +4039,6 @@ def __call__(
return hidden_states


class MochiAttnProcessor2_0:
"""Attention processor used in Mochi."""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)

encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))

if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)

if image_rotary_emb is not None:

def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float()
x_odd = x[..., 1::2].float()

cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)

return torch.stack([cos, sin], dim=-1).flatten(-2)

query = apply_rotary_emb(query, *image_rotary_emb)
key = apply_rotary_emb(key, *image_rotary_emb)

query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
encoder_query, encoder_key, encoder_value = (
encoder_query.transpose(1, 2),
encoder_key.transpose(1, 2),
encoder_value.transpose(1, 2),
)

sequence_length = query.size(2)
encoder_sequence_length = encoder_query.size(2)

query = torch.cat([query, encoder_query], dim=2)
key = torch.cat([key, encoder_key], dim=2)
value = torch.cat([value, encoder_value], dim=2)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)

hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1
)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if getattr(attn, "to_add_out", None) is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

return hidden_states, encoder_hidden_states


class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
Expand Down Expand Up @@ -5668,13 +5751,13 @@ def __call__(
AttnProcessorNPU,
AttnProcessor2_0,
MochiVaeAttnProcessor2_0,
MochiAttnProcessor2_0,
StableAudioAttnProcessor2_0,
HunyuanAttnProcessor2_0,
FusedHunyuanAttnProcessor2_0,
PAGHunyuanAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
LuminaAttnProcessor2_0,
MochiAttnProcessor2_0,
FusedAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,6 @@ def forward(self, latent):
height, width = latent.shape[-2:]
else:
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size

latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
Expand Down
57 changes: 30 additions & 27 deletions src/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,33 +234,6 @@ def forward(
return x, gate_msa, scale_mlp, gate_mlp


class MochiRMSNormZero(nn.Module):
r"""
Adaptive RMS Norm used in Mochi.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
"""

def __init__(
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
) -> None:
super().__init__()

self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, hidden_dim)
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)

def forward(
self, hidden_states: torch.Tensor, emb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None])

return hidden_states, gate_msa, scale_mlp, gate_mlp


class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
Expand Down Expand Up @@ -549,6 +522,36 @@ def forward(self, hidden_states):
return hidden_states


# TODO: (Dhruv) This can be replaced with regular RMSNorm in Mochi once `_keep_in_fp32_modules` is supported
# for sharded checkpoints, see: https://github.com/huggingface/diffusers/issues/10013
class MochiRMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
super().__init__()

self.eps = eps

if isinstance(dim, numbers.Integral):
dim = (dim,)

self.dim = torch.Size(dim)

if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

if self.weight is not None:
hidden_states = hidden_states * self.weight
hidden_states = hidden_states.to(input_dtype)

return hidden_states


class GlobalResponseNorm(nn.Module):
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
def __init__(self, dim):
Expand Down
Loading

0 comments on commit 128b96f

Please sign in to comment.