From 3335e2262d47e7d7e311a44dea7f454b5f01b643 Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Thu, 5 Dec 2024 18:42:48 +0530 Subject: [PATCH] [FIX] Bug in FluxPosEmbed (#10115) * Fix get_1d_rotary_pos_embed in embedding.py * Update embeddings.py --------- Co-authored-by: hlky --- src/diffusers/models/embeddings.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 91451fa9aac2..8f8f1073da74 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -959,7 +959,12 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: freqs_dtype = torch.float32 if is_mps else torch.float64 for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( - self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, ) cos_out.append(cos) sin_out.append(sin)