Skip to content

Commit

Permalink
(SD) Fix tokenizers imports in pyinstaller builds. (#1828)
Browse files Browse the repository at this point in the history
* Fix tokenizers metadata.

* (SD) Disable VAE lowering configs (rdna3) and add versioned tunings.

* Update sd_annotation.py

* (SD) Add cv2 to spec.

* Update stencil pipeline with the new img2img arg.
  • Loading branch information
monorimet authored Sep 12, 2023
1 parent b817bb8 commit 684943a
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
7 changes: 4 additions & 3 deletions apps/stable_diffusion/shark_studio_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

# datafiles for pyinstaller
datas = []
datas += collect_data_files("torch")
datas += copy_metadata("torch")
datas += copy_metadata("tokenizers")
datas += copy_metadata("tqdm")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
Expand All @@ -31,18 +31,17 @@
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
datas += collect_data_files("accelerate")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("pytorch_lightning")
datas += collect_data_files("opencv_python")
datas += collect_data_files("skimage")
datas += collect_data_files("gradio")
datas += collect_data_files("gradio_client")
datas += collect_data_files("iree")
datas += collect_data_files("google_cloud_storage")
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("timm", include_py_files=True)
datas += collect_data_files("tqdm")
Expand All @@ -53,6 +52,7 @@
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("langchain")
datas += collect_data_files("cv2")
datas += [
("src/utils/resources/prompts.json", "resources"),
("src/utils/resources/model_db.json", "resources"),
Expand Down Expand Up @@ -81,3 +81,4 @@
if not any(kw in x for kw in blacklist)
]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def generate_images(
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
resample_type,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.
Expand Down
8 changes: 2 additions & 6 deletions apps/stable_diffusion/src/utils/sd_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def load_lower_configs(base_model_id=None):
f"{spec}.json"
)

full_gs_url = config_bucket + config_name
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
print("Loading lowering config file from ", lowering_config_dir)
full_gs_url = config_bucket + config_name
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir

Expand Down Expand Up @@ -281,13 +281,9 @@ def sd_model_annotation(mlir_model, model_name, base_model_id=None):
if "rdna2" not in args.iree_vulkan_target_triple.split("-")[0]:
use_winograd = True
winograd_config_dir = load_winograd_configs()
winograd_model = annotate_with_winograd(
tuned_model = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs(base_model_id)
tuned_model = annotate_with_lower_configs(
winograd_model, lowering_config_dir, model_name, use_winograd
)
else:
tuned_model = mlir_model
else:
Expand Down

0 comments on commit 684943a

Please sign in to comment.