Skip to content

Commit

Permalink
SD - Fix civitai download on Windows +improvements (#1907)
Browse files Browse the repository at this point in the history
  • Loading branch information
one-lithe-rune authored Oct 21, 2023
1 parent 7cd14fd commit 1344419
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ generated_imgs/

# Custom model related artefacts
variants.json
models/
/models/

# models folder
apps/stable_diffusion/web/models/
Expand Down
23 changes: 12 additions & 11 deletions apps/stable_diffusion/src/models/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import subprocess
import sys
import os
import requests
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_opt_flags,
Expand All @@ -16,6 +17,7 @@
preprocessCKPT,
convert_original_vae,
get_path_to_diffusers_checkpoint,
get_civitai_checkpoint,
fetch_and_update_base_model_id,
get_path_stem,
get_extended_name,
Expand Down Expand Up @@ -94,28 +96,27 @@ def __init__(
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
self.custom_weights = custom_weights.strip()
self.use_quantize = use_quantize
if custom_weights != "":
if "civitai" in custom_weights:
weights_id = custom_weights.split("/")[-1]
# TODO: use model name and identify file type by civitai rest api
weights_path = (
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
)
if not os.path.isfile(weights_path):
subprocess.run(
["wget", custom_weights, "-O", weights_path]
)
if custom_weights.startswith("https://civitai.com/api/"):
# download the checkpoint from civitai if we don't already have it
weights_path = get_civitai_checkpoint(custom_weights)

# act as if we were given the local file as custom_weights originally
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
self.custom_weights = weights_path

# needed to ensure webui sets the correct model name metadata
args.ckpt_loc = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(
custom_weights
)

self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/src/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@
resize_stencil,
_compile_module,
)
from apps.stable_diffusion.src.utils.civitai import get_civitai_checkpoint
42 changes: 42 additions & 0 deletions apps/stable_diffusion/src/utils/civitai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import re
import requests
from apps.stable_diffusion.src.utils.stable_args import args

from pathlib import Path
from tqdm import tqdm


def get_civitai_checkpoint(url: str):
with requests.get(url, allow_redirects=True, stream=True) as response:
response.raise_for_status()

# civitai api returns the filename in the content disposition
base_filename = re.findall(
'"([^"]*)"', response.headers["Content-Disposition"]
)[0]
destination_path = (
Path.cwd() / (args.ckpt_dir or "models") / base_filename
)

# we don't have this model downloaded yet
if not destination_path.is_file():
print(
f"downloading civitai model from {url} to {destination_path}"
)

size = int(response.headers["content-length"], 0)
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)

with open(destination_path, "wb") as f:
for chunk in response.iter_content(chunk_size=65536):
f.write(chunk)
progress_bar.update(len(chunk))

progress_bar.close()

# we already have this model downloaded
else:
print(f"civitai model already downloaded to {destination_path}")

response.close()
return destination_path.as_posix()

0 comments on commit 1344419

Please sign in to comment.