Skip to content

Commit

Permalink
fix(diffusers): fixed up timing and generation/s calculation.
Browse files Browse the repository at this point in the history
  • Loading branch information
imangohari1 committed Aug 19, 2024
1 parent cb8bae0 commit 1d6d047
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from optimum.utils import logging

from ....transformers.gaudi_configuration import GaudiConfig
from ....utils import HabanaProfile, speed_metrics
from ....utils import HabanaProfile, speed_metrics, warmup_inference_steps_time_adjustment
from ..pipeline_utils import GaudiDiffusionPipeline
from ..stable_diffusion.pipeline_stable_diffusion import (
GaudiStableDiffusionPipeline,
Expand Down Expand Up @@ -497,11 +497,17 @@ def __call__(

# 8. Denoising loop
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
use_warmup_inference_steps = (
num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
)

for j in self.progress_bar(range(num_batches)):
# The throughput is calculated from the 3rd iteration
# because compilation occurs in the first two iterations
if j == throughput_warmup_steps:
t1 = time.time()
if use_warmup_inference_steps:
t0_inf = time.time()

latents_batch = latents_batches[0]
latents_batches = torch.roll(latents_batches, shifts=-1, dims=0)
Expand All @@ -510,6 +516,11 @@ def __call__(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

for i in range(num_inference_steps):
ts=time.time()
if use_warmup_inference_steps and i == throughput_warmup_steps:
t1_inf = time.time()
t1 += t1_inf - t0_inf

t = timesteps[0]
timesteps = torch.roll(timesteps, shifts=-1, dims=0)

Expand Down Expand Up @@ -598,6 +609,13 @@ def __call__(

hb_profiler.step()

logger.info(f"i {i}elapsed {time.time()-ts}")

if use_warmup_inference_steps:
t1 = warmup_inference_steps_time_adjustment(
t1, t1_inf, num_inference_steps, throughput_warmup_steps
)

if not output_type == "latent":
# 8. Post-processing
output_image = self.vae.decode(
Expand All @@ -617,9 +635,9 @@ def __call__(
split=speed_metrics_prefix,
start_time=t0,
num_samples=num_batches * batch_size
if t1 == t0
if t1 == t0 or use_warmup_inference_steps
else (num_batches - throughput_warmup_steps) * batch_size,
num_steps=num_batches,
num_steps=num_batches * batch_size * num_inference_steps,
start_time_after_warmup=t1,
)
logger.info(f"Speed metrics: {speed_measures}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def __call__(
# 8. Denoising loop
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
use_warmup_inference_steps = (
num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
)

for j in self.progress_bar(range(num_batches)):
Expand All @@ -513,6 +513,7 @@ def __call__(
text_embeddings_batches = torch.roll(text_embeddings_batches, shifts=-1, dims=0)

for i in range(len(timesteps)):
ts=time.time()
if use_warmup_inference_steps and i == throughput_warmup_steps:
t1_inf = time.time()
t1 += t1_inf - t0_inf
Expand Down Expand Up @@ -573,6 +574,7 @@ def __call__(
callback(step_idx, timestep, latents_batch)

hb_profiler.step()
logger.info(f"i {i}elapsed {time.time()-ts}")

if use_warmup_inference_steps:
t1 = warmup_inference_steps_time_adjustment(
Expand Down Expand Up @@ -600,7 +602,7 @@ def __call__(
num_samples=num_batches * batch_size
if t1 == t0 or use_warmup_inference_steps
else (num_batches - throughput_warmup_steps) * batch_size,
num_steps=num_batches,
num_steps=num_batches * batch_size * num_inference_steps,
start_time_after_warmup=t1,
)
logger.info(f"Speed metrics: {speed_measures}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def __call__(
t1 = t0
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
use_warmup_inference_steps = (
num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
)
for j in self.progress_bar(range(num_batches)):
# The throughput is calculated from the 3rd iteration
Expand Down Expand Up @@ -376,7 +376,7 @@ def __call__(
num_samples=num_batches * batch_size
if t1 == t0 or use_warmup_inference_steps
else (num_batches - throughput_warmup_steps) * batch_size,
num_steps=num_batches,
num_steps=num_batches * batch_size * num_inference_steps,
start_time_after_warmup=t1,
)
logger.info(f"Speed metrics: {speed_measures}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def __call__(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
use_warmup_inference_steps = (
num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
)

self._num_timesteps = len(timesteps)
Expand Down Expand Up @@ -715,7 +715,7 @@ def __call__(
num_samples=num_batches * batch_size
if t1 == t0 or use_warmup_inference_steps
else (num_batches - throughput_warmup_steps) * batch_size,
num_steps=num_batches,
num_steps=num_batches * batch_size * num_inference_steps,
start_time_after_warmup=t1,
)
logger.info(f"Speed metrics: {speed_measures}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def __call__(
t1 = t0
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
use_warmup_inference_steps = (
num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
)
for j in self.progress_bar(range(num_batches)):
# The throughput is calculated from the 3rd iteration
Expand All @@ -414,6 +414,7 @@ def __call__(
prompt_embeds_batches = torch.roll(prompt_embeds_batches, shifts=-1, dims=0)

for i in range(len(timesteps)):
ts=time.time()
if use_warmup_inference_steps and i == throughput_warmup_steps:
t1_inf = time.time()
t1 += t1_inf - t0_inf
Expand Down Expand Up @@ -473,6 +474,8 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents_batch)
hb_profiler.step()
logger.info(f"i {i} elapsed {time.time()-ts}")

if use_warmup_inference_steps:
t1 = warmup_inference_steps_time_adjustment(
t1, t1_inf, num_inference_steps, throughput_warmup_steps
Expand All @@ -487,14 +490,15 @@ def __call__(
self.htcore.mark_step()

hb_profiler.stop()
logger.info(f"t1-t0 {t1-t0} num_samples {num_batches * batch_size if t1 == t0 or use_warmup_inference_steps else (num_batches - throughput_warmup_steps) * batch_size} ")
speed_metrics_prefix = "generation"
speed_measures = speed_metrics(
split=speed_metrics_prefix,
start_time=t0,
num_samples=num_batches * batch_size
if t1 == t0 or use_warmup_inference_steps
else (num_batches - throughput_warmup_steps) * batch_size,
num_steps=num_batches,
num_steps=num_batches * batch_size * num_inference_steps,
start_time_after_warmup=t1,
)
logger.info(f"Speed metrics: {speed_measures}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def __call__(
# 8. Denoising loop
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
use_warmup_inference_steps = (
num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
)

for j in self.progress_bar(range(num_batches)):
Expand All @@ -358,6 +358,7 @@ def __call__(
text_embeddings_batches = torch.roll(text_embeddings_batches, shifts=-1, dims=0)

for i in range(len(timesteps)):
ts=time.time()
if use_warmup_inference_steps and i == throughput_warmup_steps:
t1_inf = time.time()
t1 += t1_inf - t0_inf
Expand Down Expand Up @@ -398,6 +399,8 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents_batch)

logger.info(f"i {i}elapsed {time.time()-ts}")

if use_warmup_inference_steps:
t1 = warmup_inference_steps_time_adjustment(
t1, t1_inf, num_inference_steps, throughput_warmup_steps
Expand All @@ -420,7 +423,7 @@ def __call__(
num_samples=num_batches * batch_size
if t1 == t0 or use_warmup_inference_steps
else (num_batches - throughput_warmup_steps) * batch_size,
num_steps=num_batches,
num_steps=num_batches * batch_size * num_inference_steps,
start_time_after_warmup=t1,
)
logger.info(f"Speed metrics: {speed_measures}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def __call__(
# 10. Denoising loop
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
use_warmup_inference_steps = (
num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
)

for j in self.progress_bar(range(num_batches)):
Expand Down Expand Up @@ -541,7 +541,7 @@ def __call__(
num_samples=num_batches * batch_size
if t1 == t0 or use_warmup_inference_steps
else (num_batches - throughput_warmup_steps) * batch_size,
num_steps=num_batches,
num_steps=num_batches * batch_size * num_inference_steps,
start_time_after_warmup=t1,
)
logger.info(f"Speed metrics: {speed_measures}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def __call__(
# 8.3 Denoising loop
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
use_warmup_inference_steps = (
num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
)

for j in self.progress_bar(range(num_batches)):
Expand All @@ -718,6 +718,7 @@ def __call__(
self.scheduler._init_step_index(timesteps[0])

for i in range(num_inference_steps):
ts=time.time()
if use_warmup_inference_steps and i == throughput_warmup_steps:
t1_inf = time.time()
t1 += t1_inf - t0_inf
Expand Down Expand Up @@ -792,6 +793,7 @@ def __call__(
callback(step_idx, timestep, latents)

hb_profiler.step()
logger.info(f"i {i}elapsed {time.time()-ts}")

if use_warmup_inference_steps:
t1 = warmup_inference_steps_time_adjustment(
Expand Down Expand Up @@ -823,7 +825,7 @@ def __call__(
num_samples=num_batches * batch_size
if t1 == t0 or use_warmup_inference_steps
else (num_batches - throughput_warmup_steps) * batch_size,
num_steps=num_batches,
num_steps=num_batches * batch_size * num_inference_steps,
start_time_after_warmup=t1,
)
logger.info(f"Speed metrics: {speed_measures}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def denoising_value_valid(dnv):
# 8.3 Denoising loop
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
use_warmup_inference_steps = (
num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
)
for j in self.progress_bar(range(num_batches)):
# The throughput is calculated from the 3rd iteration
Expand All @@ -559,6 +559,7 @@ def denoising_value_valid(dnv):
add_time_ids_batch = add_time_ids_batches[0]
add_time_ids_batches = torch.roll(add_time_ids_batches, shifts=-1, dims=0)


if hasattr(self.scheduler, "_init_step_index"):
# Reset scheduler step index for next batch
self.scheduler._init_step_index(timesteps[0])
Expand Down Expand Up @@ -672,7 +673,7 @@ def denoising_value_valid(dnv):
num_samples=num_batches * batch_size
if t1 == t0 or use_warmup_inference_steps
else (num_batches - throughput_warmup_steps) * batch_size,
num_steps=num_batches,
num_steps=num_batches * batch_size * num_inference_steps,
start_time_after_warmup=t1,
)
logger.info(f"Speed metrics: {speed_measures}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def denoising_value_valid(dnv):
t1 = t0
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
use_warmup_inference_steps = (
num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
)

for j in self.progress_bar(range(num_batches)):
Expand Down Expand Up @@ -920,7 +920,7 @@ def denoising_value_valid(dnv):
num_samples=num_batches * batch_size
if t1 == t0 or use_warmup_inference_steps
else (num_batches - throughput_warmup_steps) * batch_size,
num_steps=num_batches,
num_steps=num_batches * batch_size * num_inference_steps,
start_time_after_warmup=t1,
)
logger.info(f"Speed metrics: {speed_measures}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def __call__(
# 10. Denoising loop
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
use_warmup_inference_steps = (
num_batches < throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps
)
self._num_timesteps = len(timesteps)
for j in self.progress_bar(range(num_batches)):
Expand All @@ -493,6 +493,7 @@ def __call__(
added_time_ids_batches = torch.roll(added_time_ids_batches, shifts=-1, dims=0)

for i in self.progress_bar(range(num_inference_steps)):
ts=time.time()
if use_warmup_inference_steps and i == throughput_warmup_steps:
t1 += time.time() - t0_inf

Expand Down Expand Up @@ -533,6 +534,8 @@ def __call__(

latents_batch = callback_outputs.pop("latents", latents_batch)

logger.info(f"i {i}elapsed {time.time()-ts}")

if not output_type == "latent":
# cast back to fp16/bf16 if needed
if needs_upcasting:
Expand All @@ -552,7 +555,7 @@ def __call__(
num_samples=num_batches * batch_size
if t1 == t0 or use_warmup_inference_steps
else (num_batches - throughput_warmup_steps) * batch_size,
num_steps=num_batches,
num_steps=num_batches * batch_size * num_inference_steps,
start_time_after_warmup=t1,
)
logger.info(f"Speed metrics: {speed_measures}")
Expand Down

0 comments on commit 1d6d047

Please sign in to comment.