Skip to content

Commit

Permalink
Updated caching for diffusers
Browse files Browse the repository at this point in the history
  • Loading branch information
shijianjian committed Aug 20, 2024
1 parent a0ab5da commit 777129f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
23 changes: 18 additions & 5 deletions .github/download-models-weights.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,34 @@
import argparse

import torch
import diffusers

fonts = {
"sold2_wireframe": "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth",
models = {
"sold2_wireframe": ("torchhub", "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth"),
"CompVis/stable-diffusion-v1-4": ("diffusers", "StableDiffusionPipeline"),
"runwayml/stable-diffusion-v1-5": ("diffusers", "StableDiffusionPipeline"),
"stabilityai/stable-diffusion-v2-1": ("diffusers", "StableDiffusionPipeline"),
}


if __name__ == "__main__":
parser = argparse.ArgumentParser("WeightsDownloader")
parser.add_argument("--target_directory", "-t", required=False, default="target_directory")

args = parser.parse_args()

torch.hub.set_dir(args.target_directory)
# For HuggingFace model caching
os.environ["HF_HOME"] = args.target_directory

for name, url in fonts.items():
print(f"Downloading weights of `{name}` from `url`. Caching to dir `{args.target_directory}`")
torch.hub.load_state_dict_from_url(url, model_dir=args.target_directory, map_location=torch.device("cpu"))
for name, (src, path) in models.items():
if src == "torchhub":
print(f"Downloading weights of `{name}` from `{path}`. Caching to dir `{args.target_directory}`")
torch.hub.load_state_dict_from_url(url, model_dir=args.target_directory, map_location=torch.device("cpu"))
elif src == "diffusers":
print(f"Downloading `{name}` from diffusers. Caching to dir `{args.target_directory}`")
if path == "StableDiffusionPipeline":
diffusers.StableDiffusionPipeline.from_pretrained(
name, cache_dir=args.target_directory, device_map=torch.device("cpu"))

raise SystemExit(0)
2 changes: 2 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ def pytest_sessionstart(session):

os.makedirs(WEIGHTS_CACHE_DIR, exist_ok=True)
torch.hub.set_dir(WEIGHTS_CACHE_DIR)
# For HuggingFace model caching
os.environ["HF_HOME"] = WEIGHTS_CACHE_DIR


def _get_env_info() -> Dict[str, Dict[str, str]]:
Expand Down

0 comments on commit 777129f

Please sign in to comment.