From 6a88b8118fdaccbded084bc9151f2606a3703ba0 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Tue, 7 Nov 2023 14:10:48 -0600 Subject: [PATCH 01/16] Better requirement handling? --- postinstall.py | 136 ++++++++++++++++++++++++++++++++++++----------- requirements.txt | 27 ++++------ 2 files changed, 114 insertions(+), 49 deletions(-) diff --git a/postinstall.py b/postinstall.py index 886c0b88..9a083193 100644 --- a/postinstall.py +++ b/postinstall.py @@ -4,8 +4,13 @@ import subprocess import sys from dataclasses import dataclass +from typing import Optional import git +from packaging import version as pv + +from importlib import metadata + from packaging.version import Version from dreambooth import shared @@ -56,6 +61,32 @@ def pip_install(*args): print(line) +def is_installed(pkg: str, version: Optional[str] = None, check_strict: bool = True) -> bool: + try: + # Retrieve the package version from the installed package metadata + installed_version = metadata.version(pkg) + + # If version is not specified, just return True as the package is installed + if version is None: + return True + + # Compare the installed version with the required version + if check_strict: + # Strict comparison (must be an exact match) + return pv.parse(installed_version) == pv.parse(version) + else: + # Non-strict comparison (installed version must be greater than or equal to the required version) + return pv.parse(installed_version) >= pv.parse(version) + + except metadata.PackageNotFoundError: + # The package is not installed + return False + except Exception as e: + # Any other exceptions encountered + print(f"Error: {e}") + return False + + def install_requirements(): dreambooth_skip_install = os.environ.get("DREAMBOOTH_SKIP_INSTALL", False) req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt") @@ -63,26 +94,52 @@ def install_requirements(): if dreambooth_skip_install or req_file == req_file_startup_arg: return - + print("Checking Dreambooth requirements...") has_diffusers = importlib.util.find_spec("diffusers") is not None has_tqdm = importlib.util.find_spec("tqdm") is not None transformers_version = importlib_metadata.version("transformers") + strict = True + non_strict_separators = ["==", ">=", "<=", ">", "<", "~="] + # Load the requirements file + with open(req_file_startup_arg, "r") as f: + reqs = f.readlines() + + if os.name == "darwin": + reqs.append("tensorboard==2.11.2") + else: + reqs.append("tensorboard==2.13.0") + + for line in reqs: + try: + package = line.strip() + if package and not package.startswith("#"): + package_version = None + strict = True + for separator in non_strict_separators: + if separator in package: + strict = False + package, package_version = line.split(separator) + package = package.strip() + package_version = package_version.strip() + break + if not is_installed(package, package_version, strict): + print(f"[Dreambooth] {package} v{package_version} is not installed.") + pip_install(line) + else: + print(f"[Dreambooth] {package} v{package_version} is already installed.") - try: - pip_install("-r", req_file) - - if has_diffusers and has_tqdm and Version(transformers_version) < Version("4.26.1"): - print() - print("Does your project take forever to startup?") - print("Repetitive dependency installation may be the reason.") - print("Automatic1111's base project sets strict requirements on outdated dependencies.") - print("If an extension is using a newer version, the dependency is uninstalled and reinstalled twice every startup.") - print() - except subprocess.CalledProcessError as grepexc: - error_msg = grepexc.stdout.decode() - print_requirement_installation_error(error_msg) - raise grepexc + except subprocess.CalledProcessError as grepexc: + error_msg = grepexc.stdout.decode() + print_requirement_installation_error(error_msg) + if has_diffusers and has_tqdm and Version(transformers_version) < Version("4.26.1"): + print() + print("Does your project take forever to startup?") + print("Repetitive dependency installation may be the reason.") + print("Automatic1111's base project sets strict requirements on outdated dependencies.") + print( + "If an extension is using a newer version, the dependency is uninstalled and reinstalled twice every startup.") + print() def check_xformers(): """ @@ -90,18 +147,27 @@ def check_xformers(): """ try: xformers_version = importlib_metadata.version("xformers") - xformers_outdated = Version(xformers_version) < Version("0.0.20") + xformers_outdated = Version(xformers_version) < Version("0.0.21") if xformers_outdated: try: torch_version = importlib_metadata.version("torch") is_torch_1 = Version(torch_version) < Version("2") + is_torch_2_1 = Version(torch_version) < Version("2.0") if is_torch_1: print_xformers_torch1_instructions(xformers_version) + # Torch 2.0.1 is not available on PyPI for xformers version 22 + elif is_torch_2_1 is False: + os_string = "win_amd64" if os.name == "nt" else "manylinux2014_x86_64" + # Get the version of python + py_string = f"cp{sys.version_info.major}{sys.version_info.minor}-cp{sys.version_info.major}{sys.version_info.minor}" + wheel_url = f"https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-{py_string}-{os_string}.whl" + pip_install(wheel_url, "--upgrade", "--no-deps") else: - pip_install("--force-reinstall", "xformers") + pip_install("xformers==0.0.22", "--upgrade") except subprocess.CalledProcessError as grepexc: error_msg = grepexc.stdout.decode() - print_xformers_installation_error(error_msg) + if "WARNING: Ignoring invalid distribution" not in error_msg: + print_xformers_installation_error(error_msg) except: pass @@ -113,15 +179,23 @@ def check_bitsandbytes(): bitsandbytes_version = importlib_metadata.version("bitsandbytes") if os.name == "nt": if bitsandbytes_version != "0.41.1": - try: - print("Installing bitsandbytes") - pip_install("--force-install","==prefer-binary","--extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui","bitsandbytes==0.41.1") - except: - print("Bitsandbytes 0.41.1 installation failed.") - print("Some features such as 8bit optimizers will be unavailable") - print("Please install manually with") - print("'python -m pip install bitsandbytes==0.41.1 --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui --prefer-binary --force-install'") - pass + venv_path = os.environ.get("VIRTUAL_ENV", None) + # Check for the dll in venv/lib/site-packages/bitsandbytes/libbitsandbytes_cuda118.dll + # If it doesn't exist, append the requirement + if not venv_path: + print("Could not find the virtual environment path. Skipping bitsandbytes installation.") + else: + win_dll = os.path.join(venv_path, "lib", "site-packages", "bitsandbytes", "libbitsandbytes_cuda118.dll") + if not os.path.exists(win_dll): + try: + print("Installing bitsandbytes") + pip_install("--force-install","==prefer-binary","--extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui","bitsandbytes==0.41.1") + except: + print("Bitsandbytes 0.41.1 installation failed.") + print("Some features such as 8bit optimizers will be unavailable") + print("Please install manually with") + print("'python -m pip install bitsandbytes==0.41.1 --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui --prefer-binary --force-install'") + pass else: if bitsandbytes_version != "0.41.1": try: @@ -150,12 +224,12 @@ def check_versions(): #Probably a bad idea but update ALL the dependencies dependencies = [ - Dependency(module="xformers", version="0.0.21", required=False), + Dependency(module="xformers", version="0.0.22", required=False), Dependency(module="torch", version="1.13.1" if is_mac else "2.0.1+cu118"), Dependency(module="torchvision", version="0.14.1" if is_mac else "0.15.2+cu118"), - Dependency(module="accelerate", version="0.22.0"), - Dependency(module="diffusers", version="0.20.1"), - Dependency(module="transformers", version="4.25.1"), + Dependency(module="accelerate", version="0.21.0"), + Dependency(module="diffusers", version="0.22.1"), + Dependency(module="transformers", version="4.35.0"), Dependency(module="bitsandbytes", version="0.41.1"), ] diff --git a/requirements.txt b/requirements.txt index cc4c3193..2b0d986c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,21 +1,12 @@ -accelerate~=0.23.0 +accelerate==0.21.0 bitsandbytes~=0.41.1 dadaptation==3.1 -diffusers~=0.21.2 -discord-webhook~=1.1.0 -fastapi~=0.94.1 -gitpython~=3.1.31 -pytorch_optimizer==2.11.1 -Pillow==9.5.0 -tqdm==4.65.0 +diffusers~=0.22.1 +discord-webhook==1.3.0 +fastapi +gitpython==3.1.40 +pytorch_optimizer==2.12.0 +Pillow +tqdm tomesd~=0.1.2 -transformers~=4.32.1; # > 4.26.x causes issues (db extension #1110) -# Get prebuilt Windows wheels from jllllll -bitsandbytes~=0.41.1; sys_platform == 'win32' or platform_machine == 'AMD64' \ ---extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui --prefer-binary -# Get Linux and MacOS wheels from PyPi -bitsandbytes~=0.41.1; sys_platform != 'win32' or platform_machine != 'AMD64' --prefer-binary -# Tensor -tensorboard==2.13.0; sys_platform != 'darwin' or platform_machine != 'arm64' -# Tensor MacOS -tensorboard==2.11.2; sys_platform == 'darwin' and platform_machine == 'arm64' +transformers~=4.30.2 \ No newline at end of file From fbc0dea7e257a214c342f7376edb6fd7c5090a2a Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Tue, 7 Nov 2023 14:11:23 -0600 Subject: [PATCH 02/16] Potential fix for #1363 --- dreambooth/train_dreambooth.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dreambooth/train_dreambooth.py b/dreambooth/train_dreambooth.py index dbab4251..c3c6b250 100644 --- a/dreambooth/train_dreambooth.py +++ b/dreambooth/train_dreambooth.py @@ -31,6 +31,7 @@ ) from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict from diffusers.models.attention_processor import LoRAAttnProcessor2_0, LoRAAttnProcessor +from diffusers.training_utils import unet_lora_state_dict from diffusers.utils import logging as dl from diffusers.utils.torch_utils import randn_tensor from torch.cuda.profiler import profile @@ -57,7 +58,7 @@ disable_safe_unpickle, enable_safe_unpickle, xformerify, - torch2ify, unet_attn_processors_state_dict + torch2ify ) from dreambooth.utils.text_utils import encode_hidden_state, save_token_counts from dreambooth.utils.utils import (cleanup, printm, verify_locon_installed, @@ -1097,7 +1098,7 @@ def lora_save_function(weights, filename): pbar2.set_description("Saving Lora Weights...") # setup directory logger.debug(f"Saving lora to {lora_save_file}") - unet_lora_layers_to_save = unet_attn_processors_state_dict(unet) + unet_lora_layers_to_save = unet_lora_state_dict(unet) text_encoder_one_lora_layers_to_save = None text_encoder_two_lora_layers_to_save = None if args.stop_text_encoder != 0: From c6622c19d0d8488d6c773bc58873416e7845e1ce Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Tue, 7 Nov 2023 14:11:35 -0600 Subject: [PATCH 03/16] Fix gradio style warning --- scripts/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/main.py b/scripts/main.py index 0d6df7b2..59ed6a1c 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -298,7 +298,7 @@ def on_ui_tabs(): db_cancel = gr.Button(value="Cancel", elem_id="db_cancel") with gr.Row(): gr.HTML(value="Select or create a model to begin.", elem_id="hint_row") - with gr.Row().style(equal_height=False): + with gr.Row(equal_height=False): with gr.Column(variant="panel", elem_id="ModelPanel"): with gr.Column(): gr.HTML(value="Model") @@ -1532,8 +1532,8 @@ def toggle_loss_items(scale): outputs=[hub_row, local_row], ) - def toggle_shared_row(shared_row): - return gr.update(visible=shared_row), gr.update(value="") + def toggle_shared_row(row): + return gr.update(visible=row), gr.update(value="") db_use_shared_src.change( fn=toggle_shared_row, From 150619b7d9397888d8d57c6dfe9a12dc942e3440 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Tue, 7 Nov 2023 14:26:17 -0600 Subject: [PATCH 04/16] Fix transformers warning from attempted upgrade/revert --- postinstall.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postinstall.py b/postinstall.py index 9a083193..4bb0644b 100644 --- a/postinstall.py +++ b/postinstall.py @@ -229,7 +229,7 @@ def check_versions(): Dependency(module="torchvision", version="0.14.1" if is_mac else "0.15.2+cu118"), Dependency(module="accelerate", version="0.21.0"), Dependency(module="diffusers", version="0.22.1"), - Dependency(module="transformers", version="4.35.0"), + Dependency(module="transformers", version="4.30.2"), Dependency(module="bitsandbytes", version="0.41.1"), ] From 955318d74c4cb79bdb21f61a448933e82a352261 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Tue, 7 Nov 2023 14:31:59 -0600 Subject: [PATCH 05/16] Ensure model type is set. --- dreambooth/ui_functions.py | 6 ++++-- scripts/main.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/dreambooth/ui_functions.py b/dreambooth/ui_functions.py index 1e6c5fa8..3be90bd8 100644 --- a/dreambooth/ui_functions.py +++ b/dreambooth/ui_functions.py @@ -941,11 +941,13 @@ def create_model( new_model_token="", extract_ema=False, train_unfrozen=False, - model_type="v1" + model_type="v1x" ): + if not model_type: + model_type = "v1x" printm("Extracting model.") res = 512 - is_512 = model_type == "v1" + is_512 = model_type == "v1x" if model_type == "v1x" or model_type=="v2x-512": res = 512 elif model_type == "v2x": diff --git a/scripts/main.py b/scripts/main.py index 59ed6a1c..3ba1d947 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -365,7 +365,7 @@ def on_ui_tabs(): db_create_from_hub = gr.Checkbox( label="Create From Hub", value=False ) - db_model_type_select=gr.Dropdown(label="Model Type", choices=["v1x", "v2x-512", "v2x", "SDXL", "ControlNet"]) + db_model_type_select=gr.Dropdown(label="Model Type", choices=["v1x", "v2x-512", "v2x", "SDXL", "ControlNet"], value="v1x") db_use_shared_src = gr.Checkbox( label="Experimental Shared Src", value=False, visible=False ) From 1e963923763f181ce5a34aea2f533c1e53581a08 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Wed, 8 Nov 2023 11:53:51 -0600 Subject: [PATCH 06/16] Fix dumb postinstall --- postinstall.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/postinstall.py b/postinstall.py index 4bb0644b..8e421ee8 100644 --- a/postinstall.py +++ b/postinstall.py @@ -147,7 +147,7 @@ def check_xformers(): """ try: xformers_version = importlib_metadata.version("xformers") - xformers_outdated = Version(xformers_version) < Version("0.0.21") + xformers_outdated = Version(xformers_version) < Version("0.0.20") if xformers_outdated: try: torch_version = importlib_metadata.version("torch") @@ -156,14 +156,14 @@ def check_xformers(): if is_torch_1: print_xformers_torch1_instructions(xformers_version) # Torch 2.0.1 is not available on PyPI for xformers version 22 - elif is_torch_2_1 is False: + elif is_torch_2_1: os_string = "win_amd64" if os.name == "nt" else "manylinux2014_x86_64" # Get the version of python py_string = f"cp{sys.version_info.major}{sys.version_info.minor}-cp{sys.version_info.major}{sys.version_info.minor}" wheel_url = f"https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-{py_string}-{os_string}.whl" pip_install(wheel_url, "--upgrade", "--no-deps") else: - pip_install("xformers==0.0.22", "--upgrade") + pip_install("xformers==0.0.21", "--index-url https://download.pytorch.org/whl/cu118") except subprocess.CalledProcessError as grepexc: error_msg = grepexc.stdout.decode() if "WARNING: Ignoring invalid distribution" not in error_msg: From 8dd24d665975a8af495b32522032fb85e42d787d Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Wed, 8 Nov 2023 11:54:11 -0600 Subject: [PATCH 07/16] Begin UI Rework --- dreambooth/ui_functions.py | 4 +- javascript/dreambooth.js | 126 +---- preprocess/preprocess_utils.py | 50 ++ scripts/main.py | 958 +++++++++++++++++---------------- style.css | 8 + 5 files changed, 563 insertions(+), 583 deletions(-) create mode 100644 preprocess/preprocess_utils.py diff --git a/dreambooth/ui_functions.py b/dreambooth/ui_functions.py index 3be90bd8..18398a9a 100644 --- a/dreambooth/ui_functions.py +++ b/dreambooth/ui_functions.py @@ -11,6 +11,7 @@ import traceback from collections import OrderedDict +import gradio import torch import torch.utils.data.dataloader from accelerate import find_executable_batch_size @@ -639,7 +640,7 @@ def load_model_params(model_name): if config is None: print("Can't load config!") msg = f"Error loading model params: '{model_name}'." - return "", "", "", "", "", db_model_snapshots, msg + return gradio.update(visible=False), "", "", "", "", "", db_model_snapshots, msg else: snaps = get_model_snapshots(config) snap_selection = config.revision if str(config.revision) in snaps else "" @@ -649,6 +650,7 @@ def load_model_params(model_name): db_lora_models = gr_update(choices=loras) msg = f"Selected model: '{model_name}'." return ( + gradio.update(visible=True), config.model_dir, config.revision, config.epoch, diff --git a/javascript/dreambooth.js b/javascript/dreambooth.js index 361be54b..4d22fa24 100644 --- a/javascript/dreambooth.js +++ b/javascript/dreambooth.js @@ -22,7 +22,7 @@ function save_config() { } function toggleComponents(enable, disableAll) { - const elements = ['DbTopRow', 'SettingsPanel']; + const elements = ["DbTopRow", "TabConcepts", "TabSettings", "TabSave", "TabGenerate", "TabDebug"]; if (disableAll) { console.log("Disabling all DB elements!"); elements.push("ModelPanel") @@ -53,124 +53,7 @@ function toggleComponents(enable, disableAll) { }); } -// Disconnect a gradio mutation observer, update the element value, and reconnect the observer? -function updateInputValue(elements, newValue) { - const savedListeners = []; - const savedObservers = []; - - elements.forEach((element) => { - // Save any existing listeners and remove them - const listeners = []; - const events = ['change', 'input']; - events.forEach((event) => { - if (element['on' + event]) { - listeners.push({ - event, - listener: element['on' + event], - }); - element['on' + event] = null; - } - const eventListeners = element.getEventListeners?.(event); - if (eventListeners) { - eventListeners.forEach(({ listener }) => { - listeners.push({ - event, - listener, - }); - element.removeEventListener(event, listener); - }); - } - }); - savedListeners.push(listeners); - - // Save any existing MutationObservers and disconnect them - const observer = new MutationObserver(() => { - }); - if (observer && element.tagName === 'INPUT') { - observer.observe(element, { - attributes: true, - attributeFilter: ['value'], - }); - savedObservers.push(observer); - observer.disconnect(); - } else { - savedObservers.push(null); - } - - // Update the value of the element - element.value = newValue; - }); - - // Restore any saved listeners and MutationObservers - savedListeners.forEach((listeners, i) => { - const element = elements[i]; - listeners.forEach(({ event, listener }) => { - if (listener) { - element.addEventListener(event, listener); - } - }); - }); - - savedObservers.forEach((observer, i) => { - const element = elements[i]; - if (observer) { - observer.observe(element, { - attributes: true, - attributeFilter: ['value'], - }); - } - }); -} - - -// Fix steps on sliders. God this is a lot of work for one stupid thing... -function handleNumberInputs() { - const numberInputs = gradioApp() - .querySelector('#tab_dreambooth_interface') - ?.querySelectorAll('input[type="number"]'); - numberInputs?.forEach((numberInput) => { - const step = Number(numberInput.step) || 1; - const parentDiv = numberInput.parentElement; - const labelFor = parentDiv.querySelector('label'); - if (labelFor) { - const tgt = labelFor.getAttribute("for"); - if (listeners[tgt]) return; - const rangeInput = getRealElement(tgt); - if (rangeInput && rangeInput.type === 'range') { - let timeouts = []; - listeners[tgt] = true; - numberInput.oninput = () => { - if (timeouts[tgt]) { - clearTimeout(timeouts[tgt]); - } - timeouts[tgt] = setTimeout(() => { - let value = Number(numberInput.value) || 0; - const min = parseFloat(rangeInput.min) || 0; - const max = parseFloat(rangeInput.max) || 100; - if (value < min) { - value = min; - } else if (value > max) { - value = max; - } - const remainder = value % step; - if (remainder !== 0) { - value -= remainder; - if (remainder >= step / 2) { - value += step; - } - } - if (value !== numberInput.value) { - numberInput.value = value; - } - }, 500); - }; - - } - } - }); -} - - +// Don't delete this, it's used by the UI function check_save() { let do_save = true; if (params_loaded === false) { @@ -545,11 +428,6 @@ onUiUpdate(function () { observer.observe(btn, options); }); - try { - handleNumberInputs(); - } catch (e) { - console.log("Gotcha: ", e); - } }); diff --git a/preprocess/preprocess_utils.py b/preprocess/preprocess_utils.py new file mode 100644 index 00000000..5d4c203a --- /dev/null +++ b/preprocess/preprocess_utils.py @@ -0,0 +1,50 @@ +import os +from typing import Tuple, List, Dict + +import gradio as gr + +from dreambooth.utils.image_utils import FilenameTextGetter + +image_data = [] + +def load_image_data(input_path: str, recurse: bool = False) -> List[Dict[str,str]]: + if not os.path.exists(input_path): + print(f"Input path {input_path} does not exist") + return [] + global image_data + results = [] + from dreambooth.utils.image_utils import list_features, is_image + pil_features = list_features() + # Get a list from PIL of all the image formats it supports + + for root, dirs, files in os.walk(input_path): + for file in files: + full_path = os.path.join(root, file) + print(f"Checking {full_path}") + if is_image(full_path, pil_features): + results.append(full_path) + if not recurse: + break + + output = [] + text_getter = FilenameTextGetter() + for img_path in results: + file_text = text_getter.read_text(img_path) + output.append({'image': img_path, 'text': file_text}) + image_data = output + return output + +def check_preprocess_path(input_path: str, recurse: bool = False) -> Tuple[gr.update, gr.update]: + output_status = gr.update(visible=True) + output_gallery = gr.update(visible=True) + results = load_image_data(input_path, recurse) + if len(results) == 0: + return output_status, output_gallery + else: + images = [r['image'] for r in results] + output_status = gr.update(visible=True, value='Found {len(results)} images') + output_gallery = gr.update(visible=True, value=images) + return output_status, output_gallery + +def load_image_caption(evt: gr.SelectData): # SelectData is a subclass of EventData + return gr.update(value=f"You selected {evt.value} at {evt.index} from {evt.target}") \ No newline at end of file diff --git a/scripts/main.py b/scripts/main.py index 3ba1d947..f6ed51d6 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -53,7 +53,9 @@ from helpers.version_helper import check_updates from modules import script_callbacks, sd_models from modules.ui import gr_show, create_refresh_button +from preprocess.preprocess_utils import check_preprocess_path, load_image_caption +preprocess_params = [] params_to_save = [] params_to_load = [] refresh_symbol = "\U0001f504" # 🔄 @@ -298,494 +300,171 @@ def on_ui_tabs(): db_cancel = gr.Button(value="Cancel", elem_id="db_cancel") with gr.Row(): gr.HTML(value="Select or create a model to begin.", elem_id="hint_row") - with gr.Row(equal_height=False): - with gr.Column(variant="panel", elem_id="ModelPanel"): - with gr.Column(): - gr.HTML(value="Model") - with gr.Tab("Select"): - with gr.Row(): - db_model_name = gr.Dropdown( - label="Model", choices=sorted(get_db_models()) - ) - create_refresh_button( - db_model_name, - get_db_models, - lambda: {"choices": sorted(get_db_models())}, - "refresh_db_models", - ) - with gr.Row(): - db_snapshot = gr.Dropdown( - label="Snapshot to Resume", - choices=sorted(get_model_snapshots()), - ) - create_refresh_button( - db_snapshot, - get_model_snapshots, - lambda: {"choices": sorted(get_model_snapshots())}, - "refresh_db_snapshots", - ) - with gr.Row(visible=False) as lora_model_row: - db_lora_model_name = gr.Dropdown( - label="Lora Model", choices=get_sorted_lora_models() - ) - create_refresh_button( - db_lora_model_name, - get_sorted_lora_models, - lambda: {"choices": get_sorted_lora_models()}, - "refresh_lora_models", - ) - with gr.Row(): + with gr.Row(elem_id="ModelDetailRow", visible=False, variant="compact") as db_model_info: + with gr.Column(): + with gr.Row(variant="compact"): + with gr.Column(): + with gr.Row(variant="compact"): gr.HTML(value="Loaded Model:") db_model_path = gr.HTML() - with gr.Row(): - gr.HTML(value="Model Revision:") - db_revision = gr.HTML(elem_id="db_revision") - with gr.Row(): - gr.HTML(value="Model Epoch:") - db_epochs = gr.HTML(elem_id="db_epochs") - with gr.Row(): - gr.HTML(value="Model type:") - db_model_type = gr.HTML(elem_id="db_model_type") - with gr.Row(): - gr.HTML(value="Has EMA:") - db_has_ema = gr.HTML(elem_id="db_has_ema") - with gr.Row(): + with gr.Row(variant="compact"): gr.HTML(value="Source Checkpoint:") db_src = gr.HTML() - with gr.Row(visible=False): - gr.HTML(value="Experimental Shared Source:") - db_shared_diffusers_path = gr.HTML() - with gr.Tab("Create"): - with gr.Column(): - db_create_model = gr.Button( - value="Create Model", variant="primary" - ) - db_new_model_name = gr.Textbox(label="Name") - with gr.Row(): - db_create_from_hub = gr.Checkbox( - label="Create From Hub", value=False - ) - db_model_type_select=gr.Dropdown(label="Model Type", choices=["v1x", "v2x-512", "v2x", "SDXL", "ControlNet"], value="v1x") - db_use_shared_src = gr.Checkbox( - label="Experimental Shared Src", value=False, visible=False - ) - with gr.Column(visible=False) as hub_row: - db_new_model_url = gr.Textbox( - label="Model Path", - placeholder="runwayml/stable-diffusion-v1-5", - ) - db_new_model_token = gr.Textbox( - label="HuggingFace Token", value="" - ) - with gr.Column(visible=True) as local_row: + with gr.Column(): + with gr.Row(variant="compact"): + gr.HTML(value="Model Epoch:") + db_epochs = gr.HTML(elem_id="db_epochs") + with gr.Row(variant="compact"): + gr.HTML(value="Model Revision:") + db_revision = gr.HTML(elem_id="db_revision") + with gr.Column(): + with gr.Row(variant="compact"): + gr.HTML(value="Model type:") + db_model_type = gr.HTML(elem_id="db_model_type") + with gr.Row(variant="compact"): + gr.HTML(value="Has EMA:") + db_has_ema = gr.HTML(elem_id="db_has_ema") + with gr.Row(variant="compact", visible=False): + gr.HTML(value="Experimental Shared Source:") + db_shared_diffusers_path = gr.HTML() + with gr.Row(equal_height=False): + with gr.Column(variant="panel", elem_id="SettingsPanel"): + gr.HTML(value="Settings") + with gr.Tab("Model", elem_id="ModelPanel"): + with gr.Column(): + with gr.Tab("Select"): with gr.Row(): - db_new_model_src = gr.Dropdown( - label="Source Checkpoint", - choices=sorted(get_sd_models()), + db_model_name = gr.Dropdown( + label="Model", choices=sorted(get_db_models()) ) create_refresh_button( - db_new_model_src, - get_sd_models, - lambda: {"choices": sorted(get_sd_models())}, - "refresh_sd_models", + db_model_name, + get_db_models, + lambda: {"choices": sorted(get_db_models())}, + "refresh_db_models", ) - with gr.Column(visible=False) as shared_row: with gr.Row(): - db_new_model_shared_src = gr.Dropdown( - label="EXPERIMENTAL: LoRA Shared Diffusers Source", - choices=sorted(get_shared_models()), - value="", - visible=False + db_snapshot = gr.Dropdown( + label="Snapshot to Resume", + choices=sorted(get_model_snapshots()), ) create_refresh_button( - db_new_model_shared_src, - get_shared_models, - lambda: {"choices": sorted(get_shared_models())}, - "refresh_shared_models", - ) - db_new_model_extract_ema = gr.Checkbox( - label="Extract EMA Weights", value=False - ) - db_train_unfrozen = gr.Checkbox(label="Unfreeze Model", value=False) - with gr.Column(): - with gr.Accordion(open=False, label="Resources"): - with gr.Column(): - gr.HTML( - value="Beginners guide", - ) - gr.HTML( - value="Release notes", - ) - with gr.Column(variant="panel", elem_id="SettingsPanel"): - gr.HTML(value="Input") - with gr.Tab("Settings", elem_id="TabSettings"): - db_performance_wizard = gr.Button(value="Performance Wizard (WIP)") - with gr.Accordion(open=True, label="Basic"): - with gr.Column(): - gr.HTML(value="General") - db_use_lora = gr.Checkbox(label="Use LORA", value=False) - db_use_lora_extended = gr.Checkbox( - label="Use Lora Extended", - value=False, - visible=False, - ) - db_train_imagic = gr.Checkbox(label="Train Imagic Only", value=False, visible=False) - db_train_inpainting = gr.Checkbox( - label="Train Inpainting Model", - value=False, - visible=False, - ) - with gr.Column(): - gr.HTML(value="Intervals") - db_num_train_epochs = gr.Slider( - label="Training Steps Per Image (Epochs)", - value=100, - maximum=1000, - step=1, - ) - db_epoch_pause_frequency = gr.Slider( - label="Pause After N Epochs", - value=0, - maximum=100, - step=1, - ) - db_epoch_pause_time = gr.Slider( - label="Amount of time to pause between Epochs (s)", - value=0, - maximum=3600, - step=1, - ) - db_save_embedding_every = gr.Slider( - label="Save Model Frequency (Epochs)", - value=25, - maximum=1000, - step=1, - ) - db_save_preview_every = gr.Slider( - label="Save Preview(s) Frequency (Epochs)", - value=5, - maximum=1000, - step=1, - ) - - with gr.Column(): - gr.HTML(value="Batching") - db_train_batch_size = gr.Slider( - label="Batch Size", - value=1, - minimum=1, - maximum=100, - step=1, - ) - db_gradient_accumulation_steps = gr.Slider( - label="Gradient Accumulation Steps", - value=1, - minimum=1, - maximum=100, - step=1, - ) - db_sample_batch_size = gr.Slider( - label="Class Batch Size", - minimum=1, - maximum=100, - value=1, - step=1, - ) - db_gradient_set_to_none = gr.Checkbox( - label="Set Gradients to None When Zeroing", value=True - ) - db_gradient_checkpointing = gr.Checkbox( - label="Gradient Checkpointing", value=True - ) - - with gr.Column(): - gr.HTML(value="Learning Rate") - with gr.Row(visible=False) as lora_lr_row: - db_lora_learning_rate = gr.Number( - label="Lora UNET Learning Rate", value=1e-4 - ) - db_lora_txt_learning_rate = gr.Number( - label="Lora Text Encoder Learning Rate", value=5e-5 + db_snapshot, + get_model_snapshots, + lambda: {"choices": sorted(get_model_snapshots())}, + "refresh_db_snapshots", ) - with gr.Row() as standard_lr_row: - db_learning_rate = gr.Number( - label="Learning Rate", value=2e-6 + with gr.Row(visible=False) as lora_model_row: + db_lora_model_name = gr.Dropdown( + label="Lora Model", choices=get_sorted_lora_models() ) - db_txt_learning_rate = gr.Number( - label="Text Encoder Learning Rate", value=1e-6 + create_refresh_button( + db_lora_model_name, + get_sorted_lora_models, + lambda: {"choices": get_sorted_lora_models()}, + "refresh_lora_models", ) - - db_lr_scheduler = gr.Dropdown( - label="Learning Rate Scheduler", - value="constant_with_warmup", - choices=list_schedulers(), - ) - db_learning_rate_min = gr.Number( - label="Min Learning Rate", value=1e-6, visible=False - ) - db_lr_cycles = gr.Number( - label="Number of Hard Resets", - value=1, - precision=0, - visible=False, - ) - db_lr_factor = gr.Number( - label="Constant/Linear Starting Factor", - value=0.5, - precision=2, - visible=False, - ) - db_lr_power = gr.Number( - label="Polynomial Power", - value=1.0, - precision=1, - visible=False, - ) - db_lr_scale_pos = gr.Slider( - label="Scale Position", - value=0.5, - minimum=0, - maximum=1, - step=0.05, - visible=False, - ) - db_lr_warmup_steps = gr.Slider( - label="Learning Rate Warmup Steps", - value=500, - step=5, - maximum=1000, - ) - - with gr.Column(visible=False) as lora_rank_col: - gr.HTML("Lora") - db_lora_unet_rank = gr.Slider( - label="Lora UNET Rank", - value=4, - minimum=2, - maximum=128, - step=2, - ) - db_lora_txt_rank = gr.Slider( - label="Lora Text Encoder Rank", - value=4, - minimum=2, - maximum=128, - step=2, - ) - db_lora_weight = gr.Slider( - label="Lora Weight (Alpha)", - value=0.8, - minimum=0.1, - maximum=1, - step=0.1, - ) - - with gr.Column(): - gr.HTML(value="Image Processing") - db_resolution = gr.Slider( - label="Max Resolution", - step=64, - minimum=128, - value=512, - maximum=2048, - elem_id="max_res", - ) - db_hflip = gr.Checkbox( - label="Apply Horizontal Flip", value=False - ) - db_dynamic_img_norm = gr.Checkbox( - label="Dynamic Image Normalization", value=False - ) - - with gr.Column(): - gr.HTML(value="Tuning") - db_use_ema = gr.Checkbox( - label="Use EMA", value=False - ) - db_optimizer = gr.Dropdown( - label="Optimizer", - value="8bit AdamW", - choices=list_optimizer(), - ) - db_mixed_precision = gr.Dropdown( - label="Mixed Precision", - value=select_precision(), - choices=list_precisions(), - ) - db_full_mixed_precision = gr.Checkbox( - label="Full Mixed Precision", value=True - ) - db_attention = gr.Dropdown( - label="Memory Attention", - value=select_attention(), - choices=list_attention(), - ) - db_cache_latents = gr.Checkbox( - label="Cache Latents", value=True - ) - db_train_unet = gr.Checkbox( - label="Train UNET", value=True - ) - db_stop_text_encoder = gr.Slider( - label="Step Ratio of Text Encoder Training", - minimum=0, - maximum=1, - step=0.05, - value=1.0, - visible=True, - ) - db_offset_noise = gr.Slider( - label="Offset Noise", - minimum=-1, - maximum=1, - step=0.01, - value=0, - ) - db_freeze_clip_normalization = gr.Checkbox( - label="Freeze CLIP Normalization Layers", - visible=True, - value=False, - ) - db_clip_skip = gr.Slider( - label="Clip Skip", - value=2, - minimum=1, - maximum=12, - step=1, - ) - db_weight_decay = gr.Slider( - label="Weight Decay", - minimum=0, - maximum=1, - step=0.001, - value=0.01, - visible=True, - ) - db_tenc_weight_decay = gr.Slider( - label="TENC Weight Decay", - minimum=0, - maximum=1, - step=0.001, - value=0.01, - visible=True, - ) - db_tenc_grad_clip_norm = gr.Slider( - label="TENC Gradient Clip Norm", - minimum=0, - maximum=128, - step=0.25, - value=0, - visible=True, - ) - db_min_snr_gamma = gr.Slider( - label="Min SNR Gamma", - minimum=0, - maximum=10, - step=0.1, - visible=True, - ) - db_pad_tokens = gr.Checkbox( - label="Pad Tokens", value=True - ) - db_strict_tokens = gr.Checkbox( - label="Strict Tokens", value=False - ) - db_shuffle_tags = gr.Checkbox( - label="Shuffle Tags", value=True - ) - db_max_token_length = gr.Slider( - label="Max Token Length", - minimum=75, - maximum=300, - step=75, - ) - with gr.Column(): - gr.HTML(value="Prior Loss") - db_prior_loss_scale = gr.Checkbox( - label="Scale Prior Loss", value=False - ) - db_prior_loss_weight = gr.Slider( - label="Prior Loss Weight", - minimum=0.01, - maximum=1, - step=0.01, - value=0.75, - ) - db_prior_loss_target = gr.Number( - label="Prior Loss Target", - value=100, - visible=False, - ) - db_prior_loss_weight_min = gr.Slider( - label="Minimum Prior Loss Weight", - minimum=0.01, - maximum=1, - step=0.01, - value=0.1, - visible=False, - ) - - with gr.Accordion(open=False, label="Advanced"): - with gr.Row(): + with gr.Tab("Create"): with gr.Column(): - gr.HTML(value="Sanity Samples") - db_sanity_prompt = gr.Textbox( - label="Sanity Sample Prompt", - placeholder="A generic prompt used to generate a sample image " - "to verify model fidelity.", + db_create_model = gr.Button( + value="Create Model", variant="primary" ) - db_sanity_negative_prompt = gr.Textbox( - label="Sanity Sample Negative Prompt", - placeholder="A negative prompt for the generic sample image.", - ) - db_sanity_seed = gr.Number( - label="Sanity Sample Seed", value=420420 + db_new_model_name = gr.Textbox(label="Name") + with gr.Row(): + db_create_from_hub = gr.Checkbox( + label="Create From Hub", value=False ) - - with gr.Column(): - gr.HTML(value="Miscellaneous") - db_pretrained_vae_name_or_path = gr.Textbox( - label="Pretrained VAE Name or Path", - placeholder="Leave blank to use base model VAE.", - value="", + db_model_type_select = gr.Dropdown(label="Model Type", + choices=["v1x", "v2x-512", "v2x", "SDXL", + "ControlNet"], value="v1x") + db_use_shared_src = gr.Checkbox( + label="Experimental Shared Src", value=False, visible=False ) - db_use_concepts = gr.Checkbox( - label="Use Concepts List", value=False + with gr.Column(visible=False) as hub_row: + db_new_model_url = gr.Textbox( + label="Model Path", + placeholder="runwayml/stable-diffusion-v1-5", ) - db_concepts_path = gr.Textbox( - label="Concepts List", - placeholder="Path to JSON file with concepts to train.", + db_new_model_token = gr.Textbox( + label="HuggingFace Token", value="" ) + with gr.Column(visible=True) as local_row: with gr.Row(): - db_secret = gr.Textbox( - label="API Key", value=get_secret, interactive=False + db_new_model_src = gr.Dropdown( + label="Source Checkpoint", + choices=sorted(get_sd_models()), ) - db_refresh_button = gr.Button( - value=refresh_symbol, elem_id="refresh_secret" + create_refresh_button( + db_new_model_src, + get_sd_models, + lambda: {"choices": sorted(get_sd_models())}, + "refresh_sd_models", ) - db_clear_secret = gr.Button( - value=delete_symbol, elem_id="clear_secret" + with gr.Column(visible=False) as shared_row: + with gr.Row(): + db_new_model_shared_src = gr.Dropdown( + label="EXPERIMENTAL: LoRA Shared Diffusers Source", + choices=sorted(get_shared_models()), + value="", + visible=False ) - + create_refresh_button( + db_new_model_shared_src, + get_shared_models, + lambda: {"choices": sorted(get_shared_models())}, + "refresh_shared_models", + ) + db_new_model_extract_ema = gr.Checkbox( + label="Extract EMA Weights", value=False + ) + db_train_unfrozen = gr.Checkbox(label="Unfreeze Model", value=False) + with gr.Column(): + with gr.Accordion(open=False, label="Resources"): with gr.Column(): - # In the future change this to something more generic and list the supported types - # from DreamboothWebhookTarget enum; for now, Discord is what I use ;) - # Add options to include notifications on training complete and exceptions that halt training - db_notification_webhook_url = gr.Textbox( - label="Discord Webhook", - placeholder="https://discord.com/api/webhooks/XXX/XXXX", - value="", + gr.HTML( + value="Beginners guide", ) - notification_webhook_test_btn = gr.Button( - value="Save and Test Webhook" + gr.HTML( + value="Release notes", ) - + with gr.Tab("Preprocess", elem_id="PreprocessPanel", visible=False): with gr.Row(): - with gr.Column(scale=2): - gr.HTML(value="") + with gr.Column(scale=2, variant="compact"): + db_preprocess_path = gr.Textbox( + label="Image Path", value="", placeholder="Enter the path to your images" + ) + with gr.Column(variant="compact"): + db_preprocess_recursive = gr.Checkbox( + label="Recursive", value=False, container=True, elem_classes=["singleCheckbox"] + ) + with gr.Row(): + with gr.Tab("Auto-Caption"): + with gr.Row(): + gr.HTML(value="Auto-Caption") + with gr.Tab("Edit Captions"): + with gr.Row(): + db_preprocess_autosave = gr.Checkbox( + label="Autosave", value=False + ) + with gr.Row(): + gr.HTML(value="Edit Captions") + with gr.Tab("Edit Images"): + with gr.Row(): + gr.HTML(value="Edit Images") + with gr.Row(): + db_preprocess = gr.Button( + value="Preprocess", variant="primary" + ) + db_preprocess_all = gr.Button( + value="Preprocess All", variant="primary" + ) + with gr.Row(): + db_preprocess_all = gr.Button( + value="Preprocess All", variant="primary" + ) with gr.Tab("Concepts", elem_id="TabConcepts") as concept_tab: with gr.Column(variant="panel"): - with gr.Row(): + with gr.Row(visible=False): db_train_wizard_person = gr.Button( value="Training Wizard (Person)" ) @@ -875,6 +554,352 @@ def on_ui_tabs(): c4_save_guidance_scale, c4_save_infer_steps, ) = build_concept_panel(4) + with gr.Tab("Parameters", elem_id="TabSettings"): + db_performance_wizard = gr.Button(value="Performance Wizard (WIP)", visible=False) + with gr.Accordion(open=False, label="Performance"): + db_use_ema = gr.Checkbox( + label="Use EMA", value=False + ) + db_optimizer = gr.Dropdown( + label="Optimizer", + value="8bit AdamW", + choices=list_optimizer(), + ) + db_mixed_precision = gr.Dropdown( + label="Mixed Precision", + value=select_precision(), + choices=list_precisions(), + ) + db_full_mixed_precision = gr.Checkbox( + label="Full Mixed Precision", value=True + ) + db_attention = gr.Dropdown( + label="Memory Attention", + value=select_attention(), + choices=list_attention(), + ) + db_cache_latents = gr.Checkbox( + label="Cache Latents", value=True + ) + db_train_unet = gr.Checkbox( + label="Train UNET", value=True + ) + db_stop_text_encoder = gr.Slider( + label="Step Ratio of Text Encoder Training", + minimum=0, + maximum=1, + step=0.05, + value=1.0, + visible=True, + ) + db_offset_noise = gr.Slider( + label="Offset Noise", + minimum=-1, + maximum=1, + step=0.01, + value=0, + ) + db_freeze_clip_normalization = gr.Checkbox( + label="Freeze CLIP Normalization Layers", + visible=True, + value=False, + ) + db_clip_skip = gr.Slider( + label="Clip Skip", + value=2, + minimum=1, + maximum=12, + step=1, + ) + db_weight_decay = gr.Slider( + label="Weight Decay", + minimum=0, + maximum=1, + step=0.001, + value=0.01, + visible=True, + ) + db_tenc_weight_decay = gr.Slider( + label="TENC Weight Decay", + minimum=0, + maximum=1, + step=0.001, + value=0.01, + visible=True, + ) + db_tenc_grad_clip_norm = gr.Slider( + label="TENC Gradient Clip Norm", + minimum=0, + maximum=128, + step=0.25, + value=0, + visible=True, + ) + db_min_snr_gamma = gr.Slider( + label="Min SNR Gamma", + minimum=0, + maximum=10, + step=0.1, + visible=True, + ) + db_pad_tokens = gr.Checkbox( + label="Pad Tokens", value=True + ) + db_strict_tokens = gr.Checkbox( + label="Strict Tokens", value=False + ) + db_shuffle_tags = gr.Checkbox( + label="Shuffle Tags", value=True + ) + db_max_token_length = gr.Slider( + label="Max Token Length", + minimum=75, + maximum=300, + step=75, + ) + with gr.Accordion(open=False, label="Intervals"): + db_num_train_epochs = gr.Slider( + label="Training Steps Per Image (Epochs)", + value=100, + maximum=1000, + step=1, + ) + db_epoch_pause_frequency = gr.Slider( + label="Pause After N Epochs", + value=0, + maximum=100, + step=1, + ) + db_epoch_pause_time = gr.Slider( + label="Amount of time to pause between Epochs (s)", + value=0, + maximum=3600, + step=1, + ) + db_save_embedding_every = gr.Slider( + label="Save Model Frequency (Epochs)", + value=25, + maximum=1000, + step=1, + ) + db_save_preview_every = gr.Slider( + label="Save Preview(s) Frequency (Epochs)", + value=5, + maximum=1000, + step=1, + ) + with gr.Accordion(open=False, label="Batch Sizes"): + db_train_batch_size = gr.Slider( + label="Batch Size", + value=1, + minimum=1, + maximum=100, + step=1, + ) + db_gradient_accumulation_steps = gr.Slider( + label="Gradient Accumulation Steps", + value=1, + minimum=1, + maximum=100, + step=1, + ) + db_sample_batch_size = gr.Slider( + label="Class Batch Size", + minimum=1, + maximum=100, + value=1, + step=1, + ) + db_gradient_set_to_none = gr.Checkbox( + label="Set Gradients to None When Zeroing", value=True + ) + db_gradient_checkpointing = gr.Checkbox( + label="Gradient Checkpointing", value=True + ) + with gr.Accordion(open=False, label="Learning Rate"): + with gr.Row(visible=False) as lora_lr_row: + db_lora_learning_rate = gr.Number( + label="Lora UNET Learning Rate", value=1e-4 + ) + db_lora_txt_learning_rate = gr.Number( + label="Lora Text Encoder Learning Rate", value=5e-5 + ) + with gr.Row() as standard_lr_row: + db_learning_rate = gr.Number( + label="Learning Rate", value=2e-6 + ) + db_txt_learning_rate = gr.Number( + label="Text Encoder Learning Rate", value=1e-6 + ) + + db_lr_scheduler = gr.Dropdown( + label="Learning Rate Scheduler", + value="constant_with_warmup", + choices=list_schedulers(), + ) + db_learning_rate_min = gr.Number( + label="Min Learning Rate", value=1e-6, visible=False + ) + db_lr_cycles = gr.Number( + label="Number of Hard Resets", + value=1, + precision=0, + visible=False, + ) + db_lr_factor = gr.Number( + label="Constant/Linear Starting Factor", + value=0.5, + precision=2, + visible=False, + ) + db_lr_power = gr.Number( + label="Polynomial Power", + value=1.0, + precision=1, + visible=False, + ) + db_lr_scale_pos = gr.Slider( + label="Scale Position", + value=0.5, + minimum=0, + maximum=1, + step=0.05, + visible=False, + ) + db_lr_warmup_steps = gr.Slider( + label="Learning Rate Warmup Steps", + value=500, + step=5, + maximum=1000, + ) + with gr.Accordion(open=False, label="Lora"): + db_use_lora = gr.Checkbox(label="Use LORA", value=False) + db_use_lora_extended = gr.Checkbox( + label="Use Lora Extended", + value=False, + visible=False, + ) + db_train_imagic = gr.Checkbox(label="Train Imagic Only", value=False, visible=False) + db_train_inpainting = gr.Checkbox( + label="Train Inpainting Model", + value=False, + visible=False, + ) + with gr.Column(visible=False) as lora_rank_col: + db_lora_unet_rank = gr.Slider( + label="Lora UNET Rank", + value=4, + minimum=2, + maximum=128, + step=2, + ) + db_lora_txt_rank = gr.Slider( + label="Lora Text Encoder Rank", + value=4, + minimum=2, + maximum=128, + step=2, + ) + db_lora_weight = gr.Slider( + label="Lora Weight (Alpha)", + value=0.8, + minimum=0.1, + maximum=1, + step=0.1, + ) + with gr.Accordion(open=False, label="Image Processing"): + db_resolution = gr.Slider( + label="Max Resolution", + step=64, + minimum=128, + value=512, + maximum=2048, + elem_id="max_res", + ) + db_hflip = gr.Checkbox( + label="Apply Horizontal Flip", value=False + ) + db_dynamic_img_norm = gr.Checkbox( + label="Dynamic Image Normalization", value=False + ) + with gr.Accordion(open=False, label="Prior Loss"): + db_prior_loss_scale = gr.Checkbox( + label="Scale Prior Loss", value=False + ) + db_prior_loss_weight = gr.Slider( + label="Prior Loss Weight", + minimum=0.01, + maximum=1, + step=0.01, + value=0.75, + ) + db_prior_loss_target = gr.Number( + label="Prior Loss Target", + value=100, + visible=False, + ) + db_prior_loss_weight_min = gr.Slider( + label="Minimum Prior Loss Weight", + minimum=0.01, + maximum=1, + step=0.01, + value=0.1, + visible=False, + ) + with gr.Accordion(open=False, label="Extras"): + with gr.Column(): + gr.HTML(value="Sanity Samples") + db_sanity_prompt = gr.Textbox( + label="Sanity Sample Prompt", + placeholder="A generic prompt used to generate a sample image " + "to verify model fidelity.", + ) + db_sanity_negative_prompt = gr.Textbox( + label="Sanity Sample Negative Prompt", + placeholder="A negative prompt for the generic sample image.", + ) + db_sanity_seed = gr.Number( + label="Sanity Sample Seed", value=420420 + ) + + with gr.Column(): + gr.HTML(value="Miscellaneous") + db_pretrained_vae_name_or_path = gr.Textbox( + label="Pretrained VAE Name or Path", + placeholder="Leave blank to use base model VAE.", + value="", + ) + db_use_concepts = gr.Checkbox( + label="Use Concepts List", value=False + ) + db_concepts_path = gr.Textbox( + label="Concepts List", + placeholder="Path to JSON file with concepts to train.", + ) + with gr.Row(): + db_secret = gr.Textbox( + label="API Key", value=get_secret, interactive=False + ) + db_refresh_button = gr.Button( + value=refresh_symbol, elem_id="refresh_secret" + ) + db_clear_secret = gr.Button( + value=delete_symbol, elem_id="clear_secret" + ) + + with gr.Column(): + gr.HTML(value="Webhooks") + # In the future change this to something more generic and list the supported types + # from DreamboothWebhookTarget enum; for now, Discord is what I use ;) + # Add options to include notifications on training complete and exceptions that halt training + db_notification_webhook_url = gr.Textbox( + label="Discord Webhook", + placeholder="https://discord.com/api/webhooks/XXX/XXXX", + value="", + ) + notification_webhook_test_btn = gr.Button( + value="Save and Test Webhook" + ) with gr.Tab("Saving", elme_id="TabSave"): with gr.Column(): gr.HTML("General") @@ -1310,6 +1335,22 @@ def format_updates(): with gr.Column(): change_log = gr.HTML(format_updates(), elem_id="change_log") + + global preprocess_params + + preprocess_params = [ + db_preprocess_path, + db_preprocess_recursive + ] + + db_preprocess_path.change( + fn=check_preprocess_path, + inputs=[db_preprocess_path, db_preprocess_recursive], + outputs=[db_status, db_gallery] + ) + + db_gallery.select(load_image_caption, None, db_status) + global params_to_save global params_to_load @@ -1657,6 +1698,7 @@ def class_gen_method_changed(method): fn=load_model_params, inputs=[db_model_name], outputs=[ + db_model_info, db_model_path, db_revision, db_epochs, diff --git a/style.css b/style.css index 65eb2230..29e85331 100644 --- a/style.css +++ b/style.css @@ -7,6 +7,10 @@ max-height: 20px; } +#modelDetailRow { + text-align: center; +} + button:disabled, input:disabled { background: #4444 !important; color: #6666 !important; @@ -41,6 +45,10 @@ button:disabled, input:disabled { border: 1px grey solid; } +.singleCheckbox { + margin-top: 29px !important; +} + #change_modal.active { display: block; position: fixed; From 29db0f26fa30c83f66904cd424642976fb182b2e Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Wed, 8 Nov 2023 16:56:52 -0600 Subject: [PATCH 08/16] UI Rework, continued. --- javascript/dreambooth.js | 17 +- scripts/main.py | 748 +++++++++++++++++++++------------------ style.css | 37 +- 3 files changed, 447 insertions(+), 355 deletions(-) diff --git a/javascript/dreambooth.js b/javascript/dreambooth.js index 4d22fa24..35e4ca00 100644 --- a/javascript/dreambooth.js +++ b/javascript/dreambooth.js @@ -7,6 +7,7 @@ let locked = false; let listenersSet = false; let timeouts = []; let listeners = {}; +let elementsHidden = false; function save_config() { let btn = gradioApp().getElementById("db_save_config"); @@ -90,7 +91,7 @@ function update_params() { let btn = gradioApp().getElementById("db_update_params"); if (btn == null) return; btn.click(); - }, 500); + }, 100); } function getRealElement(selector) { @@ -338,8 +339,22 @@ let db_titles = { "Weight Decay": "Values closer to 0 closely match your training dataset, and values closer to 1 generalize more and deviate from your training dataset. Default is 1e-2, values lower than 0.1 are recommended. For D-Adaptation values between 0.02 and 0.04 are recommended", } +function hideElements() { + if (!elementsHidden) { + let btn = gradioApp().getElementById("db_hide_advanced"); + if (btn == null) return; + elementsHidden = true; + console.log("Hiding advanced elements!"); + btn.click(); + } +} + // Do a thing when the UI updates onUiUpdate(function () { + setTimeout(function () { + hideElements(); + },100); + let db_active = document.getElementById("db_active"); if (db_active) { db_active.parentElement.style.display = "none"; diff --git a/scripts/main.py b/scripts/main.py index f6ed51d6..e9a49fed 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -62,6 +62,7 @@ delete_symbol = "\U0001F5D1" # 🗑️ update_symbol = "\U0001F51D" # 🠝 log_parser = LogParser() +show_advanced = False def read_metadata_from_safetensors(filename): @@ -329,7 +330,12 @@ def on_ui_tabs(): db_shared_diffusers_path = gr.HTML() with gr.Row(equal_height=False): with gr.Column(variant="panel", elem_id="SettingsPanel"): - gr.HTML(value="Settings") + with gr.Row(): + with gr.Column(scale=1, min_width=100, elem_classes="halfElement"): + gr.HTML(value="Settings") + with gr.Column(scale=1, min_width=100, elem_classes="halfElement"): + db_show_advanced = gr.Button(value="Show Advanced", size="sm", elem_classes="advBtn", visible=False) + db_hide_advanced = gr.Button(value="Hide Advanced", variant="primary", size="sm", elem_id="db_hide_advanced", elem_classes="advBtn") with gr.Tab("Model", elem_id="ModelPanel"): with gr.Column(): with gr.Tab("Select"): @@ -343,7 +349,7 @@ def on_ui_tabs(): lambda: {"choices": sorted(get_db_models())}, "refresh_db_models", ) - with gr.Row(): + with gr.Row() as db_snapshot_row: db_snapshot = gr.Dropdown( label="Snapshot to Resume", choices=sorted(get_model_snapshots()), @@ -417,7 +423,7 @@ def on_ui_tabs(): db_new_model_extract_ema = gr.Checkbox( label="Extract EMA Weights", value=False ) - db_train_unfrozen = gr.Checkbox(label="Unfreeze Model", value=False) + db_train_unfrozen = gr.Checkbox(label="Unfreeze Model", value=True) with gr.Column(): with gr.Accordion(open=False, label="Resources"): with gr.Column(): @@ -427,51 +433,44 @@ def on_ui_tabs(): gr.HTML( value="Release notes", ) - with gr.Tab("Preprocess", elem_id="PreprocessPanel", visible=False): - with gr.Row(): - with gr.Column(scale=2, variant="compact"): - db_preprocess_path = gr.Textbox( - label="Image Path", value="", placeholder="Enter the path to your images" - ) - with gr.Column(variant="compact"): - db_preprocess_recursive = gr.Checkbox( - label="Recursive", value=False, container=True, elem_classes=["singleCheckbox"] - ) - with gr.Row(): - with gr.Tab("Auto-Caption"): - with gr.Row(): - gr.HTML(value="Auto-Caption") - with gr.Tab("Edit Captions"): - with gr.Row(): - db_preprocess_autosave = gr.Checkbox( - label="Autosave", value=False - ) - with gr.Row(): - gr.HTML(value="Edit Captions") - with gr.Tab("Edit Images"): - with gr.Row(): - gr.HTML(value="Edit Images") - with gr.Row(): - db_preprocess = gr.Button( - value="Preprocess", variant="primary" - ) - db_preprocess_all = gr.Button( - value="Preprocess All", variant="primary" - ) - with gr.Row(): - db_preprocess_all = gr.Button( - value="Preprocess All", variant="primary" - ) + # with gr.Tab("Preprocess", elem_id="PreprocessPanel", visible=False): + # with gr.Row(): + # with gr.Column(scale=2, variant="compact"): + # db_preprocess_path = gr.Textbox( + # label="Image Path", value="", placeholder="Enter the path to your images" + # ) + # with gr.Column(variant="compact"): + # db_preprocess_recursive = gr.Checkbox( + # label="Recursive", value=False, container=True, elem_classes=["singleCheckbox"] + # ) + # with gr.Row(): + # with gr.Tab("Auto-Caption"): + # with gr.Row(): + # gr.HTML(value="Auto-Caption") + # with gr.Tab("Edit Captions"): + # with gr.Row(): + # db_preprocess_autosave = gr.Checkbox( + # label="Autosave", value=False + # ) + # with gr.Row(): + # gr.HTML(value="Edit Captions") + # with gr.Tab("Edit Images"): + # with gr.Row(): + # gr.HTML(value="Edit Images") + # with gr.Row(): + # db_preprocess = gr.Button( + # value="Preprocess", variant="primary" + # ) + # db_preprocess_all = gr.Button( + # value="Preprocess All", variant="primary" + # ) + # with gr.Row(): + # db_preprocess_all = gr.Button( + # value="Preprocess All", variant="primary" + # ) with gr.Tab("Concepts", elem_id="TabConcepts") as concept_tab: with gr.Column(variant="panel"): - with gr.Row(visible=False): - db_train_wizard_person = gr.Button( - value="Training Wizard (Person)" - ) - db_train_wizard_object = gr.Button( - value="Training Wizard (Object/Style)" - ) - with gr.Tab("Concept 1"): + with gr.Accordion(open=False, label="Concept 1"): ( c1_instance_data_dir, c1_class_data_dir, @@ -492,7 +491,7 @@ def on_ui_tabs(): c1_save_infer_steps, ) = build_concept_panel(1) - with gr.Tab("Concept 2"): + with gr.Accordion(open=False, label="Concept 2"): ( c2_instance_data_dir, c2_class_data_dir, @@ -513,7 +512,7 @@ def on_ui_tabs(): c2_save_infer_steps, ) = build_concept_panel(2) - with gr.Tab("Concept 3"): + with gr.Accordion(open=False, label="Concept 3"): ( c3_instance_data_dir, c3_class_data_dir, @@ -534,7 +533,7 @@ def on_ui_tabs(): c3_save_infer_steps, ) = build_concept_panel(3) - with gr.Tab("Concept 4"): + with gr.Accordion(open=False, label="Concept 4"): ( c4_instance_data_dir, c4_class_data_dir, @@ -688,7 +687,7 @@ def on_ui_tabs(): maximum=1000, step=1, ) - with gr.Accordion(open=False, label="Batch Sizes"): + with gr.Accordion(open=False, label="Batch Sizes") as db_batch_size_view: db_train_batch_size = gr.Slider( label="Batch Size", value=1, @@ -822,7 +821,7 @@ def on_ui_tabs(): db_dynamic_img_norm = gr.Checkbox( label="Dynamic Image Normalization", value=False ) - with gr.Accordion(open=False, label="Prior Loss"): + with gr.Accordion(open=False, label="Prior Loss") as db_prior_loss_view: db_prior_loss_scale = gr.Checkbox( label="Scale Prior Loss", value=False ) @@ -846,6 +845,168 @@ def on_ui_tabs(): value=0.1, visible=False, ) + with gr.Accordion(open=False, label="Saving", elme_id="TabSave") as db_save_tab: + with gr.Column(): + gr.HTML("General") + db_custom_model_name = gr.Textbox( + label="Custom Model Name", + value="", + placeholder="Enter a model name for saving checkpoints and lora models.", + ) + db_save_safetensors = gr.Checkbox( + label="Save in .safetensors format", + value=True, + visible=False, + ) + db_save_ema = gr.Checkbox( + label="Save EMA Weights to Generated Models", value=True + ) + db_infer_ema = gr.Checkbox( + label="Use EMA Weights for Inference", value=False + ) + with gr.Column(): + gr.HTML("Checkpoints") + db_half_model = gr.Checkbox(label="Half Model", value=False) + db_use_subdir = gr.Checkbox( + label="Save Checkpoint to Subdirectory", value=True + ) + db_save_ckpt_during = gr.Checkbox( + label="Generate a .ckpt file when saving during training." + ) + db_save_ckpt_after = gr.Checkbox( + label="Generate a .ckpt file when training completes.", + value=True, + ) + db_save_ckpt_cancel = gr.Checkbox( + label="Generate a .ckpt file when training is canceled." + ) + with gr.Column(visible=False) as lora_save_col: + db_save_lora_during = gr.Checkbox( + label="Generate lora weights when saving during training." + ) + db_save_lora_after = gr.Checkbox( + label="Generate lora weights when training completes.", + value=True, + ) + db_save_lora_cancel = gr.Checkbox( + label="Generate lora weights when training is canceled." + ) + db_save_lora_for_extra_net = gr.Checkbox( + label="Generate lora weights for extra networks." + ) + with gr.Column(): + gr.HTML("Diffusion Weights (training snapshots)") + db_save_state_during = gr.Checkbox( + label="Save separate diffusers snapshots when saving during training." + ) + db_save_state_after = gr.Checkbox( + label="Save separate diffusers snapshots when training completes." + ) + db_save_state_cancel = gr.Checkbox( + label="Save separate diffusers snapshots when training is canceled." + ) + with gr.Accordion(open=False, label="Image Generation", elem_id="TabGenerate") as db_generate_tab: + gr.HTML(value="Class Generation Schedulers") + db_class_gen_method = gr.Dropdown( + label="Image Generation Library", + value="Native Diffusers", + choices=[ + "A1111 txt2img (Euler a)", + "Native Diffusers", + ] + ) + db_scheduler = gr.Dropdown( + label="Image Generation Scheduler", + value="DEISMultistep", + choices=get_scheduler_names(), + ) + gr.HTML(value="Manual Class Generation") + with gr.Column(): + db_generate_classes = gr.Button(value="Generate Class Images") + db_generate_graph = gr.Button(value="Generate Graph") + db_graph_smoothing = gr.Slider( + value=50, + label="Graph Smoothing Steps", + minimum=10, + maximum=500, + ) + db_debug_buckets = gr.Button(value="Debug Buckets") + db_bucket_epochs = gr.Slider( + value=10, + step=1, + minimum=1, + maximum=1000, + label="Epochs to Simulate", + ) + db_bucket_batch = gr.Slider( + value=1, + step=1, + minimum=1, + maximum=500, + label="Batch Size to Simulate", + ) + db_generate_sample = gr.Button(value="Generate Sample Images") + db_sample_prompt = gr.Textbox(label="Sample Prompt") + db_sample_negative = gr.Textbox(label="Sample Negative Prompt") + db_sample_prompt_file = gr.Textbox(label="Sample Prompt File") + db_sample_width = gr.Slider( + label="Sample Width", + value=512, + step=64, + minimum=128, + maximum=2048, + ) + db_sample_height = gr.Slider( + label="Sample Height", + value=512, + step=64, + minimum=128, + maximum=2048, + ) + db_sample_seed = gr.Number( + label="Sample Seed", value=-1, precision=0 + ) + db_num_samples = gr.Slider( + label="Number of Samples to Generate", + value=1, + minimum=1, + maximum=1000, + step=1, + ) + db_gen_sample_batch_size = gr.Slider( + label="Sample Batch Size", + value=1, + step=1, + minimum=1, + maximum=100, + interactive=True, + ) + db_sample_steps = gr.Slider( + label="Sample Steps", + value=20, + minimum=1, + maximum=500, + step=1, + ) + db_sample_scale = gr.Slider( + label="Sample CFG Scale", + value=7.5, + step=0.1, + minimum=1, + maximum=20, + ) + with gr.Column(variant="panel", visible=has_face_swap()): + db_swap_faces = gr.Checkbox(label="Swap Sample Faces") + db_swap_prompt = gr.Textbox(label="Swap Prompt") + db_swap_negative = gr.Textbox(label="Swap Negative Prompt") + db_swap_steps = gr.Slider(label="Swap Steps", value=40) + db_swap_batch = gr.Slider(label="Swap Batch", value=40) + + db_sample_txt2img = gr.Checkbox( + label="Use txt2img", + value=False, + visible=False # db_sample_txt2img not implemented yet + ) with gr.Accordion(open=False, label="Extras"): with gr.Column(): gr.HTML(value="Sanity Samples") @@ -861,8 +1022,7 @@ def on_ui_tabs(): db_sanity_seed = gr.Number( label="Sanity Sample Seed", value=420420 ) - - with gr.Column(): + with gr.Column() as db_misc_view: gr.HTML(value="Miscellaneous") db_pretrained_vae_name_or_path = gr.Textbox( label="Pretrained VAE Name or Path", @@ -886,8 +1046,7 @@ def on_ui_tabs(): db_clear_secret = gr.Button( value=delete_symbol, elem_id="clear_secret" ) - - with gr.Column(): + with gr.Column() as db_hook_view: gr.HTML(value="Webhooks") # In the future change this to something more generic and list the supported types # from DreamboothWebhookTarget enum; for now, Discord is what I use ;) @@ -900,223 +1059,64 @@ def on_ui_tabs(): notification_webhook_test_btn = gr.Button( value="Save and Test Webhook" ) - with gr.Tab("Saving", elme_id="TabSave"): - with gr.Column(): - gr.HTML("General") - db_custom_model_name = gr.Textbox( - label="Custom Model Name", - value="", - placeholder="Enter a model name for saving checkpoints and lora models.", - ) - db_save_safetensors = gr.Checkbox( - label="Save in .safetensors format", - value=True, + with gr.Column() as db_test_tab: + gr.HTML(value="Experimental Settings") + db_tomesd = gr.Slider( + value=0, + label="Token Merging (ToMe)", + minimum=0, + maximum=1, + step=0.1, + ) + db_split_loss = gr.Checkbox( + label="Calculate Split Loss", value=True + ) + db_disable_class_matching = gr.Checkbox(label="Disable Class Matching") + db_disable_logging = gr.Checkbox(label="Disable Logging") + db_deterministic = gr.Checkbox(label="Deterministic") + db_ema_predict = gr.Checkbox(label="Use EMA for prediction") + db_lora_use_buggy_requires_grad = gr.Checkbox(label="LoRA use buggy requires grad") + db_noise_scheduler = gr.Dropdown( + label="Noise scheduler", + value="DDPM", + choices=[ + "DDPM", + "DEIS", + "UniPC" + ] + ) + db_update_extension = gr.Button( + value="Update Extension and Restart" + ) + + with gr.Column(variant="panel"): + gr.HTML(value="Bucket Cropping") + db_crop_src_path = gr.Textbox(label="Source Path") + db_crop_dst_path = gr.Textbox(label="Dest Path") + db_crop_max_res = gr.Slider( + label="Max Res", value=512, step=64, maximum=2048 + ) + db_crop_bucket_step = gr.Slider( + label="Bucket Steps", value=8, step=8, maximum=512 + ) + db_crop_dry = gr.Checkbox(label="Dry Run") + db_start_crop = gr.Button("Start Cropping") + with gr.Column(variant="panel"): + with gr.Row(): + with gr.Column(scale=1, min_width=110): + gr.HTML(value="Output") + with gr.Column(scale=1, min_width=110): + db_check_progress_initial = gr.Button( + value=update_symbol, + elem_id="db_check_progress_initial", visible=False, ) - db_save_ema = gr.Checkbox( - label="Save EMA Weights to Generated Models", value=True - ) - db_infer_ema = gr.Checkbox( - label="Use EMA Weights for Inference", value=False - ) - with gr.Column(): - gr.HTML("Checkpoints") - db_half_model = gr.Checkbox(label="Half Model", value=False) - db_use_subdir = gr.Checkbox( - label="Save Checkpoint to Subdirectory", value=True - ) - db_save_ckpt_during = gr.Checkbox( - label="Generate a .ckpt file when saving during training." - ) - db_save_ckpt_after = gr.Checkbox( - label="Generate a .ckpt file when training completes.", - value=True, - ) - db_save_ckpt_cancel = gr.Checkbox( - label="Generate a .ckpt file when training is canceled." - ) - with gr.Column(visible=False) as lora_save_col: - db_save_lora_during = gr.Checkbox( - label="Generate lora weights when saving during training." - ) - db_save_lora_after = gr.Checkbox( - label="Generate lora weights when training completes.", - value=True, - ) - db_save_lora_cancel = gr.Checkbox( - label="Generate lora weights when training is canceled." - ) - db_save_lora_for_extra_net = gr.Checkbox( - label="Generate lora weights for extra networks." - ) - with gr.Column(): - gr.HTML("Diffusion Weights (training snapshots)") - db_save_state_during = gr.Checkbox( - label="Save separate diffusers snapshots when saving during training." - ) - db_save_state_after = gr.Checkbox( - label="Save separate diffusers snapshots when training completes." - ) - db_save_state_cancel = gr.Checkbox( - label="Save separate diffusers snapshots when training is canceled." - ) - with gr.Tab("Generate", elem_id="TabGenerate"): - gr.HTML(value="Class Generation Schedulers") - db_class_gen_method = gr.Dropdown( - label="Image Generation Library", - value="Native Diffusers", - choices=[ - "A1111 txt2img (Euler a)", - "Native Diffusers", - ] - ) - db_scheduler = gr.Dropdown( - label="Image Generation Scheduler", - value="DEISMultistep", - choices=get_scheduler_names(), - ) - gr.HTML(value="Manual Class Generation") - with gr.Column(): - db_generate_classes = gr.Button(value="Generate Class Images") - db_generate_graph = gr.Button(value="Generate Graph") - db_graph_smoothing = gr.Slider( - value=50, - label="Graph Smoothing Steps", - minimum=10, - maximum=500, - ) - db_debug_buckets = gr.Button(value="Debug Buckets") - db_bucket_epochs = gr.Slider( - value=10, - step=1, - minimum=1, - maximum=1000, - label="Epochs to Simulate", - ) - db_bucket_batch = gr.Slider( - value=1, - step=1, - minimum=1, - maximum=500, - label="Batch Size to Simulate", - ) - db_generate_sample = gr.Button(value="Generate Sample Images") - db_sample_prompt = gr.Textbox(label="Sample Prompt") - db_sample_negative = gr.Textbox(label="Sample Negative Prompt") - db_sample_prompt_file = gr.Textbox(label="Sample Prompt File") - db_sample_width = gr.Slider( - label="Sample Width", - value=512, - step=64, - minimum=128, - maximum=2048, - ) - db_sample_height = gr.Slider( - label="Sample Height", - value=512, - step=64, - minimum=128, - maximum=2048, - ) - db_sample_seed = gr.Number( - label="Sample Seed", value=-1, precision=0 - ) - db_num_samples = gr.Slider( - label="Number of Samples to Generate", - value=1, - minimum=1, - maximum=1000, - step=1, - ) - db_gen_sample_batch_size = gr.Slider( - label="Sample Batch Size", - value=1, - step=1, - minimum=1, - maximum=100, - interactive=True, - ) - db_sample_steps = gr.Slider( - label="Sample Steps", - value=20, - minimum=1, - maximum=500, - step=1, - ) - db_sample_scale = gr.Slider( - label="Sample CFG Scale", - value=7.5, - step=0.1, - minimum=1, - maximum=20, - ) - with gr.Column(variant="panel", visible=has_face_swap()): - db_swap_faces = gr.Checkbox(label="Swap Sample Faces") - db_swap_prompt = gr.Textbox(label="Swap Prompt") - db_swap_negative = gr.Textbox(label="Swap Negative Prompt") - db_swap_steps = gr.Slider(label="Swap Steps", value=40) - db_swap_batch = gr.Slider(label="Swap Batch", value=40) - - db_sample_txt2img = gr.Checkbox( - label="Use txt2img", - value=False, - visible=False # db_sample_txt2img not implemented yet - ) - with gr.Tab("Testing", elem_id="TabDebug"): - gr.HTML(value="Experimental Settings") - db_tomesd = gr.Slider( - value=0, - label="Token Merging (ToMe)", - minimum=0, - maximum=1, - step=0.1, - ) - db_split_loss = gr.Checkbox( - label="Calculate Split Loss", value=True - ) - db_disable_class_matching = gr.Checkbox(label="Disable Class Matching") - db_disable_logging = gr.Checkbox(label="Disable Logging") - db_deterministic = gr.Checkbox(label="Deterministic") - db_ema_predict = gr.Checkbox(label="Use EMA for prediction") - db_lora_use_buggy_requires_grad = gr.Checkbox(label="LoRA use buggy requires grad") - db_noise_scheduler = gr.Dropdown( - label="Noise scheduler", - value="DDPM", - choices=[ - "DDPM", - "DEIS", - "UniPC" - ] - ) - db_update_extension = gr.Button( - value="Update Extension and Restart" - ) + # These two should be updated while doing things + db_active = gr.Checkbox(elem_id="db_active", value=False, visible=False) - with gr.Column(variant="panel"): - gr.HTML(value="Bucket Cropping") - db_crop_src_path = gr.Textbox(label="Source Path") - db_crop_dst_path = gr.Textbox(label="Dest Path") - db_crop_max_res = gr.Slider( - label="Max Res", value=512, step=64, maximum=2048 + ui_check_progress_initial = gr.Button( + value="Refresh", elem_id="ui_check_progress_initial", elem_classes="advBtn", size="sm" ) - db_crop_bucket_step = gr.Slider( - label="Bucket Steps", value=8, step=8, maximum=512 - ) - db_crop_dry = gr.Checkbox(label="Dry Run") - db_start_crop = gr.Button("Start Cropping") - with gr.Column(variant="panel"): - gr.HTML(value="Output") - db_check_progress_initial = gr.Button( - value=update_symbol, - elem_id="db_check_progress_initial", - visible=False, - ) - # These two should be updated while doing things - db_active = gr.Checkbox(elem_id="db_active", value=False, visible=False) - - ui_check_progress_initial = gr.Button( - value=update_symbol, elem_id="ui_check_progress_initial" - ) db_status = gr.HTML(elem_id="db_status", value="") db_progressbar = gr.HTML(elem_id="db_progressbar") db_gallery = gr.Gallery( @@ -1335,20 +1335,104 @@ def format_updates(): with gr.Column(): change_log = gr.HTML(format_updates(), elem_id="change_log") - - global preprocess_params - - preprocess_params = [ - db_preprocess_path, - db_preprocess_recursive + advanced_elements = [ + db_snapshot_row, + db_create_from_hub, + db_new_model_extract_ema, + db_train_unfrozen, + db_use_ema, + db_freeze_clip_normalization, + db_full_mixed_precision, + db_offset_noise, + db_weight_decay, + db_tenc_weight_decay, + db_tenc_grad_clip_norm, + db_min_snr_gamma, + db_pad_tokens, + db_strict_tokens, + db_max_token_length, + db_epoch_pause_frequency, + db_epoch_pause_time, + db_batch_size_view, + db_lr_scheduler, + db_lr_warmup_steps, + db_hflip, + db_prior_loss_view, + db_misc_view, + db_hook_view, + db_save_tab, + db_generate_tab, + db_test_tab, + db_dynamic_img_norm, + db_tomesd, + db_split_loss, + db_disable_class_matching, + db_disable_logging, + db_deterministic, + db_ema_predict, + db_lora_use_buggy_requires_grad, + db_noise_scheduler, + c1_class_guidance_scale, + c1_class_infer_steps, + c1_save_sample_negative_prompt, + c1_sample_seed, + c1_save_guidance_scale, + c1_save_infer_steps, + c2_class_guidance_scale, + c2_class_infer_steps, + c2_save_sample_negative_prompt, + c2_sample_seed, + c2_save_guidance_scale, + c2_save_infer_steps, + c3_class_guidance_scale, + c3_class_infer_steps, + c3_save_sample_negative_prompt, + c3_sample_seed, + c3_save_guidance_scale, + c3_save_infer_steps, + c4_class_guidance_scale, + c4_class_infer_steps, + c4_save_sample_negative_prompt, + c4_sample_seed, + c4_save_guidance_scale, + c4_save_infer_steps, ] - db_preprocess_path.change( - fn=check_preprocess_path, - inputs=[db_preprocess_path, db_preprocess_recursive], - outputs=[db_status, db_gallery] + def toggle_advanced(): + global show_advanced + show_advanced = False if show_advanced else True + outputs = [gr.update(visible=True), gr.update(visible=False)] + print(f"Advanced elements visible: {show_advanced}") + for _ in advanced_elements: + outputs.append(gr.update(visible=show_advanced)) + + return outputs + # Merge db_show advanced, db_hide_advanced, and advanced elements into one list + db_show_advanced.click( + fn=toggle_advanced, + inputs=[], + outputs=[db_hide_advanced, db_show_advanced, *advanced_elements] ) + db_hide_advanced.click( + fn=toggle_advanced, + inputs=[], + outputs=[db_show_advanced, db_hide_advanced, *advanced_elements] + ) + + global preprocess_params + + # preprocess_params = [ + # db_preprocess_path, + # db_preprocess_recursive + # ] + # + # db_preprocess_path.change( + # fn=check_preprocess_path, + # inputs=[db_preprocess_path, db_preprocess_recursive], + # outputs=[db_status, db_gallery] + # ) + db_gallery.select(load_image_caption, None, db_status) global params_to_save @@ -1754,33 +1838,6 @@ def class_gen_method_changed(method): ], ) - db_train_wizard_person.click( - fn=training_wizard_person, - _js="db_start_twizard", - inputs=[db_model_name], - outputs=[ - db_num_train_epochs, - c1_num_class_images_per, - c2_num_class_images_per, - c3_num_class_images_per, - c4_num_class_images_per, - db_status, - ], - ) - - db_train_wizard_object.click( - fn=training_wizard, - _js="db_start_twizard", - inputs=[db_model_name], - outputs=[ - db_num_train_epochs, - c1_num_class_images_per, - c2_num_class_images_per, - c3_num_class_images_per, - c4_num_class_images_per, - db_status, - ], - ) db_generate_sample.click( fn=wrap_gpu_call(generate_samples), @@ -1878,53 +1935,61 @@ def set_gen_sample(): def build_concept_panel(concept: int): - with gr.Column(): - gr.HTML(value="Directories") + with gr.Tab(label="Instance Images"): instance_data_dir = gr.Textbox( - label="Dataset Directory", + label="Directory", placeholder="Path to directory with input images", elem_id=f"idd{concept}", ) + instance_prompt = gr.Textbox(label="Prompt", value="[filewords]") + gr.HTML(value="Use [filewords] here to read prompts from caption files/filename, or a prompt to describe your training images.
" + "If using [filewords], your instance and class tokens will be inserted into the prompt as necessary for training.", elem_classes="hintHtml") + instance_token = gr.Textbox(label="Instance Token") + gr.HTML(value="If using [filewords] above, this is the unique word used for your subject, like 'fydodog' or 'ohwx'.", + elem_classes="hintHtml") + class_token = gr.Textbox(label="Class Token") + gr.HTML(value="If using [filewords] above, this is the generic word used for your subject, like 'dog' or 'person'.", + elem_classes="hintHtml") + + with gr.Tab(label="Class Images"): class_data_dir = gr.Textbox( - label="Classification Dataset Directory", + label="Directory", placeholder="(Optional) Path to directory with " "classification/regularization images", elem_id=f"cdd{concept}", ) - with gr.Column(): - gr.HTML(value="Filewords") - instance_token = gr.Textbox( - label="Instance Token", - placeholder="When using [filewords], this is the subject to use when building prompts.", - ) - class_token = gr.Textbox( - label="Class Token", - placeholder="When using [filewords], this is the class to use when building prompts.", - ) + class_prompt = gr.Textbox(label="Prompt", value="[filewords]") + gr.HTML( + value="Use [filewords] here to read prompts from caption files/filename, or a prompt to describe your training images.
" + "If using [filewords], your class token will be inserted into the file prompts if it is not found.", + elem_classes="hintHtml") - with gr.Column(): - gr.HTML(value="Training Prompts") - instance_prompt = gr.Textbox( - label="Instance Prompt", - placeholder="Optionally use [filewords] to read image " - "captions from files.", + class_negative_prompt = gr.Textbox( + label="Negative Prompt" ) - class_prompt = gr.Textbox( - label="Class Prompt", - placeholder="Optionally use [filewords] to read image " - "captions from files.", + num_class_images_per = gr.Slider( + label="Class Images Per Instance Image", value=0, precision=0 ) - class_negative_prompt = gr.Textbox( - label="Classification Image Negative Prompt" + gr.HTML(value="For every instance image, this many classification images will be used/generated. Leave at 0 to disable.", + elem_classes="hintHtml") + class_guidance_scale = gr.Slider( + label="Classification CFG Scale", value=7.5, maximum=12, minimum=1, step=0.1 ) - with gr.Column(): - gr.HTML(value="Sample Prompts") + class_infer_steps = gr.Slider( + label="Classification Steps", value=40, minimum=10, maximum=200, step=1 + ) + with gr.Tab(label="Sample Images"): save_sample_prompt = gr.Textbox( label="Sample Image Prompt", placeholder="Leave blank to use instance prompt. " "Optionally use [filewords] to base " "sample captions on instance images.", ) + gr.HTML( + value="Leave blank or use [filewords] here to randomly select prompts from the existing instance prompt(s).
" + "If using [filewords], your instance token will be inserted into the file prompts if it is not found.", + elem_classes="hintHtml") + save_sample_negative_prompt = gr.Textbox( label="Sample Negative Prompt" ) @@ -1932,21 +1997,8 @@ def build_concept_panel(concept: int): label="Sample Prompt Template File", placeholder="Enter the path to a txt file containing sample prompts.", ) - - with gr.Column(): - gr.HTML("Class Image Generation") - num_class_images_per = gr.Slider( - label="Class Images Per Instance Image", value=0, precision=0 - ) - class_guidance_scale = gr.Slider( - label="Classification CFG Scale", value=7.5, maximum=12, minimum=1, step=0.1 - ) - class_infer_steps = gr.Slider( - label="Classification Steps", value=40, minimum=10, maximum=200, step=1 - ) - - with gr.Column(): - gr.HTML("Sample Image Generation") + gr.HTML(value="When enabled the above prompt and negative prompt will be ignored.", + elem_classes="hintHtml") n_save_sample = gr.Slider( label="Number of Samples to Generate", value=1, maximum=100, step=1 ) diff --git a/style.css b/style.css index 29e85331..230fa24a 100644 --- a/style.css +++ b/style.css @@ -49,6 +49,26 @@ button:disabled, input:disabled { margin-top: 29px !important; } +button.advBtn { + width: 110px !important; + position: absolute; + right: 0; +} + +.halfElement { + max-width: 50% !important; +} + +.hideAdvanced { + display: none !important; + width: 0 !important; + max-width: 0 !important; + height: 0 !important; + max-height: 0 !important; + margin: 0 !important; + padding: 0 !important; +} + #change_modal.active { display: block; position: fixed; @@ -116,18 +136,13 @@ button:disabled, input:disabled { } -#refresh_db_models, #refresh_lora_models, #refresh_sd_models, #refresh_secret, #clear_secret, #ui_check_progress_initial { +#refresh_db_models, #refresh_lora_models, #refresh_sd_models, #refresh_secret, #clear_secret { margin-top: 0.75em; max-width: 2.5em; min-width: 2.5em; height: 2.4em; } -#ui_check_progress_initial { - position: absolute; - top: -3px; - right: 10px; -} #db_gen_ckpt_during, #db_train_sample { display: none; @@ -171,4 +186,14 @@ button:disabled, input:disabled { .hyperlink { text-decoration: underline +} + +.hintHtml { + padding-bottom: 10px; + padding-top: 5px; + color: #b3b3b3; +} + +#TabSettings .gap, #TabConcepts .gap { + display: block !important; } \ No newline at end of file From 20b5c3aba606019881d99c153ba52cbc48eae185 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Wed, 8 Nov 2023 18:01:37 -0600 Subject: [PATCH 09/16] Fix up the modal too. --- scripts/main.py | 18 +++++++++--------- style.css | 36 +++++++++++++++++++++++++++++++----- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/scripts/main.py b/scripts/main.py index e9a49fed..67e9adc0 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -62,7 +62,7 @@ delete_symbol = "\U0001F5D1" # 🗑️ update_symbol = "\U0001F51D" # 🠝 log_parser = LogParser() -show_advanced = False +show_advanced = True def read_metadata_from_safetensors(filename): @@ -284,21 +284,21 @@ def on_ui_tabs(): with gr.Blocks() as dreambooth_interface: # Top button row with gr.Row(equal_height=True, elem_id="DbTopRow"): - db_load_params = gr.Button(value="Load Settings", elem_id="db_load_params") - db_save_params = gr.Button(value="Save Settings", elem_id="db_save_config") + db_load_params = gr.Button(value="Load Settings", elem_id="db_load_params", size="sm") + db_save_params = gr.Button(value="Save Settings", elem_id="db_save_config", size="sm") db_train_model = gr.Button( - value="Train", variant="primary", elem_id="db_train" + value="Train", variant="primary", elem_id="db_train", size="sm" ) db_generate_checkpoint = gr.Button( - value="Generate Ckpt", elem_id="db_gen_ckpt" + value="Generate Ckpt", elem_id="db_gen_ckpt", size="sm" ) db_generate_checkpoint_during = gr.Button( - value="Save Weights", elem_id="db_gen_ckpt_during" + value="Save Weights", elem_id="db_gen_ckpt_during", size="sm" ) db_train_sample = gr.Button( - value="Generate Samples", elem_id="db_train_sample" + value="Generate Samples", elem_id="db_train_sample", size="sm" ) - db_cancel = gr.Button(value="Cancel", elem_id="db_cancel") + db_cancel = gr.Button(value="Cancel", elem_id="db_cancel", size="sm") with gr.Row(): gr.HTML(value="Select or create a model to begin.", elem_id="hint_row") with gr.Row(elem_id="ModelDetailRow", visible=False, variant="compact") as db_model_info: @@ -1292,7 +1292,7 @@ def update_model_options(model_type): fn=lambda: check_progress_call(), show_progress=False, inputs=[], - outputs=progress_elements, + outputs=progress_elements ) db_check_progress_initial.click( diff --git a/style.css b/style.css index 230fa24a..eeba49cc 100644 --- a/style.css +++ b/style.css @@ -43,6 +43,28 @@ button:disabled, input:disabled { .commitDiv { border: 1px grey solid; + padding: 5px; +} + +/* Targets the first .commitDiv in its parent */ +.commitDiv:first-child { + border-top-left-radius: 5px; + border-top-right-radius: 5px; +} + +/* Targets the last .commitDiv in its parent */ +.commitDiv:last-child { + border-bottom-left-radius: 5px; + border-bottom-right-radius: 5px; +} + +/* Removes border-radius for every other .commitDiv */ +.commitDiv:not(:first-child):not(:last-child) { + border-radius: 0; +} + +.commitDiv h3 { + margin-top: 0!important; } .singleCheckbox { @@ -82,7 +104,7 @@ button.advBtn { margin: 0 auto; z-index: 10000; border: 1px solid white; - overflow-y: scroll; + overflow-y: auto; overflow-x: hidden; } @@ -94,10 +116,14 @@ button.advBtn { #close_modal { - min-width: 10px; - max-width: 10px; - min-height: 10px; - max-height: 35px; + min-width: 30px; + max-width: 30px; + min-height: 30px; + max-height: 30px; + height: 30px; + width: 30px; + padding: 3px; + font-family: monospace; position: absolute; right: 11px; } From 1bb9987e80574b2f63a155d0ef48d86fe549b248 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Thu, 9 Nov 2023 13:19:15 -0600 Subject: [PATCH 10/16] Fix saving checkpoint, wandb messages... --- dreambooth/train_dreambooth.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/dreambooth/train_dreambooth.py b/dreambooth/train_dreambooth.py index c3c6b250..c55f312f 100644 --- a/dreambooth/train_dreambooth.py +++ b/dreambooth/train_dreambooth.py @@ -20,6 +20,7 @@ import torch.backends.cuda import torch.backends.cudnn import torch.nn.functional as F +import wandb from accelerate import Accelerator from accelerate.utils.random import set_seed as set_seed2 from diffusers import ( @@ -72,6 +73,9 @@ set_lora_requires_grad, ) +# Disable annoying wandb popup? +wandb.config.auto_init = False + logger = logging.getLogger(__name__) # define a Handler which writes DEBUG messages or higher to the sys.stderr dl.set_verbosity_error() @@ -933,7 +937,8 @@ def check_save(is_epoch_check=False): if global_step > 0: save_image = True save_model = True - save_lora = True + if args.use_lora: + save_lora = True save_snapshot = False @@ -966,6 +971,8 @@ def check_save(is_epoch_check=False): if save_checkpoint and args.use_lora: save_checkpoint = False save_lora = True + if not args.use_lora: + save_lora = False if ( save_checkpoint @@ -1073,6 +1080,7 @@ def save_weights( weights_dir = f"{weights_dir}_temp" os.makedirs(weights_dir, exist_ok=True) else: + save_lora = False logger.debug(f"Save checkpoint: {save_checkpoint} save lora {save_lora}.") # Is inference_mode() needed here to prevent issues when saving? logger.debug(f"Loras dir: {loras_dir}") From 32f664def1973491e81c16e6302d3e2a1ccc4e1d Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Thu, 9 Nov 2023 13:19:25 -0600 Subject: [PATCH 11/16] Update defaults --- scripts/main.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/scripts/main.py b/scripts/main.py index 67e9adc0..9ab47c4b 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -1981,12 +1981,10 @@ def build_concept_panel(concept: int): with gr.Tab(label="Sample Images"): save_sample_prompt = gr.Textbox( label="Sample Image Prompt", - placeholder="Leave blank to use instance prompt. " - "Optionally use [filewords] to base " - "sample captions on instance images.", + value='[filewords]' ) gr.HTML( - value="Leave blank or use [filewords] here to randomly select prompts from the existing instance prompt(s).
" + value="A prompt to generate samples from, or use [filewords] here to randomly select prompts from the existing instance prompt(s).
" "If using [filewords], your instance token will be inserted into the file prompts if it is not found.", elem_classes="hintHtml") From f6ab7934e2739c9cfb1e91c12035556d9a96022e Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Thu, 9 Nov 2023 13:52:14 -0600 Subject: [PATCH 12/16] Only show model names in info display --- dreambooth/ui_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dreambooth/ui_functions.py b/dreambooth/ui_functions.py index 18398a9a..c76f574c 100644 --- a/dreambooth/ui_functions.py +++ b/dreambooth/ui_functions.py @@ -651,12 +651,12 @@ def load_model_params(model_name): msg = f"Selected model: '{model_name}'." return ( gradio.update(visible=True), - config.model_dir, + os.path.basename(config.model_dir), config.revision, config.epoch, config.model_type, "True" if config.has_ema and not config.use_lora else "False", - config.src, + os.path.basename(config.src), config.shared_diffusers_path, db_model_snapshots, db_lora_models, From 63f57f41a745d4d02a7a28edee10229cd7f2756b Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Thu, 9 Nov 2023 14:00:59 -0600 Subject: [PATCH 13/16] Update ui_functions.py --- dreambooth/ui_functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dreambooth/ui_functions.py b/dreambooth/ui_functions.py index c76f574c..eb364411 100644 --- a/dreambooth/ui_functions.py +++ b/dreambooth/ui_functions.py @@ -649,6 +649,9 @@ def load_model_params(model_name): loras = get_lora_models(config) db_lora_models = gr_update(choices=loras) msg = f"Selected model: '{model_name}'." + src_name = os.path.basename(config.src) + # Strip the extension + src_name = os.path.splitext(src_name)[0] return ( gradio.update(visible=True), os.path.basename(config.model_dir), @@ -656,7 +659,7 @@ def load_model_params(model_name): config.epoch, config.model_type, "True" if config.has_ema and not config.use_lora else "False", - os.path.basename(config.src), + src_name, config.shared_diffusers_path, db_model_snapshots, db_lora_models, From dd0841db0b4b9f8c2c7bc23fab5eaca53b22f2bc Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Thu, 9 Nov 2023 14:10:59 -0600 Subject: [PATCH 14/16] Catch wandb errors. --- dreambooth/train_dreambooth.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/dreambooth/train_dreambooth.py b/dreambooth/train_dreambooth.py index c55f312f..ec8481ec 100644 --- a/dreambooth/train_dreambooth.py +++ b/dreambooth/train_dreambooth.py @@ -20,7 +20,6 @@ import torch.backends.cuda import torch.backends.cudnn import torch.nn.functional as F -import wandb from accelerate import Accelerator from accelerate.utils.random import set_seed as set_seed2 from diffusers import ( @@ -73,8 +72,13 @@ set_lora_requires_grad, ) -# Disable annoying wandb popup? -wandb.config.auto_init = False +try: + import wandb + + # Disable annoying wandb popup? + wandb.config.auto_init = False +except: + pass logger = logging.getLogger(__name__) # define a Handler which writes DEBUG messages or higher to the sys.stderr From a24c6167df2233dcfcd0564e408c1dda83f9c4b9 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Tue, 14 Nov 2023 09:15:20 -0600 Subject: [PATCH 15/16] KILL IT WITH FIRE --- dreambooth/train_dreambooth.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dreambooth/train_dreambooth.py b/dreambooth/train_dreambooth.py index ec8481ec..40f7255b 100644 --- a/dreambooth/train_dreambooth.py +++ b/dreambooth/train_dreambooth.py @@ -985,6 +985,7 @@ def check_save(is_epoch_check=False): or save_image or save_model ): + disable_safe_unpickle() save_weights( save_image, save_model, @@ -992,6 +993,7 @@ def check_save(is_epoch_check=False): save_checkpoint, save_lora ) + enable_safe_unpickle() return save_model, save_image From e27f707a296e69701b991d802081ad38e2b78067 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Wed, 15 Nov 2023 09:02:45 -0600 Subject: [PATCH 16/16] Manually implement PR #1359 Thanks @Imageder for the contribution! --- dreambooth/diff_to_sd.py | 7 +- dreambooth/sd_to_diff.py | 106 +- dreambooth/train_dreambooth.py | 2696 +++++++++++++++---------------- dreambooth/utils/model_utils.py | 26 +- helpers/image_builder.py | 105 +- lora_diffusion/lora.py | 68 +- 6 files changed, 1508 insertions(+), 1500 deletions(-) diff --git a/dreambooth/diff_to_sd.py b/dreambooth/diff_to_sd.py index fbef6690..3123e55f 100644 --- a/dreambooth/diff_to_sd.py +++ b/dreambooth/diff_to_sd.py @@ -20,7 +20,7 @@ from dreambooth.shared import status from dreambooth.utils.model_utils import unload_system_models, \ reload_system_models, \ - disable_safe_unpickle, enable_safe_unpickle, import_model_class_from_model_name_or_path + safe_unpickle_disabled, import_model_class_from_model_name_or_path from dreambooth.utils.utils import printi from helpers.mytqdm import mytqdm from lora_diffusion.lora import merge_lora_to_model @@ -562,9 +562,8 @@ def load_model(model_path: str, map_location: str): if ".safetensors" in model_path: return safetensors.torch.load_file(model_path, device=map_location) else: - disable_safe_unpickle() - loaded = torch.load(model_path, map_location=map_location) - enable_safe_unpickle() + with safe_unpickle_disabled(): + loaded = torch.load(model_path, map_location=map_location) return loaded diff --git a/dreambooth/sd_to_diff.py b/dreambooth/sd_to_diff.py index ffd11e2f..3467a1be 100644 --- a/dreambooth/sd_to_diff.py +++ b/dreambooth/sd_to_diff.py @@ -25,7 +25,7 @@ from dreambooth import shared from dreambooth.dataclasses.db_config import DreamboothConfig -from dreambooth.utils.model_utils import enable_safe_unpickle, disable_safe_unpickle, unload_system_models, \ +from dreambooth.utils.model_utils import safe_unpickle_disabled, unload_system_models, \ reload_system_models @@ -131,7 +131,6 @@ def extract_checkpoint( # sh.update_status(status) # else: # modules.shared.status.update(status) - disable_safe_unpickle() if image_size is None: image_size = 512 if model_type == "v2x": @@ -162,59 +161,60 @@ def extract_checkpoint( db_config.resolution = image_size db_config.save() try: - if from_safetensors: - if model_type == "SDXL": - pipe = StableDiffusionXLPipeline.from_single_file( - pretrained_model_link_or_path=checkpoint_file, + with safe_unpickle_disabled(): + if from_safetensors: + if model_type == "SDXL": + pipe = StableDiffusionXLPipeline.from_single_file( + pretrained_model_link_or_path=checkpoint_file, + ) + else: + pipe = StableDiffusionPipeline.from_single_file( + pretrained_model_link_or_path=checkpoint_file, + ) + elif model_type == "SDXL": + pipe = StableDiffusionXLPipeline.from_pretrained( + checkpoint_path_or_dict=checkpoint_file, + original_config_file=original_config_file, + image_size=image_size, + prediction_type=prediction_type, + model_type=pipeline_type, + extract_ema=extract_ema, + scheduler_type=scheduler_type, + num_in_channels=num_in_channels, + upcast_attention=upcast_attention, + from_safetensors=from_safetensors, + device=device, + pretrained_model_name_or_path=checkpoint_file, + stable_unclip=stable_unclip, + stable_unclip_prior=stable_unclip_prior, + clip_stats_path=clip_stats_path, + controlnet=controlnet, + vae_path=vae_path, + pipeline_class=pipeline_class, + half=half ) else: - pipe = StableDiffusionPipeline.from_single_file( - pretrained_model_link_or_path=checkpoint_file, + pipe = StableDiffusionPipeline.from_pretrained( + checkpoint_path_or_dict=checkpoint_file, + original_config_file=original_config_file, + image_size=image_size, + prediction_type=prediction_type, + model_type=pipeline_type, + extract_ema=extract_ema, + scheduler_type=scheduler_type, + num_in_channels=num_in_channels, + upcast_attention=upcast_attention, + from_safetensors=from_safetensors, + device=device, + pretrained_model_name_or_path=checkpoint_file, + stable_unclip=stable_unclip, + stable_unclip_prior=stable_unclip_prior, + clip_stats_path=clip_stats_path, + controlnet=controlnet, + vae_path=vae_path, + pipeline_class=pipeline_class, + half=half ) - elif model_type == "SDXL": - pipe = StableDiffusionXLPipeline.from_pretrained( - checkpoint_path_or_dict=checkpoint_file, - original_config_file=original_config_file, - image_size=image_size, - prediction_type=prediction_type, - model_type=pipeline_type, - extract_ema=extract_ema, - scheduler_type=scheduler_type, - num_in_channels=num_in_channels, - upcast_attention=upcast_attention, - from_safetensors=from_safetensors, - device=device, - pretrained_model_name_or_path=checkpoint_file, - stable_unclip=stable_unclip, - stable_unclip_prior=stable_unclip_prior, - clip_stats_path=clip_stats_path, - controlnet=controlnet, - vae_path=vae_path, - pipeline_class=pipeline_class, - half=half - ) - else: - pipe = StableDiffusionPipeline.from_pretrained( - checkpoint_path_or_dict=checkpoint_file, - original_config_file=original_config_file, - image_size=image_size, - prediction_type=prediction_type, - model_type=pipeline_type, - extract_ema=extract_ema, - scheduler_type=scheduler_type, - num_in_channels=num_in_channels, - upcast_attention=upcast_attention, - from_safetensors=from_safetensors, - device=device, - pretrained_model_name_or_path=checkpoint_file, - stable_unclip=stable_unclip, - stable_unclip_prior=stable_unclip_prior, - clip_stats_path=clip_stats_path, - controlnet=controlnet, - vae_path=vae_path, - pipeline_class=pipeline_class, - half=half - ) dump_path = db_config.get_pretrained_model_name_or_path() if controlnet: @@ -246,7 +246,7 @@ def extract_checkpoint( print(f"Couldn't find {full_path}") break remove_dirs = ["logging", "samples"] - enable_safe_unpickle() + reload_system_models() if success: for rd in remove_dirs: diff --git a/dreambooth/train_dreambooth.py b/dreambooth/train_dreambooth.py index 40f7255b..dadbec73 100644 --- a/dreambooth/train_dreambooth.py +++ b/dreambooth/train_dreambooth.py @@ -55,8 +55,7 @@ from dreambooth.utils.model_utils import ( unload_system_models, import_model_class_from_model_name_or_path, - disable_safe_unpickle, - enable_safe_unpickle, + safe_unpickle_disabled, xformerify, torch2ify ) @@ -312,1456 +311,1485 @@ def create_vae(): if args.pretrained_vae_name_or_path else args.get_pretrained_model_name_or_path() ) - disable_safe_unpickle() - new_vae = AutoencoderKL.from_pretrained( - vae_path, - subfolder=None if args.pretrained_vae_name_or_path else "vae", - revision=args.revision, - ) - enable_safe_unpickle() + with safe_unpickle_disabled(): + new_vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder=None if args.pretrained_vae_name_or_path else "vae", + revision=args.revision, + ) new_vae.requires_grad_(False) new_vae.to(accelerator.device, dtype=weight_dtype) return new_vae - disable_safe_unpickle() - # Load the tokenizer - pbar2.set_description("Loading tokenizer...") - pbar2.update() - pbar2.set_postfix(refresh=True) - tokenizer = AutoTokenizer.from_pretrained( - os.path.join(pretrained_path, "tokenizer"), - revision=args.revision, - use_fast=False, - ) - - tokenizer_two = None - if args.model_type == "SDXL": - pbar2.set_description("Loading tokenizer 2...") + with safe_unpickle_disabled(): + # Load the tokenizer + pbar2.set_description("Loading tokenizer...") pbar2.update() pbar2.set_postfix(refresh=True) - tokenizer_two = AutoTokenizer.from_pretrained( - os.path.join(pretrained_path, "tokenizer_2"), + tokenizer = AutoTokenizer.from_pretrained( + os.path.join(pretrained_path, "tokenizer"), revision=args.revision, use_fast=False, ) - # import correct text encoder class - text_encoder_cls = import_model_class_from_model_name_or_path( - args.get_pretrained_model_name_or_path(), args.revision - ) - - pbar2.set_description("Loading text encoder...") - pbar2.update() - pbar2.set_postfix(refresh=True) - # Load models and create wrapper for stable diffusion - text_encoder = text_encoder_cls.from_pretrained( - args.get_pretrained_model_name_or_path(), - subfolder="text_encoder", - revision=args.revision, - torch_dtype=torch.float32, - ) + tokenizer_two = None + if args.model_type == "SDXL": + pbar2.set_description("Loading tokenizer 2...") + pbar2.update() + pbar2.set_postfix(refresh=True) + tokenizer_two = AutoTokenizer.from_pretrained( + os.path.join(pretrained_path, "tokenizer_2"), + revision=args.revision, + use_fast=False, + ) - if args.model_type == "SDXL": # import correct text encoder class - text_encoder_cls_two = import_model_class_from_model_name_or_path( - args.get_pretrained_model_name_or_path(), args.revision, subfolder="text_encoder_2" + text_encoder_cls = import_model_class_from_model_name_or_path( + args.get_pretrained_model_name_or_path(), args.revision ) - pbar2.set_description("Loading text encoder 2...") + pbar2.set_description("Loading text encoder...") pbar2.update() pbar2.set_postfix(refresh=True) # Load models and create wrapper for stable diffusion - text_encoder_two = text_encoder_cls_two.from_pretrained( + text_encoder = text_encoder_cls.from_pretrained( args.get_pretrained_model_name_or_path(), - subfolder="text_encoder_2", + subfolder="text_encoder", revision=args.revision, torch_dtype=torch.float32, ) - printm("Created tenc") - pbar2.set_description("Loading VAE...") - pbar2.update() - vae = create_vae() - printm("Created vae") - - pbar2.set_description("Loading unet...") - pbar2.update() - unet = UNet2DConditionModel.from_pretrained( - args.get_pretrained_model_name_or_path(), - subfolder="unet", - revision=args.revision, - torch_dtype=torch.float32, - ) - - if args.attention == "xformers" and not shared.force_cpu: - xformerify(unet, use_lora=args.use_lora) - xformerify(vae, use_lora=args.use_lora) - - unet = torch2ify(unet) - - if args.full_mixed_precision: - if args.mixed_precision == "fp16": - patch_accelerator_for_fp16_training(accelerator) - unet.to(accelerator.device, dtype=weight_dtype) - else: - # Check that all trainable models are in full precision - low_precision_error_string = ( - "Please make sure to always have all model weights in full float32 precision when starting training - " - "even if doing mixed precision training. copy of the weights should still be float32." - ) - - if accelerator.unwrap_model(unet).dtype != torch.float32: - logger.warning( - f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" + if args.model_type == "SDXL": + # import correct text encoder class + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.get_pretrained_model_name_or_path(), args.revision, subfolder="text_encoder_2" ) - if ( - args.stop_text_encoder != 0 - and accelerator.unwrap_model(text_encoder).dtype != torch.float32 - ): - logger.warning( - f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}." - f" {low_precision_error_string}" - ) - - if ( - args.stop_text_encoder != 0 - and accelerator.unwrap_model(text_encoder_two).dtype != torch.float32 - ): - logger.warning( - f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder_two).dtype}." - f" {low_precision_error_string}" - ) - - if args.gradient_checkpointing: - if args.train_unet: - unet.enable_gradient_checkpointing() - if stop_text_percentage != 0: - text_encoder.gradient_checkpointing_enable() - if args.model_type == "SDXL": - text_encoder_two.gradient_checkpointing_enable() - if args.use_lora: - # We need to enable gradients on an input for gradient checkpointing to work - # This will not be optimized because it is not a param to optimizer - text_encoder.text_model.embeddings.position_embedding.requires_grad_(True) - if args.model_type == "SDXL": - text_encoder_two.text_model.embeddings.position_embedding.requires_grad_(True) - else: - text_encoder.to(accelerator.device, dtype=weight_dtype) - if args.model_type == "SDXL": - text_encoder_two.to(accelerator.device, dtype=weight_dtype) - - ema_model = None - if args.use_ema: - if os.path.exists( - os.path.join( - args.get_pretrained_model_name_or_path(), - "ema_unet", - "diffusion_pytorch_model.safetensors", - ) - ): - ema_unet = UNet2DConditionModel.from_pretrained( + pbar2.set_description("Loading text encoder 2...") + pbar2.update() + pbar2.set_postfix(refresh=True) + # Load models and create wrapper for stable diffusion + text_encoder_two = text_encoder_cls_two.from_pretrained( args.get_pretrained_model_name_or_path(), - subfolder="ema_unet", + subfolder="text_encoder_2", revision=args.revision, - torch_dtype=weight_dtype, + torch_dtype=torch.float32, ) - if args.attention == "xformers" and not shared.force_cpu: - xformerify(ema_unet, use_lora=args.use_lora) - ema_model = EMAModel( - ema_unet, device=accelerator.device, dtype=weight_dtype - ) - del ema_unet + printm("Created tenc") + pbar2.set_description("Loading VAE...") + pbar2.update() + vae = create_vae() + printm("Created vae") + + pbar2.set_description("Loading unet...") + pbar2.update() + unet = UNet2DConditionModel.from_pretrained( + args.get_pretrained_model_name_or_path(), + subfolder="unet", + revision=args.revision, + torch_dtype=torch.float32, + ) + + if args.attention == "xformers" and not shared.force_cpu: + xformerify(unet, use_lora=args.use_lora) + xformerify(vae, use_lora=args.use_lora) + + unet = torch2ify(unet) + + if args.full_mixed_precision: + if args.mixed_precision == "fp16": + patch_accelerator_for_fp16_training(accelerator) + unet.to(accelerator.device, dtype=weight_dtype) else: - ema_model = EMAModel( - unet, device=accelerator.device, dtype=weight_dtype + # Check that all trainable models are in full precision + low_precision_error_string = ( + "Please make sure to always have all model weights in full float32 precision when starting training - " + "even if doing mixed precision training. copy of the weights should still be float32." ) - # Create shared unet/tenc learning rate variables - - learning_rate = args.learning_rate - txt_learning_rate = args.txt_learning_rate - if args.use_lora: - learning_rate = args.lora_learning_rate - txt_learning_rate = args.lora_txt_learning_rate - - if args.use_lora or not args.train_unet: - unet.requires_grad_(False) - - unet_lora_params = None - - if args.use_lora: - pbar2.reset(1) - pbar2.set_description("Loading LoRA...") - # now we will add new LoRA weights to the attention layers - # Set correct lora layers - unet_lora_attn_procs = {} - unet_lora_params = [] - rank = args.lora_unet_rank - - for name, attn_processor in unet.attn_processors.items(): - cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim - hidden_size = None - if name.startswith("mid_block"): - hidden_size = unet.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = unet.config.block_out_channels[block_id] - - lora_attn_processor_class = ( - LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor - ) - if hidden_size is None: - logger.warning(f"Could not find hidden size for {name}. Skipping...") - continue - module = lora_attn_processor_class( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank - ) - unet_lora_attn_procs[name] = module - unet_lora_params.extend(module.parameters()) + if accelerator.unwrap_model(unet).dtype != torch.float32: + logger.warning( + f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" + ) - unet.set_attn_processor(unet_lora_attn_procs) + if ( + args.stop_text_encoder != 0 + and accelerator.unwrap_model(text_encoder).dtype != torch.float32 + ): + logger.warning( + f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}." + f" {low_precision_error_string}" + ) - # The text encoder comes from 🤗 transformers, so we cannot directly modify it. - # So, instead, we monkey-patch the forward calls of its attention-blocks. - if stop_text_percentage != 0: - # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - text_encoder_lora_params = LoraLoaderMixin._modify_text_encoder( - text_encoder, dtype=torch.float32, rank=args.lora_txt_rank - ) + if ( + args.stop_text_encoder != 0 + and accelerator.unwrap_model(text_encoder_two).dtype != torch.float32 + ): + logger.warning( + f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder_two).dtype}." + f" {low_precision_error_string}" + ) + + if args.gradient_checkpointing: + if args.train_unet: + unet.enable_gradient_checkpointing() + if stop_text_percentage != 0: + text_encoder.gradient_checkpointing_enable() + if args.model_type == "SDXL": + text_encoder_two.gradient_checkpointing_enable() + if args.use_lora: + # We need to enable gradients on an input for gradient checkpointing to work + # This will not be optimized because it is not a param to optimizer + text_encoder.text_model.embeddings.position_embedding.requires_grad_(True) + if args.model_type == "SDXL": + text_encoder_two.text_model.embeddings.position_embedding.requires_grad_(True) + else: + text_encoder.to(accelerator.device, dtype=weight_dtype) + if args.model_type == "SDXL": + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + ema_model = None + if args.use_ema: + if os.path.exists( + os.path.join( + args.get_pretrained_model_name_or_path(), + "ema_unet", + "diffusion_pytorch_model.safetensors", + ) + ): + ema_unet = UNet2DConditionModel.from_pretrained( + args.get_pretrained_model_name_or_path(), + subfolder="ema_unet", + revision=args.revision, + torch_dtype=weight_dtype, + ) + if args.attention == "xformers" and not shared.force_cpu: + xformerify(ema_unet, use_lora=args.use_lora) - if args.model_type == "SDXL": - text_encoder_lora_params_two = LoraLoaderMixin._modify_text_encoder( - text_encoder_two, dtype=torch.float32, rank=args.lora_txt_rank + ema_model = EMAModel( + ema_unet, device=accelerator.device, dtype=weight_dtype ) - params_to_optimize = ( - itertools.chain(unet_lora_params, text_encoder_lora_params, text_encoder_lora_params_two)) + del ema_unet else: - params_to_optimize = (itertools.chain(unet_lora_params, text_encoder_lora_params)) + ema_model = EMAModel( + unet, device=accelerator.device, dtype=weight_dtype + ) - else: - params_to_optimize = unet_lora_params + # Create shared unet/tenc learning rate variables - # Load LoRA weights if specified - if args.lora_model_name is not None and args.lora_model_name != "": - logger.debug(f"Load lora from {args.lora_model_name}") - lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(args.lora_model_name) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet) + learning_rate = args.learning_rate + txt_learning_rate = args.txt_learning_rate + if args.use_lora: + learning_rate = args.lora_learning_rate + txt_learning_rate = args.lora_txt_learning_rate - LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder) - if text_encoder_two is not None: - LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two) + if args.use_lora or not args.train_unet: + unet.requires_grad_(False) + unet_lora_params = None + + if args.use_lora: + pbar2.reset(1) + pbar2.set_description("Loading LoRA...") + # now we will add new LoRA weights to the attention layers + # Set correct lora layers + unet_lora_attn_procs = {} + unet_lora_params = [] + rank = args.lora_unet_rank + + for name, attn_processor in unet.attn_processors.items(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + hidden_size = None + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + if hidden_size is None: + logger.warning(f"Could not find hidden size for {name}. Skipping...") + continue + module = lora_attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + ) + unet_lora_attn_procs[name] = module + unet_lora_params.extend(module.parameters()) + + unet.set_attn_processor(unet_lora_attn_procs) + + # The text encoder comes from 🤗 transformers, so we cannot directly modify it. + # So, instead, we monkey-patch the forward calls of its attention-blocks. + if stop_text_percentage != 0: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + text_encoder_lora_params = LoraLoaderMixin._modify_text_encoder( + text_encoder, dtype=torch.float32, rank=args.lora_txt_rank + ) + + if args.model_type == "SDXL": + text_encoder_lora_params_two = LoraLoaderMixin._modify_text_encoder( + text_encoder_two, dtype=torch.float32, rank=args.lora_txt_rank + ) + params_to_optimize = ( + itertools.chain(unet_lora_params, text_encoder_lora_params, text_encoder_lora_params_two)) + else: + params_to_optimize = (itertools.chain(unet_lora_params, text_encoder_lora_params)) - elif stop_text_percentage != 0: - if args.train_unet: - if args.model_type == "SDXL": - params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters(), - text_encoder_two.parameters()) - else: - params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters()) - else: - if args.model_type == "SDXL": - params_to_optimize = itertools.chain(text_encoder.parameters(), text_encoder_two.parameters()) else: - params_to_optimize = itertools.chain(text_encoder.parameters()) - else: - params_to_optimize = unet.parameters() + params_to_optimize = unet_lora_params - optimizer = get_optimizer(args.optimizer, learning_rate, args.weight_decay, params_to_optimize) - if len(optimizer.param_groups) > 1: - try: - optimizer.param_groups[1]["weight_decay"] = args.tenc_weight_decay - optimizer.param_groups[1]["grad_clip_norm"] = args.tenc_grad_clip_norm - except: - logger.warning("Exception setting tenc weight decay") - traceback.print_exc() + # Load LoRA weights if specified + if args.lora_model_name is not None and args.lora_model_name != "": + logger.debug(f"Load lora from {args.lora_model_name}") + lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(args.lora_model_name) + LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet) - if len(optimizer.param_groups) > 2: - try: - optimizer.param_groups[2]["weight_decay"] = args.tenc_weight_decay - optimizer.param_groups[2]["grad_clip_norm"] = args.tenc_grad_clip_norm - except: - logger.warning("Exception setting tenc weight decay") - traceback.print_exc() + LoraLoaderMixin.load_lora_into_text_encoder( + lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder) + if text_encoder_two is not None: + LoraLoaderMixin.load_lora_into_text_encoder( + lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two) - noise_scheduler = get_noise_scheduler(args) - global to_delete - to_delete = [unet, text_encoder, text_encoder_two, tokenizer, tokenizer_two, optimizer, vae] - def cleanup_memory(): - try: - if unet: - del unet - if text_encoder: - del text_encoder - if text_encoder_two: - del text_encoder_two - if tokenizer: - del tokenizer - if tokenizer_two: - del tokenizer_two - if optimizer: - del optimizer - if train_dataloader: - del train_dataloader - if train_dataset: - del train_dataset - if lr_scheduler: - del lr_scheduler - if vae: - del vae - if unet_lora_params: - del unet_lora_params - except: - pass - cleanup(True) - if args.cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() + elif stop_text_percentage != 0: + if args.train_unet: + if args.model_type == "SDXL": + params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters(), + text_encoder_two.parameters()) + else: + params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters()) + else: + if args.model_type == "SDXL": + params_to_optimize = itertools.chain(text_encoder.parameters(), text_encoder_two.parameters()) + else: + params_to_optimize = itertools.chain(text_encoder.parameters()) + else: + params_to_optimize = unet.parameters() - if status.interrupted: - result.msg = "Training interrupted." - stop_profiler(profiler) - return result + optimizer = get_optimizer(args.optimizer, learning_rate, args.weight_decay, params_to_optimize) + if len(optimizer.param_groups) > 1: + try: + optimizer.param_groups[1]["weight_decay"] = args.tenc_weight_decay + optimizer.param_groups[1]["grad_clip_norm"] = args.tenc_grad_clip_norm + except: + logger.warning("Exception setting tenc weight decay") + traceback.print_exc() - printm("Loading dataset...") - pbar2.reset() - pbar2.set_description("Loading dataset") - - with_prior_preservation = False - tokenizers = [tokenizer] if tokenizer_two is None else [tokenizer, tokenizer_two] - text_encoders = [text_encoder] if text_encoder_two is None else [text_encoder, text_encoder_two] - train_dataset = generate_dataset( - model_name=args.model_name, - instance_prompts=instance_prompts, - class_prompts=class_prompts, - batch_size=args.train_batch_size, - tokenizer=tokenizers, - text_encoder=text_encoders, - accelerator=accelerator, - vae=vae if args.cache_latents else None, - debug=False, - model_dir=args.model_dir, - max_token_length=args.max_token_length, - pbar=pbar2 - ) - if train_dataset.class_count > 0: - with_prior_preservation = True - pbar2.reset() - printm("Dataset loaded.") - tokenizer_max_length = tokenizer.model_max_length - if args.cache_latents: - printm("Unloading vae.") - del vae - # Preserve reference to vae for later checks - vae = None - # TODO: Try unloading tokenizers here? - del tokenizer - if tokenizer_two is not None: - del tokenizer_two - tokenizer = None - tokenizer2 = None + if len(optimizer.param_groups) > 2: + try: + optimizer.param_groups[2]["weight_decay"] = args.tenc_weight_decay + optimizer.param_groups[2]["grad_clip_norm"] = args.tenc_grad_clip_norm + except: + logger.warning("Exception setting tenc weight decay") + traceback.print_exc() - if status.interrupted: - result.msg = "Training interrupted." - stop_profiler(profiler) - return result + noise_scheduler = get_noise_scheduler(args) + global to_delete + to_delete = [unet, text_encoder, text_encoder_two, tokenizer, tokenizer_two, optimizer, vae] + def cleanup_memory(): + try: + if unet: + del unet + if text_encoder: + del text_encoder + if text_encoder_two: + del text_encoder_two + if tokenizer: + del tokenizer + if tokenizer_two: + del tokenizer_two + if optimizer: + del optimizer + if train_dataloader: + del train_dataloader + if train_dataset: + del train_dataset + if lr_scheduler: + del lr_scheduler + if vae: + del vae + if unet_lora_params: + del unet_lora_params + except: + pass + cleanup(True) - if train_dataset.__len__ == 0: - msg = "Please provide a directory with actual images in it." - logger.warning(msg) - status.textinfo = msg - update_status({"status": status}) - cleanup_memory() - result.msg = msg - result.config = args - stop_profiler(profiler) - return result + if args.cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() - def collate_fn_db(examples): - input_ids = [example["input_ids"] for example in examples] - pixel_values = [example["image"] for example in examples] - types = [example["is_class"] for example in examples] - weights = [ - current_prior_loss_weight if example["is_class"] else 1.0 - for example in examples - ] - loss_avg = 0 - for weight in weights: - loss_avg += weight - loss_avg /= len(weights) - pixel_values = torch.stack(pixel_values) - if not args.cache_latents: - pixel_values = pixel_values.to( - memory_format=torch.contiguous_format - ).float() - input_ids = torch.cat(input_ids, dim=0) - - batch_data = { - "input_ids": input_ids, - "images": pixel_values, - "types": types, - "loss_avg": loss_avg, - } - if "input_ids2" in examples[0]: - input_ids_2 = [example["input_ids2"] for example in examples] - input_ids_2 = torch.stack(input_ids_2) - - batch_data["input_ids2"] = input_ids_2 - batch_data["original_sizes_hw"] = torch.stack( - [torch.LongTensor(x["original_sizes_hw"]) for x in examples]) - batch_data["crop_top_lefts"] = torch.stack([torch.LongTensor(x["crop_top_lefts"]) for x in examples]) - batch_data["target_sizes_hw"] = torch.stack([torch.LongTensor(x["target_sizes_hw"]) for x in examples]) - return batch_data - - def collate_fn_sdxl(examples): - input_ids = [example["input_ids"] for example in examples if not example["is_class"]] - pixel_values = [example["image"] for example in examples if not example["is_class"]] - add_text_embeds = [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples if - not example["is_class"]] - add_time_ids = [example["instance_added_cond_kwargs"]["time_ids"] for example in examples if - not example["is_class"]] - - # Concat class and instance examples for prior preservation. - # We do this to avoid doing two forward passes. - if with_prior_preservation: - input_ids += [example["input_ids"] for example in examples if example["is_class"]] - pixel_values += [example["image"] for example in examples if example["is_class"]] - add_text_embeds += [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples if - example["is_class"]] - add_time_ids += [example["instance_added_cond_kwargs"]["time_ids"] for example in examples if - example["is_class"]] - - pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - - input_ids = torch.cat(input_ids, dim=0) - add_text_embeds = torch.cat(add_text_embeds, dim=0) - add_time_ids = torch.cat(add_time_ids, dim=0) - - batch = { - "input_ids": input_ids, - "images": pixel_values, - "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids}, - } + if status.interrupted: + result.msg = "Training interrupted." + stop_profiler(profiler) + return result - return batch + printm("Loading dataset...") + pbar2.reset() + pbar2.set_description("Loading dataset") + + with_prior_preservation = False + tokenizers = [tokenizer] if tokenizer_two is None else [tokenizer, tokenizer_two] + text_encoders = [text_encoder] if text_encoder_two is None else [text_encoder, text_encoder_two] + train_dataset = generate_dataset( + model_name=args.model_name, + instance_prompts=instance_prompts, + class_prompts=class_prompts, + batch_size=args.train_batch_size, + tokenizer=tokenizers, + text_encoder=text_encoders, + accelerator=accelerator, + vae=vae if args.cache_latents else None, + debug=False, + model_dir=args.model_dir, + max_token_length=args.max_token_length, + pbar=pbar2 + ) + if train_dataset.class_count > 0: + with_prior_preservation = True + pbar2.reset() + printm("Dataset loaded.") + tokenizer_max_length = tokenizer.model_max_length + if args.cache_latents: + printm("Unloading vae.") + del vae + # Preserve reference to vae for later checks + vae = None + # TODO: Try unloading tokenizers here? + del tokenizer + if tokenizer_two is not None: + del tokenizer_two + tokenizer = None + tokenizer2 = None + + if status.interrupted: + result.msg = "Training interrupted." + stop_profiler(profiler) + return result + + if train_dataset.__len__ == 0: + msg = "Please provide a directory with actual images in it." + logger.warning(msg) + status.textinfo = msg + update_status({"status": status}) + cleanup_memory() + result.msg = msg + result.config = args + stop_profiler(profiler) + return result + + def collate_fn_db(examples): + input_ids = [example["input_ids"] for example in examples] + pixel_values = [example["image"] for example in examples] + types = [example["is_class"] for example in examples] + weights = [ + current_prior_loss_weight if example["is_class"] else 1.0 + for example in examples + ] + loss_avg = 0 + for weight in weights: + loss_avg += weight + loss_avg /= len(weights) + pixel_values = torch.stack(pixel_values) + if not args.cache_latents: + pixel_values = pixel_values.to( + memory_format=torch.contiguous_format + ).float() + input_ids = torch.cat(input_ids, dim=0) + + batch_data = { + "input_ids": input_ids, + "images": pixel_values, + "types": types, + "loss_avg": loss_avg, + } + if "input_ids2" in examples[0]: + input_ids_2 = [example["input_ids2"] for example in examples] + input_ids_2 = torch.stack(input_ids_2) + + batch_data["input_ids2"] = input_ids_2 + batch_data["original_sizes_hw"] = torch.stack( + [torch.LongTensor(x["original_sizes_hw"]) for x in examples]) + batch_data["crop_top_lefts"] = torch.stack([torch.LongTensor(x["crop_top_lefts"]) for x in examples]) + batch_data["target_sizes_hw"] = torch.stack([torch.LongTensor(x["target_sizes_hw"]) for x in examples]) + return batch_data + + def collate_fn_sdxl(examples): + input_ids = [example["input_ids"] for example in examples if not example["is_class"]] + pixel_values = [example["image"] for example in examples if not example["is_class"]] + add_text_embeds = [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples if + not example["is_class"]] + add_time_ids = [example["instance_added_cond_kwargs"]["time_ids"] for example in examples if + not example["is_class"]] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["input_ids"] for example in examples if example["is_class"]] + pixel_values += [example["image"] for example in examples if example["is_class"]] + add_text_embeds += [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples if + example["is_class"]] + add_time_ids += [example["instance_added_cond_kwargs"]["time_ids"] for example in examples if + example["is_class"]] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + add_text_embeds = torch.cat(add_text_embeds, dim=0) + add_time_ids = torch.cat(add_time_ids, dim=0) + + batch = { + "input_ids": input_ids, + "images": pixel_values, + "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids}, + } - sampler = BucketSampler(train_dataset, train_batch_size) + return batch - collate_fn = collate_fn_db - if args.model_type == "SDXL": - collate_fn = collate_fn_sdxl - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=1, - batch_sampler=sampler, - collate_fn=collate_fn, - num_workers=n_workers, - ) + sampler = BucketSampler(train_dataset, train_batch_size) - max_train_steps = args.num_train_epochs * len(train_dataset) - - # This is separate, because optimizer.step is only called once per "step" in training, so it's not - # affected by batch size - sched_train_steps = args.num_train_epochs * train_dataset.num_train_images - - lr_scale_pos = args.lr_scale_pos - if class_prompts: - lr_scale_pos *= 2 - - lr_scheduler = UniversalScheduler( - name=args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps, - total_training_steps=sched_train_steps, - min_lr=args.learning_rate_min, - total_epochs=args.num_train_epochs, - num_cycles=args.lr_cycles, - power=args.lr_power, - factor=args.lr_factor, - scale_pos=lr_scale_pos, - unet_lr=learning_rate, - tenc_lr=txt_learning_rate, - ) + collate_fn = collate_fn_db + if args.model_type == "SDXL": + collate_fn = collate_fn_sdxl + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=1, + batch_sampler=sampler, + collate_fn=collate_fn, + num_workers=n_workers, + ) - # create ema, fix OOM - if args.use_ema: - if stop_text_percentage != 0: - ( - ema_model.model, - unet, - text_encoder, - optimizer, - train_dataloader, - lr_scheduler, - ) = accelerator.prepare( - ema_model.model, - unet, - text_encoder, - optimizer, - train_dataloader, - lr_scheduler, - ) - else: - ( - ema_model.model, - unet, - optimizer, - train_dataloader, - lr_scheduler, - ) = accelerator.prepare( - ema_model.model, unet, optimizer, train_dataloader, lr_scheduler - ) - else: - if stop_text_percentage != 0: - ( - unet, - text_encoder, - optimizer, - train_dataloader, - lr_scheduler, - ) = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) + max_train_steps = args.num_train_epochs * len(train_dataset) + + # This is separate, because optimizer.step is only called once per "step" in training, so it's not + # affected by batch size + sched_train_steps = args.num_train_epochs * train_dataset.num_train_images + + lr_scale_pos = args.lr_scale_pos + if class_prompts: + lr_scale_pos *= 2 + + lr_scheduler = UniversalScheduler( + name=args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps, + total_training_steps=sched_train_steps, + min_lr=args.learning_rate_min, + total_epochs=args.num_train_epochs, + num_cycles=args.lr_cycles, + power=args.lr_power, + factor=args.lr_factor, + scale_pos=lr_scale_pos, + unet_lr=learning_rate, + tenc_lr=txt_learning_rate, + ) + + # create ema, fix OOM + if args.use_ema: + if stop_text_percentage != 0: + ( + ema_model.model, + unet, + text_encoder, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + ema_model.model, + unet, + text_encoder, + optimizer, + train_dataloader, + lr_scheduler, + ) + else: + ( + ema_model.model, + unet, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + ema_model.model, unet, optimizer, train_dataloader, lr_scheduler + ) else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler - ) + if stop_text_percentage != 0: + ( + unet, + text_encoder, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) - if not args.cache_latents and vae is not None: - vae.to(accelerator.device, dtype=weight_dtype) + if not args.cache_latents and vae is not None: + vae.to(accelerator.device, dtype=weight_dtype) - if stop_text_percentage == 0: - text_encoder.to(accelerator.device, dtype=weight_dtype) - # Afterwards we recalculate our number of training epochs - # We need to initialize the trackers we use, and also store our configuration. - # The trackers will initialize automatically on the main process. - if accelerator.is_main_process: - accelerator.init_trackers("dreambooth") + if stop_text_percentage == 0: + text_encoder.to(accelerator.device, dtype=weight_dtype) + # Afterwards we recalculate our number of training epochs + # We need to initialize the trackers we use, and also store our configuration. + # The trackers will initialize automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth") - # Train! - total_batch_size = ( - train_batch_size * accelerator.num_processes * gradient_accumulation_steps - ) - max_train_epochs = args.num_train_epochs - # we calculate our number of tenc training epochs - text_encoder_epochs = round(max_train_epochs * stop_text_percentage) - global_step = 0 - global_epoch = 0 - session_epoch = 0 - first_epoch = 0 - resume_step = 0 - last_model_save = 0 - last_image_save = 0 - resume_from_checkpoint = False - new_hotness = os.path.join( - args.model_dir, "checkpoints", f"checkpoint-{args.snapshot}" - ) - if os.path.exists(new_hotness): - logger.debug(f"Resuming from checkpoint {new_hotness}") + # Train! + total_batch_size = ( + train_batch_size * accelerator.num_processes * gradient_accumulation_steps + ) + max_train_epochs = args.num_train_epochs + # we calculate our number of tenc training epochs + text_encoder_epochs = round(max_train_epochs * stop_text_percentage) + global_step = 0 + global_epoch = 0 + session_epoch = 0 + first_epoch = 0 + resume_step = 0 + last_model_save = 0 + last_image_save = 0 + resume_from_checkpoint = False + new_hotness = os.path.join( + args.model_dir, "checkpoints", f"checkpoint-{args.snapshot}" + ) + if os.path.exists(new_hotness): + logger.debug(f"Resuming from checkpoint {new_hotness}") - try: - import modules.shared - no_safe = modules.shared.cmd_opts.disable_safe_unpickle - modules.shared.cmd_opts.disable_safe_unpickle = True - except: - no_safe = False + try: + import modules.shared + no_safe = modules.shared.cmd_opts.disable_safe_unpickle + modules.shared.cmd_opts.disable_safe_unpickle = True + except: + no_safe = False - try: - import modules.shared - accelerator.load_state(new_hotness) - modules.shared.cmd_opts.disable_safe_unpickle = no_safe - global_step = resume_step = args.revision - resume_from_checkpoint = True - first_epoch = args.lifetime_epoch - global_epoch = args.lifetime_epoch - except Exception as lex: - logger.warning(f"Exception loading checkpoint: {lex}") - logger.debug(" ***** Running training *****") - if shared.force_cpu: - logger.debug(f" TRAINING WITH CPU ONLY") - logger.debug(f" Num batches each epoch = {len(train_dataset) // train_batch_size}") - logger.debug(f" Num Epochs = {max_train_epochs}") - logger.debug(f" Batch Size Per Device = {train_batch_size}") - logger.debug(f" Gradient Accumulation steps = {gradient_accumulation_steps}") - logger.debug(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.debug(f" Text Encoder Epochs: {text_encoder_epochs}") - logger.debug(f" Total optimization steps = {sched_train_steps}") - logger.debug(f" Total training steps = {max_train_steps}") - logger.debug(f" Resuming from checkpoint: {resume_from_checkpoint}") - logger.debug(f" First resume epoch: {first_epoch}") - logger.debug(f" First resume step: {resume_step}") - logger.debug(f" Lora: {args.use_lora}, Optimizer: {args.optimizer}, Prec: {precision}") - logger.debug(f" Gradient Checkpointing: {args.gradient_checkpointing}") - logger.debug(f" EMA: {args.use_ema}") - logger.debug(f" UNET: {args.train_unet}") - logger.debug(f" Freeze CLIP Normalization Layers: {args.freeze_clip_normalization}") - logger.debug(f" LR{' (Lora)' if args.use_lora else ''}: {learning_rate}") - if stop_text_percentage > 0: - logger.debug(f" Tenc LR{' (Lora)' if args.use_lora else ''}: {txt_learning_rate}") - logger.debug(f" V2: {args.v2}") - - os.environ.__setattr__("CUDA_LAUNCH_BLOCKING", 1) - - def check_save(is_epoch_check=False): - nonlocal last_model_save - nonlocal last_image_save - save_model_interval = args.save_embedding_every - save_image_interval = args.save_preview_every - save_completed = session_epoch >= max_train_epochs - save_canceled = status.interrupted - save_image = False - save_model = False - save_lora = False - - if not save_canceled and not save_completed: - # Check to see if the number of epochs since last save is gt the interval - if 0 < save_model_interval <= session_epoch - last_model_save: - save_model = True - if args.use_lora: - save_lora = True - last_model_save = session_epoch + try: + import modules.shared + accelerator.load_state(new_hotness) + modules.shared.cmd_opts.disable_safe_unpickle = no_safe + global_step = resume_step = args.revision + resume_from_checkpoint = True + first_epoch = args.lifetime_epoch + global_epoch = args.lifetime_epoch + except Exception as lex: + logger.warning(f"Exception loading checkpoint: {lex}") + logger.debug(" ***** Running training *****") + if shared.force_cpu: + logger.debug(f" TRAINING WITH CPU ONLY") + logger.debug(f" Num batches each epoch = {len(train_dataset) // train_batch_size}") + logger.debug(f" Num Epochs = {max_train_epochs}") + logger.debug(f" Batch Size Per Device = {train_batch_size}") + logger.debug(f" Gradient Accumulation steps = {gradient_accumulation_steps}") + logger.debug(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.debug(f" Text Encoder Epochs: {text_encoder_epochs}") + logger.debug(f" Total optimization steps = {sched_train_steps}") + logger.debug(f" Total training steps = {max_train_steps}") + logger.debug(f" Resuming from checkpoint: {resume_from_checkpoint}") + logger.debug(f" First resume epoch: {first_epoch}") + logger.debug(f" First resume step: {resume_step}") + logger.debug(f" Lora: {args.use_lora}, Optimizer: {args.optimizer}, Prec: {precision}") + logger.debug(f" Gradient Checkpointing: {args.gradient_checkpointing}") + logger.debug(f" EMA: {args.use_ema}") + logger.debug(f" UNET: {args.train_unet}") + logger.debug(f" Freeze CLIP Normalization Layers: {args.freeze_clip_normalization}") + logger.debug(f" LR{' (Lora)' if args.use_lora else ''}: {learning_rate}") + if stop_text_percentage > 0: + logger.debug(f" Tenc LR{' (Lora)' if args.use_lora else ''}: {txt_learning_rate}") + logger.debug(f" V2: {args.v2}") + + os.environ.__setattr__("CUDA_LAUNCH_BLOCKING", 1) + + def check_save(is_epoch_check=False): + nonlocal last_model_save + nonlocal last_image_save + save_model_interval = args.save_embedding_every + save_image_interval = args.save_preview_every + save_completed = session_epoch >= max_train_epochs + save_canceled = status.interrupted + save_image = False + save_model = False + save_lora = False - # Repeat for sample images - if 0 < save_image_interval <= session_epoch - last_image_save: - save_image = True - last_image_save = session_epoch + if not save_canceled and not save_completed: + # Check to see if the number of epochs since last save is gt the interval + if 0 < save_model_interval <= session_epoch - last_model_save: + save_model = True + if args.use_lora: + save_lora = True + last_model_save = session_epoch - else: - logger.debug("\nSave completed/canceled.") - if global_step > 0: - save_image = True - save_model = True - if args.use_lora: - save_lora = True + # Repeat for sample images + if 0 < save_image_interval <= session_epoch - last_image_save: + save_image = True + last_image_save = session_epoch - save_snapshot = False + else: + logger.debug("\nSave completed/canceled.") + if global_step > 0: + save_image = True + save_model = True + if args.use_lora: + save_lora = True - if is_epoch_check: - if shared.status.do_save_samples: - save_image = True - shared.status.do_save_samples = False + save_snapshot = False - if shared.status.do_save_model: - if args.use_lora: + if is_epoch_check: + if shared.status.do_save_samples: + save_image = True + shared.status.do_save_samples = False + + if shared.status.do_save_model: + if args.use_lora: + save_lora = True + save_model = True + shared.status.do_save_model = False + + save_checkpoint = False + if save_model: + if save_canceled: + if global_step > 0: + logger.debug("Canceled, enabling saves.") + save_snapshot = args.save_state_cancel + save_checkpoint = args.save_ckpt_cancel + elif save_completed: + if global_step > 0: + logger.debug("Completed, enabling saves.") + save_snapshot = args.save_state_after + save_checkpoint = args.save_ckpt_after + else: + save_snapshot = args.save_state_during + save_checkpoint = args.save_ckpt_during + if save_checkpoint and args.use_lora: + save_checkpoint = False save_lora = True - save_model = True - shared.status.do_save_model = False + if not args.use_lora: + save_lora = False - save_checkpoint = False - if save_model: - if save_canceled: - if global_step > 0: - logger.debug("Canceled, enabling saves.") - save_snapshot = args.save_state_cancel - save_checkpoint = args.save_ckpt_cancel - elif save_completed: - if global_step > 0: - logger.debug("Completed, enabling saves.") - save_snapshot = args.save_state_after - save_checkpoint = args.save_ckpt_after - else: - save_snapshot = args.save_state_during - save_checkpoint = args.save_ckpt_during - if save_checkpoint and args.use_lora: - save_checkpoint = False - save_lora = True - if not args.use_lora: - save_lora = False + if ( + save_checkpoint + or save_snapshot + or save_lora + or save_image + or save_model + ): + save_weights( + save_image, + save_model, + save_snapshot, + save_checkpoint, + save_lora + ) + + return save_model, save_image - if ( - save_checkpoint - or save_snapshot - or save_lora - or save_image - or save_model + def save_weights( + save_image, save_diffusers, save_snapshot, save_checkpoint, save_lora ): - disable_safe_unpickle() - save_weights( - save_image, - save_model, - save_snapshot, - save_checkpoint, - save_lora - ) - enable_safe_unpickle() + global last_samples + global last_prompts + nonlocal vae + nonlocal pbar2 - return save_model, save_image + printm(" Saving weights.") + pbar2.reset() + pbar2.set_description("Saving weights/samples...") + pbar2.set_postfix(refresh=True) + + # Create the pipeline using the trained modules and save it. + if accelerator.is_main_process: + printm("Pre-cleanup.") + torch_rng_state = None + cuda_gpu_rng_state = None + cuda_cpu_rng_state = None + # Save random states so sample generation doesn't impact training. + if shared.device.type == 'cuda': + torch_rng_state = torch.get_rng_state() + cuda_gpu_rng_state = torch.cuda.get_rng_state(device="cuda") + cuda_cpu_rng_state = torch.cuda.get_rng_state(device="cpu") + + optim_to(profiler, optimizer) + + if profiler is None: + cleanup() + + if vae is None: + printm("Loading vae.") + vae = create_vae() + + printm("Creating pipeline.") + if args.model_type == "SDXL": + s_pipeline = StableDiffusionXLPipeline.from_pretrained( + args.get_pretrained_model_name_or_path(), + unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True), + text_encoder=accelerator.unwrap_model( + text_encoder, keep_fp32_wrapper=True + ), + text_encoder_2=accelerator.unwrap_model( + text_encoder_two, keep_fp32_wrapper=True + ), + vae=vae.to(accelerator.device), + torch_dtype=weight_dtype, + revision=args.revision, + ) + xformerify(s_pipeline.unet,use_lora=args.use_lora) + else: + s_pipeline = DiffusionPipeline.from_pretrained( + args.get_pretrained_model_name_or_path(), + unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True), + text_encoder=accelerator.unwrap_model( + text_encoder, keep_fp32_wrapper=True + ), + vae=vae, + torch_dtype=weight_dtype, + revision=args.revision, + ) + xformerify(s_pipeline.unet,use_lora=args.use_lora) + xformerify(s_pipeline.vae,use_lora=args.use_lora) - def save_weights( - save_image, save_diffusers, save_snapshot, save_checkpoint, save_lora - ): - global last_samples - global last_prompts - nonlocal vae - nonlocal pbar2 + weights_dir = args.get_pretrained_model_name_or_path() - printm(" Saving weights.") - pbar2.reset() - pbar2.set_description("Saving weights/samples...") - pbar2.set_postfix(refresh=True) + if user_model_dir != "": + loras_dir = os.path.join(user_model_dir, "Lora") + else: + model_dir = shared.models_path + loras_dir = os.path.join(model_dir, "Lora") + delete_tmp_lora = False + # Update the temp path if we just need to save an image + if save_image: + logger.debug("Save image is set.") + if args.use_lora: + if not save_lora: + logger.debug("Saving lora weights instead of checkpoint, using temp dir.") + save_lora = True + delete_tmp_lora = True + save_checkpoint = False + save_diffusers = False + os.makedirs(loras_dir, exist_ok=True) + elif not save_diffusers: + logger.debug("Saving checkpoint, using temp dir.") + save_diffusers = True + weights_dir = f"{weights_dir}_temp" + os.makedirs(weights_dir, exist_ok=True) + else: + save_lora = False + logger.debug(f"Save checkpoint: {save_checkpoint} save lora {save_lora}.") + # Is inference_mode() needed here to prevent issues when saving? + logger.debug(f"Loras dir: {loras_dir}") - # Create the pipeline using the trained modules and save it. - if accelerator.is_main_process: - printm("Pre-cleanup.") - torch_rng_state = None - cuda_gpu_rng_state = None - cuda_cpu_rng_state = None - # Save random states so sample generation doesn't impact training. - if shared.device.type == 'cuda': - torch_rng_state = torch.get_rng_state() - cuda_gpu_rng_state = torch.cuda.get_rng_state(device="cuda") - cuda_cpu_rng_state = torch.cuda.get_rng_state(device="cpu") - - optim_to(profiler, optimizer) - - if profiler is None: - cleanup() + # setup pt path + if args.custom_model_name == "": + lora_model_name = args.model_name + else: + lora_model_name = args.custom_model_name - if vae is None: - printm("Loading vae.") - vae = create_vae() + lora_save_file = os.path.join(loras_dir, f"{lora_model_name}_{args.revision}.safetensors") - printm("Creating pipeline.") - if args.model_type == "SDXL": - s_pipeline = StableDiffusionXLPipeline.from_pretrained( - args.get_pretrained_model_name_or_path(), - unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True), - text_encoder=accelerator.unwrap_model( - text_encoder, keep_fp32_wrapper=True - ), - text_encoder_2=accelerator.unwrap_model( - text_encoder_two, keep_fp32_wrapper=True - ), - vae=vae.to(accelerator.device), - torch_dtype=weight_dtype, - revision=args.revision, - ) - xformerify(s_pipeline.unet,use_lora=args.use_lora) - else: - s_pipeline = DiffusionPipeline.from_pretrained( - args.get_pretrained_model_name_or_path(), - unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True), - text_encoder=accelerator.unwrap_model( - text_encoder, keep_fp32_wrapper=True - ), - vae=vae, - torch_dtype=weight_dtype, - revision=args.revision, - ) - xformerify(s_pipeline.unet,use_lora=args.use_lora) - xformerify(s_pipeline.vae,use_lora=args.use_lora) + with accelerator.autocast(), torch.inference_mode(): - weights_dir = args.get_pretrained_model_name_or_path() + def lora_save_function(weights, filename): + metadata = args.export_ss_metadata() + logger.debug(f"Saving lora to {filename}") + safetensors.torch.save_file(weights, filename, metadata=metadata) - if user_model_dir != "": - loras_dir = os.path.join(user_model_dir, "Lora") - else: - model_dir = shared.models_path - loras_dir = os.path.join(model_dir, "Lora") - delete_tmp_lora = False - # Update the temp path if we just need to save an image - if save_image: - logger.debug("Save image is set.") - if args.use_lora: - if not save_lora: - logger.debug("Saving lora weights instead of checkpoint, using temp dir.") - save_lora = True - delete_tmp_lora = True - save_checkpoint = False - save_diffusers = False - os.makedirs(loras_dir, exist_ok=True) - elif not save_diffusers: - logger.debug("Saving checkpoint, using temp dir.") - save_diffusers = True - weights_dir = f"{weights_dir}_temp" - os.makedirs(weights_dir, exist_ok=True) - else: - save_lora = False - logger.debug(f"Save checkpoint: {save_checkpoint} save lora {save_lora}.") - # Is inference_mode() needed here to prevent issues when saving? - logger.debug(f"Loras dir: {loras_dir}") - - # setup pt path - if args.custom_model_name == "": - lora_model_name = args.model_name - else: - lora_model_name = args.custom_model_name - - lora_save_file = os.path.join(loras_dir, f"{lora_model_name}_{args.revision}.safetensors") - - with accelerator.autocast(), torch.inference_mode(): - - def lora_save_function(weights, filename): - metadata = args.export_ss_metadata() - logger.debug(f"Saving lora to {filename}") - safetensors.torch.save_file(weights, filename, metadata=metadata) - - if save_lora: - # TODO: Add a version for the lora model? - pbar2.reset(1) - pbar2.set_description("Saving Lora Weights...") - # setup directory - logger.debug(f"Saving lora to {lora_save_file}") - unet_lora_layers_to_save = unet_lora_state_dict(unet) - text_encoder_one_lora_layers_to_save = None - text_encoder_two_lora_layers_to_save = None - if args.stop_text_encoder != 0: - text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(text_encoder) - if args.model_type == "SDXL": + if save_lora: + # TODO: Add a version for the lora model? + pbar2.reset(1) + pbar2.set_description("Saving Lora Weights...") + # setup directory + logger.debug(f"Saving lora to {lora_save_file}") + unet_lora_layers_to_save = unet_lora_state_dict(unet) + text_encoder_one_lora_layers_to_save = None + text_encoder_two_lora_layers_to_save = None if args.stop_text_encoder != 0: - text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(text_encoder_two) - StableDiffusionXLPipeline.save_lora_weights( - loras_dir, - unet_lora_layers=unet_lora_layers_to_save, - text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, - text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, - weight_name=lora_save_file, - safe_serialization=True, - save_function=lora_save_function - ) - scheduler_args = {} + text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(text_encoder) + if args.model_type == "SDXL": + if args.stop_text_encoder != 0: + text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(text_encoder_two) + StableDiffusionXLPipeline.save_lora_weights( + loras_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, + weight_name=lora_save_file, + safe_serialization=True, + save_function=lora_save_function + ) + scheduler_args = {} - if "variance_type" in s_pipeline.scheduler.config: - variance_type = s_pipeline.scheduler.config.variance_type + if "variance_type" in s_pipeline.scheduler.config: + variance_type = s_pipeline.scheduler.config.variance_type - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" - scheduler_args["variance_type"] = variance_type + scheduler_args["variance_type"] = variance_type - s_pipeline.scheduler = UniPCMultistepScheduler.from_config(s_pipeline.scheduler.config, **scheduler_args) + s_pipeline.scheduler = UniPCMultistepScheduler.from_config(s_pipeline.scheduler.config, **scheduler_args) + save_lora = False + save_model = False + else: + StableDiffusionPipeline.save_lora_weights( + loras_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + weight_name=lora_save_file, + safe_serialization=True + ) + s_pipeline.scheduler = get_scheduler_class("UniPCMultistep").from_config( + s_pipeline.scheduler.config) + s_pipeline.scheduler.config.solver_type = "bh2" save_lora = False save_model = False - else: - StableDiffusionPipeline.save_lora_weights( - loras_dir, - unet_lora_layers=unet_lora_layers_to_save, - text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, - weight_name=lora_save_file, - safe_serialization=True - ) - s_pipeline.scheduler = get_scheduler_class("UniPCMultistep").from_config( - s_pipeline.scheduler.config) - s_pipeline.scheduler.config.solver_type = "bh2" - save_lora = False - save_model = False - - elif save_diffusers: - # We are saving weights, we need to ensure revision is saved - if "_tmp" not in weights_dir: - args.save() - try: - out_file = None - status.textinfo = ( - f"Saving diffusion model at step {args.revision}..." - ) - update_status({"status": status.textinfo}) - pbar2.reset(1) - - pbar2.set_description("Saving diffusion model") - s_pipeline.save_pretrained( - weights_dir, - safe_serialization=False, - ) - if ema_model is not None: - ema_model.save_pretrained( - os.path.join( - weights_dir, - "ema_unet", - ), - safe_serialization=False, - ) - pbar2.update() - if save_snapshot: - pbar2.reset(1) - pbar2.set_description("Saving Snapshot") + elif save_diffusers: + # We are saving weights, we need to ensure revision is saved + if "_tmp" not in weights_dir: + args.save() + try: + out_file = None status.textinfo = ( - f"Saving snapshot at step {args.revision}..." + f"Saving diffusion model at step {args.revision}..." ) update_status({"status": status.textinfo}) - accelerator.save_state( - os.path.join( - args.model_dir, - "checkpoints", - f"checkpoint-{args.revision}", - ) + pbar2.reset(1) + + pbar2.set_description("Saving diffusion model") + s_pipeline.save_pretrained( + weights_dir, + safe_serialization=False, ) + if ema_model is not None: + ema_model.save_pretrained( + os.path.join( + weights_dir, + "ema_unet", + ), + safe_serialization=False, + ) pbar2.update() - # We should save this regardless, because it's our fallback if no snapshot exists. + if save_snapshot: + pbar2.reset(1) + pbar2.set_description("Saving Snapshot") + status.textinfo = ( + f"Saving snapshot at step {args.revision}..." + ) + update_status({"status": status.textinfo}) + accelerator.save_state( + os.path.join( + args.model_dir, + "checkpoints", + f"checkpoint-{args.revision}", + ) + ) + pbar2.update() - # package pt into checkpoint - if save_checkpoint: - pbar2.reset(1) - pbar2.set_description("Compiling Checkpoint") - snap_rev = str(args.revision) if save_snapshot else "" - if export_diffusers: - copy_diffusion_model(args.model_name, os.path.join(user_model_dir, "diffusers")) - else: - if args.model_type == "SDXL": - compile_checkpoint_xl(args.model_name, reload_models=False, - lora_file_name=out_file, - log=False, snap_rev=snap_rev, pbar=pbar2) + # We should save this regardless, because it's our fallback if no snapshot exists. + + # package pt into checkpoint + if save_checkpoint: + pbar2.reset(1) + pbar2.set_description("Compiling Checkpoint") + snap_rev = str(args.revision) if save_snapshot else "" + if export_diffusers: + copy_diffusion_model(args.model_name, os.path.join(user_model_dir, "diffusers")) else: - compile_checkpoint(args.model_name, reload_models=False, - lora_file_name=out_file, - log=False, snap_rev=snap_rev, pbar=pbar2) - printm("Restored, moved to acc.device.") - pbar2.update() + if args.model_type == "SDXL": + compile_checkpoint_xl(args.model_name, reload_models=False, + lora_file_name=out_file, + log=False, snap_rev=snap_rev, pbar=pbar2) + else: + compile_checkpoint(args.model_name, reload_models=False, + lora_file_name=out_file, + log=False, snap_rev=snap_rev, pbar=pbar2) + printm("Restored, moved to acc.device.") + pbar2.update() + + except Exception as ex: + logger.warning(f"Exception saving checkpoint/model: {ex}") + traceback.print_exc() + pass + save_dir = args.model_dir + + if save_image: + logger.debug("Saving images...") + # Get the path to a temporary directory + del s_pipeline + logger.debug(f"Loading image pipeline from {weights_dir}...") + if args.model_type == "SDXL": + s_pipeline = StableDiffusionXLPipeline.from_pretrained( + weights_dir, vae=vae, revision=args.revision, + torch_dtype=weight_dtype + ) + else: + s_pipeline = StableDiffusionPipeline.from_pretrained( + weights_dir, vae=vae, revision=args.revision, + torch_dtype=weight_dtype + ) + if args.tomesd: + tomesd.apply_patch(s_pipeline, ratio=args.tomesd, use_rand=False) + if args.use_lora: + s_pipeline.load_lora_weights(lora_save_file) - except Exception as ex: - logger.warning(f"Exception saving checkpoint/model: {ex}") - traceback.print_exc() + try: + s_pipeline.enable_vae_tiling() + s_pipeline.enable_vae_slicing() + s_pipeline.enable_sequential_cpu_offload() + s_pipeline.enable_xformers_memory_efficient_attention() + except: pass - save_dir = args.model_dir - if save_image: - logger.debug("Saving images...") - # Get the path to a temporary directory - del s_pipeline - logger.debug(f"Loading image pipeline from {weights_dir}...") - if args.model_type == "SDXL": - s_pipeline = StableDiffusionXLPipeline.from_pretrained( - weights_dir, vae=vae, revision=args.revision, - torch_dtype=weight_dtype - ) - else: - s_pipeline = StableDiffusionPipeline.from_pretrained( - weights_dir, vae=vae, revision=args.revision, - torch_dtype=weight_dtype + samples = [] + sample_prompts = [] + last_samples = [] + last_prompts = [] + status.textinfo = ( + f"Saving preview image(s) at step {args.revision}..." ) - if args.tomesd: - tomesd.apply_patch(s_pipeline, ratio=args.tomesd, use_rand=False) - if args.use_lora: - s_pipeline.load_lora_weights(lora_save_file) + update_status({"status": status.textinfo}) + try: + s_pipeline.set_progress_bar_config(disable=True) + sample_dir = os.path.join(save_dir, "samples") + os.makedirs(sample_dir, exist_ok=True) + + sd = SampleDataset(args) + prompts = sd.prompts + logger.debug(f"Generating {len(prompts)} samples...") + + concepts = args.concepts() + if args.sanity_prompt: + epd = PromptData( + prompt=args.sanity_prompt, + seed=args.sanity_seed, + negative_prompt=concepts[ + 0 + ].save_sample_negative_prompt, + resolution=(args.resolution, args.resolution), + ) + prompts.append(epd) - try: - s_pipeline.enable_vae_tiling() - s_pipeline.enable_vae_slicing() - s_pipeline.enable_sequential_cpu_offload() - s_pipeline.enable_xformers_memory_efficient_attention() - except: - pass + prompt_lengths = len(prompts) + if args.disable_logging: + pbar2.reset(prompt_lengths) + else: + pbar2.reset(prompt_lengths + 2) + pbar2.set_description("Generating Samples") + ci = 0 + for c in prompts: + c.out_dir = os.path.join(args.model_dir, "samples") + generator = torch.manual_seed(int(c.seed)) + s_image = s_pipeline( + c.prompt, + num_inference_steps=c.steps, + guidance_scale=c.scale, + negative_prompt=c.negative_prompt, + height=c.resolution[1], + width=c.resolution[0], + generator=generator, + ).images[0] + sample_prompts.append(c.prompt) + image_name = db_save_image( + s_image, + c, + custom_name=f"sample_{args.revision}-{ci}", + ) + shared.status.current_image = image_name + shared.status.sample_prompts = [c.prompt] + update_status({"images": [image_name], "prompts": [c.prompt]}) + samples.append(image_name) + pbar2.update() + ci += 1 + for sample in samples: + last_samples.append(sample) + for prompt in sample_prompts: + last_prompts.append(prompt) + del samples + del prompts + except: + logger.warning(f"Exception saving sample.") + traceback.print_exc() + pass - samples = [] - sample_prompts = [] - last_samples = [] - last_prompts = [] - status.textinfo = ( - f"Saving preview image(s) at step {args.revision}..." - ) - update_status({"status": status.textinfo}) - try: - s_pipeline.set_progress_bar_config(disable=True) - sample_dir = os.path.join(save_dir, "samples") - os.makedirs(sample_dir, exist_ok=True) - - sd = SampleDataset(args) - prompts = sd.prompts - logger.debug(f"Generating {len(prompts)} samples...") - - concepts = args.concepts() - if args.sanity_prompt: - epd = PromptData( - prompt=args.sanity_prompt, - seed=args.sanity_seed, - negative_prompt=concepts[ - 0 - ].save_sample_negative_prompt, - resolution=(args.resolution, args.resolution), - ) - prompts.append(epd) + del s_pipeline + printm("Starting cleanup.") - prompt_lengths = len(prompts) - if args.disable_logging: - pbar2.reset(prompt_lengths) - else: - pbar2.reset(prompt_lengths + 2) - pbar2.set_description("Generating Samples") - ci = 0 - for c in prompts: - c.out_dir = os.path.join(args.model_dir, "samples") - generator = torch.manual_seed(int(c.seed)) - s_image = s_pipeline( - c.prompt, - num_inference_steps=c.steps, - guidance_scale=c.scale, - negative_prompt=c.negative_prompt, - height=c.resolution[1], - width=c.resolution[0], - generator=generator, - ).images[0] - sample_prompts.append(c.prompt) - image_name = db_save_image( - s_image, - c, - custom_name=f"sample_{args.revision}-{ci}", - ) - shared.status.current_image = image_name - shared.status.sample_prompts = [c.prompt] - update_status({"images": [image_name], "prompts": [c.prompt]}) - samples.append(image_name) - pbar2.update() - ci += 1 - for sample in samples: - last_samples.append(sample) - for prompt in sample_prompts: - last_prompts.append(prompt) - del samples - del prompts - except: - logger.warning(f"Exception saving sample.") - traceback.print_exc() - pass + if os.path.isdir(loras_dir) and "_tmp" in loras_dir: + shutil.rmtree(loras_dir) - del s_pipeline - printm("Starting cleanup.") + if os.path.isdir(weights_dir) and "_tmp" in weights_dir: + shutil.rmtree(weights_dir) - if os.path.isdir(loras_dir) and "_tmp" in loras_dir: - shutil.rmtree(loras_dir) + if "generator" in locals(): + del generator - if os.path.isdir(weights_dir) and "_tmp" in weights_dir: - shutil.rmtree(weights_dir) + if not args.disable_logging: + try: + printm("Parse logs.") + log_images, log_names = log_parser.parse_logs(model_name=args.model_name) + pbar2.update() + for log_image in log_images: + last_samples.append(log_image) + for log_name in log_names: + last_prompts.append(log_name) + + del log_images + del log_names + except Exception as l: + traceback.print_exc() + logger.warning(f"Exception parsing logz: {l}") + pass + + send_training_update( + last_samples, + args.model_name, + last_prompts, + global_step, + args.revision + ) - if "generator" in locals(): - del generator + status.sample_prompts = last_prompts + status.current_image = last_samples + update_status({"images": last_samples, "prompts": last_prompts}) + pbar2.update() - if not args.disable_logging: - try: - printm("Parse logs.") - log_images, log_names = log_parser.parse_logs(model_name=args.model_name) - pbar2.update() - for log_image in log_images: - last_samples.append(log_image) - for log_name in log_names: - last_prompts.append(log_name) - - del log_images - del log_names - except Exception as l: - traceback.print_exc() - logger.warning(f"Exception parsing logz: {l}") - pass - send_training_update( - last_samples, - args.model_name, - last_prompts, - global_step, - args.revision - ) + if args.cache_latents: + printm("Unloading vae.") + del vae + # Preserve the reference again + vae = None - status.sample_prompts = last_prompts status.current_image = last_samples - update_status({"images": last_samples, "prompts": last_prompts}) - pbar2.update() + update_status({"images": last_samples}) + cleanup() + printm("Cleanup.") + optim_to(profiler, optimizer, accelerator.device) - if args.cache_latents: - printm("Unloading vae.") - del vae - # Preserve the reference again - vae = None + # Restore all random states to avoid having sampling impact training. + if shared.device.type == 'cuda': + torch.set_rng_state(torch_rng_state) + torch.cuda.set_rng_state(cuda_cpu_rng_state, device="cpu") + torch.cuda.set_rng_state(cuda_gpu_rng_state, device="cuda") - status.current_image = last_samples - update_status({"images": last_samples}) - cleanup() - printm("Cleanup.") + cleanup() - optim_to(profiler, optimizer, accelerator.device) + # Save the lora weights if we are saving the model + if os.path.isfile(lora_save_file) and not delete_tmp_lora: + meta = args.export_ss_metadata() + convert_diffusers_to_kohya_lora(lora_save_file, meta, args.lora_weight) + else: + if os.path.isfile(lora_save_file): + os.remove(lora_save_file) - # Restore all random states to avoid having sampling impact training. - if shared.device.type == 'cuda': - torch.set_rng_state(torch_rng_state) - torch.cuda.set_rng_state(cuda_cpu_rng_state, device="cpu") - torch.cuda.set_rng_state(cuda_gpu_rng_state, device="cuda") + printm("Completed saving weights.") + pbar2.reset() - cleanup() + # Only show the progress bar once on each machine, and do not send statuses to the new UI. + progress_bar = mytqdm( + range(global_step, max_train_steps), + disable=not accelerator.is_local_main_process, + position=0 + ) + progress_bar.set_description("Steps") + progress_bar.set_postfix(refresh=True) + args.revision = ( + args.revision if isinstance(args.revision, int) else + int(args.revision) if str(args.revision).strip() else + 0 + ) + lifetime_step = args.revision + lifetime_epoch = args.epoch + status.job_count = max_train_steps + status.job_no = global_step + update_status({"progress_1_total": max_train_steps, "progress_1_job_current": global_step}) + training_complete = False + msg = "" - # Save the lora weights if we are saving the model - if os.path.isfile(lora_save_file) and not delete_tmp_lora: - meta = args.export_ss_metadata() - convert_diffusers_to_kohya_lora(lora_save_file, meta, args.lora_weight) - else: - if os.path.isfile(lora_save_file): - os.remove(lora_save_file) + last_tenc = 0 < text_encoder_epochs + if stop_text_percentage == 0: + last_tenc = False + + cleanup() + stats = { + "loss": 0.0, + "prior_loss": 0.0, + "instance_loss": 0.0, + "unet_lr": learning_rate, + "tenc_lr": txt_learning_rate, + "session_epoch": 0, + "lifetime_epoch": args.epoch, + "total_session_epoch": args.num_train_epochs, + "total_lifetime_epoch": args.epoch + args.num_train_epochs, + "lifetime_step": args.revision, + "session_step": 0, + "total_session_step": max_train_steps, + "total_lifetime_step": args.revision + max_train_steps, + "steps_per_epoch": len(train_dataset), + "iterations_per_second": 0.0, + "vram": round(torch.cuda.memory_reserved(0) / 1024 ** 3, 1) + } + for epoch in range(first_epoch, max_train_epochs): + if training_complete: + logger.debug("Training complete, breaking epoch.") + break - printm("Completed saving weights.") - pbar2.reset() + if args.train_unet: + unet.train() + elif args.use_lora and not args.lora_use_buggy_requires_grad: + set_lora_requires_grad(unet, False) - # Only show the progress bar once on each machine, and do not send statuses to the new UI. - progress_bar = mytqdm( - range(global_step, max_train_steps), - disable=not accelerator.is_local_main_process, - position=0 - ) - progress_bar.set_description("Steps") - progress_bar.set_postfix(refresh=True) - args.revision = ( - args.revision if isinstance(args.revision, int) else - int(args.revision) if str(args.revision).strip() else - 0 - ) - lifetime_step = args.revision - lifetime_epoch = args.epoch - status.job_count = max_train_steps - status.job_no = global_step - update_status({"progress_1_total": max_train_steps, "progress_1_job_current": global_step}) - training_complete = False - msg = "" - - last_tenc = 0 < text_encoder_epochs - if stop_text_percentage == 0: - last_tenc = False - - cleanup() - stats = { - "loss": 0.0, - "prior_loss": 0.0, - "instance_loss": 0.0, - "unet_lr": learning_rate, - "tenc_lr": txt_learning_rate, - "session_epoch": 0, - "lifetime_epoch": args.epoch, - "total_session_epoch": args.num_train_epochs, - "total_lifetime_epoch": args.epoch + args.num_train_epochs, - "lifetime_step": args.revision, - "session_step": 0, - "total_session_step": max_train_steps, - "total_lifetime_step": args.revision + max_train_steps, - "steps_per_epoch": len(train_dataset), - "iterations_per_second": 0.0, - "vram": round(torch.cuda.memory_reserved(0) / 1024 ** 3, 1) - } - for epoch in range(first_epoch, max_train_epochs): - if training_complete: - logger.debug("Training complete, breaking epoch.") - break - - if args.train_unet: - unet.train() - elif args.use_lora and not args.lora_use_buggy_requires_grad: - set_lora_requires_grad(unet, False) - - train_tenc = epoch < text_encoder_epochs - if stop_text_percentage == 0: - train_tenc = False + train_tenc = epoch < text_encoder_epochs + if stop_text_percentage == 0: + train_tenc = False - if args.freeze_clip_normalization: - text_encoder.eval() - if args.model_type == "SDXL": - text_encoder_two.eval() - else: - text_encoder.train(train_tenc) - if args.model_type == "SDXL": - text_encoder_two.train(train_tenc) + if args.freeze_clip_normalization: + text_encoder.eval() + if args.model_type == "SDXL": + text_encoder_two.eval() + else: + text_encoder.train(train_tenc) + if args.model_type == "SDXL": + text_encoder_two.train(train_tenc) - if args.use_lora: - if not args.lora_use_buggy_requires_grad: - set_lora_requires_grad(text_encoder, train_tenc) - # We need to enable gradients on an input for gradient checkpointing to work - # This will not be optimized because it is not a param to optimizer - text_encoder.text_model.embeddings.position_embedding.requires_grad_(train_tenc) + if args.use_lora: + if not args.lora_use_buggy_requires_grad: + set_lora_requires_grad(text_encoder, train_tenc) + # We need to enable gradients on an input for gradient checkpointing to work + # This will not be optimized because it is not a param to optimizer + text_encoder.text_model.embeddings.position_embedding.requires_grad_(train_tenc) + if args.model_type == "SDXL": + set_lora_requires_grad(text_encoder_two, train_tenc) + text_encoder_two.text_model.embeddings.position_embedding.requires_grad_(train_tenc) + else: + text_encoder.requires_grad_(train_tenc) if args.model_type == "SDXL": - set_lora_requires_grad(text_encoder_two, train_tenc) - text_encoder_two.text_model.embeddings.position_embedding.requires_grad_(train_tenc) - else: - text_encoder.requires_grad_(train_tenc) - if args.model_type == "SDXL": - text_encoder_two.requires_grad_(train_tenc) + text_encoder_two.requires_grad_(train_tenc) - if last_tenc != train_tenc: - last_tenc = train_tenc - cleanup() + if last_tenc != train_tenc: + last_tenc = train_tenc + cleanup() - loss_total = 0 + loss_total = 0 - current_prior_loss_weight = current_prior_loss( - args, current_epoch=global_epoch - ) - - instance_loss = None - prior_loss = None + current_prior_loss_weight = current_prior_loss( + args, current_epoch=global_epoch + ) - for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if ( - resume_from_checkpoint - and epoch == first_epoch - and step < resume_step - ): - progress_bar.update(train_batch_size) - progress_bar.reset() - status.job_count = max_train_steps - status.job_no += train_batch_size - stats["session_step"] += train_batch_size - stats["lifetime_step"] += train_batch_size - update_status(stats) - continue + instance_loss = None + prior_loss = None + + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if ( + resume_from_checkpoint + and epoch == first_epoch + and step < resume_step + ): + progress_bar.update(train_batch_size) + progress_bar.reset() + status.job_count = max_train_steps + status.job_no += train_batch_size + stats["session_step"] += train_batch_size + stats["lifetime_step"] += train_batch_size + update_status(stats) + continue + + with ConditionalAccumulator(accelerator, unet, text_encoder, text_encoder_two): + # Convert images to latent space + with torch.no_grad(): + if args.cache_latents: + latents = batch["images"].to(accelerator.device) + else: + latents = vae.encode( + batch["images"].to(dtype=weight_dtype) + ).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the model input + noise = torch.randn_like(latents, device=latents.device) + if args.offset_noise != 0: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.offset_noise * torch.randn( + (latents.shape[0], + latents.shape[1], + 1, + 1), + device=latents.device + ) + b_size, channels, height, width = latents.shape - with ConditionalAccumulator(accelerator, unet, text_encoder, text_encoder_two): - # Convert images to latent space - with torch.no_grad(): - if args.cache_latents: - latents = batch["images"].to(accelerator.device) - else: - latents = vae.encode( - batch["images"].to(dtype=weight_dtype) - ).latent_dist.sample() - latents = latents * 0.18215 - - # Sample noise that we'll add to the model input - noise = torch.randn_like(latents, device=latents.device) - if args.offset_noise != 0: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.offset_noise * torch.randn( - (latents.shape[0], - latents.shape[1], - 1, - 1), + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (b_size,), device=latents.device ) - b_size, channels, height, width = latents.shape - - # Sample a random timestep for each image - timesteps = torch.randint( - 0, - noise_scheduler.config.num_train_timesteps, - (b_size,), - device=latents.device - ) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - pad_tokens = args.pad_tokens if train_tenc else False - input_ids = batch["input_ids"] - encoder_hidden_states = None - if args.model_type != "SDXL" and text_encoder is not None: - encoder_hidden_states = encode_hidden_state( - text_encoder, - batch["input_ids"], - pad_tokens, - b_size, - args.max_token_length, - tokenizer_max_length, - args.clip_skip, - ) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + pad_tokens = args.pad_tokens if train_tenc else False + input_ids = batch["input_ids"] + encoder_hidden_states = None + if args.model_type != "SDXL" and text_encoder is not None: + encoder_hidden_states = encode_hidden_state( + text_encoder, + batch["input_ids"], + pad_tokens, + b_size, + args.max_token_length, + tokenizer_max_length, + args.clip_skip, + ) - if unet.config.in_channels > channels: - needed_additional_channels = unet.config.in_channels - channels - additional_latents = randn_tensor( - (b_size, needed_additional_channels, height, width), - device=noisy_latents.device, - dtype=noisy_latents.dtype, - ) - noisy_latents = torch.cat([additional_latents, noisy_latents], dim=1) - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + if unet.config.in_channels > channels: + needed_additional_channels = unet.config.in_channels - channels + additional_latents = randn_tensor( + (b_size, needed_additional_channels, height, width), + device=noisy_latents.device, + dtype=noisy_latents.dtype, + ) + noisy_latents = torch.cat([additional_latents, noisy_latents], dim=1) + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if args.model_type == "SDXL": - with accelerator.autocast(): - model_pred = unet( - noisy_latents, timesteps, batch["input_ids"], - added_cond_kwargs=batch["unet_added_conditions"] - ).sample - else: - # Predict the noise residual and compute loss - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.model_type != "SDXL": - # TODO: set a prior preservation flag and use that to ensure this ony happens in dreambooth - if not args.split_loss and not with_prior_preservation: - loss = instance_loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean") - loss *= batch["loss_avg"] + if args.model_type == "SDXL": + with accelerator.autocast(): + model_pred = unet( + noisy_latents, timesteps, batch["input_ids"], + added_cond_kwargs=batch["unet_added_conditions"] + ).sample else: - # Predict the noise residual - if model_pred.shape[1] == 6: - model_pred, _ = torch.chunk(model_pred, 2, dim=1) - - if model_pred.shape[0] > 1 and with_prior_preservation: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - print("model shape:") - print(model_pred.shape) - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) - - # Compute instance loss + # Predict the noise residual and compute loss + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.model_type != "SDXL": + # TODO: set a prior preservation flag and use that to ensure this ony happens in dreambooth + if not args.split_loss and not with_prior_preservation: + loss = instance_loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss *= batch["loss_avg"] + else: + # Predict the noise residual + if model_pred.shape[1] == 6: + model_pred, _ = torch.chunk(model_pred, 2, dim=1) + + if model_pred.shape[0] > 1 and with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + print("model shape:") + print(model_pred.shape) + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), + reduction="mean") + else: + # Compute loss loss = instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + if with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), - reduction="mean") + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss else: - # Compute loss - loss = instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - else: - if with_prior_preservation: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + accelerator.backward(loss) - # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + if accelerator.sync_gradients and not args.use_lora: + if train_tenc: + if args.model_type == "SDXL": + params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters(), + text_encoder_two.parameters()) + else: + params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters()) + else: + params_to_clip = unet.parameters() + accelerator.clip_grad_norm_(params_to_clip, 1) + + optimizer.step() + lr_scheduler.step(train_batch_size) + if args.use_ema and ema_model is not None: + ema_model.step(unet) + if profiler is not None: + profiler.step() + + optimizer.zero_grad(set_to_none=args.gradient_set_to_none) + + allocated = round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1) + cached = round(torch.cuda.memory_reserved(0) / 1024 ** 3, 1) + lr_data = lr_scheduler.get_last_lr() + last_lr = lr_data[0] + last_tenc_lr = 0 + stats["lr_data"] = lr_data + try: + if len(optimizer.param_groups) > 1: + last_tenc_lr = optimizer.param_groups[1]["lr"] if train_tenc else 0 + except: + logger.debug("Exception getting tenc lr") + pass - # Add the prior loss to the instance loss. - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + if 'adapt' in args.optimizer: + last_lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + if len(optimizer.param_groups) > 1: + try: + last_tenc_lr = optimizer.param_groups[1]["d"] * optimizer.param_groups[1]["lr"] + except: + logger.warning("Exception setting tenc weight decay") + traceback.print_exc() - accelerator.backward(loss) + update_status(stats) + del latents + del encoder_hidden_states + del noise + del timesteps + del noisy_latents + del target + + global_step += train_batch_size + args.revision += train_batch_size + status.job_no += train_batch_size + loss_step = loss.detach().item() + loss_total += loss_step - if accelerator.sync_gradients and not args.use_lora: - if train_tenc: - if args.model_type == "SDXL": - params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters(), - text_encoder_two.parameters()) - else: - params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters()) - else: - params_to_clip = unet.parameters() - accelerator.clip_grad_norm_(params_to_clip, 1) - - optimizer.step() - lr_scheduler.step(train_batch_size) - if args.use_ema and ema_model is not None: - ema_model.step(unet) - if profiler is not None: - profiler.step() - - optimizer.zero_grad(set_to_none=args.gradient_set_to_none) - - allocated = round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1) - cached = round(torch.cuda.memory_reserved(0) / 1024 ** 3, 1) - lr_data = lr_scheduler.get_last_lr() - last_lr = lr_data[0] - last_tenc_lr = 0 - stats["lr_data"] = lr_data - try: - if len(optimizer.param_groups) > 1: - last_tenc_lr = optimizer.param_groups[1]["lr"] if train_tenc else 0 - except: - logger.debug("Exception getting tenc lr") - pass + stats["session_step"] += train_batch_size + stats["lifetime_step"] += train_batch_size + stats["loss"] = loss_step - if 'adapt' in args.optimizer: - last_lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] - if len(optimizer.param_groups) > 1: - try: - last_tenc_lr = optimizer.param_groups[1]["d"] * optimizer.param_groups[1]["lr"] - except: - logger.warning("Exception setting tenc weight decay") - traceback.print_exc() + logs = { + "lr": float(last_lr), + "loss": float(loss_step), + "vram": float(cached), + } - update_status(stats) - del latents - del encoder_hidden_states - del noise - del timesteps - del noisy_latents - del target - - global_step += train_batch_size - args.revision += train_batch_size - status.job_no += train_batch_size - loss_step = loss.detach().item() - loss_total += loss_step - - stats["session_step"] += train_batch_size - stats["lifetime_step"] += train_batch_size - stats["loss"] = loss_step - - logs = { - "lr": float(last_lr), - "loss": float(loss_step), - "vram": float(cached), - } + stats["vram"] = logs["vram"] + stats["unet_lr"] = '{:.2E}'.format(Decimal(last_lr)) + stats["tenc_lr"] = '{:.2E}'.format(Decimal(last_tenc_lr)) - stats["vram"] = logs["vram"] - stats["unet_lr"] = '{:.2E}'.format(Decimal(last_lr)) - stats["tenc_lr"] = '{:.2E}'.format(Decimal(last_tenc_lr)) + if args.split_loss and with_prior_preservation and args.model_type != "SDXL": + logs["inst_loss"] = float(instance_loss.detach().item()) - if args.split_loss and with_prior_preservation and args.model_type != "SDXL": - logs["inst_loss"] = float(instance_loss.detach().item()) - - if prior_loss is not None: - logs["prior_loss"] = float(prior_loss.detach().item()) + if prior_loss is not None: + logs["prior_loss"] = float(prior_loss.detach().item()) + else: + logs["prior_loss"] = None # or some other default value + stats["instance_loss"] = logs["inst_loss"] + stats["prior_loss"] = logs["prior_loss"] + + if 'adapt' in args.optimizer: + status.textinfo2 = ( + f"Loss: {'%.2f' % loss_step}, UNET DLR: {'{:.2E}'.format(Decimal(last_lr))}, TENC DLR: {'{:.2E}'.format(Decimal(last_tenc_lr))}, " + f"VRAM: {allocated}/{cached} GB" + ) else: - logs["prior_loss"] = None # or some other default value - stats["instance_loss"] = logs["inst_loss"] - stats["prior_loss"] = logs["prior_loss"] - - if 'adapt' in args.optimizer: - status.textinfo2 = ( - f"Loss: {'%.2f' % loss_step}, UNET DLR: {'{:.2E}'.format(Decimal(last_lr))}, TENC DLR: {'{:.2E}'.format(Decimal(last_tenc_lr))}, " - f"VRAM: {allocated}/{cached} GB" - ) - else: - status.textinfo2 = ( - f"Loss: {'%.2f' % loss_step}, LR: {'{:.2E}'.format(Decimal(last_lr))}, " - f"VRAM: {allocated}/{cached} GB" - ) + status.textinfo2 = ( + f"Loss: {'%.2f' % loss_step}, LR: {'{:.2E}'.format(Decimal(last_lr))}, " + f"VRAM: {allocated}/{cached} GB" + ) - progress_bar.update(train_batch_size) - rate = progress_bar.format_dict["rate"] if "rate" in progress_bar.format_dict else None - if rate is None: - rate_string = "" - else: - if rate > 1: - rate_string = f"{rate:.2f} it/s" + progress_bar.update(train_batch_size) + rate = progress_bar.format_dict["rate"] if "rate" in progress_bar.format_dict else None + if rate is None: + rate_string = "" else: - rate_string = f"{1 / rate:.2f} s/it" if rate != 0 else "N/A" - stats["iterations_per_second"] = rate_string - progress_bar.set_postfix(**logs) - accelerator.log(logs, step=args.revision) + if rate > 1: + rate_string = f"{rate:.2f} it/s" + else: + rate_string = f"{1 / rate:.2f} s/it" if rate != 0 else "N/A" + stats["iterations_per_second"] = rate_string + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=args.revision) + + logs = {"epoch_loss": loss_total / len(train_dataloader)} + accelerator.log(logs, step=global_step) + stats["epoch_loss"] = '%.2f' % (loss_total / len(train_dataloader)) + + status.job_count = max_train_steps + status.job_no = global_step + stats["lifetime_step"] = args.revision + stats["session_step"] = global_step + # status0 = f"Steps: {global_step}/{max_train_steps} (Current), {rate_string}" + # status1 = f"{args.revision}/{lifetime_step + max_train_steps} (Lifetime), Epoch: {global_epoch}" + status.textinfo = ( + f"Steps: {global_step}/{max_train_steps} (Current), {rate_string}" + f" {args.revision}/{lifetime_step + max_train_steps} (Lifetime), Epoch: {global_epoch}" + ) + update_status(stats) - logs = {"epoch_loss": loss_total / len(train_dataloader)} - accelerator.log(logs, step=global_step) - stats["epoch_loss"] = '%.2f' % (loss_total / len(train_dataloader)) + if math.isnan(loss_step): + logger.warning("Loss is NaN, your model is dead. Cancelling training.") + status.interrupted = True + if status_handler: + status_handler.end("Training interrrupted due to NaN loss.") + + # Log completion message + if training_complete or status.interrupted: + shared.in_progress = False + shared.in_progress_step = 0 + shared.in_progress_epoch = 0 + logger.debug(" Training complete (step check).") + if status.interrupted: + state = "canceled" + else: + state = "complete" + status.textinfo = ( + f"Training {state} {global_step}/{max_train_steps}, {args.revision}" + f" total." + ) + if status_handler: + status_handler.end(status.textinfo) + break + + accelerator.wait_for_everyone() + + args.epoch += 1 + global_epoch += 1 + lifetime_epoch += 1 + session_epoch += 1 + stats["session_epoch"] += 1 + stats["lifetime_epoch"] += 1 + lr_scheduler.step(is_epoch=True) status.job_count = max_train_steps status.job_no = global_step - stats["lifetime_step"] = args.revision - stats["session_step"] = global_step - # status0 = f"Steps: {global_step}/{max_train_steps} (Current), {rate_string}" - # status1 = f"{args.revision}/{lifetime_step + max_train_steps} (Lifetime), Epoch: {global_epoch}" - status.textinfo = ( - f"Steps: {global_step}/{max_train_steps} (Current), {rate_string}" - f" {args.revision}/{lifetime_step + max_train_steps} (Lifetime), Epoch: {global_epoch}" - ) update_status(stats) + check_save(True) - if math.isnan(loss_step): - logger.warning("Loss is NaN, your model is dead. Cancelling training.") - status.interrupted = True - if status_handler: - status_handler.end("Training interrrupted due to NaN loss.") + if args.num_train_epochs > 1: + training_complete = session_epoch >= max_train_epochs - # Log completion message if training_complete or status.interrupted: - shared.in_progress = False - shared.in_progress_step = 0 - shared.in_progress_epoch = 0 logger.debug(" Training complete (step check).") if status.interrupted: state = "canceled" @@ -1776,59 +1804,27 @@ def lora_save_function(weights, filename): status_handler.end(status.textinfo) break - accelerator.wait_for_everyone() - - args.epoch += 1 - global_epoch += 1 - lifetime_epoch += 1 - session_epoch += 1 - stats["session_epoch"] += 1 - stats["lifetime_epoch"] += 1 - lr_scheduler.step(is_epoch=True) - status.job_count = max_train_steps - status.job_no = global_step - update_status(stats) - check_save(True) - - if args.num_train_epochs > 1: - training_complete = session_epoch >= max_train_epochs - - if training_complete or status.interrupted: - logger.debug(" Training complete (step check).") - if status.interrupted: - state = "canceled" - else: - state = "complete" + # Do this at the very END of the epoch, only after we're sure we're not done + if args.epoch_pause_frequency > 0 and args.epoch_pause_time > 0: + if not session_epoch % args.epoch_pause_frequency: + logger.debug( + f"Giving the GPU a break for {args.epoch_pause_time} seconds." + ) + for i in range(args.epoch_pause_time): + if status.interrupted: + training_complete = True + logger.debug("Training complete, interrupted.") + if status_handler: + status_handler.end("Training interrrupted.") + break + time.sleep(1) - status.textinfo = ( - f"Training {state} {global_step}/{max_train_steps}, {args.revision}" - f" total." - ) - if status_handler: - status_handler.end(status.textinfo) - break - - # Do this at the very END of the epoch, only after we're sure we're not done - if args.epoch_pause_frequency > 0 and args.epoch_pause_time > 0: - if not session_epoch % args.epoch_pause_frequency: - logger.debug( - f"Giving the GPU a break for {args.epoch_pause_time} seconds." - ) - for i in range(args.epoch_pause_time): - if status.interrupted: - training_complete = True - logger.debug("Training complete, interrupted.") - if status_handler: - status_handler.end("Training interrrupted.") - break - time.sleep(1) - - cleanup_memory() - accelerator.end_training() - result.msg = msg - result.config = args - result.samples = last_samples - stop_profiler(profiler) - return result + cleanup_memory() + accelerator.end_training() + result.msg = msg + result.config = args + result.samples = last_samples + stop_profiler(profiler) + return result return inner_loop() diff --git a/dreambooth/utils/model_utils.py b/dreambooth/utils/model_utils.py index 3461015d..56d04615 100644 --- a/dreambooth/utils/model_utils.py +++ b/dreambooth/utils/model_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import collections +import contextlib import json import logging import os @@ -257,23 +258,40 @@ def get_checkpoint_match(search_string): return None +disable_safe_unpickle_count = 0 + def disable_safe_unpickle(): + global disable_safe_unpickle_count try: from modules import shared as auto_shared - auto_shared.cmd_opts.disable_safe_unpickle = True - torch.load = unsafe_torch_load + if not auto_shared.cmd_opts.disable_safe_unpickle: + auto_shared.cmd_opts.disable_safe_unpickle = True + torch.load = unsafe_torch_load + disable_safe_unpickle_count += 1 except: pass def enable_safe_unpickle(): + global disable_safe_unpickle_count try: from modules import shared as auto_shared - auto_shared.cmd_opts.disable_safe_unpickle = False - torch.load = load + if disable_safe_unpickle_count > 0: + disable_safe_unpickle_count -= 1 + if disable_safe_unpickle_count == 0 and auto_shared.cmd_opts.disable_safe_unpickle: + auto_shared.cmd_opts.disable_safe_unpickle = False + torch.load = load except: pass +@contextlib.contextmanager +def safe_unpickle_disabled(): + disable_safe_unpickle() + try: + yield + finally: + enable_safe_unpickle() + def xformerify(obj, use_lora): try: diff --git a/helpers/image_builder.py b/helpers/image_builder.py index 2863e2f8..4b8c4919 100644 --- a/helpers/image_builder.py +++ b/helpers/image_builder.py @@ -14,12 +14,11 @@ from dreambooth import shared from dreambooth.dataclasses.db_config import DreamboothConfig from dreambooth.dataclasses.prompt_data import PromptData -from dreambooth.shared import disable_safe_unpickle from dreambooth.utils import image_utils from dreambooth.utils.image_utils import process_txt2img, get_scheduler_class from dreambooth.utils.model_utils import get_checkpoint_match, \ reload_system_models, \ - enable_safe_unpickle, disable_safe_unpickle, unload_system_models + safe_unpickle_disabled, unload_system_models from helpers.mytqdm import mytqdm from lora_diffusion.lora import _text_lora_path_ui, patch_pipe, tune_lora_scale, \ get_target_module @@ -83,71 +82,69 @@ def __init__( msg = f"Exception initializing accelerator: {e}" print(msg) torch_dtype = torch.float16 if shared.device.type == "cuda" else torch.float32 - disable_safe_unpickle() + with safe_unpickle_disabled(): - self.image_pipe = DiffusionPipeline.from_pretrained(config.get_pretrained_model_name_or_path(), torch_dtype=torch.float16) + self.image_pipe = DiffusionPipeline.from_pretrained(config.get_pretrained_model_name_or_path(), torch_dtype=torch.float16) - if config.pretrained_vae_name_or_path: - logging.getLogger(__name__).info("Using pretrained VAE.") - self.image_pipe.vae = AutoencoderKL.from_pretrained( - config.pretrained_vae_name_or_path or config.get_pretrained_model_name_or_path(), - subfolder=None if config.pretrained_vae_name_or_path else "vae", - revision=config.revision, - torch_dtype=torch_dtype - ) + if config.pretrained_vae_name_or_path: + logging.getLogger(__name__).info("Using pretrained VAE.") + self.image_pipe.vae = AutoencoderKL.from_pretrained( + config.pretrained_vae_name_or_path or config.get_pretrained_model_name_or_path(), + subfolder=None if config.pretrained_vae_name_or_path else "vae", + revision=config.revision, + torch_dtype=torch_dtype + ) - if config.infer_ema: - logging.getLogger(__name__).info("Using EMA model for inference.") - ema_path = os.path.join(config.get_pretrained_model_name_or_path(), "ema_unet", - "diffusion_pytorch_model.safetensors") - if os.path.isfile(ema_path): - self.image_pipe.unet = UNet2DConditionModel.from_pretrained(ema_path, torch_dtype=torch.float16), + if config.infer_ema: + logging.getLogger(__name__).info("Using EMA model for inference.") + ema_path = os.path.join(config.get_pretrained_model_name_or_path(), "ema_unet", + "diffusion_pytorch_model.safetensors") + if os.path.isfile(ema_path): + self.image_pipe.unet = UNet2DConditionModel.from_pretrained(ema_path, torch_dtype=torch.float16), - self.image_pipe.enable_model_cpu_offload() - self.image_pipe.unet.set_attn_processor(AttnProcessor2_0()) - if os.name != "nt": - self.image_pipe.unet = torch.compile(self.image_pipe.unet) - self.image_pipe.enable_xformers_memory_efficient_attention() - self.image_pipe.vae.enable_slicing() - tomesd.apply_patch(self.image_pipe, ratio=0.5) - self.image_pipe.scheduler.config["solver_type"] = "bh2" - self.image_pipe.progress_bar = self.progress_bar + self.image_pipe.enable_model_cpu_offload() + self.image_pipe.unet.set_attn_processor(AttnProcessor2_0()) + if os.name != "nt": + self.image_pipe.unet = torch.compile(self.image_pipe.unet) + self.image_pipe.enable_xformers_memory_efficient_attention() + self.image_pipe.vae.enable_slicing() + tomesd.apply_patch(self.image_pipe, ratio=0.5) + self.image_pipe.scheduler.config["solver_type"] = "bh2" + self.image_pipe.progress_bar = self.progress_bar - if scheduler is None: - scheduler = config.scheduler + if scheduler is None: + scheduler = config.scheduler - print(f"Using scheduler: {scheduler}") - scheduler_class = get_scheduler_class(scheduler) + print(f"Using scheduler: {scheduler}") + scheduler_class = get_scheduler_class(scheduler) - self.image_pipe.scheduler = scheduler_class.from_config(self.image_pipe.scheduler.config) + self.image_pipe.scheduler = scheduler_class.from_config(self.image_pipe.scheduler.config) - if "UniPC" in scheduler: - self.image_pipe.scheduler.config.solver_type = "bh2" + if "UniPC" in scheduler: + self.image_pipe.scheduler.config.solver_type = "bh2" - self.image_pipe.to(accelerator.device) - new_hotness = os.path.join(config.model_dir, "checkpoints", f"checkpoint-{config.revision}") - if os.path.exists(new_hotness): - accelerator.print(f"Resuming from checkpoint {new_hotness}") - disable_safe_unpickle() - accelerator.load_state(new_hotness) - enable_safe_unpickle() + self.image_pipe.to(accelerator.device) + new_hotness = os.path.join(config.model_dir, "checkpoints", f"checkpoint-{config.revision}") + if os.path.exists(new_hotness): + accelerator.print(f"Resuming from checkpoint {new_hotness}") + accelerator.load_state(new_hotness) - if config.use_lora and lora_model: - lora_model_path = shared.ui_lora_models_path - if os.path.exists(lora_model_path): - patch_pipe( - pipe=self.image_pipe, - maybe_unet_path=lora_model_path, - unet_target_replace_module=get_target_module("module", config.use_lora_extended), - token=None, - r=lora_unet_rank, - r_txt=lora_txt_rank - ) - tune_lora_scale(self.image_pipe.unet, config.lora_weight) + if config.use_lora and lora_model: + lora_model_path = shared.ui_lora_models_path + if os.path.exists(lora_model_path): + patch_pipe( + pipe=self.image_pipe, + maybe_unet_path=lora_model_path, + unet_target_replace_module=get_target_module("module", config.use_lora_extended), + token=None, + r=lora_unet_rank, + r_txt=lora_txt_rank + ) + tune_lora_scale(self.image_pipe.unet, config.lora_weight) lora_txt_path = _text_lora_path_ui(lora_model_path) if os.path.exists(lora_txt_path): - tune_lora_scale(self.image_pipe.text_encoder, config.lora_txt_weight) + tune_lora_scale(self.image_pipe.text_encoder, config.lora_weight) else: try: diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 47af5797..01cb1e1c 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -8,7 +8,7 @@ from safetensors.torch import save_file as safe_save from torch import dtype -from dreambooth.utils.model_utils import disable_safe_unpickle, enable_safe_unpickle +from dreambooth.utils.model_utils import safe_unpickle_disabled class LoraInjectedLinear(nn.Module): @@ -215,12 +215,13 @@ def inject_trainable_lora( if target_replace_module is None: target_replace_module = DEFAULT_TARGET_REPLACE - disable_safe_unpickle() + require_grad_params = [] names = [] if loras is not None: - loras = torch.load(loras) + with safe_unpickle_disabled(): + loras = torch.load(loras) for _module, name, _child_module in _find_modules( model, target_replace_module, search_class=[nn.Linear] @@ -252,7 +253,6 @@ def inject_trainable_lora( _module._modules[name].lora_down.weight.requires_grad = True names.append(name) - enable_safe_unpickle() return require_grad_params, names @@ -267,12 +267,12 @@ def inject_trainable_lora_extended( """ if target_replace_module is None: target_replace_module = UNET_EXTENDED_TARGET_REPLACE - disable_safe_unpickle() require_grad_params = [] names = [] if loras is not None: - loras = torch.load(loras) + with safe_unpickle_disabled(): + loras = torch.load(loras) for _module, name, _child_module in _find_modules( model, target_replace_module, search_class=[nn.Linear, nn.Conv2d] @@ -326,7 +326,6 @@ def inject_trainable_lora_extended( _module._modules[name].lora_down.weight.requires_grad = True names.append(name) - enable_safe_unpickle() return require_grad_params, names @@ -458,9 +457,9 @@ def convert_loras_to_safeloras_with_embeds( for name, (path, target_replace_module, r) in modelmap.items(): metadata[name] = json.dumps(list(target_replace_module)) - disable_safe_unpickle() - lora = torch.load(path) - enable_safe_unpickle() + with safe_unpickle_disabled(): + lora = torch.load(path) + for i, weight in enumerate(lora): is_up = i % 2 == 0 i = i // 2 @@ -903,9 +902,8 @@ def load_learned_embed_in_clip( token: Optional[Union[str, List[str]]] = None, idempotent=False, ): - disable_safe_unpickle() - learned_embeds = torch.load(learned_embeds_path) - enable_safe_unpickle() + with safe_unpickle_disabled(): + learned_embeds = torch.load(learned_embeds_path) apply_learned_embed_in_clip( learned_embeds, text_encoder, tokenizer, token, idempotent ) @@ -941,30 +939,30 @@ def patch_pipe( ti_path = _ti_lora_path(unet_path) text_path = _text_lora_path_ui(unet_path) - disable_safe_unpickle() - if patch_unet: - print("LoRA : Patching Unet") - lora_patch = get_target_module( - "patch", - bool(unet_target_replace_module == UNET_EXTENDED_TARGET_REPLACE) - ) + with safe_unpickle_disabled(): + if patch_unet: + print("LoRA : Patching Unet") + lora_patch = get_target_module( + "patch", + bool(unet_target_replace_module == UNET_EXTENDED_TARGET_REPLACE) + ) - lora_patch( - pipe.unet, - torch.load(unet_path), - r=r, - target_replace_module=unet_target_replace_module, - ) + lora_patch( + pipe.unet, + torch.load(unet_path), + r=r, + target_replace_module=unet_target_replace_module, + ) - if patch_text: - print("LoRA : Patching text encoder") - monkeypatch_or_replace_lora( - pipe.text_encoder, - torch.load(text_path), - target_replace_module=text_target_replace_module, - r=r_txt, - ) - enable_safe_unpickle() + if patch_text: + print("LoRA : Patching text encoder") + monkeypatch_or_replace_lora( + pipe.text_encoder, + torch.load(text_path), + target_replace_module=text_target_replace_module, + r=r_txt, + ) + if patch_ti: print("LoRA : Patching token input") token = load_learned_embed_in_clip(