Skip to content

Commit

Permalink
fea(): Updates the func inputs. Added G1 test
Browse files Browse the repository at this point in the history
  • Loading branch information
imangohari1 committed Aug 28, 2024
1 parent 21b5829 commit 2eb1fe1
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions tests/test_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,11 +1256,16 @@ def test_stable_diffusion_xl_inference_script(self):
# Ensure the run finished without any issue
self.assertEqual(return_code, 0)

_sdxl_inferece_throughput_data = (("ddim", 0.301), ("euler_discrete", 0.301))

@parameterized.expand(_sdxl_inferece_throughput_data)
def test_stable_diffusion_xl_generation_throughput(self, scheduler, baseline):
def _sdxl_generation(self, scheduler: str, baseline: float):
if IS_GAUDI2:
_sdxl_inferece_throughput_data = (("ddim", 1, 10, 0.301), ("euler_discrete", 1, 10, 0.301))
else:
_sdxl_inferece_throughput_data = (("ddim", 1, 10, 0.074),)

@parameterized.expand(_sdxl_inferece_throughput_data, skip_on_empty=True)
def test_stable_diffusion_xl_generation_throughput(
self, scheduler: str, batch_size: int, num_images_per_prompt: int, baseline: float
):
def _sdxl_generation(self, scheduler: str, batch_size: int, num_images_per_prompt: int, baseline: float):
kwargs = {"timestep_spacing": "linspace"}
if scheduler == "euler_discrete":
scheduler = GaudiEulerDiscreteScheduler.from_pretrained(
Expand All @@ -1281,18 +1286,18 @@ def _sdxl_generation(self, scheduler: str, baseline: float):
"stabilityai/stable-diffusion-xl-base-1.0",
**kwargs,
)
num_images_per_prompt = 10
num_images_per_prompt = num_images_per_prompt
res = {}
outputs = pipeline(
prompt="Sailing ship painting by Van Gogh",
num_images_per_prompt=num_images_per_prompt,
batch_size=1,
batch_size=batch_size,
num_inference_steps=30,
**res,
)
self.assertGreaterEqual(outputs.throughput, 0.95 * baseline)

_sdxl_generation(self, scheduler, baseline)
_sdxl_generation(self, scheduler, batch_size, num_images_per_prompt, baseline)


class GaudiStableDiffusion3PipelineTester(TestCase):
Expand Down

0 comments on commit 2eb1fe1

Please sign in to comment.