Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hako-mikan committed Jan 10, 2025
1 parent 469465e commit b165dcf
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
7 changes: 6 additions & 1 deletion scripts/kohyas/sdxl_model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,24 @@ def convert_key(key):

return new_sd, logit_scale

delkeys =[
'text_model.embeddings.position_ids'
]

# load state_dict without allocating new tensors
def _load_state_dict_on_device(model, state_dict, device, dtype=None):
# dtype will use fp32 as default
missing_keys = list(model.state_dict().keys() - state_dict.keys())
unexpected_keys = list(state_dict.keys() - model.state_dict().keys())

unexpected_keys = list(set(unexpected_keys) - set(delkeys))

# similar to model.load_state_dict()
if not missing_keys and not unexpected_keys:
for k in list(state_dict.keys()):
set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
return "<All keys matched successfully>"

# error_msgs
error_msgs: List[str] = []
if missing_keys:
Expand Down
17 changes: 12 additions & 5 deletions scripts/mergers/pluslora.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,16 +287,22 @@ def makelora(model_a,model_b,dim,saveto,settings,alpha,beta,save_precision,calc_
return "ERROR: No model Selected"
gc.collect()

currentinfo = shared.sd_model.sd_checkpoint_info
try:
currentinfo = shared.sd_model.sd_checkpoint_info
except:
currentinfo = None

checkpoint_info = sd_models.get_closet_checkpoint_match(model_a)
load_model(checkpoint_info)

model = shared.sd_model
print(type(model).__name__)
print("XL" in type(model).__name__)

is_sdxl = hasattr(model, 'conditioner')
is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
is_sd1 = not model.is_sdxl and not model.is_sd2
is_sdxl = type(model).__name__ == "StableDiffusionXL" or getattr(model,'is_sdxl', False)
is_sd2 = type(model).__name__ == "StableDiffusion2" or getattr(model,'is_sd2', False)
is_sd1 = type(model).__name__ == "StableDiffusion" or getattr(model,'is_sd1', False)
is_flux = type(model).__name__ == "Flux" or getattr(model,'is_flux', False)

print(f"Detected model type: SDXL: {is_sdxl}, SD2.X: {is_sd2}, SD1.X: {is_sd1}")

Expand Down Expand Up @@ -333,7 +339,8 @@ def makelora(model_a,model_b,dim,saveto,settings,alpha,beta,save_precision,calc_

result = ext.svd(args)

load_model(currentinfo)
if currentinfo:
load_model(currentinfo)
return result

##############################################################
Expand Down

0 comments on commit b165dcf

Please sign in to comment.