Skip to content

Commit

Permalink
add testcase
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi committed Aug 29, 2024
1 parent 8394c03 commit 7f81362
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 2 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ slow_tests_custom_file_input: test_installs
# Run single-card non-regression tests
slow_tests_1x: test_installs
python -m pytest tests/test_examples.py -v -s -k "single_card"
python -m pip install peft==0.10.0
python -m pip install peft==0.12.0
python -m pytest tests/test_peft_inference.py
python -m pytest tests/test_pipeline.py

Expand All @@ -96,7 +96,7 @@ slow_tests_deepspeed: test_installs
slow_tests_diffusers: test_installs
python -m pytest tests/test_diffusers.py -v -s -k "test_no_"
python -m pytest tests/test_diffusers.py -v -s -k "test_textual_inversion"
python -m pip install peft==0.7.0
python -m pip install peft==0.12.0
python -m pytest tests/test_diffusers.py -v -s -k "test_train_text_to_image_"
python -m pytest tests/test_diffusers.py -v -s -k "test_train_controlnet"
python -m pytest tests/test_diffusers.py -v -s -k "test_deterministic_image_generation"
Expand Down
121 changes: 121 additions & 0 deletions tests/test_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,88 @@ class GaudiStableDiffusionPipelineTester(TestCase):
Tests the StableDiffusionPipeline for Gaudi.
"""

def merge_peft_adapter(self, model, adapter):
from peft import BOFTConfig, LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, get_peft_model

UNET_TARGET_MODULES = [
"to_q",
"to_k",
"to_v",
"proj",
"proj_in",
"proj_out",
"conv",
"conv1",
"conv2",
"conv_shortcut",
"to_out.0",
"time_emb_proj",
"ff.net.2",
]
TEXT_ENCODER_TARGET_MODULES = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"]
target_modules = (
UNET_TARGET_MODULES if isinstance(model, UNet2DConditionModel) else TEXT_ENCODER_TARGET_MODULES
)

if adapter == "lora":
config = LoraConfig(
r=2,
lora_alpha=2,
target_modules=target_modules,
lora_dropout=0.0,
bias="none",
init_lora_weights=True,
)
elif adapter == "loha":
config = LoHaConfig(
r=2,
alpha=2,
target_modules=target_modules,
rank_dropout=0.0,
module_dropout=0.0,
use_effective_conv2d=False,
init_weights=True,
)
elif adapter == "lokr":
config = LoKrConfig(
r=2,
alpha=2,
target_modules=target_modules,
rank_dropout=0.0,
module_dropout=0.0,
use_effective_conv2d=False,
decompose_both=False,
decompose_factor=-1,
init_weights=True,
)
elif adapter == "oft":
config = OFTConfig(
r=2,
target_modules=target_modules,
module_dropout=0.0,
init_weights=True,
coft=False,
eps=0.0,
)
elif adapter == "boft":
from peft import tuners

from optimum.habana.peft.layer import GaudiBoftGetDeltaWeight

tuners.boft.layer.Linear.get_delta_weight = GaudiBoftGetDeltaWeight
tuners.boft.layer.Conv2d.get_delta_weight = GaudiBoftGetDeltaWeight
tuners.boft.layer._FBD_CUDA = False
config = BOFTConfig(
boft_block_size=1,
boft_block_num=0,
boft_n_butterfly_factor=1,
target_modules=target_modules,
boft_dropout=0.1,
bias="boft_only",
)
model = get_peft_model(model, config)
return model.merge_and_unload()

def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
Expand Down Expand Up @@ -612,6 +694,45 @@ def test_stable_diffusion_hpu_graphs(self):
self.assertEqual(len(images), 10)
self.assertEqual(images[-1].shape, (64, 64, 3))

@parameterized.expand(["lora", "loha", "lokr", "oft", "boft"])
@slow
def test_no_peft_regression_bf16(self, peft_adapter):
prompts = [
"An image of a squirrel in Picasso style",
]
num_images_per_prompt = 1
batch_size = 1
model_name = "runwayml/stable-diffusion-v1-5"
scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
pipeline = GaudiStableDiffusionPipeline.from_pretrained(
model_name,
scheduler=scheduler,
use_habana=True,
use_hpu_graphs=True,
gaudi_config=GaudiConfig.from_pretrained("Habana/stable-diffusion"),
torch_dtype=torch.bfloat16,
)
if peft_adapter not in ["boft", "oft"]:
with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True):
pipeline.unet = self.merge_peft_adapter(pipeline.unet, peft_adapter)
pipeline.text_encoder = self.merge_peft_adapter(pipeline.text_encoder, peft_adapter)
else:
# WA torch.inverse issue in Synapse AI 1.17 for oft and boft
pipeline.unet = pipeline.unet.to(torch.float32)
pipeline.unet = self.merge_peft_adapter(pipeline.unet, peft_adapter)
pipeline.unet = pipeline.unet.to(torch.bfloat16)
pipeline.text_encoder = pipeline.text_encoder.to(torch.float32)
pipeline.text_encoder = self.merge_peft_adapter(pipeline.text_encoder, peft_adapter)
pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16)

set_seed(27)
outputs = pipeline(
prompt=prompts,
num_images_per_prompt=num_images_per_prompt,
batch_size=batch_size,
)
self.assertEqual(len(outputs.images), num_images_per_prompt * len(prompts))

@slow
def test_no_throughput_regression_bf16(self):
prompts = [
Expand Down

0 comments on commit 7f81362

Please sign in to comment.