Skip to content

Commit

Permalink
Add dynamic_shifting to SD3 (#10236)
Browse files Browse the repository at this point in the history
* Add `dynamic_shifting` to SD3

* calculate_shift

* FlowMatchHeunDiscreteScheduler doesn't support mu

* Inpaint/img2img
  • Loading branch information
hlky authored Dec 16, 2024
1 parent 672bd49 commit a7d5052
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@
"""


# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
Expand Down Expand Up @@ -702,6 +716,7 @@ def __call__(
skip_layer_guidance_scale: int = 2.8,
skip_layer_guidance_stop: int = 0.2,
skip_layer_guidance_start: int = 0.01,
mu: Optional[float] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -802,6 +817,7 @@ def __call__(
`skip_guidance_layers` will start. The guidance will be applied to the layers specified in
`skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
Examples:
Expand Down Expand Up @@ -882,12 +898,7 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)

# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)

# 5. Prepare latent variables
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
Expand All @@ -900,6 +911,33 @@ def __call__(
latents,
)

# 5. Prepare timesteps
scheduler_kwargs = {}
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
_, _, height, width = latents.shape
image_seq_len = (height // self.transformer.config.patch_size) * (
width // self.transformer.config.patch_size
)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
scheduler_kwargs["mu"] = mu
elif mu is not None:
scheduler_kwargs["mu"] = mu
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
**scheduler_kwargs,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)

# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,20 @@
"""


# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
Expand Down Expand Up @@ -748,6 +762,7 @@ def __call__(
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
mu: Optional[float] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -832,6 +847,7 @@ def __call__(
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
Examples:
Expand Down Expand Up @@ -913,7 +929,24 @@ def __call__(
image = self.image_processor.preprocess(image, height=height, width=width)

# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
scheduler_kwargs = {}
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * (
int(width) // self.vae_scale_factor // self.transformer.config.patch_size
)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
scheduler_kwargs["mu"] = mu
elif mu is not None:
scheduler_kwargs["mu"] = mu
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@
"""


# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
Expand Down Expand Up @@ -838,6 +852,7 @@ def __call__(
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
mu: Optional[float] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -947,6 +962,7 @@ def __call__(
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
Examples:
Expand Down Expand Up @@ -1023,7 +1039,24 @@ def __call__(
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)

# 3. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
scheduler_kwargs = {}
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * (
int(width) // self.vae_scale_factor // self.transformer.config.patch_size
)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
scheduler_kwargs["mu"] = mu
elif mu is not None:
scheduler_kwargs["mu"] = mu
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
# check that number of inference steps is not < 1 - as this doesn't make sense
if num_inference_steps < 1:
Expand Down

0 comments on commit a7d5052

Please sign in to comment.