Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hako-mikan committed Jan 14, 2025
1 parent f6bf9b5 commit c694678
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
12 changes: 8 additions & 4 deletions scripts/A1111/lyco_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@ def make_weight_cp(t, wa, wb):


def rebuild_conventional(up, down, shape, dyn_dim=None):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
up = cpufloat(up.reshape(up.size(0), -1))
down = cpufloat(down.reshape(down.size(0), -1))
if dyn_dim is not None:
up = up[:, :dyn_dim]
down = down[:dyn_dim, :]
return (up @ down).reshape(shape)


def rebuild_cp_decomposition(up, down, mid):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
up = cpufloat(up.reshape(up.size(0), -1))
down = cpufloat(down.reshape(down.size(0), -1))
mid = cpufloat(mid)
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)


Expand Down Expand Up @@ -66,3 +67,6 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
n, m = m, n
return m, n

def cpufloat(module):
if not module: return module #None対策
return module.to(torch.float) if module.device.type == "cpu" else module
17 changes: 9 additions & 8 deletions scripts/mergers/mergers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,10 @@ def smergegen(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,m

save = True if SAVEMODES[0] in save_sets else False

if not forge:
result = savemodel(theta_0,currentmodel,custom_name,save_sets,metadata) if save else "Merged model loaded:"+currentmodel
result = savemodel(theta_0,currentmodel,custom_name,save_sets,metadata) if save else "Merged model loaded:"+currentmodel

model_loader(checkpoint_info, theta_0, metadata, currentmodel)

if forge and save:
result = forge_save(custom_name if custom_name else currentmodel.replace(" ","").replace(",","_").replace("(","_").replace(")","_"))

cachedealer(False)

Expand Down Expand Up @@ -308,7 +305,7 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
if not(len(weights_b) == 25 or len(weights_b) == 19 or len(weights_a) == 60): return f"ERROR: weights beta value must be 20 or 26 or 61.",*NON4

caster("model load start",hearm)
printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks['ceed'],fine,inex,ex_blocks,ex_elems)
printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks['ceed'],fine,inex,ex_blocks,ex_elems,device)

theta_1=load_model_weights_m(model_b,2,cachetarget,device).copy()

Expand Down Expand Up @@ -417,6 +414,7 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode

theta_0[key] = theta_0[key].to(device)
theta_1[key] = theta_1[key].to(device)

try:
theta_2[key] = theta_2[key].to(device)
except Exception as e:
Expand Down Expand Up @@ -567,12 +565,14 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
if flux and not calcmode == "smoothAdd MT":
theta_1[key] = None
del theta_1[key]

theta_0[key] = theta_0[key].to("cpu")
try:
theta_1[key] = theta_1[key].to("cpu")
except:
pass


#flux
if qtype[0]:
dellist = []
Expand Down Expand Up @@ -950,9 +950,9 @@ def elementals(key,weight_index,deep,randomer,num,lucks,deepprint,current_alpha)

def forkforker(filename,device):
if forge:
return load_torch_file(filename)
return load_torch_file(filename, device = torch.device(device))
try:
return sd_models.read_state_dict(filename,map_location = device)
return sd_models.read_state_dict(filename, map_location = device)
except:
return sd_models.read_state_dict(filename)

Expand Down Expand Up @@ -1550,7 +1550,7 @@ def getcachelist():
################################################
##### print

def printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks,fine,inex,ex_blocks,ex_elems):
def printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks,fine,inex,ex_blocks,ex_elems,device):
print(f" model A \t: {model_a}")
print(f" model B \t: {model_b}")
print(f" model C \t: {model_c}")
Expand All @@ -1564,6 +1564,7 @@ def printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,
print(f" Weights Seed\t: {lucks}")
print(f" {inex} \t: {ex_blocks,ex_elems}")
print(f" Adjust \t: {fine}")
print(f" Device \t: {device}")

def caster(news,hear):
if hear: print(news)
Expand Down
1 change: 1 addition & 0 deletions scripts/mergers/pluslora.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def makelora(model_a,model_b,dim,saveto,settings,alpha,beta,save_precision,calc_
except:
currentinfo = None

lowvram.module_in_gpu = None #web-uiのバグ対策

checkpoint_info = sd_models.get_closet_checkpoint_match(model_a)
load_model(checkpoint_info)
Expand Down

0 comments on commit c694678

Please sign in to comment.