Skip to content

Commit

Permalink
Manually implement PR #1359
Browse files Browse the repository at this point in the history
Thanks @Imageder for the contribution!
  • Loading branch information
d8ahazard committed Nov 15, 2023
1 parent a24c616 commit e27f707
Show file tree
Hide file tree
Showing 6 changed files with 1,508 additions and 1,500 deletions.
7 changes: 3 additions & 4 deletions dreambooth/diff_to_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dreambooth.shared import status
from dreambooth.utils.model_utils import unload_system_models, \
reload_system_models, \
disable_safe_unpickle, enable_safe_unpickle, import_model_class_from_model_name_or_path
safe_unpickle_disabled, import_model_class_from_model_name_or_path
from dreambooth.utils.utils import printi
from helpers.mytqdm import mytqdm
from lora_diffusion.lora import merge_lora_to_model
Expand Down Expand Up @@ -562,9 +562,8 @@ def load_model(model_path: str, map_location: str):
if ".safetensors" in model_path:
return safetensors.torch.load_file(model_path, device=map_location)
else:
disable_safe_unpickle()
loaded = torch.load(model_path, map_location=map_location)
enable_safe_unpickle()
with safe_unpickle_disabled():
loaded = torch.load(model_path, map_location=map_location)
return loaded


Expand Down
106 changes: 53 additions & 53 deletions dreambooth/sd_to_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from dreambooth import shared
from dreambooth.dataclasses.db_config import DreamboothConfig
from dreambooth.utils.model_utils import enable_safe_unpickle, disable_safe_unpickle, unload_system_models, \
from dreambooth.utils.model_utils import safe_unpickle_disabled, unload_system_models, \
reload_system_models


Expand Down Expand Up @@ -131,7 +131,6 @@ def extract_checkpoint(
# sh.update_status(status)
# else:
# modules.shared.status.update(status)
disable_safe_unpickle()
if image_size is None:
image_size = 512
if model_type == "v2x":
Expand Down Expand Up @@ -162,59 +161,60 @@ def extract_checkpoint(
db_config.resolution = image_size
db_config.save()
try:
if from_safetensors:
if model_type == "SDXL":
pipe = StableDiffusionXLPipeline.from_single_file(
pretrained_model_link_or_path=checkpoint_file,
with safe_unpickle_disabled():
if from_safetensors:
if model_type == "SDXL":
pipe = StableDiffusionXLPipeline.from_single_file(
pretrained_model_link_or_path=checkpoint_file,
)
else:
pipe = StableDiffusionPipeline.from_single_file(
pretrained_model_link_or_path=checkpoint_file,
)
elif model_type == "SDXL":
pipe = StableDiffusionXLPipeline.from_pretrained(
checkpoint_path_or_dict=checkpoint_file,
original_config_file=original_config_file,
image_size=image_size,
prediction_type=prediction_type,
model_type=pipeline_type,
extract_ema=extract_ema,
scheduler_type=scheduler_type,
num_in_channels=num_in_channels,
upcast_attention=upcast_attention,
from_safetensors=from_safetensors,
device=device,
pretrained_model_name_or_path=checkpoint_file,
stable_unclip=stable_unclip,
stable_unclip_prior=stable_unclip_prior,
clip_stats_path=clip_stats_path,
controlnet=controlnet,
vae_path=vae_path,
pipeline_class=pipeline_class,
half=half
)
else:
pipe = StableDiffusionPipeline.from_single_file(
pretrained_model_link_or_path=checkpoint_file,
pipe = StableDiffusionPipeline.from_pretrained(
checkpoint_path_or_dict=checkpoint_file,
original_config_file=original_config_file,
image_size=image_size,
prediction_type=prediction_type,
model_type=pipeline_type,
extract_ema=extract_ema,
scheduler_type=scheduler_type,
num_in_channels=num_in_channels,
upcast_attention=upcast_attention,
from_safetensors=from_safetensors,
device=device,
pretrained_model_name_or_path=checkpoint_file,
stable_unclip=stable_unclip,
stable_unclip_prior=stable_unclip_prior,
clip_stats_path=clip_stats_path,
controlnet=controlnet,
vae_path=vae_path,
pipeline_class=pipeline_class,
half=half
)
elif model_type == "SDXL":
pipe = StableDiffusionXLPipeline.from_pretrained(
checkpoint_path_or_dict=checkpoint_file,
original_config_file=original_config_file,
image_size=image_size,
prediction_type=prediction_type,
model_type=pipeline_type,
extract_ema=extract_ema,
scheduler_type=scheduler_type,
num_in_channels=num_in_channels,
upcast_attention=upcast_attention,
from_safetensors=from_safetensors,
device=device,
pretrained_model_name_or_path=checkpoint_file,
stable_unclip=stable_unclip,
stable_unclip_prior=stable_unclip_prior,
clip_stats_path=clip_stats_path,
controlnet=controlnet,
vae_path=vae_path,
pipeline_class=pipeline_class,
half=half
)
else:
pipe = StableDiffusionPipeline.from_pretrained(
checkpoint_path_or_dict=checkpoint_file,
original_config_file=original_config_file,
image_size=image_size,
prediction_type=prediction_type,
model_type=pipeline_type,
extract_ema=extract_ema,
scheduler_type=scheduler_type,
num_in_channels=num_in_channels,
upcast_attention=upcast_attention,
from_safetensors=from_safetensors,
device=device,
pretrained_model_name_or_path=checkpoint_file,
stable_unclip=stable_unclip,
stable_unclip_prior=stable_unclip_prior,
clip_stats_path=clip_stats_path,
controlnet=controlnet,
vae_path=vae_path,
pipeline_class=pipeline_class,
half=half
)

dump_path = db_config.get_pretrained_model_name_or_path()
if controlnet:
Expand Down Expand Up @@ -246,7 +246,7 @@ def extract_checkpoint(
print(f"Couldn't find {full_path}")
break
remove_dirs = ["logging", "samples"]
enable_safe_unpickle()

reload_system_models()
if success:
for rd in remove_dirs:
Expand Down
Loading

0 comments on commit e27f707

Please sign in to comment.