diff --git a/dreambooth/diff_to_sd.py b/dreambooth/diff_to_sd.py
index fbef6690..3123e55f 100644
--- a/dreambooth/diff_to_sd.py
+++ b/dreambooth/diff_to_sd.py
@@ -20,7 +20,7 @@
from dreambooth.shared import status
from dreambooth.utils.model_utils import unload_system_models, \
reload_system_models, \
- disable_safe_unpickle, enable_safe_unpickle, import_model_class_from_model_name_or_path
+ safe_unpickle_disabled, import_model_class_from_model_name_or_path
from dreambooth.utils.utils import printi
from helpers.mytqdm import mytqdm
from lora_diffusion.lora import merge_lora_to_model
@@ -562,9 +562,8 @@ def load_model(model_path: str, map_location: str):
if ".safetensors" in model_path:
return safetensors.torch.load_file(model_path, device=map_location)
else:
- disable_safe_unpickle()
- loaded = torch.load(model_path, map_location=map_location)
- enable_safe_unpickle()
+ with safe_unpickle_disabled():
+ loaded = torch.load(model_path, map_location=map_location)
return loaded
diff --git a/dreambooth/sd_to_diff.py b/dreambooth/sd_to_diff.py
index ffd11e2f..3467a1be 100644
--- a/dreambooth/sd_to_diff.py
+++ b/dreambooth/sd_to_diff.py
@@ -25,7 +25,7 @@
from dreambooth import shared
from dreambooth.dataclasses.db_config import DreamboothConfig
-from dreambooth.utils.model_utils import enable_safe_unpickle, disable_safe_unpickle, unload_system_models, \
+from dreambooth.utils.model_utils import safe_unpickle_disabled, unload_system_models, \
reload_system_models
@@ -131,7 +131,6 @@ def extract_checkpoint(
# sh.update_status(status)
# else:
# modules.shared.status.update(status)
- disable_safe_unpickle()
if image_size is None:
image_size = 512
if model_type == "v2x":
@@ -162,59 +161,60 @@ def extract_checkpoint(
db_config.resolution = image_size
db_config.save()
try:
- if from_safetensors:
- if model_type == "SDXL":
- pipe = StableDiffusionXLPipeline.from_single_file(
- pretrained_model_link_or_path=checkpoint_file,
+ with safe_unpickle_disabled():
+ if from_safetensors:
+ if model_type == "SDXL":
+ pipe = StableDiffusionXLPipeline.from_single_file(
+ pretrained_model_link_or_path=checkpoint_file,
+ )
+ else:
+ pipe = StableDiffusionPipeline.from_single_file(
+ pretrained_model_link_or_path=checkpoint_file,
+ )
+ elif model_type == "SDXL":
+ pipe = StableDiffusionXLPipeline.from_pretrained(
+ checkpoint_path_or_dict=checkpoint_file,
+ original_config_file=original_config_file,
+ image_size=image_size,
+ prediction_type=prediction_type,
+ model_type=pipeline_type,
+ extract_ema=extract_ema,
+ scheduler_type=scheduler_type,
+ num_in_channels=num_in_channels,
+ upcast_attention=upcast_attention,
+ from_safetensors=from_safetensors,
+ device=device,
+ pretrained_model_name_or_path=checkpoint_file,
+ stable_unclip=stable_unclip,
+ stable_unclip_prior=stable_unclip_prior,
+ clip_stats_path=clip_stats_path,
+ controlnet=controlnet,
+ vae_path=vae_path,
+ pipeline_class=pipeline_class,
+ half=half
)
else:
- pipe = StableDiffusionPipeline.from_single_file(
- pretrained_model_link_or_path=checkpoint_file,
+ pipe = StableDiffusionPipeline.from_pretrained(
+ checkpoint_path_or_dict=checkpoint_file,
+ original_config_file=original_config_file,
+ image_size=image_size,
+ prediction_type=prediction_type,
+ model_type=pipeline_type,
+ extract_ema=extract_ema,
+ scheduler_type=scheduler_type,
+ num_in_channels=num_in_channels,
+ upcast_attention=upcast_attention,
+ from_safetensors=from_safetensors,
+ device=device,
+ pretrained_model_name_or_path=checkpoint_file,
+ stable_unclip=stable_unclip,
+ stable_unclip_prior=stable_unclip_prior,
+ clip_stats_path=clip_stats_path,
+ controlnet=controlnet,
+ vae_path=vae_path,
+ pipeline_class=pipeline_class,
+ half=half
)
- elif model_type == "SDXL":
- pipe = StableDiffusionXLPipeline.from_pretrained(
- checkpoint_path_or_dict=checkpoint_file,
- original_config_file=original_config_file,
- image_size=image_size,
- prediction_type=prediction_type,
- model_type=pipeline_type,
- extract_ema=extract_ema,
- scheduler_type=scheduler_type,
- num_in_channels=num_in_channels,
- upcast_attention=upcast_attention,
- from_safetensors=from_safetensors,
- device=device,
- pretrained_model_name_or_path=checkpoint_file,
- stable_unclip=stable_unclip,
- stable_unclip_prior=stable_unclip_prior,
- clip_stats_path=clip_stats_path,
- controlnet=controlnet,
- vae_path=vae_path,
- pipeline_class=pipeline_class,
- half=half
- )
- else:
- pipe = StableDiffusionPipeline.from_pretrained(
- checkpoint_path_or_dict=checkpoint_file,
- original_config_file=original_config_file,
- image_size=image_size,
- prediction_type=prediction_type,
- model_type=pipeline_type,
- extract_ema=extract_ema,
- scheduler_type=scheduler_type,
- num_in_channels=num_in_channels,
- upcast_attention=upcast_attention,
- from_safetensors=from_safetensors,
- device=device,
- pretrained_model_name_or_path=checkpoint_file,
- stable_unclip=stable_unclip,
- stable_unclip_prior=stable_unclip_prior,
- clip_stats_path=clip_stats_path,
- controlnet=controlnet,
- vae_path=vae_path,
- pipeline_class=pipeline_class,
- half=half
- )
dump_path = db_config.get_pretrained_model_name_or_path()
if controlnet:
@@ -246,7 +246,7 @@ def extract_checkpoint(
print(f"Couldn't find {full_path}")
break
remove_dirs = ["logging", "samples"]
- enable_safe_unpickle()
+
reload_system_models()
if success:
for rd in remove_dirs:
diff --git a/dreambooth/train_dreambooth.py b/dreambooth/train_dreambooth.py
index dbab4251..dadbec73 100644
--- a/dreambooth/train_dreambooth.py
+++ b/dreambooth/train_dreambooth.py
@@ -31,6 +31,7 @@
)
from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
from diffusers.models.attention_processor import LoRAAttnProcessor2_0, LoRAAttnProcessor
+from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import logging as dl
from diffusers.utils.torch_utils import randn_tensor
from torch.cuda.profiler import profile
@@ -54,10 +55,9 @@
from dreambooth.utils.model_utils import (
unload_system_models,
import_model_class_from_model_name_or_path,
- disable_safe_unpickle,
- enable_safe_unpickle,
+ safe_unpickle_disabled,
xformerify,
- torch2ify, unet_attn_processors_state_dict
+ torch2ify
)
from dreambooth.utils.text_utils import encode_hidden_state, save_token_counts
from dreambooth.utils.utils import (cleanup, printm, verify_locon_installed,
@@ -71,6 +71,14 @@
set_lora_requires_grad,
)
+try:
+ import wandb
+
+ # Disable annoying wandb popup?
+ wandb.config.auto_init = False
+except:
+ pass
+
logger = logging.getLogger(__name__)
# define a Handler which writes DEBUG messages or higher to the sys.stderr
dl.set_verbosity_error()
@@ -303,1450 +311,1485 @@ def create_vae():
if args.pretrained_vae_name_or_path
else args.get_pretrained_model_name_or_path()
)
- disable_safe_unpickle()
- new_vae = AutoencoderKL.from_pretrained(
- vae_path,
- subfolder=None if args.pretrained_vae_name_or_path else "vae",
- revision=args.revision,
- )
- enable_safe_unpickle()
+ with safe_unpickle_disabled():
+ new_vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder=None if args.pretrained_vae_name_or_path else "vae",
+ revision=args.revision,
+ )
new_vae.requires_grad_(False)
new_vae.to(accelerator.device, dtype=weight_dtype)
return new_vae
- disable_safe_unpickle()
- # Load the tokenizer
- pbar2.set_description("Loading tokenizer...")
- pbar2.update()
- pbar2.set_postfix(refresh=True)
- tokenizer = AutoTokenizer.from_pretrained(
- os.path.join(pretrained_path, "tokenizer"),
- revision=args.revision,
- use_fast=False,
- )
-
- tokenizer_two = None
- if args.model_type == "SDXL":
- pbar2.set_description("Loading tokenizer 2...")
+ with safe_unpickle_disabled():
+ # Load the tokenizer
+ pbar2.set_description("Loading tokenizer...")
pbar2.update()
pbar2.set_postfix(refresh=True)
- tokenizer_two = AutoTokenizer.from_pretrained(
- os.path.join(pretrained_path, "tokenizer_2"),
+ tokenizer = AutoTokenizer.from_pretrained(
+ os.path.join(pretrained_path, "tokenizer"),
revision=args.revision,
use_fast=False,
)
- # import correct text encoder class
- text_encoder_cls = import_model_class_from_model_name_or_path(
- args.get_pretrained_model_name_or_path(), args.revision
- )
-
- pbar2.set_description("Loading text encoder...")
- pbar2.update()
- pbar2.set_postfix(refresh=True)
- # Load models and create wrapper for stable diffusion
- text_encoder = text_encoder_cls.from_pretrained(
- args.get_pretrained_model_name_or_path(),
- subfolder="text_encoder",
- revision=args.revision,
- torch_dtype=torch.float32,
- )
+ tokenizer_two = None
+ if args.model_type == "SDXL":
+ pbar2.set_description("Loading tokenizer 2...")
+ pbar2.update()
+ pbar2.set_postfix(refresh=True)
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ os.path.join(pretrained_path, "tokenizer_2"),
+ revision=args.revision,
+ use_fast=False,
+ )
- if args.model_type == "SDXL":
# import correct text encoder class
- text_encoder_cls_two = import_model_class_from_model_name_or_path(
- args.get_pretrained_model_name_or_path(), args.revision, subfolder="text_encoder_2"
+ text_encoder_cls = import_model_class_from_model_name_or_path(
+ args.get_pretrained_model_name_or_path(), args.revision
)
- pbar2.set_description("Loading text encoder 2...")
+ pbar2.set_description("Loading text encoder...")
pbar2.update()
pbar2.set_postfix(refresh=True)
# Load models and create wrapper for stable diffusion
- text_encoder_two = text_encoder_cls_two.from_pretrained(
+ text_encoder = text_encoder_cls.from_pretrained(
args.get_pretrained_model_name_or_path(),
- subfolder="text_encoder_2",
+ subfolder="text_encoder",
revision=args.revision,
torch_dtype=torch.float32,
)
- printm("Created tenc")
- pbar2.set_description("Loading VAE...")
- pbar2.update()
- vae = create_vae()
- printm("Created vae")
-
- pbar2.set_description("Loading unet...")
- pbar2.update()
- unet = UNet2DConditionModel.from_pretrained(
- args.get_pretrained_model_name_or_path(),
- subfolder="unet",
- revision=args.revision,
- torch_dtype=torch.float32,
- )
-
- if args.attention == "xformers" and not shared.force_cpu:
- xformerify(unet, use_lora=args.use_lora)
- xformerify(vae, use_lora=args.use_lora)
-
- unet = torch2ify(unet)
-
- if args.full_mixed_precision:
- if args.mixed_precision == "fp16":
- patch_accelerator_for_fp16_training(accelerator)
- unet.to(accelerator.device, dtype=weight_dtype)
- else:
- # Check that all trainable models are in full precision
- low_precision_error_string = (
- "Please make sure to always have all model weights in full float32 precision when starting training - "
- "even if doing mixed precision training. copy of the weights should still be float32."
- )
-
- if accelerator.unwrap_model(unet).dtype != torch.float32:
- logger.warning(
- f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
+ if args.model_type == "SDXL":
+ # import correct text encoder class
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.get_pretrained_model_name_or_path(), args.revision, subfolder="text_encoder_2"
)
- if (
- args.stop_text_encoder != 0
- and accelerator.unwrap_model(text_encoder).dtype != torch.float32
- ):
- logger.warning(
- f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
- f" {low_precision_error_string}"
- )
-
- if (
- args.stop_text_encoder != 0
- and accelerator.unwrap_model(text_encoder_two).dtype != torch.float32
- ):
- logger.warning(
- f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder_two).dtype}."
- f" {low_precision_error_string}"
- )
-
- if args.gradient_checkpointing:
- if args.train_unet:
- unet.enable_gradient_checkpointing()
- if stop_text_percentage != 0:
- text_encoder.gradient_checkpointing_enable()
- if args.model_type == "SDXL":
- text_encoder_two.gradient_checkpointing_enable()
- if args.use_lora:
- # We need to enable gradients on an input for gradient checkpointing to work
- # This will not be optimized because it is not a param to optimizer
- text_encoder.text_model.embeddings.position_embedding.requires_grad_(True)
- if args.model_type == "SDXL":
- text_encoder_two.text_model.embeddings.position_embedding.requires_grad_(True)
- else:
- text_encoder.to(accelerator.device, dtype=weight_dtype)
- if args.model_type == "SDXL":
- text_encoder_two.to(accelerator.device, dtype=weight_dtype)
-
- ema_model = None
- if args.use_ema:
- if os.path.exists(
- os.path.join(
- args.get_pretrained_model_name_or_path(),
- "ema_unet",
- "diffusion_pytorch_model.safetensors",
- )
- ):
- ema_unet = UNet2DConditionModel.from_pretrained(
+ pbar2.set_description("Loading text encoder 2...")
+ pbar2.update()
+ pbar2.set_postfix(refresh=True)
+ # Load models and create wrapper for stable diffusion
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
args.get_pretrained_model_name_or_path(),
- subfolder="ema_unet",
+ subfolder="text_encoder_2",
revision=args.revision,
- torch_dtype=weight_dtype,
+ torch_dtype=torch.float32,
)
- if args.attention == "xformers" and not shared.force_cpu:
- xformerify(ema_unet, use_lora=args.use_lora)
- ema_model = EMAModel(
- ema_unet, device=accelerator.device, dtype=weight_dtype
- )
- del ema_unet
- else:
- ema_model = EMAModel(
- unet, device=accelerator.device, dtype=weight_dtype
- )
+ printm("Created tenc")
+ pbar2.set_description("Loading VAE...")
+ pbar2.update()
+ vae = create_vae()
+ printm("Created vae")
- # Create shared unet/tenc learning rate variables
-
- learning_rate = args.learning_rate
- txt_learning_rate = args.txt_learning_rate
- if args.use_lora:
- learning_rate = args.lora_learning_rate
- txt_learning_rate = args.lora_txt_learning_rate
-
- if args.use_lora or not args.train_unet:
- unet.requires_grad_(False)
-
- unet_lora_params = None
-
- if args.use_lora:
- pbar2.reset(1)
- pbar2.set_description("Loading LoRA...")
- # now we will add new LoRA weights to the attention layers
- # Set correct lora layers
- unet_lora_attn_procs = {}
- unet_lora_params = []
- rank = args.lora_unet_rank
-
- for name, attn_processor in unet.attn_processors.items():
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
- hidden_size = None
- if name.startswith("mid_block"):
- hidden_size = unet.config.block_out_channels[-1]
- elif name.startswith("up_blocks"):
- block_id = int(name[len("up_blocks.")])
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
- elif name.startswith("down_blocks"):
- block_id = int(name[len("down_blocks.")])
- hidden_size = unet.config.block_out_channels[block_id]
-
- lora_attn_processor_class = (
- LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
- )
- if hidden_size is None:
- logger.warning(f"Could not find hidden size for {name}. Skipping...")
- continue
- module = lora_attn_processor_class(
- hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
- )
- unet_lora_attn_procs[name] = module
- unet_lora_params.extend(module.parameters())
+ pbar2.set_description("Loading unet...")
+ pbar2.update()
+ unet = UNet2DConditionModel.from_pretrained(
+ args.get_pretrained_model_name_or_path(),
+ subfolder="unet",
+ revision=args.revision,
+ torch_dtype=torch.float32,
+ )
- unet.set_attn_processor(unet_lora_attn_procs)
+ if args.attention == "xformers" and not shared.force_cpu:
+ xformerify(unet, use_lora=args.use_lora)
+ xformerify(vae, use_lora=args.use_lora)
- # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
- # So, instead, we monkey-patch the forward calls of its attention-blocks.
- if stop_text_percentage != 0:
- # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
- text_encoder_lora_params = LoraLoaderMixin._modify_text_encoder(
- text_encoder, dtype=torch.float32, rank=args.lora_txt_rank
- )
-
- if args.model_type == "SDXL":
- text_encoder_lora_params_two = LoraLoaderMixin._modify_text_encoder(
- text_encoder_two, dtype=torch.float32, rank=args.lora_txt_rank
- )
- params_to_optimize = (
- itertools.chain(unet_lora_params, text_encoder_lora_params, text_encoder_lora_params_two))
- else:
- params_to_optimize = (itertools.chain(unet_lora_params, text_encoder_lora_params))
+ unet = torch2ify(unet)
+ if args.full_mixed_precision:
+ if args.mixed_precision == "fp16":
+ patch_accelerator_for_fp16_training(accelerator)
+ unet.to(accelerator.device, dtype=weight_dtype)
else:
- params_to_optimize = unet_lora_params
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ "Please make sure to always have all model weights in full float32 precision when starting training - "
+ "even if doing mixed precision training. copy of the weights should still be float32."
+ )
- # Load LoRA weights if specified
- if args.lora_model_name is not None and args.lora_model_name != "":
- logger.debug(f"Load lora from {args.lora_model_name}")
- lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(args.lora_model_name)
- LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet)
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
+ logger.warning(
+ f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
+ )
- LoraLoaderMixin.load_lora_into_text_encoder(
- lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder)
- if text_encoder_two is not None:
- LoraLoaderMixin.load_lora_into_text_encoder(
- lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two)
+ if (
+ args.stop_text_encoder != 0
+ and accelerator.unwrap_model(text_encoder).dtype != torch.float32
+ ):
+ logger.warning(
+ f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
+ f" {low_precision_error_string}"
+ )
+ if (
+ args.stop_text_encoder != 0
+ and accelerator.unwrap_model(text_encoder_two).dtype != torch.float32
+ ):
+ logger.warning(
+ f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder_two).dtype}."
+ f" {low_precision_error_string}"
+ )
- elif stop_text_percentage != 0:
- if args.train_unet:
- if args.model_type == "SDXL":
- params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters(),
- text_encoder_two.parameters())
- else:
- params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters())
- else:
- if args.model_type == "SDXL":
- params_to_optimize = itertools.chain(text_encoder.parameters(), text_encoder_two.parameters())
+ if args.gradient_checkpointing:
+ if args.train_unet:
+ unet.enable_gradient_checkpointing()
+ if stop_text_percentage != 0:
+ text_encoder.gradient_checkpointing_enable()
+ if args.model_type == "SDXL":
+ text_encoder_two.gradient_checkpointing_enable()
+ if args.use_lora:
+ # We need to enable gradients on an input for gradient checkpointing to work
+ # This will not be optimized because it is not a param to optimizer
+ text_encoder.text_model.embeddings.position_embedding.requires_grad_(True)
+ if args.model_type == "SDXL":
+ text_encoder_two.text_model.embeddings.position_embedding.requires_grad_(True)
else:
- params_to_optimize = itertools.chain(text_encoder.parameters())
- else:
- params_to_optimize = unet.parameters()
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ if args.model_type == "SDXL":
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ ema_model = None
+ if args.use_ema:
+ if os.path.exists(
+ os.path.join(
+ args.get_pretrained_model_name_or_path(),
+ "ema_unet",
+ "diffusion_pytorch_model.safetensors",
+ )
+ ):
+ ema_unet = UNet2DConditionModel.from_pretrained(
+ args.get_pretrained_model_name_or_path(),
+ subfolder="ema_unet",
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ if args.attention == "xformers" and not shared.force_cpu:
+ xformerify(ema_unet, use_lora=args.use_lora)
- optimizer = get_optimizer(args.optimizer, learning_rate, args.weight_decay, params_to_optimize)
- if len(optimizer.param_groups) > 1:
- try:
- optimizer.param_groups[1]["weight_decay"] = args.tenc_weight_decay
- optimizer.param_groups[1]["grad_clip_norm"] = args.tenc_grad_clip_norm
- except:
- logger.warning("Exception setting tenc weight decay")
- traceback.print_exc()
+ ema_model = EMAModel(
+ ema_unet, device=accelerator.device, dtype=weight_dtype
+ )
+ del ema_unet
+ else:
+ ema_model = EMAModel(
+ unet, device=accelerator.device, dtype=weight_dtype
+ )
- if len(optimizer.param_groups) > 2:
- try:
- optimizer.param_groups[2]["weight_decay"] = args.tenc_weight_decay
- optimizer.param_groups[2]["grad_clip_norm"] = args.tenc_grad_clip_norm
- except:
- logger.warning("Exception setting tenc weight decay")
- traceback.print_exc()
+ # Create shared unet/tenc learning rate variables
- noise_scheduler = get_noise_scheduler(args)
- global to_delete
- to_delete = [unet, text_encoder, text_encoder_two, tokenizer, tokenizer_two, optimizer, vae]
- def cleanup_memory():
- try:
- if unet:
- del unet
- if text_encoder:
- del text_encoder
- if text_encoder_two:
- del text_encoder_two
- if tokenizer:
- del tokenizer
- if tokenizer_two:
- del tokenizer_two
- if optimizer:
- del optimizer
- if train_dataloader:
- del train_dataloader
- if train_dataset:
- del train_dataset
- if lr_scheduler:
- del lr_scheduler
- if vae:
- del vae
- if unet_lora_params:
- del unet_lora_params
- except:
- pass
- cleanup(True)
+ learning_rate = args.learning_rate
+ txt_learning_rate = args.txt_learning_rate
+ if args.use_lora:
+ learning_rate = args.lora_learning_rate
+ txt_learning_rate = args.lora_txt_learning_rate
- if args.cache_latents:
- vae.to(accelerator.device, dtype=weight_dtype)
- vae.requires_grad_(False)
- vae.eval()
+ if args.use_lora or not args.train_unet:
+ unet.requires_grad_(False)
- if status.interrupted:
- result.msg = "Training interrupted."
- stop_profiler(profiler)
- return result
+ unet_lora_params = None
- printm("Loading dataset...")
- pbar2.reset()
- pbar2.set_description("Loading dataset")
-
- with_prior_preservation = False
- tokenizers = [tokenizer] if tokenizer_two is None else [tokenizer, tokenizer_two]
- text_encoders = [text_encoder] if text_encoder_two is None else [text_encoder, text_encoder_two]
- train_dataset = generate_dataset(
- model_name=args.model_name,
- instance_prompts=instance_prompts,
- class_prompts=class_prompts,
- batch_size=args.train_batch_size,
- tokenizer=tokenizers,
- text_encoder=text_encoders,
- accelerator=accelerator,
- vae=vae if args.cache_latents else None,
- debug=False,
- model_dir=args.model_dir,
- max_token_length=args.max_token_length,
- pbar=pbar2
- )
- if train_dataset.class_count > 0:
- with_prior_preservation = True
- pbar2.reset()
- printm("Dataset loaded.")
- tokenizer_max_length = tokenizer.model_max_length
- if args.cache_latents:
- printm("Unloading vae.")
- del vae
- # Preserve reference to vae for later checks
- vae = None
- # TODO: Try unloading tokenizers here?
- del tokenizer
- if tokenizer_two is not None:
- del tokenizer_two
- tokenizer = None
- tokenizer2 = None
+ if args.use_lora:
+ pbar2.reset(1)
+ pbar2.set_description("Loading LoRA...")
+ # now we will add new LoRA weights to the attention layers
+ # Set correct lora layers
+ unet_lora_attn_procs = {}
+ unet_lora_params = []
+ rank = args.lora_unet_rank
+
+ for name, attn_processor in unet.attn_processors.items():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ hidden_size = None
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+
+ lora_attn_processor_class = (
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
+ )
+ if hidden_size is None:
+ logger.warning(f"Could not find hidden size for {name}. Skipping...")
+ continue
+ module = lora_attn_processor_class(
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
+ )
+ unet_lora_attn_procs[name] = module
+ unet_lora_params.extend(module.parameters())
- if status.interrupted:
- result.msg = "Training interrupted."
- stop_profiler(profiler)
- return result
+ unet.set_attn_processor(unet_lora_attn_procs)
- if train_dataset.__len__ == 0:
- msg = "Please provide a directory with actual images in it."
- logger.warning(msg)
- status.textinfo = msg
- update_status({"status": status})
- cleanup_memory()
- result.msg = msg
- result.config = args
- stop_profiler(profiler)
- return result
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
+ if stop_text_percentage != 0:
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
+ text_encoder_lora_params = LoraLoaderMixin._modify_text_encoder(
+ text_encoder, dtype=torch.float32, rank=args.lora_txt_rank
+ )
- def collate_fn_db(examples):
- input_ids = [example["input_ids"] for example in examples]
- pixel_values = [example["image"] for example in examples]
- types = [example["is_class"] for example in examples]
- weights = [
- current_prior_loss_weight if example["is_class"] else 1.0
- for example in examples
- ]
- loss_avg = 0
- for weight in weights:
- loss_avg += weight
- loss_avg /= len(weights)
- pixel_values = torch.stack(pixel_values)
- if not args.cache_latents:
- pixel_values = pixel_values.to(
- memory_format=torch.contiguous_format
- ).float()
- input_ids = torch.cat(input_ids, dim=0)
-
- batch_data = {
- "input_ids": input_ids,
- "images": pixel_values,
- "types": types,
- "loss_avg": loss_avg,
- }
- if "input_ids2" in examples[0]:
- input_ids_2 = [example["input_ids2"] for example in examples]
- input_ids_2 = torch.stack(input_ids_2)
-
- batch_data["input_ids2"] = input_ids_2
- batch_data["original_sizes_hw"] = torch.stack(
- [torch.LongTensor(x["original_sizes_hw"]) for x in examples])
- batch_data["crop_top_lefts"] = torch.stack([torch.LongTensor(x["crop_top_lefts"]) for x in examples])
- batch_data["target_sizes_hw"] = torch.stack([torch.LongTensor(x["target_sizes_hw"]) for x in examples])
- return batch_data
-
- def collate_fn_sdxl(examples):
- input_ids = [example["input_ids"] for example in examples if not example["is_class"]]
- pixel_values = [example["image"] for example in examples if not example["is_class"]]
- add_text_embeds = [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples if
- not example["is_class"]]
- add_time_ids = [example["instance_added_cond_kwargs"]["time_ids"] for example in examples if
- not example["is_class"]]
-
- # Concat class and instance examples for prior preservation.
- # We do this to avoid doing two forward passes.
- if with_prior_preservation:
- input_ids += [example["input_ids"] for example in examples if example["is_class"]]
- pixel_values += [example["image"] for example in examples if example["is_class"]]
- add_text_embeds += [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples if
- example["is_class"]]
- add_time_ids += [example["instance_added_cond_kwargs"]["time_ids"] for example in examples if
- example["is_class"]]
-
- pixel_values = torch.stack(pixel_values)
- pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
-
- input_ids = torch.cat(input_ids, dim=0)
- add_text_embeds = torch.cat(add_text_embeds, dim=0)
- add_time_ids = torch.cat(add_time_ids, dim=0)
-
- batch = {
- "input_ids": input_ids,
- "images": pixel_values,
- "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids},
- }
+ if args.model_type == "SDXL":
+ text_encoder_lora_params_two = LoraLoaderMixin._modify_text_encoder(
+ text_encoder_two, dtype=torch.float32, rank=args.lora_txt_rank
+ )
+ params_to_optimize = (
+ itertools.chain(unet_lora_params, text_encoder_lora_params, text_encoder_lora_params_two))
+ else:
+ params_to_optimize = (itertools.chain(unet_lora_params, text_encoder_lora_params))
- return batch
+ else:
+ params_to_optimize = unet_lora_params
- sampler = BucketSampler(train_dataset, train_batch_size)
+ # Load LoRA weights if specified
+ if args.lora_model_name is not None and args.lora_model_name != "":
+ logger.debug(f"Load lora from {args.lora_model_name}")
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(args.lora_model_name)
+ LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet)
- collate_fn = collate_fn_db
- if args.model_type == "SDXL":
- collate_fn = collate_fn_sdxl
- train_dataloader = torch.utils.data.DataLoader(
- train_dataset,
- batch_size=1,
- batch_sampler=sampler,
- collate_fn=collate_fn,
- num_workers=n_workers,
- )
+ LoraLoaderMixin.load_lora_into_text_encoder(
+ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder)
+ if text_encoder_two is not None:
+ LoraLoaderMixin.load_lora_into_text_encoder(
+ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two)
- max_train_steps = args.num_train_epochs * len(train_dataset)
-
- # This is separate, because optimizer.step is only called once per "step" in training, so it's not
- # affected by batch size
- sched_train_steps = args.num_train_epochs * train_dataset.num_train_images
-
- lr_scale_pos = args.lr_scale_pos
- if class_prompts:
- lr_scale_pos *= 2
-
- lr_scheduler = UniversalScheduler(
- name=args.lr_scheduler,
- optimizer=optimizer,
- num_warmup_steps=args.lr_warmup_steps,
- total_training_steps=sched_train_steps,
- min_lr=args.learning_rate_min,
- total_epochs=args.num_train_epochs,
- num_cycles=args.lr_cycles,
- power=args.lr_power,
- factor=args.lr_factor,
- scale_pos=lr_scale_pos,
- unet_lr=learning_rate,
- tenc_lr=txt_learning_rate,
- )
- # create ema, fix OOM
- if args.use_ema:
- if stop_text_percentage != 0:
- (
- ema_model.model,
- unet,
- text_encoder,
- optimizer,
- train_dataloader,
- lr_scheduler,
- ) = accelerator.prepare(
- ema_model.model,
- unet,
- text_encoder,
- optimizer,
- train_dataloader,
- lr_scheduler,
- )
- else:
- (
- ema_model.model,
- unet,
- optimizer,
- train_dataloader,
- lr_scheduler,
- ) = accelerator.prepare(
- ema_model.model, unet, optimizer, train_dataloader, lr_scheduler
- )
- else:
- if stop_text_percentage != 0:
- (
- unet,
- text_encoder,
- optimizer,
- train_dataloader,
- lr_scheduler,
- ) = accelerator.prepare(
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler
- )
+ elif stop_text_percentage != 0:
+ if args.train_unet:
+ if args.model_type == "SDXL":
+ params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters(),
+ text_encoder_two.parameters())
+ else:
+ params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters())
+ else:
+ if args.model_type == "SDXL":
+ params_to_optimize = itertools.chain(text_encoder.parameters(), text_encoder_two.parameters())
+ else:
+ params_to_optimize = itertools.chain(text_encoder.parameters())
else:
- unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- unet, optimizer, train_dataloader, lr_scheduler
- )
+ params_to_optimize = unet.parameters()
- if not args.cache_latents and vae is not None:
- vae.to(accelerator.device, dtype=weight_dtype)
+ optimizer = get_optimizer(args.optimizer, learning_rate, args.weight_decay, params_to_optimize)
+ if len(optimizer.param_groups) > 1:
+ try:
+ optimizer.param_groups[1]["weight_decay"] = args.tenc_weight_decay
+ optimizer.param_groups[1]["grad_clip_norm"] = args.tenc_grad_clip_norm
+ except:
+ logger.warning("Exception setting tenc weight decay")
+ traceback.print_exc()
- if stop_text_percentage == 0:
- text_encoder.to(accelerator.device, dtype=weight_dtype)
- # Afterwards we recalculate our number of training epochs
- # We need to initialize the trackers we use, and also store our configuration.
- # The trackers will initialize automatically on the main process.
- if accelerator.is_main_process:
- accelerator.init_trackers("dreambooth")
+ if len(optimizer.param_groups) > 2:
+ try:
+ optimizer.param_groups[2]["weight_decay"] = args.tenc_weight_decay
+ optimizer.param_groups[2]["grad_clip_norm"] = args.tenc_grad_clip_norm
+ except:
+ logger.warning("Exception setting tenc weight decay")
+ traceback.print_exc()
- # Train!
- total_batch_size = (
- train_batch_size * accelerator.num_processes * gradient_accumulation_steps
- )
- max_train_epochs = args.num_train_epochs
- # we calculate our number of tenc training epochs
- text_encoder_epochs = round(max_train_epochs * stop_text_percentage)
- global_step = 0
- global_epoch = 0
- session_epoch = 0
- first_epoch = 0
- resume_step = 0
- last_model_save = 0
- last_image_save = 0
- resume_from_checkpoint = False
- new_hotness = os.path.join(
- args.model_dir, "checkpoints", f"checkpoint-{args.snapshot}"
- )
- if os.path.exists(new_hotness):
- logger.debug(f"Resuming from checkpoint {new_hotness}")
+ noise_scheduler = get_noise_scheduler(args)
+ global to_delete
+ to_delete = [unet, text_encoder, text_encoder_two, tokenizer, tokenizer_two, optimizer, vae]
+ def cleanup_memory():
+ try:
+ if unet:
+ del unet
+ if text_encoder:
+ del text_encoder
+ if text_encoder_two:
+ del text_encoder_two
+ if tokenizer:
+ del tokenizer
+ if tokenizer_two:
+ del tokenizer_two
+ if optimizer:
+ del optimizer
+ if train_dataloader:
+ del train_dataloader
+ if train_dataset:
+ del train_dataset
+ if lr_scheduler:
+ del lr_scheduler
+ if vae:
+ del vae
+ if unet_lora_params:
+ del unet_lora_params
+ except:
+ pass
+ cleanup(True)
- try:
- import modules.shared
- no_safe = modules.shared.cmd_opts.disable_safe_unpickle
- modules.shared.cmd_opts.disable_safe_unpickle = True
- except:
- no_safe = False
+ if args.cache_latents:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ vae.requires_grad_(False)
+ vae.eval()
- try:
- import modules.shared
- accelerator.load_state(new_hotness)
- modules.shared.cmd_opts.disable_safe_unpickle = no_safe
- global_step = resume_step = args.revision
- resume_from_checkpoint = True
- first_epoch = args.lifetime_epoch
- global_epoch = args.lifetime_epoch
- except Exception as lex:
- logger.warning(f"Exception loading checkpoint: {lex}")
- logger.debug(" ***** Running training *****")
- if shared.force_cpu:
- logger.debug(f" TRAINING WITH CPU ONLY")
- logger.debug(f" Num batches each epoch = {len(train_dataset) // train_batch_size}")
- logger.debug(f" Num Epochs = {max_train_epochs}")
- logger.debug(f" Batch Size Per Device = {train_batch_size}")
- logger.debug(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
- logger.debug(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
- logger.debug(f" Text Encoder Epochs: {text_encoder_epochs}")
- logger.debug(f" Total optimization steps = {sched_train_steps}")
- logger.debug(f" Total training steps = {max_train_steps}")
- logger.debug(f" Resuming from checkpoint: {resume_from_checkpoint}")
- logger.debug(f" First resume epoch: {first_epoch}")
- logger.debug(f" First resume step: {resume_step}")
- logger.debug(f" Lora: {args.use_lora}, Optimizer: {args.optimizer}, Prec: {precision}")
- logger.debug(f" Gradient Checkpointing: {args.gradient_checkpointing}")
- logger.debug(f" EMA: {args.use_ema}")
- logger.debug(f" UNET: {args.train_unet}")
- logger.debug(f" Freeze CLIP Normalization Layers: {args.freeze_clip_normalization}")
- logger.debug(f" LR{' (Lora)' if args.use_lora else ''}: {learning_rate}")
- if stop_text_percentage > 0:
- logger.debug(f" Tenc LR{' (Lora)' if args.use_lora else ''}: {txt_learning_rate}")
- logger.debug(f" V2: {args.v2}")
-
- os.environ.__setattr__("CUDA_LAUNCH_BLOCKING", 1)
-
- def check_save(is_epoch_check=False):
- nonlocal last_model_save
- nonlocal last_image_save
- save_model_interval = args.save_embedding_every
- save_image_interval = args.save_preview_every
- save_completed = session_epoch >= max_train_epochs
- save_canceled = status.interrupted
- save_image = False
- save_model = False
- save_lora = False
-
- if not save_canceled and not save_completed:
- # Check to see if the number of epochs since last save is gt the interval
- if 0 < save_model_interval <= session_epoch - last_model_save:
- save_model = True
- if args.use_lora:
- save_lora = True
- last_model_save = session_epoch
+ if status.interrupted:
+ result.msg = "Training interrupted."
+ stop_profiler(profiler)
+ return result
- # Repeat for sample images
- if 0 < save_image_interval <= session_epoch - last_image_save:
- save_image = True
- last_image_save = session_epoch
+ printm("Loading dataset...")
+ pbar2.reset()
+ pbar2.set_description("Loading dataset")
+
+ with_prior_preservation = False
+ tokenizers = [tokenizer] if tokenizer_two is None else [tokenizer, tokenizer_two]
+ text_encoders = [text_encoder] if text_encoder_two is None else [text_encoder, text_encoder_two]
+ train_dataset = generate_dataset(
+ model_name=args.model_name,
+ instance_prompts=instance_prompts,
+ class_prompts=class_prompts,
+ batch_size=args.train_batch_size,
+ tokenizer=tokenizers,
+ text_encoder=text_encoders,
+ accelerator=accelerator,
+ vae=vae if args.cache_latents else None,
+ debug=False,
+ model_dir=args.model_dir,
+ max_token_length=args.max_token_length,
+ pbar=pbar2
+ )
+ if train_dataset.class_count > 0:
+ with_prior_preservation = True
+ pbar2.reset()
+ printm("Dataset loaded.")
+ tokenizer_max_length = tokenizer.model_max_length
+ if args.cache_latents:
+ printm("Unloading vae.")
+ del vae
+ # Preserve reference to vae for later checks
+ vae = None
+ # TODO: Try unloading tokenizers here?
+ del tokenizer
+ if tokenizer_two is not None:
+ del tokenizer_two
+ tokenizer = None
+ tokenizer2 = None
+
+ if status.interrupted:
+ result.msg = "Training interrupted."
+ stop_profiler(profiler)
+ return result
+
+ if train_dataset.__len__ == 0:
+ msg = "Please provide a directory with actual images in it."
+ logger.warning(msg)
+ status.textinfo = msg
+ update_status({"status": status})
+ cleanup_memory()
+ result.msg = msg
+ result.config = args
+ stop_profiler(profiler)
+ return result
+
+ def collate_fn_db(examples):
+ input_ids = [example["input_ids"] for example in examples]
+ pixel_values = [example["image"] for example in examples]
+ types = [example["is_class"] for example in examples]
+ weights = [
+ current_prior_loss_weight if example["is_class"] else 1.0
+ for example in examples
+ ]
+ loss_avg = 0
+ for weight in weights:
+ loss_avg += weight
+ loss_avg /= len(weights)
+ pixel_values = torch.stack(pixel_values)
+ if not args.cache_latents:
+ pixel_values = pixel_values.to(
+ memory_format=torch.contiguous_format
+ ).float()
+ input_ids = torch.cat(input_ids, dim=0)
+
+ batch_data = {
+ "input_ids": input_ids,
+ "images": pixel_values,
+ "types": types,
+ "loss_avg": loss_avg,
+ }
+ if "input_ids2" in examples[0]:
+ input_ids_2 = [example["input_ids2"] for example in examples]
+ input_ids_2 = torch.stack(input_ids_2)
+
+ batch_data["input_ids2"] = input_ids_2
+ batch_data["original_sizes_hw"] = torch.stack(
+ [torch.LongTensor(x["original_sizes_hw"]) for x in examples])
+ batch_data["crop_top_lefts"] = torch.stack([torch.LongTensor(x["crop_top_lefts"]) for x in examples])
+ batch_data["target_sizes_hw"] = torch.stack([torch.LongTensor(x["target_sizes_hw"]) for x in examples])
+ return batch_data
+
+ def collate_fn_sdxl(examples):
+ input_ids = [example["input_ids"] for example in examples if not example["is_class"]]
+ pixel_values = [example["image"] for example in examples if not example["is_class"]]
+ add_text_embeds = [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples if
+ not example["is_class"]]
+ add_time_ids = [example["instance_added_cond_kwargs"]["time_ids"] for example in examples if
+ not example["is_class"]]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ input_ids += [example["input_ids"] for example in examples if example["is_class"]]
+ pixel_values += [example["image"] for example in examples if example["is_class"]]
+ add_text_embeds += [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples if
+ example["is_class"]]
+ add_time_ids += [example["instance_added_cond_kwargs"]["time_ids"] for example in examples if
+ example["is_class"]]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.cat(input_ids, dim=0)
+ add_text_embeds = torch.cat(add_text_embeds, dim=0)
+ add_time_ids = torch.cat(add_time_ids, dim=0)
+
+ batch = {
+ "input_ids": input_ids,
+ "images": pixel_values,
+ "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids},
+ }
- else:
- logger.debug("\nSave completed/canceled.")
- if global_step > 0:
- save_image = True
- save_model = True
- save_lora = True
+ return batch
- save_snapshot = False
+ sampler = BucketSampler(train_dataset, train_batch_size)
- if is_epoch_check:
- if shared.status.do_save_samples:
- save_image = True
- shared.status.do_save_samples = False
+ collate_fn = collate_fn_db
+ if args.model_type == "SDXL":
+ collate_fn = collate_fn_sdxl
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=1,
+ batch_sampler=sampler,
+ collate_fn=collate_fn,
+ num_workers=n_workers,
+ )
- if shared.status.do_save_model:
- if args.use_lora:
- save_lora = True
- save_model = True
- shared.status.do_save_model = False
+ max_train_steps = args.num_train_epochs * len(train_dataset)
+
+ # This is separate, because optimizer.step is only called once per "step" in training, so it's not
+ # affected by batch size
+ sched_train_steps = args.num_train_epochs * train_dataset.num_train_images
+
+ lr_scale_pos = args.lr_scale_pos
+ if class_prompts:
+ lr_scale_pos *= 2
+
+ lr_scheduler = UniversalScheduler(
+ name=args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ total_training_steps=sched_train_steps,
+ min_lr=args.learning_rate_min,
+ total_epochs=args.num_train_epochs,
+ num_cycles=args.lr_cycles,
+ power=args.lr_power,
+ factor=args.lr_factor,
+ scale_pos=lr_scale_pos,
+ unet_lr=learning_rate,
+ tenc_lr=txt_learning_rate,
+ )
- save_checkpoint = False
- if save_model:
- if save_canceled:
- if global_step > 0:
- logger.debug("Canceled, enabling saves.")
- save_snapshot = args.save_state_cancel
- save_checkpoint = args.save_ckpt_cancel
- elif save_completed:
- if global_step > 0:
- logger.debug("Completed, enabling saves.")
- save_snapshot = args.save_state_after
- save_checkpoint = args.save_ckpt_after
+ # create ema, fix OOM
+ if args.use_ema:
+ if stop_text_percentage != 0:
+ (
+ ema_model.model,
+ unet,
+ text_encoder,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ ) = accelerator.prepare(
+ ema_model.model,
+ unet,
+ text_encoder,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ )
else:
- save_snapshot = args.save_state_during
- save_checkpoint = args.save_ckpt_during
- if save_checkpoint and args.use_lora:
- save_checkpoint = False
- save_lora = True
-
- if (
- save_checkpoint
- or save_snapshot
- or save_lora
- or save_image
- or save_model
- ):
- save_weights(
- save_image,
- save_model,
- save_snapshot,
- save_checkpoint,
- save_lora
- )
+ (
+ ema_model.model,
+ unet,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ ) = accelerator.prepare(
+ ema_model.model, unet, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ if stop_text_percentage != 0:
+ (
+ unet,
+ text_encoder,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ ) = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
- return save_model, save_image
+ if not args.cache_latents and vae is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
- def save_weights(
- save_image, save_diffusers, save_snapshot, save_checkpoint, save_lora
- ):
- global last_samples
- global last_prompts
- nonlocal vae
- nonlocal pbar2
+ if stop_text_percentage == 0:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ # Afterwards we recalculate our number of training epochs
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers will initialize automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("dreambooth")
- printm(" Saving weights.")
- pbar2.reset()
- pbar2.set_description("Saving weights/samples...")
- pbar2.set_postfix(refresh=True)
+ # Train!
+ total_batch_size = (
+ train_batch_size * accelerator.num_processes * gradient_accumulation_steps
+ )
+ max_train_epochs = args.num_train_epochs
+ # we calculate our number of tenc training epochs
+ text_encoder_epochs = round(max_train_epochs * stop_text_percentage)
+ global_step = 0
+ global_epoch = 0
+ session_epoch = 0
+ first_epoch = 0
+ resume_step = 0
+ last_model_save = 0
+ last_image_save = 0
+ resume_from_checkpoint = False
+ new_hotness = os.path.join(
+ args.model_dir, "checkpoints", f"checkpoint-{args.snapshot}"
+ )
+ if os.path.exists(new_hotness):
+ logger.debug(f"Resuming from checkpoint {new_hotness}")
- # Create the pipeline using the trained modules and save it.
- if accelerator.is_main_process:
- printm("Pre-cleanup.")
- torch_rng_state = None
- cuda_gpu_rng_state = None
- cuda_cpu_rng_state = None
- # Save random states so sample generation doesn't impact training.
- if shared.device.type == 'cuda':
- torch_rng_state = torch.get_rng_state()
- cuda_gpu_rng_state = torch.cuda.get_rng_state(device="cuda")
- cuda_cpu_rng_state = torch.cuda.get_rng_state(device="cpu")
-
- optim_to(profiler, optimizer)
-
- if profiler is None:
- cleanup()
+ try:
+ import modules.shared
+ no_safe = modules.shared.cmd_opts.disable_safe_unpickle
+ modules.shared.cmd_opts.disable_safe_unpickle = True
+ except:
+ no_safe = False
- if vae is None:
- printm("Loading vae.")
- vae = create_vae()
+ try:
+ import modules.shared
+ accelerator.load_state(new_hotness)
+ modules.shared.cmd_opts.disable_safe_unpickle = no_safe
+ global_step = resume_step = args.revision
+ resume_from_checkpoint = True
+ first_epoch = args.lifetime_epoch
+ global_epoch = args.lifetime_epoch
+ except Exception as lex:
+ logger.warning(f"Exception loading checkpoint: {lex}")
+ logger.debug(" ***** Running training *****")
+ if shared.force_cpu:
+ logger.debug(f" TRAINING WITH CPU ONLY")
+ logger.debug(f" Num batches each epoch = {len(train_dataset) // train_batch_size}")
+ logger.debug(f" Num Epochs = {max_train_epochs}")
+ logger.debug(f" Batch Size Per Device = {train_batch_size}")
+ logger.debug(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
+ logger.debug(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.debug(f" Text Encoder Epochs: {text_encoder_epochs}")
+ logger.debug(f" Total optimization steps = {sched_train_steps}")
+ logger.debug(f" Total training steps = {max_train_steps}")
+ logger.debug(f" Resuming from checkpoint: {resume_from_checkpoint}")
+ logger.debug(f" First resume epoch: {first_epoch}")
+ logger.debug(f" First resume step: {resume_step}")
+ logger.debug(f" Lora: {args.use_lora}, Optimizer: {args.optimizer}, Prec: {precision}")
+ logger.debug(f" Gradient Checkpointing: {args.gradient_checkpointing}")
+ logger.debug(f" EMA: {args.use_ema}")
+ logger.debug(f" UNET: {args.train_unet}")
+ logger.debug(f" Freeze CLIP Normalization Layers: {args.freeze_clip_normalization}")
+ logger.debug(f" LR{' (Lora)' if args.use_lora else ''}: {learning_rate}")
+ if stop_text_percentage > 0:
+ logger.debug(f" Tenc LR{' (Lora)' if args.use_lora else ''}: {txt_learning_rate}")
+ logger.debug(f" V2: {args.v2}")
+
+ os.environ.__setattr__("CUDA_LAUNCH_BLOCKING", 1)
+
+ def check_save(is_epoch_check=False):
+ nonlocal last_model_save
+ nonlocal last_image_save
+ save_model_interval = args.save_embedding_every
+ save_image_interval = args.save_preview_every
+ save_completed = session_epoch >= max_train_epochs
+ save_canceled = status.interrupted
+ save_image = False
+ save_model = False
+ save_lora = False
+
+ if not save_canceled and not save_completed:
+ # Check to see if the number of epochs since last save is gt the interval
+ if 0 < save_model_interval <= session_epoch - last_model_save:
+ save_model = True
+ if args.use_lora:
+ save_lora = True
+ last_model_save = session_epoch
+
+ # Repeat for sample images
+ if 0 < save_image_interval <= session_epoch - last_image_save:
+ save_image = True
+ last_image_save = session_epoch
- printm("Creating pipeline.")
- if args.model_type == "SDXL":
- s_pipeline = StableDiffusionXLPipeline.from_pretrained(
- args.get_pretrained_model_name_or_path(),
- unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True),
- text_encoder=accelerator.unwrap_model(
- text_encoder, keep_fp32_wrapper=True
- ),
- text_encoder_2=accelerator.unwrap_model(
- text_encoder_two, keep_fp32_wrapper=True
- ),
- vae=vae.to(accelerator.device),
- torch_dtype=weight_dtype,
- revision=args.revision,
- )
- xformerify(s_pipeline.unet,use_lora=args.use_lora)
else:
- s_pipeline = DiffusionPipeline.from_pretrained(
- args.get_pretrained_model_name_or_path(),
- unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True),
- text_encoder=accelerator.unwrap_model(
- text_encoder, keep_fp32_wrapper=True
- ),
- vae=vae,
- torch_dtype=weight_dtype,
- revision=args.revision,
- )
- xformerify(s_pipeline.unet,use_lora=args.use_lora)
- xformerify(s_pipeline.vae,use_lora=args.use_lora)
+ logger.debug("\nSave completed/canceled.")
+ if global_step > 0:
+ save_image = True
+ save_model = True
+ if args.use_lora:
+ save_lora = True
- weights_dir = args.get_pretrained_model_name_or_path()
+ save_snapshot = False
- if user_model_dir != "":
- loras_dir = os.path.join(user_model_dir, "Lora")
- else:
- model_dir = shared.models_path
- loras_dir = os.path.join(model_dir, "Lora")
- delete_tmp_lora = False
- # Update the temp path if we just need to save an image
- if save_image:
- logger.debug("Save image is set.")
- if args.use_lora:
- if not save_lora:
- logger.debug("Saving lora weights instead of checkpoint, using temp dir.")
+ if is_epoch_check:
+ if shared.status.do_save_samples:
+ save_image = True
+ shared.status.do_save_samples = False
+
+ if shared.status.do_save_model:
+ if args.use_lora:
save_lora = True
- delete_tmp_lora = True
- save_checkpoint = False
- save_diffusers = False
- os.makedirs(loras_dir, exist_ok=True)
- elif not save_diffusers:
- logger.debug("Saving checkpoint, using temp dir.")
- save_diffusers = True
- weights_dir = f"{weights_dir}_temp"
- os.makedirs(weights_dir, exist_ok=True)
+ save_model = True
+ shared.status.do_save_model = False
+
+ save_checkpoint = False
+ if save_model:
+ if save_canceled:
+ if global_step > 0:
+ logger.debug("Canceled, enabling saves.")
+ save_snapshot = args.save_state_cancel
+ save_checkpoint = args.save_ckpt_cancel
+ elif save_completed:
+ if global_step > 0:
+ logger.debug("Completed, enabling saves.")
+ save_snapshot = args.save_state_after
+ save_checkpoint = args.save_ckpt_after
else:
- logger.debug(f"Save checkpoint: {save_checkpoint} save lora {save_lora}.")
- # Is inference_mode() needed here to prevent issues when saving?
- logger.debug(f"Loras dir: {loras_dir}")
+ save_snapshot = args.save_state_during
+ save_checkpoint = args.save_ckpt_during
+ if save_checkpoint and args.use_lora:
+ save_checkpoint = False
+ save_lora = True
+ if not args.use_lora:
+ save_lora = False
- # setup pt path
- if args.custom_model_name == "":
- lora_model_name = args.model_name
- else:
- lora_model_name = args.custom_model_name
-
- lora_save_file = os.path.join(loras_dir, f"{lora_model_name}_{args.revision}.safetensors")
-
- with accelerator.autocast(), torch.inference_mode():
-
- def lora_save_function(weights, filename):
- metadata = args.export_ss_metadata()
- logger.debug(f"Saving lora to {filename}")
- safetensors.torch.save_file(weights, filename, metadata=metadata)
-
- if save_lora:
- # TODO: Add a version for the lora model?
- pbar2.reset(1)
- pbar2.set_description("Saving Lora Weights...")
- # setup directory
- logger.debug(f"Saving lora to {lora_save_file}")
- unet_lora_layers_to_save = unet_attn_processors_state_dict(unet)
- text_encoder_one_lora_layers_to_save = None
- text_encoder_two_lora_layers_to_save = None
- if args.stop_text_encoder != 0:
- text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(text_encoder)
- if args.model_type == "SDXL":
- if args.stop_text_encoder != 0:
- text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(text_encoder_two)
- StableDiffusionXLPipeline.save_lora_weights(
- loras_dir,
- unet_lora_layers=unet_lora_layers_to_save,
- text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
- text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
- weight_name=lora_save_file,
- safe_serialization=True,
- save_function=lora_save_function
- )
- scheduler_args = {}
+ if (
+ save_checkpoint
+ or save_snapshot
+ or save_lora
+ or save_image
+ or save_model
+ ):
+ save_weights(
+ save_image,
+ save_model,
+ save_snapshot,
+ save_checkpoint,
+ save_lora
+ )
- if "variance_type" in s_pipeline.scheduler.config:
- variance_type = s_pipeline.scheduler.config.variance_type
+ return save_model, save_image
- if variance_type in ["learned", "learned_range"]:
- variance_type = "fixed_small"
+ def save_weights(
+ save_image, save_diffusers, save_snapshot, save_checkpoint, save_lora
+ ):
+ global last_samples
+ global last_prompts
+ nonlocal vae
+ nonlocal pbar2
- scheduler_args["variance_type"] = variance_type
+ printm(" Saving weights.")
+ pbar2.reset()
+ pbar2.set_description("Saving weights/samples...")
+ pbar2.set_postfix(refresh=True)
+
+ # Create the pipeline using the trained modules and save it.
+ if accelerator.is_main_process:
+ printm("Pre-cleanup.")
+ torch_rng_state = None
+ cuda_gpu_rng_state = None
+ cuda_cpu_rng_state = None
+ # Save random states so sample generation doesn't impact training.
+ if shared.device.type == 'cuda':
+ torch_rng_state = torch.get_rng_state()
+ cuda_gpu_rng_state = torch.cuda.get_rng_state(device="cuda")
+ cuda_cpu_rng_state = torch.cuda.get_rng_state(device="cpu")
+
+ optim_to(profiler, optimizer)
+
+ if profiler is None:
+ cleanup()
+
+ if vae is None:
+ printm("Loading vae.")
+ vae = create_vae()
+
+ printm("Creating pipeline.")
+ if args.model_type == "SDXL":
+ s_pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.get_pretrained_model_name_or_path(),
+ unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True),
+ text_encoder=accelerator.unwrap_model(
+ text_encoder, keep_fp32_wrapper=True
+ ),
+ text_encoder_2=accelerator.unwrap_model(
+ text_encoder_two, keep_fp32_wrapper=True
+ ),
+ vae=vae.to(accelerator.device),
+ torch_dtype=weight_dtype,
+ revision=args.revision,
+ )
+ xformerify(s_pipeline.unet,use_lora=args.use_lora)
+ else:
+ s_pipeline = DiffusionPipeline.from_pretrained(
+ args.get_pretrained_model_name_or_path(),
+ unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True),
+ text_encoder=accelerator.unwrap_model(
+ text_encoder, keep_fp32_wrapper=True
+ ),
+ vae=vae,
+ torch_dtype=weight_dtype,
+ revision=args.revision,
+ )
+ xformerify(s_pipeline.unet,use_lora=args.use_lora)
+ xformerify(s_pipeline.vae,use_lora=args.use_lora)
- s_pipeline.scheduler = UniPCMultistepScheduler.from_config(s_pipeline.scheduler.config, **scheduler_args)
- save_lora = False
- save_model = False
+ weights_dir = args.get_pretrained_model_name_or_path()
+
+ if user_model_dir != "":
+ loras_dir = os.path.join(user_model_dir, "Lora")
+ else:
+ model_dir = shared.models_path
+ loras_dir = os.path.join(model_dir, "Lora")
+ delete_tmp_lora = False
+ # Update the temp path if we just need to save an image
+ if save_image:
+ logger.debug("Save image is set.")
+ if args.use_lora:
+ if not save_lora:
+ logger.debug("Saving lora weights instead of checkpoint, using temp dir.")
+ save_lora = True
+ delete_tmp_lora = True
+ save_checkpoint = False
+ save_diffusers = False
+ os.makedirs(loras_dir, exist_ok=True)
+ elif not save_diffusers:
+ logger.debug("Saving checkpoint, using temp dir.")
+ save_diffusers = True
+ weights_dir = f"{weights_dir}_temp"
+ os.makedirs(weights_dir, exist_ok=True)
else:
- StableDiffusionPipeline.save_lora_weights(
- loras_dir,
- unet_lora_layers=unet_lora_layers_to_save,
- text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
- weight_name=lora_save_file,
- safe_serialization=True
- )
- s_pipeline.scheduler = get_scheduler_class("UniPCMultistep").from_config(
- s_pipeline.scheduler.config)
- s_pipeline.scheduler.config.solver_type = "bh2"
- save_lora = False
- save_model = False
-
- elif save_diffusers:
- # We are saving weights, we need to ensure revision is saved
- if "_tmp" not in weights_dir:
- args.save()
- try:
- out_file = None
- status.textinfo = (
- f"Saving diffusion model at step {args.revision}..."
- )
- update_status({"status": status.textinfo})
+ save_lora = False
+ logger.debug(f"Save checkpoint: {save_checkpoint} save lora {save_lora}.")
+ # Is inference_mode() needed here to prevent issues when saving?
+ logger.debug(f"Loras dir: {loras_dir}")
+
+ # setup pt path
+ if args.custom_model_name == "":
+ lora_model_name = args.model_name
+ else:
+ lora_model_name = args.custom_model_name
+
+ lora_save_file = os.path.join(loras_dir, f"{lora_model_name}_{args.revision}.safetensors")
+
+ with accelerator.autocast(), torch.inference_mode():
+
+ def lora_save_function(weights, filename):
+ metadata = args.export_ss_metadata()
+ logger.debug(f"Saving lora to {filename}")
+ safetensors.torch.save_file(weights, filename, metadata=metadata)
+
+ if save_lora:
+ # TODO: Add a version for the lora model?
pbar2.reset(1)
+ pbar2.set_description("Saving Lora Weights...")
+ # setup directory
+ logger.debug(f"Saving lora to {lora_save_file}")
+ unet_lora_layers_to_save = unet_lora_state_dict(unet)
+ text_encoder_one_lora_layers_to_save = None
+ text_encoder_two_lora_layers_to_save = None
+ if args.stop_text_encoder != 0:
+ text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(text_encoder)
+ if args.model_type == "SDXL":
+ if args.stop_text_encoder != 0:
+ text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(text_encoder_two)
+ StableDiffusionXLPipeline.save_lora_weights(
+ loras_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
+ weight_name=lora_save_file,
+ safe_serialization=True,
+ save_function=lora_save_function
+ )
+ scheduler_args = {}
- pbar2.set_description("Saving diffusion model")
- s_pipeline.save_pretrained(
- weights_dir,
- safe_serialization=False,
- )
- if ema_model is not None:
- ema_model.save_pretrained(
- os.path.join(
- weights_dir,
- "ema_unet",
- ),
- safe_serialization=False,
+ if "variance_type" in s_pipeline.scheduler.config:
+ variance_type = s_pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ s_pipeline.scheduler = UniPCMultistepScheduler.from_config(s_pipeline.scheduler.config, **scheduler_args)
+ save_lora = False
+ save_model = False
+ else:
+ StableDiffusionPipeline.save_lora_weights(
+ loras_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ weight_name=lora_save_file,
+ safe_serialization=True
)
- pbar2.update()
+ s_pipeline.scheduler = get_scheduler_class("UniPCMultistep").from_config(
+ s_pipeline.scheduler.config)
+ s_pipeline.scheduler.config.solver_type = "bh2"
+ save_lora = False
+ save_model = False
- if save_snapshot:
- pbar2.reset(1)
- pbar2.set_description("Saving Snapshot")
+ elif save_diffusers:
+ # We are saving weights, we need to ensure revision is saved
+ if "_tmp" not in weights_dir:
+ args.save()
+ try:
+ out_file = None
status.textinfo = (
- f"Saving snapshot at step {args.revision}..."
+ f"Saving diffusion model at step {args.revision}..."
)
update_status({"status": status.textinfo})
- accelerator.save_state(
- os.path.join(
- args.model_dir,
- "checkpoints",
- f"checkpoint-{args.revision}",
- )
+ pbar2.reset(1)
+
+ pbar2.set_description("Saving diffusion model")
+ s_pipeline.save_pretrained(
+ weights_dir,
+ safe_serialization=False,
)
+ if ema_model is not None:
+ ema_model.save_pretrained(
+ os.path.join(
+ weights_dir,
+ "ema_unet",
+ ),
+ safe_serialization=False,
+ )
pbar2.update()
- # We should save this regardless, because it's our fallback if no snapshot exists.
+ if save_snapshot:
+ pbar2.reset(1)
+ pbar2.set_description("Saving Snapshot")
+ status.textinfo = (
+ f"Saving snapshot at step {args.revision}..."
+ )
+ update_status({"status": status.textinfo})
+ accelerator.save_state(
+ os.path.join(
+ args.model_dir,
+ "checkpoints",
+ f"checkpoint-{args.revision}",
+ )
+ )
+ pbar2.update()
- # package pt into checkpoint
- if save_checkpoint:
- pbar2.reset(1)
- pbar2.set_description("Compiling Checkpoint")
- snap_rev = str(args.revision) if save_snapshot else ""
- if export_diffusers:
- copy_diffusion_model(args.model_name, os.path.join(user_model_dir, "diffusers"))
- else:
- if args.model_type == "SDXL":
- compile_checkpoint_xl(args.model_name, reload_models=False,
- lora_file_name=out_file,
- log=False, snap_rev=snap_rev, pbar=pbar2)
+ # We should save this regardless, because it's our fallback if no snapshot exists.
+
+ # package pt into checkpoint
+ if save_checkpoint:
+ pbar2.reset(1)
+ pbar2.set_description("Compiling Checkpoint")
+ snap_rev = str(args.revision) if save_snapshot else ""
+ if export_diffusers:
+ copy_diffusion_model(args.model_name, os.path.join(user_model_dir, "diffusers"))
else:
- compile_checkpoint(args.model_name, reload_models=False,
- lora_file_name=out_file,
- log=False, snap_rev=snap_rev, pbar=pbar2)
- printm("Restored, moved to acc.device.")
- pbar2.update()
+ if args.model_type == "SDXL":
+ compile_checkpoint_xl(args.model_name, reload_models=False,
+ lora_file_name=out_file,
+ log=False, snap_rev=snap_rev, pbar=pbar2)
+ else:
+ compile_checkpoint(args.model_name, reload_models=False,
+ lora_file_name=out_file,
+ log=False, snap_rev=snap_rev, pbar=pbar2)
+ printm("Restored, moved to acc.device.")
+ pbar2.update()
+
+ except Exception as ex:
+ logger.warning(f"Exception saving checkpoint/model: {ex}")
+ traceback.print_exc()
+ pass
+ save_dir = args.model_dir
+
+ if save_image:
+ logger.debug("Saving images...")
+ # Get the path to a temporary directory
+ del s_pipeline
+ logger.debug(f"Loading image pipeline from {weights_dir}...")
+ if args.model_type == "SDXL":
+ s_pipeline = StableDiffusionXLPipeline.from_pretrained(
+ weights_dir, vae=vae, revision=args.revision,
+ torch_dtype=weight_dtype
+ )
+ else:
+ s_pipeline = StableDiffusionPipeline.from_pretrained(
+ weights_dir, vae=vae, revision=args.revision,
+ torch_dtype=weight_dtype
+ )
+ if args.tomesd:
+ tomesd.apply_patch(s_pipeline, ratio=args.tomesd, use_rand=False)
+ if args.use_lora:
+ s_pipeline.load_lora_weights(lora_save_file)
- except Exception as ex:
- logger.warning(f"Exception saving checkpoint/model: {ex}")
- traceback.print_exc()
+ try:
+ s_pipeline.enable_vae_tiling()
+ s_pipeline.enable_vae_slicing()
+ s_pipeline.enable_sequential_cpu_offload()
+ s_pipeline.enable_xformers_memory_efficient_attention()
+ except:
pass
- save_dir = args.model_dir
- if save_image:
- logger.debug("Saving images...")
- # Get the path to a temporary directory
- del s_pipeline
- logger.debug(f"Loading image pipeline from {weights_dir}...")
- if args.model_type == "SDXL":
- s_pipeline = StableDiffusionXLPipeline.from_pretrained(
- weights_dir, vae=vae, revision=args.revision,
- torch_dtype=weight_dtype
- )
- else:
- s_pipeline = StableDiffusionPipeline.from_pretrained(
- weights_dir, vae=vae, revision=args.revision,
- torch_dtype=weight_dtype
+ samples = []
+ sample_prompts = []
+ last_samples = []
+ last_prompts = []
+ status.textinfo = (
+ f"Saving preview image(s) at step {args.revision}..."
)
- if args.tomesd:
- tomesd.apply_patch(s_pipeline, ratio=args.tomesd, use_rand=False)
- if args.use_lora:
- s_pipeline.load_lora_weights(lora_save_file)
+ update_status({"status": status.textinfo})
+ try:
+ s_pipeline.set_progress_bar_config(disable=True)
+ sample_dir = os.path.join(save_dir, "samples")
+ os.makedirs(sample_dir, exist_ok=True)
+
+ sd = SampleDataset(args)
+ prompts = sd.prompts
+ logger.debug(f"Generating {len(prompts)} samples...")
+
+ concepts = args.concepts()
+ if args.sanity_prompt:
+ epd = PromptData(
+ prompt=args.sanity_prompt,
+ seed=args.sanity_seed,
+ negative_prompt=concepts[
+ 0
+ ].save_sample_negative_prompt,
+ resolution=(args.resolution, args.resolution),
+ )
+ prompts.append(epd)
- try:
- s_pipeline.enable_vae_tiling()
- s_pipeline.enable_vae_slicing()
- s_pipeline.enable_sequential_cpu_offload()
- s_pipeline.enable_xformers_memory_efficient_attention()
- except:
- pass
+ prompt_lengths = len(prompts)
+ if args.disable_logging:
+ pbar2.reset(prompt_lengths)
+ else:
+ pbar2.reset(prompt_lengths + 2)
+ pbar2.set_description("Generating Samples")
+ ci = 0
+ for c in prompts:
+ c.out_dir = os.path.join(args.model_dir, "samples")
+ generator = torch.manual_seed(int(c.seed))
+ s_image = s_pipeline(
+ c.prompt,
+ num_inference_steps=c.steps,
+ guidance_scale=c.scale,
+ negative_prompt=c.negative_prompt,
+ height=c.resolution[1],
+ width=c.resolution[0],
+ generator=generator,
+ ).images[0]
+ sample_prompts.append(c.prompt)
+ image_name = db_save_image(
+ s_image,
+ c,
+ custom_name=f"sample_{args.revision}-{ci}",
+ )
+ shared.status.current_image = image_name
+ shared.status.sample_prompts = [c.prompt]
+ update_status({"images": [image_name], "prompts": [c.prompt]})
+ samples.append(image_name)
+ pbar2.update()
+ ci += 1
+ for sample in samples:
+ last_samples.append(sample)
+ for prompt in sample_prompts:
+ last_prompts.append(prompt)
+ del samples
+ del prompts
+ except:
+ logger.warning(f"Exception saving sample.")
+ traceback.print_exc()
+ pass
- samples = []
- sample_prompts = []
- last_samples = []
- last_prompts = []
- status.textinfo = (
- f"Saving preview image(s) at step {args.revision}..."
- )
- update_status({"status": status.textinfo})
- try:
- s_pipeline.set_progress_bar_config(disable=True)
- sample_dir = os.path.join(save_dir, "samples")
- os.makedirs(sample_dir, exist_ok=True)
-
- sd = SampleDataset(args)
- prompts = sd.prompts
- logger.debug(f"Generating {len(prompts)} samples...")
-
- concepts = args.concepts()
- if args.sanity_prompt:
- epd = PromptData(
- prompt=args.sanity_prompt,
- seed=args.sanity_seed,
- negative_prompt=concepts[
- 0
- ].save_sample_negative_prompt,
- resolution=(args.resolution, args.resolution),
- )
- prompts.append(epd)
+ del s_pipeline
+ printm("Starting cleanup.")
- prompt_lengths = len(prompts)
- if args.disable_logging:
- pbar2.reset(prompt_lengths)
- else:
- pbar2.reset(prompt_lengths + 2)
- pbar2.set_description("Generating Samples")
- ci = 0
- for c in prompts:
- c.out_dir = os.path.join(args.model_dir, "samples")
- generator = torch.manual_seed(int(c.seed))
- s_image = s_pipeline(
- c.prompt,
- num_inference_steps=c.steps,
- guidance_scale=c.scale,
- negative_prompt=c.negative_prompt,
- height=c.resolution[1],
- width=c.resolution[0],
- generator=generator,
- ).images[0]
- sample_prompts.append(c.prompt)
- image_name = db_save_image(
- s_image,
- c,
- custom_name=f"sample_{args.revision}-{ci}",
- )
- shared.status.current_image = image_name
- shared.status.sample_prompts = [c.prompt]
- update_status({"images": [image_name], "prompts": [c.prompt]})
- samples.append(image_name)
- pbar2.update()
- ci += 1
- for sample in samples:
- last_samples.append(sample)
- for prompt in sample_prompts:
- last_prompts.append(prompt)
- del samples
- del prompts
- except:
- logger.warning(f"Exception saving sample.")
- traceback.print_exc()
- pass
+ if os.path.isdir(loras_dir) and "_tmp" in loras_dir:
+ shutil.rmtree(loras_dir)
- del s_pipeline
- printm("Starting cleanup.")
+ if os.path.isdir(weights_dir) and "_tmp" in weights_dir:
+ shutil.rmtree(weights_dir)
- if os.path.isdir(loras_dir) and "_tmp" in loras_dir:
- shutil.rmtree(loras_dir)
+ if "generator" in locals():
+ del generator
- if os.path.isdir(weights_dir) and "_tmp" in weights_dir:
- shutil.rmtree(weights_dir)
+ if not args.disable_logging:
+ try:
+ printm("Parse logs.")
+ log_images, log_names = log_parser.parse_logs(model_name=args.model_name)
+ pbar2.update()
+ for log_image in log_images:
+ last_samples.append(log_image)
+ for log_name in log_names:
+ last_prompts.append(log_name)
+
+ del log_images
+ del log_names
+ except Exception as l:
+ traceback.print_exc()
+ logger.warning(f"Exception parsing logz: {l}")
+ pass
+
+ send_training_update(
+ last_samples,
+ args.model_name,
+ last_prompts,
+ global_step,
+ args.revision
+ )
- if "generator" in locals():
- del generator
+ status.sample_prompts = last_prompts
+ status.current_image = last_samples
+ update_status({"images": last_samples, "prompts": last_prompts})
+ pbar2.update()
- if not args.disable_logging:
- try:
- printm("Parse logs.")
- log_images, log_names = log_parser.parse_logs(model_name=args.model_name)
- pbar2.update()
- for log_image in log_images:
- last_samples.append(log_image)
- for log_name in log_names:
- last_prompts.append(log_name)
-
- del log_images
- del log_names
- except Exception as l:
- traceback.print_exc()
- logger.warning(f"Exception parsing logz: {l}")
- pass
- send_training_update(
- last_samples,
- args.model_name,
- last_prompts,
- global_step,
- args.revision
- )
+ if args.cache_latents:
+ printm("Unloading vae.")
+ del vae
+ # Preserve the reference again
+ vae = None
- status.sample_prompts = last_prompts
status.current_image = last_samples
- update_status({"images": last_samples, "prompts": last_prompts})
- pbar2.update()
+ update_status({"images": last_samples})
+ cleanup()
+ printm("Cleanup.")
+ optim_to(profiler, optimizer, accelerator.device)
- if args.cache_latents:
- printm("Unloading vae.")
- del vae
- # Preserve the reference again
- vae = None
+ # Restore all random states to avoid having sampling impact training.
+ if shared.device.type == 'cuda':
+ torch.set_rng_state(torch_rng_state)
+ torch.cuda.set_rng_state(cuda_cpu_rng_state, device="cpu")
+ torch.cuda.set_rng_state(cuda_gpu_rng_state, device="cuda")
- status.current_image = last_samples
- update_status({"images": last_samples})
- cleanup()
- printm("Cleanup.")
+ cleanup()
- optim_to(profiler, optimizer, accelerator.device)
+ # Save the lora weights if we are saving the model
+ if os.path.isfile(lora_save_file) and not delete_tmp_lora:
+ meta = args.export_ss_metadata()
+ convert_diffusers_to_kohya_lora(lora_save_file, meta, args.lora_weight)
+ else:
+ if os.path.isfile(lora_save_file):
+ os.remove(lora_save_file)
- # Restore all random states to avoid having sampling impact training.
- if shared.device.type == 'cuda':
- torch.set_rng_state(torch_rng_state)
- torch.cuda.set_rng_state(cuda_cpu_rng_state, device="cpu")
- torch.cuda.set_rng_state(cuda_gpu_rng_state, device="cuda")
+ printm("Completed saving weights.")
+ pbar2.reset()
- cleanup()
+ # Only show the progress bar once on each machine, and do not send statuses to the new UI.
+ progress_bar = mytqdm(
+ range(global_step, max_train_steps),
+ disable=not accelerator.is_local_main_process,
+ position=0
+ )
+ progress_bar.set_description("Steps")
+ progress_bar.set_postfix(refresh=True)
+ args.revision = (
+ args.revision if isinstance(args.revision, int) else
+ int(args.revision) if str(args.revision).strip() else
+ 0
+ )
+ lifetime_step = args.revision
+ lifetime_epoch = args.epoch
+ status.job_count = max_train_steps
+ status.job_no = global_step
+ update_status({"progress_1_total": max_train_steps, "progress_1_job_current": global_step})
+ training_complete = False
+ msg = ""
- # Save the lora weights if we are saving the model
- if os.path.isfile(lora_save_file) and not delete_tmp_lora:
- meta = args.export_ss_metadata()
- convert_diffusers_to_kohya_lora(lora_save_file, meta, args.lora_weight)
- else:
- if os.path.isfile(lora_save_file):
- os.remove(lora_save_file)
+ last_tenc = 0 < text_encoder_epochs
+ if stop_text_percentage == 0:
+ last_tenc = False
+
+ cleanup()
+ stats = {
+ "loss": 0.0,
+ "prior_loss": 0.0,
+ "instance_loss": 0.0,
+ "unet_lr": learning_rate,
+ "tenc_lr": txt_learning_rate,
+ "session_epoch": 0,
+ "lifetime_epoch": args.epoch,
+ "total_session_epoch": args.num_train_epochs,
+ "total_lifetime_epoch": args.epoch + args.num_train_epochs,
+ "lifetime_step": args.revision,
+ "session_step": 0,
+ "total_session_step": max_train_steps,
+ "total_lifetime_step": args.revision + max_train_steps,
+ "steps_per_epoch": len(train_dataset),
+ "iterations_per_second": 0.0,
+ "vram": round(torch.cuda.memory_reserved(0) / 1024 ** 3, 1)
+ }
+ for epoch in range(first_epoch, max_train_epochs):
+ if training_complete:
+ logger.debug("Training complete, breaking epoch.")
+ break
- printm("Completed saving weights.")
- pbar2.reset()
+ if args.train_unet:
+ unet.train()
+ elif args.use_lora and not args.lora_use_buggy_requires_grad:
+ set_lora_requires_grad(unet, False)
- # Only show the progress bar once on each machine, and do not send statuses to the new UI.
- progress_bar = mytqdm(
- range(global_step, max_train_steps),
- disable=not accelerator.is_local_main_process,
- position=0
- )
- progress_bar.set_description("Steps")
- progress_bar.set_postfix(refresh=True)
- args.revision = (
- args.revision if isinstance(args.revision, int) else
- int(args.revision) if str(args.revision).strip() else
- 0
- )
- lifetime_step = args.revision
- lifetime_epoch = args.epoch
- status.job_count = max_train_steps
- status.job_no = global_step
- update_status({"progress_1_total": max_train_steps, "progress_1_job_current": global_step})
- training_complete = False
- msg = ""
-
- last_tenc = 0 < text_encoder_epochs
- if stop_text_percentage == 0:
- last_tenc = False
-
- cleanup()
- stats = {
- "loss": 0.0,
- "prior_loss": 0.0,
- "instance_loss": 0.0,
- "unet_lr": learning_rate,
- "tenc_lr": txt_learning_rate,
- "session_epoch": 0,
- "lifetime_epoch": args.epoch,
- "total_session_epoch": args.num_train_epochs,
- "total_lifetime_epoch": args.epoch + args.num_train_epochs,
- "lifetime_step": args.revision,
- "session_step": 0,
- "total_session_step": max_train_steps,
- "total_lifetime_step": args.revision + max_train_steps,
- "steps_per_epoch": len(train_dataset),
- "iterations_per_second": 0.0,
- "vram": round(torch.cuda.memory_reserved(0) / 1024 ** 3, 1)
- }
- for epoch in range(first_epoch, max_train_epochs):
- if training_complete:
- logger.debug("Training complete, breaking epoch.")
- break
-
- if args.train_unet:
- unet.train()
- elif args.use_lora and not args.lora_use_buggy_requires_grad:
- set_lora_requires_grad(unet, False)
-
- train_tenc = epoch < text_encoder_epochs
- if stop_text_percentage == 0:
- train_tenc = False
+ train_tenc = epoch < text_encoder_epochs
+ if stop_text_percentage == 0:
+ train_tenc = False
- if args.freeze_clip_normalization:
- text_encoder.eval()
- if args.model_type == "SDXL":
- text_encoder_two.eval()
- else:
- text_encoder.train(train_tenc)
- if args.model_type == "SDXL":
- text_encoder_two.train(train_tenc)
+ if args.freeze_clip_normalization:
+ text_encoder.eval()
+ if args.model_type == "SDXL":
+ text_encoder_two.eval()
+ else:
+ text_encoder.train(train_tenc)
+ if args.model_type == "SDXL":
+ text_encoder_two.train(train_tenc)
- if args.use_lora:
- if not args.lora_use_buggy_requires_grad:
- set_lora_requires_grad(text_encoder, train_tenc)
- # We need to enable gradients on an input for gradient checkpointing to work
- # This will not be optimized because it is not a param to optimizer
- text_encoder.text_model.embeddings.position_embedding.requires_grad_(train_tenc)
+ if args.use_lora:
+ if not args.lora_use_buggy_requires_grad:
+ set_lora_requires_grad(text_encoder, train_tenc)
+ # We need to enable gradients on an input for gradient checkpointing to work
+ # This will not be optimized because it is not a param to optimizer
+ text_encoder.text_model.embeddings.position_embedding.requires_grad_(train_tenc)
+ if args.model_type == "SDXL":
+ set_lora_requires_grad(text_encoder_two, train_tenc)
+ text_encoder_two.text_model.embeddings.position_embedding.requires_grad_(train_tenc)
+ else:
+ text_encoder.requires_grad_(train_tenc)
if args.model_type == "SDXL":
- set_lora_requires_grad(text_encoder_two, train_tenc)
- text_encoder_two.text_model.embeddings.position_embedding.requires_grad_(train_tenc)
- else:
- text_encoder.requires_grad_(train_tenc)
- if args.model_type == "SDXL":
- text_encoder_two.requires_grad_(train_tenc)
+ text_encoder_two.requires_grad_(train_tenc)
- if last_tenc != train_tenc:
- last_tenc = train_tenc
- cleanup()
+ if last_tenc != train_tenc:
+ last_tenc = train_tenc
+ cleanup()
- loss_total = 0
+ loss_total = 0
- current_prior_loss_weight = current_prior_loss(
- args, current_epoch=global_epoch
- )
-
- instance_loss = None
- prior_loss = None
+ current_prior_loss_weight = current_prior_loss(
+ args, current_epoch=global_epoch
+ )
- for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if (
- resume_from_checkpoint
- and epoch == first_epoch
- and step < resume_step
- ):
- progress_bar.update(train_batch_size)
- progress_bar.reset()
- status.job_count = max_train_steps
- status.job_no += train_batch_size
- stats["session_step"] += train_batch_size
- stats["lifetime_step"] += train_batch_size
- update_status(stats)
- continue
+ instance_loss = None
+ prior_loss = None
+
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if (
+ resume_from_checkpoint
+ and epoch == first_epoch
+ and step < resume_step
+ ):
+ progress_bar.update(train_batch_size)
+ progress_bar.reset()
+ status.job_count = max_train_steps
+ status.job_no += train_batch_size
+ stats["session_step"] += train_batch_size
+ stats["lifetime_step"] += train_batch_size
+ update_status(stats)
+ continue
+
+ with ConditionalAccumulator(accelerator, unet, text_encoder, text_encoder_two):
+ # Convert images to latent space
+ with torch.no_grad():
+ if args.cache_latents:
+ latents = batch["images"].to(accelerator.device)
+ else:
+ latents = vae.encode(
+ batch["images"].to(dtype=weight_dtype)
+ ).latent_dist.sample()
+ latents = latents * 0.18215
+
+ # Sample noise that we'll add to the model input
+ noise = torch.randn_like(latents, device=latents.device)
+ if args.offset_noise != 0:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.offset_noise * torch.randn(
+ (latents.shape[0],
+ latents.shape[1],
+ 1,
+ 1),
+ device=latents.device
+ )
+ b_size, channels, height, width = latents.shape
- with ConditionalAccumulator(accelerator, unet, text_encoder, text_encoder_two):
- # Convert images to latent space
- with torch.no_grad():
- if args.cache_latents:
- latents = batch["images"].to(accelerator.device)
- else:
- latents = vae.encode(
- batch["images"].to(dtype=weight_dtype)
- ).latent_dist.sample()
- latents = latents * 0.18215
-
- # Sample noise that we'll add to the model input
- noise = torch.randn_like(latents, device=latents.device)
- if args.offset_noise != 0:
- # https://www.crosslabs.org//blog/diffusion-with-offset-noise
- noise += args.offset_noise * torch.randn(
- (latents.shape[0],
- latents.shape[1],
- 1,
- 1),
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0,
+ noise_scheduler.config.num_train_timesteps,
+ (b_size,),
device=latents.device
)
- b_size, channels, height, width = latents.shape
-
- # Sample a random timestep for each image
- timesteps = torch.randint(
- 0,
- noise_scheduler.config.num_train_timesteps,
- (b_size,),
- device=latents.device
- )
- timesteps = timesteps.long()
-
- # Add noise to the latents according to the noise magnitude at each timestep
- # (this is the forward diffusion process)
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
- pad_tokens = args.pad_tokens if train_tenc else False
- input_ids = batch["input_ids"]
- encoder_hidden_states = None
- if args.model_type != "SDXL" and text_encoder is not None:
- encoder_hidden_states = encode_hidden_state(
- text_encoder,
- batch["input_ids"],
- pad_tokens,
- b_size,
- args.max_token_length,
- tokenizer_max_length,
- args.clip_skip,
- )
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+ pad_tokens = args.pad_tokens if train_tenc else False
+ input_ids = batch["input_ids"]
+ encoder_hidden_states = None
+ if args.model_type != "SDXL" and text_encoder is not None:
+ encoder_hidden_states = encode_hidden_state(
+ text_encoder,
+ batch["input_ids"],
+ pad_tokens,
+ b_size,
+ args.max_token_length,
+ tokenizer_max_length,
+ args.clip_skip,
+ )
- if unet.config.in_channels > channels:
- needed_additional_channels = unet.config.in_channels - channels
- additional_latents = randn_tensor(
- (b_size, needed_additional_channels, height, width),
- device=noisy_latents.device,
- dtype=noisy_latents.dtype,
- )
- noisy_latents = torch.cat([additional_latents, noisy_latents], dim=1)
- # Get the target for loss depending on the prediction type
- if noise_scheduler.config.prediction_type == "epsilon":
- target = noise
- elif noise_scheduler.config.prediction_type == "v_prediction":
- target = noise_scheduler.get_velocity(latents, noise, timesteps)
- else:
- raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+ if unet.config.in_channels > channels:
+ needed_additional_channels = unet.config.in_channels - channels
+ additional_latents = randn_tensor(
+ (b_size, needed_additional_channels, height, width),
+ device=noisy_latents.device,
+ dtype=noisy_latents.dtype,
+ )
+ noisy_latents = torch.cat([additional_latents, noisy_latents], dim=1)
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
- if args.model_type == "SDXL":
- with accelerator.autocast():
- model_pred = unet(
- noisy_latents, timesteps, batch["input_ids"],
- added_cond_kwargs=batch["unet_added_conditions"]
- ).sample
- else:
- # Predict the noise residual and compute loss
- model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
-
- if args.model_type != "SDXL":
- # TODO: set a prior preservation flag and use that to ensure this ony happens in dreambooth
- if not args.split_loss and not with_prior_preservation:
- loss = instance_loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
- loss *= batch["loss_avg"]
+ if args.model_type == "SDXL":
+ with accelerator.autocast():
+ model_pred = unet(
+ noisy_latents, timesteps, batch["input_ids"],
+ added_cond_kwargs=batch["unet_added_conditions"]
+ ).sample
else:
- # Predict the noise residual
- if model_pred.shape[1] == 6:
- model_pred, _ = torch.chunk(model_pred, 2, dim=1)
-
- if model_pred.shape[0] > 1 and with_prior_preservation:
- # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
- print("model shape:")
- print(model_pred.shape)
- model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
- target, target_prior = torch.chunk(target, 2, dim=0)
-
- # Compute instance loss
+ # Predict the noise residual and compute loss
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ if args.model_type != "SDXL":
+ # TODO: set a prior preservation flag and use that to ensure this ony happens in dreambooth
+ if not args.split_loss and not with_prior_preservation:
+ loss = instance_loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ loss *= batch["loss_avg"]
+ else:
+ # Predict the noise residual
+ if model_pred.shape[1] == 6:
+ model_pred, _ = torch.chunk(model_pred, 2, dim=1)
+
+ if model_pred.shape[0] > 1 and with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ print("model shape:")
+ print(model_pred.shape)
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute instance loss
+ loss = instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(),
+ reduction="mean")
+ else:
+ # Compute loss
loss = instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ if with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute instance loss
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
- # Compute prior loss
- prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(),
- reduction="mean")
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
else:
- # Compute loss
- loss = instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
- else:
- if with_prior_preservation:
- # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
- model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
- target, target_prior = torch.chunk(target, 2, dim=0)
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
- # Compute instance loss
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ accelerator.backward(loss)
- # Compute prior loss
- prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+ if accelerator.sync_gradients and not args.use_lora:
+ if train_tenc:
+ if args.model_type == "SDXL":
+ params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters(),
+ text_encoder_two.parameters())
+ else:
+ params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
+ else:
+ params_to_clip = unet.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, 1)
+
+ optimizer.step()
+ lr_scheduler.step(train_batch_size)
+ if args.use_ema and ema_model is not None:
+ ema_model.step(unet)
+ if profiler is not None:
+ profiler.step()
+
+ optimizer.zero_grad(set_to_none=args.gradient_set_to_none)
+
+ allocated = round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)
+ cached = round(torch.cuda.memory_reserved(0) / 1024 ** 3, 1)
+ lr_data = lr_scheduler.get_last_lr()
+ last_lr = lr_data[0]
+ last_tenc_lr = 0
+ stats["lr_data"] = lr_data
+ try:
+ if len(optimizer.param_groups) > 1:
+ last_tenc_lr = optimizer.param_groups[1]["lr"] if train_tenc else 0
+ except:
+ logger.debug("Exception getting tenc lr")
+ pass
- # Add the prior loss to the instance loss.
- loss = loss + args.prior_loss_weight * prior_loss
- else:
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ if 'adapt' in args.optimizer:
+ last_lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
+ if len(optimizer.param_groups) > 1:
+ try:
+ last_tenc_lr = optimizer.param_groups[1]["d"] * optimizer.param_groups[1]["lr"]
+ except:
+ logger.warning("Exception setting tenc weight decay")
+ traceback.print_exc()
- accelerator.backward(loss)
+ update_status(stats)
+ del latents
+ del encoder_hidden_states
+ del noise
+ del timesteps
+ del noisy_latents
+ del target
+
+ global_step += train_batch_size
+ args.revision += train_batch_size
+ status.job_no += train_batch_size
+ loss_step = loss.detach().item()
+ loss_total += loss_step
- if accelerator.sync_gradients and not args.use_lora:
- if train_tenc:
- if args.model_type == "SDXL":
- params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters(),
- text_encoder_two.parameters())
- else:
- params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
- else:
- params_to_clip = unet.parameters()
- accelerator.clip_grad_norm_(params_to_clip, 1)
-
- optimizer.step()
- lr_scheduler.step(train_batch_size)
- if args.use_ema and ema_model is not None:
- ema_model.step(unet)
- if profiler is not None:
- profiler.step()
-
- optimizer.zero_grad(set_to_none=args.gradient_set_to_none)
-
- allocated = round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)
- cached = round(torch.cuda.memory_reserved(0) / 1024 ** 3, 1)
- lr_data = lr_scheduler.get_last_lr()
- last_lr = lr_data[0]
- last_tenc_lr = 0
- stats["lr_data"] = lr_data
- try:
- if len(optimizer.param_groups) > 1:
- last_tenc_lr = optimizer.param_groups[1]["lr"] if train_tenc else 0
- except:
- logger.debug("Exception getting tenc lr")
- pass
+ stats["session_step"] += train_batch_size
+ stats["lifetime_step"] += train_batch_size
+ stats["loss"] = loss_step
- if 'adapt' in args.optimizer:
- last_lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
- if len(optimizer.param_groups) > 1:
- try:
- last_tenc_lr = optimizer.param_groups[1]["d"] * optimizer.param_groups[1]["lr"]
- except:
- logger.warning("Exception setting tenc weight decay")
- traceback.print_exc()
+ logs = {
+ "lr": float(last_lr),
+ "loss": float(loss_step),
+ "vram": float(cached),
+ }
- update_status(stats)
- del latents
- del encoder_hidden_states
- del noise
- del timesteps
- del noisy_latents
- del target
-
- global_step += train_batch_size
- args.revision += train_batch_size
- status.job_no += train_batch_size
- loss_step = loss.detach().item()
- loss_total += loss_step
-
- stats["session_step"] += train_batch_size
- stats["lifetime_step"] += train_batch_size
- stats["loss"] = loss_step
-
- logs = {
- "lr": float(last_lr),
- "loss": float(loss_step),
- "vram": float(cached),
- }
+ stats["vram"] = logs["vram"]
+ stats["unet_lr"] = '{:.2E}'.format(Decimal(last_lr))
+ stats["tenc_lr"] = '{:.2E}'.format(Decimal(last_tenc_lr))
- stats["vram"] = logs["vram"]
- stats["unet_lr"] = '{:.2E}'.format(Decimal(last_lr))
- stats["tenc_lr"] = '{:.2E}'.format(Decimal(last_tenc_lr))
+ if args.split_loss and with_prior_preservation and args.model_type != "SDXL":
+ logs["inst_loss"] = float(instance_loss.detach().item())
- if args.split_loss and with_prior_preservation and args.model_type != "SDXL":
- logs["inst_loss"] = float(instance_loss.detach().item())
-
- if prior_loss is not None:
- logs["prior_loss"] = float(prior_loss.detach().item())
+ if prior_loss is not None:
+ logs["prior_loss"] = float(prior_loss.detach().item())
+ else:
+ logs["prior_loss"] = None # or some other default value
+ stats["instance_loss"] = logs["inst_loss"]
+ stats["prior_loss"] = logs["prior_loss"]
+
+ if 'adapt' in args.optimizer:
+ status.textinfo2 = (
+ f"Loss: {'%.2f' % loss_step}, UNET DLR: {'{:.2E}'.format(Decimal(last_lr))}, TENC DLR: {'{:.2E}'.format(Decimal(last_tenc_lr))}, "
+ f"VRAM: {allocated}/{cached} GB"
+ )
else:
- logs["prior_loss"] = None # or some other default value
- stats["instance_loss"] = logs["inst_loss"]
- stats["prior_loss"] = logs["prior_loss"]
-
- if 'adapt' in args.optimizer:
- status.textinfo2 = (
- f"Loss: {'%.2f' % loss_step}, UNET DLR: {'{:.2E}'.format(Decimal(last_lr))}, TENC DLR: {'{:.2E}'.format(Decimal(last_tenc_lr))}, "
- f"VRAM: {allocated}/{cached} GB"
- )
- else:
- status.textinfo2 = (
- f"Loss: {'%.2f' % loss_step}, LR: {'{:.2E}'.format(Decimal(last_lr))}, "
- f"VRAM: {allocated}/{cached} GB"
- )
+ status.textinfo2 = (
+ f"Loss: {'%.2f' % loss_step}, LR: {'{:.2E}'.format(Decimal(last_lr))}, "
+ f"VRAM: {allocated}/{cached} GB"
+ )
- progress_bar.update(train_batch_size)
- rate = progress_bar.format_dict["rate"] if "rate" in progress_bar.format_dict else None
- if rate is None:
- rate_string = ""
- else:
- if rate > 1:
- rate_string = f"{rate:.2f} it/s"
+ progress_bar.update(train_batch_size)
+ rate = progress_bar.format_dict["rate"] if "rate" in progress_bar.format_dict else None
+ if rate is None:
+ rate_string = ""
else:
- rate_string = f"{1 / rate:.2f} s/it" if rate != 0 else "N/A"
- stats["iterations_per_second"] = rate_string
- progress_bar.set_postfix(**logs)
- accelerator.log(logs, step=args.revision)
+ if rate > 1:
+ rate_string = f"{rate:.2f} it/s"
+ else:
+ rate_string = f"{1 / rate:.2f} s/it" if rate != 0 else "N/A"
+ stats["iterations_per_second"] = rate_string
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=args.revision)
+
+ logs = {"epoch_loss": loss_total / len(train_dataloader)}
+ accelerator.log(logs, step=global_step)
+ stats["epoch_loss"] = '%.2f' % (loss_total / len(train_dataloader))
+
+ status.job_count = max_train_steps
+ status.job_no = global_step
+ stats["lifetime_step"] = args.revision
+ stats["session_step"] = global_step
+ # status0 = f"Steps: {global_step}/{max_train_steps} (Current), {rate_string}"
+ # status1 = f"{args.revision}/{lifetime_step + max_train_steps} (Lifetime), Epoch: {global_epoch}"
+ status.textinfo = (
+ f"Steps: {global_step}/{max_train_steps} (Current), {rate_string}"
+ f" {args.revision}/{lifetime_step + max_train_steps} (Lifetime), Epoch: {global_epoch}"
+ )
+ update_status(stats)
- logs = {"epoch_loss": loss_total / len(train_dataloader)}
- accelerator.log(logs, step=global_step)
- stats["epoch_loss"] = '%.2f' % (loss_total / len(train_dataloader))
+ if math.isnan(loss_step):
+ logger.warning("Loss is NaN, your model is dead. Cancelling training.")
+ status.interrupted = True
+ if status_handler:
+ status_handler.end("Training interrrupted due to NaN loss.")
+
+ # Log completion message
+ if training_complete or status.interrupted:
+ shared.in_progress = False
+ shared.in_progress_step = 0
+ shared.in_progress_epoch = 0
+ logger.debug(" Training complete (step check).")
+ if status.interrupted:
+ state = "canceled"
+ else:
+ state = "complete"
+ status.textinfo = (
+ f"Training {state} {global_step}/{max_train_steps}, {args.revision}"
+ f" total."
+ )
+ if status_handler:
+ status_handler.end(status.textinfo)
+ break
+
+ accelerator.wait_for_everyone()
+
+ args.epoch += 1
+ global_epoch += 1
+ lifetime_epoch += 1
+ session_epoch += 1
+ stats["session_epoch"] += 1
+ stats["lifetime_epoch"] += 1
+ lr_scheduler.step(is_epoch=True)
status.job_count = max_train_steps
status.job_no = global_step
- stats["lifetime_step"] = args.revision
- stats["session_step"] = global_step
- # status0 = f"Steps: {global_step}/{max_train_steps} (Current), {rate_string}"
- # status1 = f"{args.revision}/{lifetime_step + max_train_steps} (Lifetime), Epoch: {global_epoch}"
- status.textinfo = (
- f"Steps: {global_step}/{max_train_steps} (Current), {rate_string}"
- f" {args.revision}/{lifetime_step + max_train_steps} (Lifetime), Epoch: {global_epoch}"
- )
update_status(stats)
+ check_save(True)
- if math.isnan(loss_step):
- logger.warning("Loss is NaN, your model is dead. Cancelling training.")
- status.interrupted = True
- if status_handler:
- status_handler.end("Training interrrupted due to NaN loss.")
+ if args.num_train_epochs > 1:
+ training_complete = session_epoch >= max_train_epochs
- # Log completion message
if training_complete or status.interrupted:
- shared.in_progress = False
- shared.in_progress_step = 0
- shared.in_progress_epoch = 0
logger.debug(" Training complete (step check).")
if status.interrupted:
state = "canceled"
@@ -1761,59 +1804,27 @@ def lora_save_function(weights, filename):
status_handler.end(status.textinfo)
break
- accelerator.wait_for_everyone()
-
- args.epoch += 1
- global_epoch += 1
- lifetime_epoch += 1
- session_epoch += 1
- stats["session_epoch"] += 1
- stats["lifetime_epoch"] += 1
- lr_scheduler.step(is_epoch=True)
- status.job_count = max_train_steps
- status.job_no = global_step
- update_status(stats)
- check_save(True)
-
- if args.num_train_epochs > 1:
- training_complete = session_epoch >= max_train_epochs
-
- if training_complete or status.interrupted:
- logger.debug(" Training complete (step check).")
- if status.interrupted:
- state = "canceled"
- else:
- state = "complete"
+ # Do this at the very END of the epoch, only after we're sure we're not done
+ if args.epoch_pause_frequency > 0 and args.epoch_pause_time > 0:
+ if not session_epoch % args.epoch_pause_frequency:
+ logger.debug(
+ f"Giving the GPU a break for {args.epoch_pause_time} seconds."
+ )
+ for i in range(args.epoch_pause_time):
+ if status.interrupted:
+ training_complete = True
+ logger.debug("Training complete, interrupted.")
+ if status_handler:
+ status_handler.end("Training interrrupted.")
+ break
+ time.sleep(1)
- status.textinfo = (
- f"Training {state} {global_step}/{max_train_steps}, {args.revision}"
- f" total."
- )
- if status_handler:
- status_handler.end(status.textinfo)
- break
-
- # Do this at the very END of the epoch, only after we're sure we're not done
- if args.epoch_pause_frequency > 0 and args.epoch_pause_time > 0:
- if not session_epoch % args.epoch_pause_frequency:
- logger.debug(
- f"Giving the GPU a break for {args.epoch_pause_time} seconds."
- )
- for i in range(args.epoch_pause_time):
- if status.interrupted:
- training_complete = True
- logger.debug("Training complete, interrupted.")
- if status_handler:
- status_handler.end("Training interrrupted.")
- break
- time.sleep(1)
-
- cleanup_memory()
- accelerator.end_training()
- result.msg = msg
- result.config = args
- result.samples = last_samples
- stop_profiler(profiler)
- return result
+ cleanup_memory()
+ accelerator.end_training()
+ result.msg = msg
+ result.config = args
+ result.samples = last_samples
+ stop_profiler(profiler)
+ return result
return inner_loop()
diff --git a/dreambooth/ui_functions.py b/dreambooth/ui_functions.py
index 1e6c5fa8..eb364411 100644
--- a/dreambooth/ui_functions.py
+++ b/dreambooth/ui_functions.py
@@ -11,6 +11,7 @@
import traceback
from collections import OrderedDict
+import gradio
import torch
import torch.utils.data.dataloader
from accelerate import find_executable_batch_size
@@ -639,7 +640,7 @@ def load_model_params(model_name):
if config is None:
print("Can't load config!")
msg = f"Error loading model params: '{model_name}'."
- return "", "", "", "", "", db_model_snapshots, msg
+ return gradio.update(visible=False), "", "", "", "", "", db_model_snapshots, msg
else:
snaps = get_model_snapshots(config)
snap_selection = config.revision if str(config.revision) in snaps else ""
@@ -648,13 +649,17 @@ def load_model_params(model_name):
loras = get_lora_models(config)
db_lora_models = gr_update(choices=loras)
msg = f"Selected model: '{model_name}'."
+ src_name = os.path.basename(config.src)
+ # Strip the extension
+ src_name = os.path.splitext(src_name)[0]
return (
- config.model_dir,
+ gradio.update(visible=True),
+ os.path.basename(config.model_dir),
config.revision,
config.epoch,
config.model_type,
"True" if config.has_ema and not config.use_lora else "False",
- config.src,
+ src_name,
config.shared_diffusers_path,
db_model_snapshots,
db_lora_models,
@@ -941,11 +946,13 @@ def create_model(
new_model_token="",
extract_ema=False,
train_unfrozen=False,
- model_type="v1"
+ model_type="v1x"
):
+ if not model_type:
+ model_type = "v1x"
printm("Extracting model.")
res = 512
- is_512 = model_type == "v1"
+ is_512 = model_type == "v1x"
if model_type == "v1x" or model_type=="v2x-512":
res = 512
elif model_type == "v2x":
diff --git a/dreambooth/utils/model_utils.py b/dreambooth/utils/model_utils.py
index 3461015d..56d04615 100644
--- a/dreambooth/utils/model_utils.py
+++ b/dreambooth/utils/model_utils.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import collections
+import contextlib
import json
import logging
import os
@@ -257,23 +258,40 @@ def get_checkpoint_match(search_string):
return None
+disable_safe_unpickle_count = 0
+
def disable_safe_unpickle():
+ global disable_safe_unpickle_count
try:
from modules import shared as auto_shared
- auto_shared.cmd_opts.disable_safe_unpickle = True
- torch.load = unsafe_torch_load
+ if not auto_shared.cmd_opts.disable_safe_unpickle:
+ auto_shared.cmd_opts.disable_safe_unpickle = True
+ torch.load = unsafe_torch_load
+ disable_safe_unpickle_count += 1
except:
pass
def enable_safe_unpickle():
+ global disable_safe_unpickle_count
try:
from modules import shared as auto_shared
- auto_shared.cmd_opts.disable_safe_unpickle = False
- torch.load = load
+ if disable_safe_unpickle_count > 0:
+ disable_safe_unpickle_count -= 1
+ if disable_safe_unpickle_count == 0 and auto_shared.cmd_opts.disable_safe_unpickle:
+ auto_shared.cmd_opts.disable_safe_unpickle = False
+ torch.load = load
except:
pass
+@contextlib.contextmanager
+def safe_unpickle_disabled():
+ disable_safe_unpickle()
+ try:
+ yield
+ finally:
+ enable_safe_unpickle()
+
def xformerify(obj, use_lora):
try:
diff --git a/helpers/image_builder.py b/helpers/image_builder.py
index 2863e2f8..4b8c4919 100644
--- a/helpers/image_builder.py
+++ b/helpers/image_builder.py
@@ -14,12 +14,11 @@
from dreambooth import shared
from dreambooth.dataclasses.db_config import DreamboothConfig
from dreambooth.dataclasses.prompt_data import PromptData
-from dreambooth.shared import disable_safe_unpickle
from dreambooth.utils import image_utils
from dreambooth.utils.image_utils import process_txt2img, get_scheduler_class
from dreambooth.utils.model_utils import get_checkpoint_match, \
reload_system_models, \
- enable_safe_unpickle, disable_safe_unpickle, unload_system_models
+ safe_unpickle_disabled, unload_system_models
from helpers.mytqdm import mytqdm
from lora_diffusion.lora import _text_lora_path_ui, patch_pipe, tune_lora_scale, \
get_target_module
@@ -83,71 +82,69 @@ def __init__(
msg = f"Exception initializing accelerator: {e}"
print(msg)
torch_dtype = torch.float16 if shared.device.type == "cuda" else torch.float32
- disable_safe_unpickle()
+ with safe_unpickle_disabled():
- self.image_pipe = DiffusionPipeline.from_pretrained(config.get_pretrained_model_name_or_path(), torch_dtype=torch.float16)
+ self.image_pipe = DiffusionPipeline.from_pretrained(config.get_pretrained_model_name_or_path(), torch_dtype=torch.float16)
- if config.pretrained_vae_name_or_path:
- logging.getLogger(__name__).info("Using pretrained VAE.")
- self.image_pipe.vae = AutoencoderKL.from_pretrained(
- config.pretrained_vae_name_or_path or config.get_pretrained_model_name_or_path(),
- subfolder=None if config.pretrained_vae_name_or_path else "vae",
- revision=config.revision,
- torch_dtype=torch_dtype
- )
+ if config.pretrained_vae_name_or_path:
+ logging.getLogger(__name__).info("Using pretrained VAE.")
+ self.image_pipe.vae = AutoencoderKL.from_pretrained(
+ config.pretrained_vae_name_or_path or config.get_pretrained_model_name_or_path(),
+ subfolder=None if config.pretrained_vae_name_or_path else "vae",
+ revision=config.revision,
+ torch_dtype=torch_dtype
+ )
- if config.infer_ema:
- logging.getLogger(__name__).info("Using EMA model for inference.")
- ema_path = os.path.join(config.get_pretrained_model_name_or_path(), "ema_unet",
- "diffusion_pytorch_model.safetensors")
- if os.path.isfile(ema_path):
- self.image_pipe.unet = UNet2DConditionModel.from_pretrained(ema_path, torch_dtype=torch.float16),
+ if config.infer_ema:
+ logging.getLogger(__name__).info("Using EMA model for inference.")
+ ema_path = os.path.join(config.get_pretrained_model_name_or_path(), "ema_unet",
+ "diffusion_pytorch_model.safetensors")
+ if os.path.isfile(ema_path):
+ self.image_pipe.unet = UNet2DConditionModel.from_pretrained(ema_path, torch_dtype=torch.float16),
- self.image_pipe.enable_model_cpu_offload()
- self.image_pipe.unet.set_attn_processor(AttnProcessor2_0())
- if os.name != "nt":
- self.image_pipe.unet = torch.compile(self.image_pipe.unet)
- self.image_pipe.enable_xformers_memory_efficient_attention()
- self.image_pipe.vae.enable_slicing()
- tomesd.apply_patch(self.image_pipe, ratio=0.5)
- self.image_pipe.scheduler.config["solver_type"] = "bh2"
- self.image_pipe.progress_bar = self.progress_bar
+ self.image_pipe.enable_model_cpu_offload()
+ self.image_pipe.unet.set_attn_processor(AttnProcessor2_0())
+ if os.name != "nt":
+ self.image_pipe.unet = torch.compile(self.image_pipe.unet)
+ self.image_pipe.enable_xformers_memory_efficient_attention()
+ self.image_pipe.vae.enable_slicing()
+ tomesd.apply_patch(self.image_pipe, ratio=0.5)
+ self.image_pipe.scheduler.config["solver_type"] = "bh2"
+ self.image_pipe.progress_bar = self.progress_bar
- if scheduler is None:
- scheduler = config.scheduler
+ if scheduler is None:
+ scheduler = config.scheduler
- print(f"Using scheduler: {scheduler}")
- scheduler_class = get_scheduler_class(scheduler)
+ print(f"Using scheduler: {scheduler}")
+ scheduler_class = get_scheduler_class(scheduler)
- self.image_pipe.scheduler = scheduler_class.from_config(self.image_pipe.scheduler.config)
+ self.image_pipe.scheduler = scheduler_class.from_config(self.image_pipe.scheduler.config)
- if "UniPC" in scheduler:
- self.image_pipe.scheduler.config.solver_type = "bh2"
+ if "UniPC" in scheduler:
+ self.image_pipe.scheduler.config.solver_type = "bh2"
- self.image_pipe.to(accelerator.device)
- new_hotness = os.path.join(config.model_dir, "checkpoints", f"checkpoint-{config.revision}")
- if os.path.exists(new_hotness):
- accelerator.print(f"Resuming from checkpoint {new_hotness}")
- disable_safe_unpickle()
- accelerator.load_state(new_hotness)
- enable_safe_unpickle()
+ self.image_pipe.to(accelerator.device)
+ new_hotness = os.path.join(config.model_dir, "checkpoints", f"checkpoint-{config.revision}")
+ if os.path.exists(new_hotness):
+ accelerator.print(f"Resuming from checkpoint {new_hotness}")
+ accelerator.load_state(new_hotness)
- if config.use_lora and lora_model:
- lora_model_path = shared.ui_lora_models_path
- if os.path.exists(lora_model_path):
- patch_pipe(
- pipe=self.image_pipe,
- maybe_unet_path=lora_model_path,
- unet_target_replace_module=get_target_module("module", config.use_lora_extended),
- token=None,
- r=lora_unet_rank,
- r_txt=lora_txt_rank
- )
- tune_lora_scale(self.image_pipe.unet, config.lora_weight)
+ if config.use_lora and lora_model:
+ lora_model_path = shared.ui_lora_models_path
+ if os.path.exists(lora_model_path):
+ patch_pipe(
+ pipe=self.image_pipe,
+ maybe_unet_path=lora_model_path,
+ unet_target_replace_module=get_target_module("module", config.use_lora_extended),
+ token=None,
+ r=lora_unet_rank,
+ r_txt=lora_txt_rank
+ )
+ tune_lora_scale(self.image_pipe.unet, config.lora_weight)
lora_txt_path = _text_lora_path_ui(lora_model_path)
if os.path.exists(lora_txt_path):
- tune_lora_scale(self.image_pipe.text_encoder, config.lora_txt_weight)
+ tune_lora_scale(self.image_pipe.text_encoder, config.lora_weight)
else:
try:
diff --git a/javascript/dreambooth.js b/javascript/dreambooth.js
index 361be54b..35e4ca00 100644
--- a/javascript/dreambooth.js
+++ b/javascript/dreambooth.js
@@ -7,6 +7,7 @@ let locked = false;
let listenersSet = false;
let timeouts = [];
let listeners = {};
+let elementsHidden = false;
function save_config() {
let btn = gradioApp().getElementById("db_save_config");
@@ -22,7 +23,7 @@ function save_config() {
}
function toggleComponents(enable, disableAll) {
- const elements = ['DbTopRow', 'SettingsPanel'];
+ const elements = ["DbTopRow", "TabConcepts", "TabSettings", "TabSave", "TabGenerate", "TabDebug"];
if (disableAll) {
console.log("Disabling all DB elements!");
elements.push("ModelPanel")
@@ -53,124 +54,7 @@ function toggleComponents(enable, disableAll) {
});
}
-// Disconnect a gradio mutation observer, update the element value, and reconnect the observer?
-function updateInputValue(elements, newValue) {
- const savedListeners = [];
- const savedObservers = [];
-
- elements.forEach((element) => {
- // Save any existing listeners and remove them
- const listeners = [];
- const events = ['change', 'input'];
- events.forEach((event) => {
- if (element['on' + event]) {
- listeners.push({
- event,
- listener: element['on' + event],
- });
- element['on' + event] = null;
- }
- const eventListeners = element.getEventListeners?.(event);
- if (eventListeners) {
- eventListeners.forEach(({ listener }) => {
- listeners.push({
- event,
- listener,
- });
- element.removeEventListener(event, listener);
- });
- }
- });
- savedListeners.push(listeners);
-
- // Save any existing MutationObservers and disconnect them
- const observer = new MutationObserver(() => {
- });
- if (observer && element.tagName === 'INPUT') {
- observer.observe(element, {
- attributes: true,
- attributeFilter: ['value'],
- });
- savedObservers.push(observer);
- observer.disconnect();
- } else {
- savedObservers.push(null);
- }
-
- // Update the value of the element
- element.value = newValue;
- });
-
- // Restore any saved listeners and MutationObservers
- savedListeners.forEach((listeners, i) => {
- const element = elements[i];
- listeners.forEach(({ event, listener }) => {
- if (listener) {
- element.addEventListener(event, listener);
- }
- });
- });
-
- savedObservers.forEach((observer, i) => {
- const element = elements[i];
- if (observer) {
- observer.observe(element, {
- attributes: true,
- attributeFilter: ['value'],
- });
- }
- });
-}
-
-
-// Fix steps on sliders. God this is a lot of work for one stupid thing...
-function handleNumberInputs() {
- const numberInputs = gradioApp()
- .querySelector('#tab_dreambooth_interface')
- ?.querySelectorAll('input[type="number"]');
- numberInputs?.forEach((numberInput) => {
- const step = Number(numberInput.step) || 1;
- const parentDiv = numberInput.parentElement;
- const labelFor = parentDiv.querySelector('label');
- if (labelFor) {
- const tgt = labelFor.getAttribute("for");
- if (listeners[tgt]) return;
- const rangeInput = getRealElement(tgt);
- if (rangeInput && rangeInput.type === 'range') {
- let timeouts = [];
- listeners[tgt] = true;
- numberInput.oninput = () => {
- if (timeouts[tgt]) {
- clearTimeout(timeouts[tgt]);
- }
- timeouts[tgt] = setTimeout(() => {
- let value = Number(numberInput.value) || 0;
- const min = parseFloat(rangeInput.min) || 0;
- const max = parseFloat(rangeInput.max) || 100;
- if (value < min) {
- value = min;
- } else if (value > max) {
- value = max;
- }
- const remainder = value % step;
- if (remainder !== 0) {
- value -= remainder;
- if (remainder >= step / 2) {
- value += step;
- }
- }
- if (value !== numberInput.value) {
- numberInput.value = value;
- }
- }, 500);
- };
-
- }
- }
- });
-}
-
-
+// Don't delete this, it's used by the UI
function check_save() {
let do_save = true;
if (params_loaded === false) {
@@ -207,7 +91,7 @@ function update_params() {
let btn = gradioApp().getElementById("db_update_params");
if (btn == null) return;
btn.click();
- }, 500);
+ }, 100);
}
function getRealElement(selector) {
@@ -455,8 +339,22 @@ let db_titles = {
"Weight Decay": "Values closer to 0 closely match your training dataset, and values closer to 1 generalize more and deviate from your training dataset. Default is 1e-2, values lower than 0.1 are recommended. For D-Adaptation values between 0.02 and 0.04 are recommended",
}
+function hideElements() {
+ if (!elementsHidden) {
+ let btn = gradioApp().getElementById("db_hide_advanced");
+ if (btn == null) return;
+ elementsHidden = true;
+ console.log("Hiding advanced elements!");
+ btn.click();
+ }
+}
+
// Do a thing when the UI updates
onUiUpdate(function () {
+ setTimeout(function () {
+ hideElements();
+ },100);
+
let db_active = document.getElementById("db_active");
if (db_active) {
db_active.parentElement.style.display = "none";
@@ -545,11 +443,6 @@ onUiUpdate(function () {
observer.observe(btn, options);
});
- try {
- handleNumberInputs();
- } catch (e) {
- console.log("Gotcha: ", e);
- }
});
diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py
index 47af5797..01cb1e1c 100644
--- a/lora_diffusion/lora.py
+++ b/lora_diffusion/lora.py
@@ -8,7 +8,7 @@
from safetensors.torch import save_file as safe_save
from torch import dtype
-from dreambooth.utils.model_utils import disable_safe_unpickle, enable_safe_unpickle
+from dreambooth.utils.model_utils import safe_unpickle_disabled
class LoraInjectedLinear(nn.Module):
@@ -215,12 +215,13 @@ def inject_trainable_lora(
if target_replace_module is None:
target_replace_module = DEFAULT_TARGET_REPLACE
- disable_safe_unpickle()
+
require_grad_params = []
names = []
if loras is not None:
- loras = torch.load(loras)
+ with safe_unpickle_disabled():
+ loras = torch.load(loras)
for _module, name, _child_module in _find_modules(
model, target_replace_module, search_class=[nn.Linear]
@@ -252,7 +253,6 @@ def inject_trainable_lora(
_module._modules[name].lora_down.weight.requires_grad = True
names.append(name)
- enable_safe_unpickle()
return require_grad_params, names
@@ -267,12 +267,12 @@ def inject_trainable_lora_extended(
"""
if target_replace_module is None:
target_replace_module = UNET_EXTENDED_TARGET_REPLACE
- disable_safe_unpickle()
require_grad_params = []
names = []
if loras is not None:
- loras = torch.load(loras)
+ with safe_unpickle_disabled():
+ loras = torch.load(loras)
for _module, name, _child_module in _find_modules(
model, target_replace_module, search_class=[nn.Linear, nn.Conv2d]
@@ -326,7 +326,6 @@ def inject_trainable_lora_extended(
_module._modules[name].lora_down.weight.requires_grad = True
names.append(name)
- enable_safe_unpickle()
return require_grad_params, names
@@ -458,9 +457,9 @@ def convert_loras_to_safeloras_with_embeds(
for name, (path, target_replace_module, r) in modelmap.items():
metadata[name] = json.dumps(list(target_replace_module))
- disable_safe_unpickle()
- lora = torch.load(path)
- enable_safe_unpickle()
+ with safe_unpickle_disabled():
+ lora = torch.load(path)
+
for i, weight in enumerate(lora):
is_up = i % 2 == 0
i = i // 2
@@ -903,9 +902,8 @@ def load_learned_embed_in_clip(
token: Optional[Union[str, List[str]]] = None,
idempotent=False,
):
- disable_safe_unpickle()
- learned_embeds = torch.load(learned_embeds_path)
- enable_safe_unpickle()
+ with safe_unpickle_disabled():
+ learned_embeds = torch.load(learned_embeds_path)
apply_learned_embed_in_clip(
learned_embeds, text_encoder, tokenizer, token, idempotent
)
@@ -941,30 +939,30 @@ def patch_pipe(
ti_path = _ti_lora_path(unet_path)
text_path = _text_lora_path_ui(unet_path)
- disable_safe_unpickle()
- if patch_unet:
- print("LoRA : Patching Unet")
- lora_patch = get_target_module(
- "patch",
- bool(unet_target_replace_module == UNET_EXTENDED_TARGET_REPLACE)
- )
+ with safe_unpickle_disabled():
+ if patch_unet:
+ print("LoRA : Patching Unet")
+ lora_patch = get_target_module(
+ "patch",
+ bool(unet_target_replace_module == UNET_EXTENDED_TARGET_REPLACE)
+ )
- lora_patch(
- pipe.unet,
- torch.load(unet_path),
- r=r,
- target_replace_module=unet_target_replace_module,
- )
+ lora_patch(
+ pipe.unet,
+ torch.load(unet_path),
+ r=r,
+ target_replace_module=unet_target_replace_module,
+ )
- if patch_text:
- print("LoRA : Patching text encoder")
- monkeypatch_or_replace_lora(
- pipe.text_encoder,
- torch.load(text_path),
- target_replace_module=text_target_replace_module,
- r=r_txt,
- )
- enable_safe_unpickle()
+ if patch_text:
+ print("LoRA : Patching text encoder")
+ monkeypatch_or_replace_lora(
+ pipe.text_encoder,
+ torch.load(text_path),
+ target_replace_module=text_target_replace_module,
+ r=r_txt,
+ )
+
if patch_ti:
print("LoRA : Patching token input")
token = load_learned_embed_in_clip(
diff --git a/postinstall.py b/postinstall.py
index 886c0b88..8e421ee8 100644
--- a/postinstall.py
+++ b/postinstall.py
@@ -4,8 +4,13 @@
import subprocess
import sys
from dataclasses import dataclass
+from typing import Optional
import git
+from packaging import version as pv
+
+from importlib import metadata
+
from packaging.version import Version
from dreambooth import shared
@@ -56,6 +61,32 @@ def pip_install(*args):
print(line)
+def is_installed(pkg: str, version: Optional[str] = None, check_strict: bool = True) -> bool:
+ try:
+ # Retrieve the package version from the installed package metadata
+ installed_version = metadata.version(pkg)
+
+ # If version is not specified, just return True as the package is installed
+ if version is None:
+ return True
+
+ # Compare the installed version with the required version
+ if check_strict:
+ # Strict comparison (must be an exact match)
+ return pv.parse(installed_version) == pv.parse(version)
+ else:
+ # Non-strict comparison (installed version must be greater than or equal to the required version)
+ return pv.parse(installed_version) >= pv.parse(version)
+
+ except metadata.PackageNotFoundError:
+ # The package is not installed
+ return False
+ except Exception as e:
+ # Any other exceptions encountered
+ print(f"Error: {e}")
+ return False
+
+
def install_requirements():
dreambooth_skip_install = os.environ.get("DREAMBOOTH_SKIP_INSTALL", False)
req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
@@ -63,26 +94,52 @@ def install_requirements():
if dreambooth_skip_install or req_file == req_file_startup_arg:
return
-
+ print("Checking Dreambooth requirements...")
has_diffusers = importlib.util.find_spec("diffusers") is not None
has_tqdm = importlib.util.find_spec("tqdm") is not None
transformers_version = importlib_metadata.version("transformers")
+ strict = True
+ non_strict_separators = ["==", ">=", "<=", ">", "<", "~="]
+ # Load the requirements file
+ with open(req_file_startup_arg, "r") as f:
+ reqs = f.readlines()
+
+ if os.name == "darwin":
+ reqs.append("tensorboard==2.11.2")
+ else:
+ reqs.append("tensorboard==2.13.0")
+
+ for line in reqs:
+ try:
+ package = line.strip()
+ if package and not package.startswith("#"):
+ package_version = None
+ strict = True
+ for separator in non_strict_separators:
+ if separator in package:
+ strict = False
+ package, package_version = line.split(separator)
+ package = package.strip()
+ package_version = package_version.strip()
+ break
+ if not is_installed(package, package_version, strict):
+ print(f"[Dreambooth] {package} v{package_version} is not installed.")
+ pip_install(line)
+ else:
+ print(f"[Dreambooth] {package} v{package_version} is already installed.")
- try:
- pip_install("-r", req_file)
-
- if has_diffusers and has_tqdm and Version(transformers_version) < Version("4.26.1"):
- print()
- print("Does your project take forever to startup?")
- print("Repetitive dependency installation may be the reason.")
- print("Automatic1111's base project sets strict requirements on outdated dependencies.")
- print("If an extension is using a newer version, the dependency is uninstalled and reinstalled twice every startup.")
- print()
- except subprocess.CalledProcessError as grepexc:
- error_msg = grepexc.stdout.decode()
- print_requirement_installation_error(error_msg)
- raise grepexc
+ except subprocess.CalledProcessError as grepexc:
+ error_msg = grepexc.stdout.decode()
+ print_requirement_installation_error(error_msg)
+ if has_diffusers and has_tqdm and Version(transformers_version) < Version("4.26.1"):
+ print()
+ print("Does your project take forever to startup?")
+ print("Repetitive dependency installation may be the reason.")
+ print("Automatic1111's base project sets strict requirements on outdated dependencies.")
+ print(
+ "If an extension is using a newer version, the dependency is uninstalled and reinstalled twice every startup.")
+ print()
def check_xformers():
"""
@@ -95,13 +152,22 @@ def check_xformers():
try:
torch_version = importlib_metadata.version("torch")
is_torch_1 = Version(torch_version) < Version("2")
+ is_torch_2_1 = Version(torch_version) < Version("2.0")
if is_torch_1:
print_xformers_torch1_instructions(xformers_version)
+ # Torch 2.0.1 is not available on PyPI for xformers version 22
+ elif is_torch_2_1:
+ os_string = "win_amd64" if os.name == "nt" else "manylinux2014_x86_64"
+ # Get the version of python
+ py_string = f"cp{sys.version_info.major}{sys.version_info.minor}-cp{sys.version_info.major}{sys.version_info.minor}"
+ wheel_url = f"https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-{py_string}-{os_string}.whl"
+ pip_install(wheel_url, "--upgrade", "--no-deps")
else:
- pip_install("--force-reinstall", "xformers")
+ pip_install("xformers==0.0.21", "--index-url https://download.pytorch.org/whl/cu118")
except subprocess.CalledProcessError as grepexc:
error_msg = grepexc.stdout.decode()
- print_xformers_installation_error(error_msg)
+ if "WARNING: Ignoring invalid distribution" not in error_msg:
+ print_xformers_installation_error(error_msg)
except:
pass
@@ -113,15 +179,23 @@ def check_bitsandbytes():
bitsandbytes_version = importlib_metadata.version("bitsandbytes")
if os.name == "nt":
if bitsandbytes_version != "0.41.1":
- try:
- print("Installing bitsandbytes")
- pip_install("--force-install","==prefer-binary","--extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui","bitsandbytes==0.41.1")
- except:
- print("Bitsandbytes 0.41.1 installation failed.")
- print("Some features such as 8bit optimizers will be unavailable")
- print("Please install manually with")
- print("'python -m pip install bitsandbytes==0.41.1 --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui --prefer-binary --force-install'")
- pass
+ venv_path = os.environ.get("VIRTUAL_ENV", None)
+ # Check for the dll in venv/lib/site-packages/bitsandbytes/libbitsandbytes_cuda118.dll
+ # If it doesn't exist, append the requirement
+ if not venv_path:
+ print("Could not find the virtual environment path. Skipping bitsandbytes installation.")
+ else:
+ win_dll = os.path.join(venv_path, "lib", "site-packages", "bitsandbytes", "libbitsandbytes_cuda118.dll")
+ if not os.path.exists(win_dll):
+ try:
+ print("Installing bitsandbytes")
+ pip_install("--force-install","==prefer-binary","--extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui","bitsandbytes==0.41.1")
+ except:
+ print("Bitsandbytes 0.41.1 installation failed.")
+ print("Some features such as 8bit optimizers will be unavailable")
+ print("Please install manually with")
+ print("'python -m pip install bitsandbytes==0.41.1 --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui --prefer-binary --force-install'")
+ pass
else:
if bitsandbytes_version != "0.41.1":
try:
@@ -150,12 +224,12 @@ def check_versions():
#Probably a bad idea but update ALL the dependencies
dependencies = [
- Dependency(module="xformers", version="0.0.21", required=False),
+ Dependency(module="xformers", version="0.0.22", required=False),
Dependency(module="torch", version="1.13.1" if is_mac else "2.0.1+cu118"),
Dependency(module="torchvision", version="0.14.1" if is_mac else "0.15.2+cu118"),
- Dependency(module="accelerate", version="0.22.0"),
- Dependency(module="diffusers", version="0.20.1"),
- Dependency(module="transformers", version="4.25.1"),
+ Dependency(module="accelerate", version="0.21.0"),
+ Dependency(module="diffusers", version="0.22.1"),
+ Dependency(module="transformers", version="4.30.2"),
Dependency(module="bitsandbytes", version="0.41.1"),
]
diff --git a/preprocess/preprocess_utils.py b/preprocess/preprocess_utils.py
new file mode 100644
index 00000000..5d4c203a
--- /dev/null
+++ b/preprocess/preprocess_utils.py
@@ -0,0 +1,50 @@
+import os
+from typing import Tuple, List, Dict
+
+import gradio as gr
+
+from dreambooth.utils.image_utils import FilenameTextGetter
+
+image_data = []
+
+def load_image_data(input_path: str, recurse: bool = False) -> List[Dict[str,str]]:
+ if not os.path.exists(input_path):
+ print(f"Input path {input_path} does not exist")
+ return []
+ global image_data
+ results = []
+ from dreambooth.utils.image_utils import list_features, is_image
+ pil_features = list_features()
+ # Get a list from PIL of all the image formats it supports
+
+ for root, dirs, files in os.walk(input_path):
+ for file in files:
+ full_path = os.path.join(root, file)
+ print(f"Checking {full_path}")
+ if is_image(full_path, pil_features):
+ results.append(full_path)
+ if not recurse:
+ break
+
+ output = []
+ text_getter = FilenameTextGetter()
+ for img_path in results:
+ file_text = text_getter.read_text(img_path)
+ output.append({'image': img_path, 'text': file_text})
+ image_data = output
+ return output
+
+def check_preprocess_path(input_path: str, recurse: bool = False) -> Tuple[gr.update, gr.update]:
+ output_status = gr.update(visible=True)
+ output_gallery = gr.update(visible=True)
+ results = load_image_data(input_path, recurse)
+ if len(results) == 0:
+ return output_status, output_gallery
+ else:
+ images = [r['image'] for r in results]
+ output_status = gr.update(visible=True, value='Found {len(results)} images')
+ output_gallery = gr.update(visible=True, value=images)
+ return output_status, output_gallery
+
+def load_image_caption(evt: gr.SelectData): # SelectData is a subclass of EventData
+ return gr.update(value=f"You selected {evt.value} at {evt.index} from {evt.target}")
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index cc4c3193..2b0d986c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,21 +1,12 @@
-accelerate~=0.23.0
+accelerate==0.21.0
bitsandbytes~=0.41.1
dadaptation==3.1
-diffusers~=0.21.2
-discord-webhook~=1.1.0
-fastapi~=0.94.1
-gitpython~=3.1.31
-pytorch_optimizer==2.11.1
-Pillow==9.5.0
-tqdm==4.65.0
+diffusers~=0.22.1
+discord-webhook==1.3.0
+fastapi
+gitpython==3.1.40
+pytorch_optimizer==2.12.0
+Pillow
+tqdm
tomesd~=0.1.2
-transformers~=4.32.1; # > 4.26.x causes issues (db extension #1110)
-# Get prebuilt Windows wheels from jllllll
-bitsandbytes~=0.41.1; sys_platform == 'win32' or platform_machine == 'AMD64' \
---extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui --prefer-binary
-# Get Linux and MacOS wheels from PyPi
-bitsandbytes~=0.41.1; sys_platform != 'win32' or platform_machine != 'AMD64' --prefer-binary
-# Tensor
-tensorboard==2.13.0; sys_platform != 'darwin' or platform_machine != 'arm64'
-# Tensor MacOS
-tensorboard==2.11.2; sys_platform == 'darwin' and platform_machine == 'arm64'
+transformers~=4.30.2
\ No newline at end of file
diff --git a/scripts/main.py b/scripts/main.py
index 0d6df7b2..9ab47c4b 100644
--- a/scripts/main.py
+++ b/scripts/main.py
@@ -53,13 +53,16 @@
from helpers.version_helper import check_updates
from modules import script_callbacks, sd_models
from modules.ui import gr_show, create_refresh_button
+from preprocess.preprocess_utils import check_preprocess_path, load_image_caption
+preprocess_params = []
params_to_save = []
params_to_load = []
refresh_symbol = "\U0001f504" # 🔄
delete_symbol = "\U0001F5D1" # 🗑️
update_symbol = "\U0001F51D" # 🠝
log_parser = LogParser()
+show_advanced = True
def read_metadata_from_safetensors(filename):
@@ -281,283 +284,507 @@ def on_ui_tabs():
with gr.Blocks() as dreambooth_interface:
# Top button row
with gr.Row(equal_height=True, elem_id="DbTopRow"):
- db_load_params = gr.Button(value="Load Settings", elem_id="db_load_params")
- db_save_params = gr.Button(value="Save Settings", elem_id="db_save_config")
+ db_load_params = gr.Button(value="Load Settings", elem_id="db_load_params", size="sm")
+ db_save_params = gr.Button(value="Save Settings", elem_id="db_save_config", size="sm")
db_train_model = gr.Button(
- value="Train", variant="primary", elem_id="db_train"
+ value="Train", variant="primary", elem_id="db_train", size="sm"
)
db_generate_checkpoint = gr.Button(
- value="Generate Ckpt", elem_id="db_gen_ckpt"
+ value="Generate Ckpt", elem_id="db_gen_ckpt", size="sm"
)
db_generate_checkpoint_during = gr.Button(
- value="Save Weights", elem_id="db_gen_ckpt_during"
+ value="Save Weights", elem_id="db_gen_ckpt_during", size="sm"
)
db_train_sample = gr.Button(
- value="Generate Samples", elem_id="db_train_sample"
+ value="Generate Samples", elem_id="db_train_sample", size="sm"
)
- db_cancel = gr.Button(value="Cancel", elem_id="db_cancel")
+ db_cancel = gr.Button(value="Cancel", elem_id="db_cancel", size="sm")
with gr.Row():
gr.HTML(value="Select or create a model to begin.", elem_id="hint_row")
- with gr.Row().style(equal_height=False):
- with gr.Column(variant="panel", elem_id="ModelPanel"):
- with gr.Column():
- gr.HTML(value="Model")
- with gr.Tab("Select"):
- with gr.Row():
- db_model_name = gr.Dropdown(
- label="Model", choices=sorted(get_db_models())
- )
- create_refresh_button(
- db_model_name,
- get_db_models,
- lambda: {"choices": sorted(get_db_models())},
- "refresh_db_models",
- )
- with gr.Row():
- db_snapshot = gr.Dropdown(
- label="Snapshot to Resume",
- choices=sorted(get_model_snapshots()),
- )
- create_refresh_button(
- db_snapshot,
- get_model_snapshots,
- lambda: {"choices": sorted(get_model_snapshots())},
- "refresh_db_snapshots",
- )
- with gr.Row(visible=False) as lora_model_row:
- db_lora_model_name = gr.Dropdown(
- label="Lora Model", choices=get_sorted_lora_models()
- )
- create_refresh_button(
- db_lora_model_name,
- get_sorted_lora_models,
- lambda: {"choices": get_sorted_lora_models()},
- "refresh_lora_models",
- )
- with gr.Row():
+ with gr.Row(elem_id="ModelDetailRow", visible=False, variant="compact") as db_model_info:
+ with gr.Column():
+ with gr.Row(variant="compact"):
+ with gr.Column():
+ with gr.Row(variant="compact"):
gr.HTML(value="Loaded Model:")
db_model_path = gr.HTML()
- with gr.Row():
- gr.HTML(value="Model Revision:")
- db_revision = gr.HTML(elem_id="db_revision")
- with gr.Row():
- gr.HTML(value="Model Epoch:")
- db_epochs = gr.HTML(elem_id="db_epochs")
- with gr.Row():
- gr.HTML(value="Model type:")
- db_model_type = gr.HTML(elem_id="db_model_type")
- with gr.Row():
- gr.HTML(value="Has EMA:")
- db_has_ema = gr.HTML(elem_id="db_has_ema")
- with gr.Row():
+ with gr.Row(variant="compact"):
gr.HTML(value="Source Checkpoint:")
db_src = gr.HTML()
- with gr.Row(visible=False):
- gr.HTML(value="Experimental Shared Source:")
- db_shared_diffusers_path = gr.HTML()
- with gr.Tab("Create"):
- with gr.Column():
- db_create_model = gr.Button(
- value="Create Model", variant="primary"
- )
- db_new_model_name = gr.Textbox(label="Name")
- with gr.Row():
- db_create_from_hub = gr.Checkbox(
- label="Create From Hub", value=False
- )
- db_model_type_select=gr.Dropdown(label="Model Type", choices=["v1x", "v2x-512", "v2x", "SDXL", "ControlNet"])
- db_use_shared_src = gr.Checkbox(
- label="Experimental Shared Src", value=False, visible=False
- )
- with gr.Column(visible=False) as hub_row:
- db_new_model_url = gr.Textbox(
- label="Model Path",
- placeholder="runwayml/stable-diffusion-v1-5",
- )
- db_new_model_token = gr.Textbox(
- label="HuggingFace Token", value=""
- )
- with gr.Column(visible=True) as local_row:
+ with gr.Column():
+ with gr.Row(variant="compact"):
+ gr.HTML(value="Model Epoch:")
+ db_epochs = gr.HTML(elem_id="db_epochs")
+ with gr.Row(variant="compact"):
+ gr.HTML(value="Model Revision:")
+ db_revision = gr.HTML(elem_id="db_revision")
+ with gr.Column():
+ with gr.Row(variant="compact"):
+ gr.HTML(value="Model type:")
+ db_model_type = gr.HTML(elem_id="db_model_type")
+ with gr.Row(variant="compact"):
+ gr.HTML(value="Has EMA:")
+ db_has_ema = gr.HTML(elem_id="db_has_ema")
+ with gr.Row(variant="compact", visible=False):
+ gr.HTML(value="Experimental Shared Source:")
+ db_shared_diffusers_path = gr.HTML()
+ with gr.Row(equal_height=False):
+ with gr.Column(variant="panel", elem_id="SettingsPanel"):
+ with gr.Row():
+ with gr.Column(scale=1, min_width=100, elem_classes="halfElement"):
+ gr.HTML(value="Settings")
+ with gr.Column(scale=1, min_width=100, elem_classes="halfElement"):
+ db_show_advanced = gr.Button(value="Show Advanced", size="sm", elem_classes="advBtn", visible=False)
+ db_hide_advanced = gr.Button(value="Hide Advanced", variant="primary", size="sm", elem_id="db_hide_advanced", elem_classes="advBtn")
+ with gr.Tab("Model", elem_id="ModelPanel"):
+ with gr.Column():
+ with gr.Tab("Select"):
with gr.Row():
- db_new_model_src = gr.Dropdown(
- label="Source Checkpoint",
- choices=sorted(get_sd_models()),
+ db_model_name = gr.Dropdown(
+ label="Model", choices=sorted(get_db_models())
)
create_refresh_button(
- db_new_model_src,
- get_sd_models,
- lambda: {"choices": sorted(get_sd_models())},
- "refresh_sd_models",
+ db_model_name,
+ get_db_models,
+ lambda: {"choices": sorted(get_db_models())},
+ "refresh_db_models",
)
- with gr.Column(visible=False) as shared_row:
- with gr.Row():
- db_new_model_shared_src = gr.Dropdown(
- label="EXPERIMENTAL: LoRA Shared Diffusers Source",
- choices=sorted(get_shared_models()),
- value="",
- visible=False
+ with gr.Row() as db_snapshot_row:
+ db_snapshot = gr.Dropdown(
+ label="Snapshot to Resume",
+ choices=sorted(get_model_snapshots()),
)
create_refresh_button(
- db_new_model_shared_src,
- get_shared_models,
- lambda: {"choices": sorted(get_shared_models())},
- "refresh_shared_models",
+ db_snapshot,
+ get_model_snapshots,
+ lambda: {"choices": sorted(get_model_snapshots())},
+ "refresh_db_snapshots",
)
- db_new_model_extract_ema = gr.Checkbox(
- label="Extract EMA Weights", value=False
- )
- db_train_unfrozen = gr.Checkbox(label="Unfreeze Model", value=False)
- with gr.Column():
- with gr.Accordion(open=False, label="Resources"):
- with gr.Column():
- gr.HTML(
- value="Beginners guide",
- )
- gr.HTML(
- value="Release notes",
- )
- with gr.Column(variant="panel", elem_id="SettingsPanel"):
- gr.HTML(value="Input")
- with gr.Tab("Settings", elem_id="TabSettings"):
- db_performance_wizard = gr.Button(value="Performance Wizard (WIP)")
- with gr.Accordion(open=True, label="Basic"):
- with gr.Column():
- gr.HTML(value="General")
- db_use_lora = gr.Checkbox(label="Use LORA", value=False)
- db_use_lora_extended = gr.Checkbox(
- label="Use Lora Extended",
- value=False,
- visible=False,
- )
- db_train_imagic = gr.Checkbox(label="Train Imagic Only", value=False, visible=False)
- db_train_inpainting = gr.Checkbox(
- label="Train Inpainting Model",
- value=False,
- visible=False,
- )
- with gr.Column():
- gr.HTML(value="Intervals")
- db_num_train_epochs = gr.Slider(
- label="Training Steps Per Image (Epochs)",
- value=100,
- maximum=1000,
- step=1,
- )
- db_epoch_pause_frequency = gr.Slider(
- label="Pause After N Epochs",
- value=0,
- maximum=100,
- step=1,
- )
- db_epoch_pause_time = gr.Slider(
- label="Amount of time to pause between Epochs (s)",
- value=0,
- maximum=3600,
- step=1,
- )
- db_save_embedding_every = gr.Slider(
- label="Save Model Frequency (Epochs)",
- value=25,
- maximum=1000,
- step=1,
- )
- db_save_preview_every = gr.Slider(
- label="Save Preview(s) Frequency (Epochs)",
- value=5,
- maximum=1000,
- step=1,
- )
-
- with gr.Column():
- gr.HTML(value="Batching")
- db_train_batch_size = gr.Slider(
- label="Batch Size",
- value=1,
- minimum=1,
- maximum=100,
- step=1,
- )
- db_gradient_accumulation_steps = gr.Slider(
- label="Gradient Accumulation Steps",
- value=1,
- minimum=1,
- maximum=100,
- step=1,
- )
- db_sample_batch_size = gr.Slider(
- label="Class Batch Size",
- minimum=1,
- maximum=100,
- value=1,
- step=1,
- )
- db_gradient_set_to_none = gr.Checkbox(
- label="Set Gradients to None When Zeroing", value=True
- )
- db_gradient_checkpointing = gr.Checkbox(
- label="Gradient Checkpointing", value=True
- )
-
- with gr.Column():
- gr.HTML(value="Learning Rate")
- with gr.Row(visible=False) as lora_lr_row:
- db_lora_learning_rate = gr.Number(
- label="Lora UNET Learning Rate", value=1e-4
+ with gr.Row(visible=False) as lora_model_row:
+ db_lora_model_name = gr.Dropdown(
+ label="Lora Model", choices=get_sorted_lora_models()
)
- db_lora_txt_learning_rate = gr.Number(
- label="Lora Text Encoder Learning Rate", value=5e-5
+ create_refresh_button(
+ db_lora_model_name,
+ get_sorted_lora_models,
+ lambda: {"choices": get_sorted_lora_models()},
+ "refresh_lora_models",
)
- with gr.Row() as standard_lr_row:
- db_learning_rate = gr.Number(
- label="Learning Rate", value=2e-6
+ with gr.Tab("Create"):
+ with gr.Column():
+ db_create_model = gr.Button(
+ value="Create Model", variant="primary"
)
- db_txt_learning_rate = gr.Number(
- label="Text Encoder Learning Rate", value=1e-6
+ db_new_model_name = gr.Textbox(label="Name")
+ with gr.Row():
+ db_create_from_hub = gr.Checkbox(
+ label="Create From Hub", value=False
)
-
- db_lr_scheduler = gr.Dropdown(
- label="Learning Rate Scheduler",
- value="constant_with_warmup",
- choices=list_schedulers(),
- )
- db_learning_rate_min = gr.Number(
- label="Min Learning Rate", value=1e-6, visible=False
- )
- db_lr_cycles = gr.Number(
- label="Number of Hard Resets",
- value=1,
- precision=0,
- visible=False,
+ db_model_type_select = gr.Dropdown(label="Model Type",
+ choices=["v1x", "v2x-512", "v2x", "SDXL",
+ "ControlNet"], value="v1x")
+ db_use_shared_src = gr.Checkbox(
+ label="Experimental Shared Src", value=False, visible=False
+ )
+ with gr.Column(visible=False) as hub_row:
+ db_new_model_url = gr.Textbox(
+ label="Model Path",
+ placeholder="runwayml/stable-diffusion-v1-5",
+ )
+ db_new_model_token = gr.Textbox(
+ label="HuggingFace Token", value=""
+ )
+ with gr.Column(visible=True) as local_row:
+ with gr.Row():
+ db_new_model_src = gr.Dropdown(
+ label="Source Checkpoint",
+ choices=sorted(get_sd_models()),
+ )
+ create_refresh_button(
+ db_new_model_src,
+ get_sd_models,
+ lambda: {"choices": sorted(get_sd_models())},
+ "refresh_sd_models",
+ )
+ with gr.Column(visible=False) as shared_row:
+ with gr.Row():
+ db_new_model_shared_src = gr.Dropdown(
+ label="EXPERIMENTAL: LoRA Shared Diffusers Source",
+ choices=sorted(get_shared_models()),
+ value="",
+ visible=False
+ )
+ create_refresh_button(
+ db_new_model_shared_src,
+ get_shared_models,
+ lambda: {"choices": sorted(get_shared_models())},
+ "refresh_shared_models",
+ )
+ db_new_model_extract_ema = gr.Checkbox(
+ label="Extract EMA Weights", value=False
)
- db_lr_factor = gr.Number(
- label="Constant/Linear Starting Factor",
- value=0.5,
- precision=2,
- visible=False,
+ db_train_unfrozen = gr.Checkbox(label="Unfreeze Model", value=True)
+ with gr.Column():
+ with gr.Accordion(open=False, label="Resources"):
+ with gr.Column():
+ gr.HTML(
+ value="Beginners guide",
+ )
+ gr.HTML(
+ value="Release notes",
+ )
+ # with gr.Tab("Preprocess", elem_id="PreprocessPanel", visible=False):
+ # with gr.Row():
+ # with gr.Column(scale=2, variant="compact"):
+ # db_preprocess_path = gr.Textbox(
+ # label="Image Path", value="", placeholder="Enter the path to your images"
+ # )
+ # with gr.Column(variant="compact"):
+ # db_preprocess_recursive = gr.Checkbox(
+ # label="Recursive", value=False, container=True, elem_classes=["singleCheckbox"]
+ # )
+ # with gr.Row():
+ # with gr.Tab("Auto-Caption"):
+ # with gr.Row():
+ # gr.HTML(value="Auto-Caption")
+ # with gr.Tab("Edit Captions"):
+ # with gr.Row():
+ # db_preprocess_autosave = gr.Checkbox(
+ # label="Autosave", value=False
+ # )
+ # with gr.Row():
+ # gr.HTML(value="Edit Captions")
+ # with gr.Tab("Edit Images"):
+ # with gr.Row():
+ # gr.HTML(value="Edit Images")
+ # with gr.Row():
+ # db_preprocess = gr.Button(
+ # value="Preprocess", variant="primary"
+ # )
+ # db_preprocess_all = gr.Button(
+ # value="Preprocess All", variant="primary"
+ # )
+ # with gr.Row():
+ # db_preprocess_all = gr.Button(
+ # value="Preprocess All", variant="primary"
+ # )
+ with gr.Tab("Concepts", elem_id="TabConcepts") as concept_tab:
+ with gr.Column(variant="panel"):
+ with gr.Accordion(open=False, label="Concept 1"):
+ (
+ c1_instance_data_dir,
+ c1_class_data_dir,
+ c1_instance_prompt,
+ c1_class_prompt,
+ c1_save_sample_prompt,
+ c1_save_sample_template,
+ c1_instance_token,
+ c1_class_token,
+ c1_num_class_images_per,
+ c1_class_negative_prompt,
+ c1_class_guidance_scale,
+ c1_class_infer_steps,
+ c1_save_sample_negative_prompt,
+ c1_n_save_sample,
+ c1_sample_seed,
+ c1_save_guidance_scale,
+ c1_save_infer_steps,
+ ) = build_concept_panel(1)
+
+ with gr.Accordion(open=False, label="Concept 2"):
+ (
+ c2_instance_data_dir,
+ c2_class_data_dir,
+ c2_instance_prompt,
+ c2_class_prompt,
+ c2_save_sample_prompt,
+ c2_save_sample_template,
+ c2_instance_token,
+ c2_class_token,
+ c2_num_class_images_per,
+ c2_class_negative_prompt,
+ c2_class_guidance_scale,
+ c2_class_infer_steps,
+ c2_save_sample_negative_prompt,
+ c2_n_save_sample,
+ c2_sample_seed,
+ c2_save_guidance_scale,
+ c2_save_infer_steps,
+ ) = build_concept_panel(2)
+
+ with gr.Accordion(open=False, label="Concept 3"):
+ (
+ c3_instance_data_dir,
+ c3_class_data_dir,
+ c3_instance_prompt,
+ c3_class_prompt,
+ c3_save_sample_prompt,
+ c3_save_sample_template,
+ c3_instance_token,
+ c3_class_token,
+ c3_num_class_images_per,
+ c3_class_negative_prompt,
+ c3_class_guidance_scale,
+ c3_class_infer_steps,
+ c3_save_sample_negative_prompt,
+ c3_n_save_sample,
+ c3_sample_seed,
+ c3_save_guidance_scale,
+ c3_save_infer_steps,
+ ) = build_concept_panel(3)
+
+ with gr.Accordion(open=False, label="Concept 4"):
+ (
+ c4_instance_data_dir,
+ c4_class_data_dir,
+ c4_instance_prompt,
+ c4_class_prompt,
+ c4_save_sample_prompt,
+ c4_save_sample_template,
+ c4_instance_token,
+ c4_class_token,
+ c4_num_class_images_per,
+ c4_class_negative_prompt,
+ c4_class_guidance_scale,
+ c4_class_infer_steps,
+ c4_save_sample_negative_prompt,
+ c4_n_save_sample,
+ c4_sample_seed,
+ c4_save_guidance_scale,
+ c4_save_infer_steps,
+ ) = build_concept_panel(4)
+ with gr.Tab("Parameters", elem_id="TabSettings"):
+ db_performance_wizard = gr.Button(value="Performance Wizard (WIP)", visible=False)
+ with gr.Accordion(open=False, label="Performance"):
+ db_use_ema = gr.Checkbox(
+ label="Use EMA", value=False
+ )
+ db_optimizer = gr.Dropdown(
+ label="Optimizer",
+ value="8bit AdamW",
+ choices=list_optimizer(),
+ )
+ db_mixed_precision = gr.Dropdown(
+ label="Mixed Precision",
+ value=select_precision(),
+ choices=list_precisions(),
+ )
+ db_full_mixed_precision = gr.Checkbox(
+ label="Full Mixed Precision", value=True
+ )
+ db_attention = gr.Dropdown(
+ label="Memory Attention",
+ value=select_attention(),
+ choices=list_attention(),
+ )
+ db_cache_latents = gr.Checkbox(
+ label="Cache Latents", value=True
+ )
+ db_train_unet = gr.Checkbox(
+ label="Train UNET", value=True
+ )
+ db_stop_text_encoder = gr.Slider(
+ label="Step Ratio of Text Encoder Training",
+ minimum=0,
+ maximum=1,
+ step=0.05,
+ value=1.0,
+ visible=True,
+ )
+ db_offset_noise = gr.Slider(
+ label="Offset Noise",
+ minimum=-1,
+ maximum=1,
+ step=0.01,
+ value=0,
+ )
+ db_freeze_clip_normalization = gr.Checkbox(
+ label="Freeze CLIP Normalization Layers",
+ visible=True,
+ value=False,
+ )
+ db_clip_skip = gr.Slider(
+ label="Clip Skip",
+ value=2,
+ minimum=1,
+ maximum=12,
+ step=1,
+ )
+ db_weight_decay = gr.Slider(
+ label="Weight Decay",
+ minimum=0,
+ maximum=1,
+ step=0.001,
+ value=0.01,
+ visible=True,
+ )
+ db_tenc_weight_decay = gr.Slider(
+ label="TENC Weight Decay",
+ minimum=0,
+ maximum=1,
+ step=0.001,
+ value=0.01,
+ visible=True,
+ )
+ db_tenc_grad_clip_norm = gr.Slider(
+ label="TENC Gradient Clip Norm",
+ minimum=0,
+ maximum=128,
+ step=0.25,
+ value=0,
+ visible=True,
+ )
+ db_min_snr_gamma = gr.Slider(
+ label="Min SNR Gamma",
+ minimum=0,
+ maximum=10,
+ step=0.1,
+ visible=True,
+ )
+ db_pad_tokens = gr.Checkbox(
+ label="Pad Tokens", value=True
+ )
+ db_strict_tokens = gr.Checkbox(
+ label="Strict Tokens", value=False
+ )
+ db_shuffle_tags = gr.Checkbox(
+ label="Shuffle Tags", value=True
+ )
+ db_max_token_length = gr.Slider(
+ label="Max Token Length",
+ minimum=75,
+ maximum=300,
+ step=75,
+ )
+ with gr.Accordion(open=False, label="Intervals"):
+ db_num_train_epochs = gr.Slider(
+ label="Training Steps Per Image (Epochs)",
+ value=100,
+ maximum=1000,
+ step=1,
+ )
+ db_epoch_pause_frequency = gr.Slider(
+ label="Pause After N Epochs",
+ value=0,
+ maximum=100,
+ step=1,
+ )
+ db_epoch_pause_time = gr.Slider(
+ label="Amount of time to pause between Epochs (s)",
+ value=0,
+ maximum=3600,
+ step=1,
+ )
+ db_save_embedding_every = gr.Slider(
+ label="Save Model Frequency (Epochs)",
+ value=25,
+ maximum=1000,
+ step=1,
+ )
+ db_save_preview_every = gr.Slider(
+ label="Save Preview(s) Frequency (Epochs)",
+ value=5,
+ maximum=1000,
+ step=1,
+ )
+ with gr.Accordion(open=False, label="Batch Sizes") as db_batch_size_view:
+ db_train_batch_size = gr.Slider(
+ label="Batch Size",
+ value=1,
+ minimum=1,
+ maximum=100,
+ step=1,
+ )
+ db_gradient_accumulation_steps = gr.Slider(
+ label="Gradient Accumulation Steps",
+ value=1,
+ minimum=1,
+ maximum=100,
+ step=1,
+ )
+ db_sample_batch_size = gr.Slider(
+ label="Class Batch Size",
+ minimum=1,
+ maximum=100,
+ value=1,
+ step=1,
+ )
+ db_gradient_set_to_none = gr.Checkbox(
+ label="Set Gradients to None When Zeroing", value=True
+ )
+ db_gradient_checkpointing = gr.Checkbox(
+ label="Gradient Checkpointing", value=True
+ )
+ with gr.Accordion(open=False, label="Learning Rate"):
+ with gr.Row(visible=False) as lora_lr_row:
+ db_lora_learning_rate = gr.Number(
+ label="Lora UNET Learning Rate", value=1e-4
)
- db_lr_power = gr.Number(
- label="Polynomial Power",
- value=1.0,
- precision=1,
- visible=False,
+ db_lora_txt_learning_rate = gr.Number(
+ label="Lora Text Encoder Learning Rate", value=5e-5
)
- db_lr_scale_pos = gr.Slider(
- label="Scale Position",
- value=0.5,
- minimum=0,
- maximum=1,
- step=0.05,
- visible=False,
+ with gr.Row() as standard_lr_row:
+ db_learning_rate = gr.Number(
+ label="Learning Rate", value=2e-6
)
- db_lr_warmup_steps = gr.Slider(
- label="Learning Rate Warmup Steps",
- value=500,
- step=5,
- maximum=1000,
+ db_txt_learning_rate = gr.Number(
+ label="Text Encoder Learning Rate", value=1e-6
)
+ db_lr_scheduler = gr.Dropdown(
+ label="Learning Rate Scheduler",
+ value="constant_with_warmup",
+ choices=list_schedulers(),
+ )
+ db_learning_rate_min = gr.Number(
+ label="Min Learning Rate", value=1e-6, visible=False
+ )
+ db_lr_cycles = gr.Number(
+ label="Number of Hard Resets",
+ value=1,
+ precision=0,
+ visible=False,
+ )
+ db_lr_factor = gr.Number(
+ label="Constant/Linear Starting Factor",
+ value=0.5,
+ precision=2,
+ visible=False,
+ )
+ db_lr_power = gr.Number(
+ label="Polynomial Power",
+ value=1.0,
+ precision=1,
+ visible=False,
+ )
+ db_lr_scale_pos = gr.Slider(
+ label="Scale Position",
+ value=0.5,
+ minimum=0,
+ maximum=1,
+ step=0.05,
+ visible=False,
+ )
+ db_lr_warmup_steps = gr.Slider(
+ label="Learning Rate Warmup Steps",
+ value=500,
+ step=5,
+ maximum=1000,
+ )
+ with gr.Accordion(open=False, label="Lora"):
+ db_use_lora = gr.Checkbox(label="Use LORA", value=False)
+ db_use_lora_extended = gr.Checkbox(
+ label="Use Lora Extended",
+ value=False,
+ visible=False,
+ )
+ db_train_imagic = gr.Checkbox(label="Train Imagic Only", value=False, visible=False)
+ db_train_inpainting = gr.Checkbox(
+ label="Train Inpainting Model",
+ value=False,
+ visible=False,
+ )
with gr.Column(visible=False) as lora_rank_col:
- gr.HTML("Lora")
db_lora_unet_rank = gr.Slider(
label="Lora UNET Rank",
value=4,
@@ -579,154 +806,208 @@ def on_ui_tabs():
maximum=1,
step=0.1,
)
-
+ with gr.Accordion(open=False, label="Image Processing"):
+ db_resolution = gr.Slider(
+ label="Max Resolution",
+ step=64,
+ minimum=128,
+ value=512,
+ maximum=2048,
+ elem_id="max_res",
+ )
+ db_hflip = gr.Checkbox(
+ label="Apply Horizontal Flip", value=False
+ )
+ db_dynamic_img_norm = gr.Checkbox(
+ label="Dynamic Image Normalization", value=False
+ )
+ with gr.Accordion(open=False, label="Prior Loss") as db_prior_loss_view:
+ db_prior_loss_scale = gr.Checkbox(
+ label="Scale Prior Loss", value=False
+ )
+ db_prior_loss_weight = gr.Slider(
+ label="Prior Loss Weight",
+ minimum=0.01,
+ maximum=1,
+ step=0.01,
+ value=0.75,
+ )
+ db_prior_loss_target = gr.Number(
+ label="Prior Loss Target",
+ value=100,
+ visible=False,
+ )
+ db_prior_loss_weight_min = gr.Slider(
+ label="Minimum Prior Loss Weight",
+ minimum=0.01,
+ maximum=1,
+ step=0.01,
+ value=0.1,
+ visible=False,
+ )
+ with gr.Accordion(open=False, label="Saving", elme_id="TabSave") as db_save_tab:
with gr.Column():
- gr.HTML(value="Image Processing")
- db_resolution = gr.Slider(
- label="Max Resolution",
- step=64,
- minimum=128,
- value=512,
- maximum=2048,
- elem_id="max_res",
+ gr.HTML("General")
+ db_custom_model_name = gr.Textbox(
+ label="Custom Model Name",
+ value="",
+ placeholder="Enter a model name for saving checkpoints and lora models.",
)
- db_hflip = gr.Checkbox(
- label="Apply Horizontal Flip", value=False
+ db_save_safetensors = gr.Checkbox(
+ label="Save in .safetensors format",
+ value=True,
+ visible=False,
)
- db_dynamic_img_norm = gr.Checkbox(
- label="Dynamic Image Normalization", value=False
+ db_save_ema = gr.Checkbox(
+ label="Save EMA Weights to Generated Models", value=True
)
-
- with gr.Column():
- gr.HTML(value="Tuning")
- db_use_ema = gr.Checkbox(
- label="Use EMA", value=False
+ db_infer_ema = gr.Checkbox(
+ label="Use EMA Weights for Inference", value=False
)
- db_optimizer = gr.Dropdown(
- label="Optimizer",
- value="8bit AdamW",
- choices=list_optimizer(),
+ with gr.Column():
+ gr.HTML("Checkpoints")
+ db_half_model = gr.Checkbox(label="Half Model", value=False)
+ db_use_subdir = gr.Checkbox(
+ label="Save Checkpoint to Subdirectory", value=True
)
- db_mixed_precision = gr.Dropdown(
- label="Mixed Precision",
- value=select_precision(),
- choices=list_precisions(),
+ db_save_ckpt_during = gr.Checkbox(
+ label="Generate a .ckpt file when saving during training."
)
- db_full_mixed_precision = gr.Checkbox(
- label="Full Mixed Precision", value=True
+ db_save_ckpt_after = gr.Checkbox(
+ label="Generate a .ckpt file when training completes.",
+ value=True,
)
- db_attention = gr.Dropdown(
- label="Memory Attention",
- value=select_attention(),
- choices=list_attention(),
+ db_save_ckpt_cancel = gr.Checkbox(
+ label="Generate a .ckpt file when training is canceled."
)
- db_cache_latents = gr.Checkbox(
- label="Cache Latents", value=True
+ with gr.Column(visible=False) as lora_save_col:
+ db_save_lora_during = gr.Checkbox(
+ label="Generate lora weights when saving during training."
)
- db_train_unet = gr.Checkbox(
- label="Train UNET", value=True
+ db_save_lora_after = gr.Checkbox(
+ label="Generate lora weights when training completes.",
+ value=True,
)
- db_stop_text_encoder = gr.Slider(
- label="Step Ratio of Text Encoder Training",
- minimum=0,
- maximum=1,
- step=0.05,
- value=1.0,
- visible=True,
+ db_save_lora_cancel = gr.Checkbox(
+ label="Generate lora weights when training is canceled."
)
- db_offset_noise = gr.Slider(
- label="Offset Noise",
- minimum=-1,
- maximum=1,
- step=0.01,
- value=0,
+ db_save_lora_for_extra_net = gr.Checkbox(
+ label="Generate lora weights for extra networks."
)
- db_freeze_clip_normalization = gr.Checkbox(
- label="Freeze CLIP Normalization Layers",
- visible=True,
- value=False,
+ with gr.Column():
+ gr.HTML("Diffusion Weights (training snapshots)")
+ db_save_state_during = gr.Checkbox(
+ label="Save separate diffusers snapshots when saving during training."
)
- db_clip_skip = gr.Slider(
- label="Clip Skip",
- value=2,
- minimum=1,
- maximum=12,
- step=1,
+ db_save_state_after = gr.Checkbox(
+ label="Save separate diffusers snapshots when training completes."
)
- db_weight_decay = gr.Slider(
- label="Weight Decay",
- minimum=0,
- maximum=1,
- step=0.001,
- value=0.01,
- visible=True,
+ db_save_state_cancel = gr.Checkbox(
+ label="Save separate diffusers snapshots when training is canceled."
)
- db_tenc_weight_decay = gr.Slider(
- label="TENC Weight Decay",
- minimum=0,
- maximum=1,
- step=0.001,
- value=0.01,
- visible=True,
+ with gr.Accordion(open=False, label="Image Generation", elem_id="TabGenerate") as db_generate_tab:
+ gr.HTML(value="Class Generation Schedulers")
+ db_class_gen_method = gr.Dropdown(
+ label="Image Generation Library",
+ value="Native Diffusers",
+ choices=[
+ "A1111 txt2img (Euler a)",
+ "Native Diffusers",
+ ]
+ )
+ db_scheduler = gr.Dropdown(
+ label="Image Generation Scheduler",
+ value="DEISMultistep",
+ choices=get_scheduler_names(),
+ )
+ gr.HTML(value="Manual Class Generation")
+ with gr.Column():
+ db_generate_classes = gr.Button(value="Generate Class Images")
+ db_generate_graph = gr.Button(value="Generate Graph")
+ db_graph_smoothing = gr.Slider(
+ value=50,
+ label="Graph Smoothing Steps",
+ minimum=10,
+ maximum=500,
)
- db_tenc_grad_clip_norm = gr.Slider(
- label="TENC Gradient Clip Norm",
- minimum=0,
- maximum=128,
- step=0.25,
- value=0,
- visible=True,
+ db_debug_buckets = gr.Button(value="Debug Buckets")
+ db_bucket_epochs = gr.Slider(
+ value=10,
+ step=1,
+ minimum=1,
+ maximum=1000,
+ label="Epochs to Simulate",
)
- db_min_snr_gamma = gr.Slider(
- label="Min SNR Gamma",
- minimum=0,
- maximum=10,
- step=0.1,
- visible=True,
+ db_bucket_batch = gr.Slider(
+ value=1,
+ step=1,
+ minimum=1,
+ maximum=500,
+ label="Batch Size to Simulate",
)
- db_pad_tokens = gr.Checkbox(
- label="Pad Tokens", value=True
+ db_generate_sample = gr.Button(value="Generate Sample Images")
+ db_sample_prompt = gr.Textbox(label="Sample Prompt")
+ db_sample_negative = gr.Textbox(label="Sample Negative Prompt")
+ db_sample_prompt_file = gr.Textbox(label="Sample Prompt File")
+ db_sample_width = gr.Slider(
+ label="Sample Width",
+ value=512,
+ step=64,
+ minimum=128,
+ maximum=2048,
)
- db_strict_tokens = gr.Checkbox(
- label="Strict Tokens", value=False
+ db_sample_height = gr.Slider(
+ label="Sample Height",
+ value=512,
+ step=64,
+ minimum=128,
+ maximum=2048,
)
- db_shuffle_tags = gr.Checkbox(
- label="Shuffle Tags", value=True
+ db_sample_seed = gr.Number(
+ label="Sample Seed", value=-1, precision=0
)
- db_max_token_length = gr.Slider(
- label="Max Token Length",
- minimum=75,
- maximum=300,
- step=75,
+ db_num_samples = gr.Slider(
+ label="Number of Samples to Generate",
+ value=1,
+ minimum=1,
+ maximum=1000,
+ step=1,
)
- with gr.Column():
- gr.HTML(value="Prior Loss")
- db_prior_loss_scale = gr.Checkbox(
- label="Scale Prior Loss", value=False
+ db_gen_sample_batch_size = gr.Slider(
+ label="Sample Batch Size",
+ value=1,
+ step=1,
+ minimum=1,
+ maximum=100,
+ interactive=True,
)
- db_prior_loss_weight = gr.Slider(
- label="Prior Loss Weight",
- minimum=0.01,
- maximum=1,
- step=0.01,
- value=0.75,
+ db_sample_steps = gr.Slider(
+ label="Sample Steps",
+ value=20,
+ minimum=1,
+ maximum=500,
+ step=1,
)
- db_prior_loss_target = gr.Number(
- label="Prior Loss Target",
- value=100,
- visible=False,
+ db_sample_scale = gr.Slider(
+ label="Sample CFG Scale",
+ value=7.5,
+ step=0.1,
+ minimum=1,
+ maximum=20,
)
- db_prior_loss_weight_min = gr.Slider(
- label="Minimum Prior Loss Weight",
- minimum=0.01,
- maximum=1,
- step=0.01,
- value=0.1,
- visible=False,
+ with gr.Column(variant="panel", visible=has_face_swap()):
+ db_swap_faces = gr.Checkbox(label="Swap Sample Faces")
+ db_swap_prompt = gr.Textbox(label="Swap Prompt")
+ db_swap_negative = gr.Textbox(label="Swap Negative Prompt")
+ db_swap_steps = gr.Slider(label="Swap Steps", value=40)
+ db_swap_batch = gr.Slider(label="Swap Batch", value=40)
+
+ db_sample_txt2img = gr.Checkbox(
+ label="Use txt2img",
+ value=False,
+ visible=False # db_sample_txt2img not implemented yet
)
-
- with gr.Accordion(open=False, label="Advanced"):
- with gr.Row():
+ with gr.Accordion(open=False, label="Extras"):
with gr.Column():
gr.HTML(value="Sanity Samples")
db_sanity_prompt = gr.Textbox(
@@ -741,8 +1022,7 @@ def on_ui_tabs():
db_sanity_seed = gr.Number(
label="Sanity Sample Seed", value=420420
)
-
- with gr.Column():
+ with gr.Column() as db_misc_view:
gr.HTML(value="Miscellaneous")
db_pretrained_vae_name_or_path = gr.Textbox(
label="Pretrained VAE Name or Path",
@@ -765,9 +1045,9 @@ def on_ui_tabs():
)
db_clear_secret = gr.Button(
value=delete_symbol, elem_id="clear_secret"
- )
-
- with gr.Column():
+ )
+ with gr.Column() as db_hook_view:
+ gr.HTML(value="Webhooks")
# In the future change this to something more generic and list the supported types
# from DreamboothWebhookTarget enum; for now, Discord is what I use ;)
# Add options to include notifications on training complete and exceptions that halt training
@@ -779,319 +1059,64 @@ def on_ui_tabs():
notification_webhook_test_btn = gr.Button(
value="Save and Test Webhook"
)
+ with gr.Column() as db_test_tab:
+ gr.HTML(value="Experimental Settings")
+ db_tomesd = gr.Slider(
+ value=0,
+ label="Token Merging (ToMe)",
+ minimum=0,
+ maximum=1,
+ step=0.1,
+ )
+ db_split_loss = gr.Checkbox(
+ label="Calculate Split Loss", value=True
+ )
+ db_disable_class_matching = gr.Checkbox(label="Disable Class Matching")
+ db_disable_logging = gr.Checkbox(label="Disable Logging")
+ db_deterministic = gr.Checkbox(label="Deterministic")
+ db_ema_predict = gr.Checkbox(label="Use EMA for prediction")
+ db_lora_use_buggy_requires_grad = gr.Checkbox(label="LoRA use buggy requires grad")
+ db_noise_scheduler = gr.Dropdown(
+ label="Noise scheduler",
+ value="DDPM",
+ choices=[
+ "DDPM",
+ "DEIS",
+ "UniPC"
+ ]
+ )
+ db_update_extension = gr.Button(
+ value="Update Extension and Restart"
+ )
- with gr.Row():
- with gr.Column(scale=2):
- gr.HTML(value="")
- with gr.Tab("Concepts", elem_id="TabConcepts") as concept_tab:
- with gr.Column(variant="panel"):
- with gr.Row():
- db_train_wizard_person = gr.Button(
- value="Training Wizard (Person)"
- )
- db_train_wizard_object = gr.Button(
- value="Training Wizard (Object/Style)"
- )
- with gr.Tab("Concept 1"):
- (
- c1_instance_data_dir,
- c1_class_data_dir,
- c1_instance_prompt,
- c1_class_prompt,
- c1_save_sample_prompt,
- c1_save_sample_template,
- c1_instance_token,
- c1_class_token,
- c1_num_class_images_per,
- c1_class_negative_prompt,
- c1_class_guidance_scale,
- c1_class_infer_steps,
- c1_save_sample_negative_prompt,
- c1_n_save_sample,
- c1_sample_seed,
- c1_save_guidance_scale,
- c1_save_infer_steps,
- ) = build_concept_panel(1)
-
- with gr.Tab("Concept 2"):
- (
- c2_instance_data_dir,
- c2_class_data_dir,
- c2_instance_prompt,
- c2_class_prompt,
- c2_save_sample_prompt,
- c2_save_sample_template,
- c2_instance_token,
- c2_class_token,
- c2_num_class_images_per,
- c2_class_negative_prompt,
- c2_class_guidance_scale,
- c2_class_infer_steps,
- c2_save_sample_negative_prompt,
- c2_n_save_sample,
- c2_sample_seed,
- c2_save_guidance_scale,
- c2_save_infer_steps,
- ) = build_concept_panel(2)
-
- with gr.Tab("Concept 3"):
- (
- c3_instance_data_dir,
- c3_class_data_dir,
- c3_instance_prompt,
- c3_class_prompt,
- c3_save_sample_prompt,
- c3_save_sample_template,
- c3_instance_token,
- c3_class_token,
- c3_num_class_images_per,
- c3_class_negative_prompt,
- c3_class_guidance_scale,
- c3_class_infer_steps,
- c3_save_sample_negative_prompt,
- c3_n_save_sample,
- c3_sample_seed,
- c3_save_guidance_scale,
- c3_save_infer_steps,
- ) = build_concept_panel(3)
-
- with gr.Tab("Concept 4"):
- (
- c4_instance_data_dir,
- c4_class_data_dir,
- c4_instance_prompt,
- c4_class_prompt,
- c4_save_sample_prompt,
- c4_save_sample_template,
- c4_instance_token,
- c4_class_token,
- c4_num_class_images_per,
- c4_class_negative_prompt,
- c4_class_guidance_scale,
- c4_class_infer_steps,
- c4_save_sample_negative_prompt,
- c4_n_save_sample,
- c4_sample_seed,
- c4_save_guidance_scale,
- c4_save_infer_steps,
- ) = build_concept_panel(4)
- with gr.Tab("Saving", elme_id="TabSave"):
- with gr.Column():
- gr.HTML("General")
- db_custom_model_name = gr.Textbox(
- label="Custom Model Name",
- value="",
- placeholder="Enter a model name for saving checkpoints and lora models.",
- )
- db_save_safetensors = gr.Checkbox(
- label="Save in .safetensors format",
- value=True,
+ with gr.Column(variant="panel"):
+ gr.HTML(value="Bucket Cropping")
+ db_crop_src_path = gr.Textbox(label="Source Path")
+ db_crop_dst_path = gr.Textbox(label="Dest Path")
+ db_crop_max_res = gr.Slider(
+ label="Max Res", value=512, step=64, maximum=2048
+ )
+ db_crop_bucket_step = gr.Slider(
+ label="Bucket Steps", value=8, step=8, maximum=512
+ )
+ db_crop_dry = gr.Checkbox(label="Dry Run")
+ db_start_crop = gr.Button("Start Cropping")
+ with gr.Column(variant="panel"):
+ with gr.Row():
+ with gr.Column(scale=1, min_width=110):
+ gr.HTML(value="Output")
+ with gr.Column(scale=1, min_width=110):
+ db_check_progress_initial = gr.Button(
+ value=update_symbol,
+ elem_id="db_check_progress_initial",
visible=False,
)
- db_save_ema = gr.Checkbox(
- label="Save EMA Weights to Generated Models", value=True
- )
- db_infer_ema = gr.Checkbox(
- label="Use EMA Weights for Inference", value=False
- )
- with gr.Column():
- gr.HTML("Checkpoints")
- db_half_model = gr.Checkbox(label="Half Model", value=False)
- db_use_subdir = gr.Checkbox(
- label="Save Checkpoint to Subdirectory", value=True
- )
- db_save_ckpt_during = gr.Checkbox(
- label="Generate a .ckpt file when saving during training."
- )
- db_save_ckpt_after = gr.Checkbox(
- label="Generate a .ckpt file when training completes.",
- value=True,
- )
- db_save_ckpt_cancel = gr.Checkbox(
- label="Generate a .ckpt file when training is canceled."
- )
- with gr.Column(visible=False) as lora_save_col:
- db_save_lora_during = gr.Checkbox(
- label="Generate lora weights when saving during training."
- )
- db_save_lora_after = gr.Checkbox(
- label="Generate lora weights when training completes.",
- value=True,
- )
- db_save_lora_cancel = gr.Checkbox(
- label="Generate lora weights when training is canceled."
- )
- db_save_lora_for_extra_net = gr.Checkbox(
- label="Generate lora weights for extra networks."
- )
- with gr.Column():
- gr.HTML("Diffusion Weights (training snapshots)")
- db_save_state_during = gr.Checkbox(
- label="Save separate diffusers snapshots when saving during training."
- )
- db_save_state_after = gr.Checkbox(
- label="Save separate diffusers snapshots when training completes."
- )
- db_save_state_cancel = gr.Checkbox(
- label="Save separate diffusers snapshots when training is canceled."
- )
- with gr.Tab("Generate", elem_id="TabGenerate"):
- gr.HTML(value="Class Generation Schedulers")
- db_class_gen_method = gr.Dropdown(
- label="Image Generation Library",
- value="Native Diffusers",
- choices=[
- "A1111 txt2img (Euler a)",
- "Native Diffusers",
- ]
- )
- db_scheduler = gr.Dropdown(
- label="Image Generation Scheduler",
- value="DEISMultistep",
- choices=get_scheduler_names(),
- )
- gr.HTML(value="Manual Class Generation")
- with gr.Column():
- db_generate_classes = gr.Button(value="Generate Class Images")
- db_generate_graph = gr.Button(value="Generate Graph")
- db_graph_smoothing = gr.Slider(
- value=50,
- label="Graph Smoothing Steps",
- minimum=10,
- maximum=500,
- )
- db_debug_buckets = gr.Button(value="Debug Buckets")
- db_bucket_epochs = gr.Slider(
- value=10,
- step=1,
- minimum=1,
- maximum=1000,
- label="Epochs to Simulate",
- )
- db_bucket_batch = gr.Slider(
- value=1,
- step=1,
- minimum=1,
- maximum=500,
- label="Batch Size to Simulate",
- )
- db_generate_sample = gr.Button(value="Generate Sample Images")
- db_sample_prompt = gr.Textbox(label="Sample Prompt")
- db_sample_negative = gr.Textbox(label="Sample Negative Prompt")
- db_sample_prompt_file = gr.Textbox(label="Sample Prompt File")
- db_sample_width = gr.Slider(
- label="Sample Width",
- value=512,
- step=64,
- minimum=128,
- maximum=2048,
- )
- db_sample_height = gr.Slider(
- label="Sample Height",
- value=512,
- step=64,
- minimum=128,
- maximum=2048,
- )
- db_sample_seed = gr.Number(
- label="Sample Seed", value=-1, precision=0
- )
- db_num_samples = gr.Slider(
- label="Number of Samples to Generate",
- value=1,
- minimum=1,
- maximum=1000,
- step=1,
- )
- db_gen_sample_batch_size = gr.Slider(
- label="Sample Batch Size",
- value=1,
- step=1,
- minimum=1,
- maximum=100,
- interactive=True,
- )
- db_sample_steps = gr.Slider(
- label="Sample Steps",
- value=20,
- minimum=1,
- maximum=500,
- step=1,
- )
- db_sample_scale = gr.Slider(
- label="Sample CFG Scale",
- value=7.5,
- step=0.1,
- minimum=1,
- maximum=20,
- )
- with gr.Column(variant="panel", visible=has_face_swap()):
- db_swap_faces = gr.Checkbox(label="Swap Sample Faces")
- db_swap_prompt = gr.Textbox(label="Swap Prompt")
- db_swap_negative = gr.Textbox(label="Swap Negative Prompt")
- db_swap_steps = gr.Slider(label="Swap Steps", value=40)
- db_swap_batch = gr.Slider(label="Swap Batch", value=40)
-
- db_sample_txt2img = gr.Checkbox(
- label="Use txt2img",
- value=False,
- visible=False # db_sample_txt2img not implemented yet
- )
- with gr.Tab("Testing", elem_id="TabDebug"):
- gr.HTML(value="Experimental Settings")
- db_tomesd = gr.Slider(
- value=0,
- label="Token Merging (ToMe)",
- minimum=0,
- maximum=1,
- step=0.1,
- )
- db_split_loss = gr.Checkbox(
- label="Calculate Split Loss", value=True
- )
- db_disable_class_matching = gr.Checkbox(label="Disable Class Matching")
- db_disable_logging = gr.Checkbox(label="Disable Logging")
- db_deterministic = gr.Checkbox(label="Deterministic")
- db_ema_predict = gr.Checkbox(label="Use EMA for prediction")
- db_lora_use_buggy_requires_grad = gr.Checkbox(label="LoRA use buggy requires grad")
- db_noise_scheduler = gr.Dropdown(
- label="Noise scheduler",
- value="DDPM",
- choices=[
- "DDPM",
- "DEIS",
- "UniPC"
- ]
- )
- db_update_extension = gr.Button(
- value="Update Extension and Restart"
- )
+ # These two should be updated while doing things
+ db_active = gr.Checkbox(elem_id="db_active", value=False, visible=False)
- with gr.Column(variant="panel"):
- gr.HTML(value="Bucket Cropping")
- db_crop_src_path = gr.Textbox(label="Source Path")
- db_crop_dst_path = gr.Textbox(label="Dest Path")
- db_crop_max_res = gr.Slider(
- label="Max Res", value=512, step=64, maximum=2048
- )
- db_crop_bucket_step = gr.Slider(
- label="Bucket Steps", value=8, step=8, maximum=512
+ ui_check_progress_initial = gr.Button(
+ value="Refresh", elem_id="ui_check_progress_initial", elem_classes="advBtn", size="sm"
)
- db_crop_dry = gr.Checkbox(label="Dry Run")
- db_start_crop = gr.Button("Start Cropping")
- with gr.Column(variant="panel"):
- gr.HTML(value="Output")
- db_check_progress_initial = gr.Button(
- value=update_symbol,
- elem_id="db_check_progress_initial",
- visible=False,
- )
- # These two should be updated while doing things
- db_active = gr.Checkbox(elem_id="db_active", value=False, visible=False)
-
- ui_check_progress_initial = gr.Button(
- value=update_symbol, elem_id="ui_check_progress_initial"
- )
db_status = gr.HTML(elem_id="db_status", value="")
db_progressbar = gr.HTML(elem_id="db_progressbar")
db_gallery = gr.Gallery(
@@ -1267,7 +1292,7 @@ def update_model_options(model_type):
fn=lambda: check_progress_call(),
show_progress=False,
inputs=[],
- outputs=progress_elements,
+ outputs=progress_elements
)
db_check_progress_initial.click(
@@ -1310,6 +1335,106 @@ def format_updates():
with gr.Column():
change_log = gr.HTML(format_updates(), elem_id="change_log")
+ advanced_elements = [
+ db_snapshot_row,
+ db_create_from_hub,
+ db_new_model_extract_ema,
+ db_train_unfrozen,
+ db_use_ema,
+ db_freeze_clip_normalization,
+ db_full_mixed_precision,
+ db_offset_noise,
+ db_weight_decay,
+ db_tenc_weight_decay,
+ db_tenc_grad_clip_norm,
+ db_min_snr_gamma,
+ db_pad_tokens,
+ db_strict_tokens,
+ db_max_token_length,
+ db_epoch_pause_frequency,
+ db_epoch_pause_time,
+ db_batch_size_view,
+ db_lr_scheduler,
+ db_lr_warmup_steps,
+ db_hflip,
+ db_prior_loss_view,
+ db_misc_view,
+ db_hook_view,
+ db_save_tab,
+ db_generate_tab,
+ db_test_tab,
+ db_dynamic_img_norm,
+ db_tomesd,
+ db_split_loss,
+ db_disable_class_matching,
+ db_disable_logging,
+ db_deterministic,
+ db_ema_predict,
+ db_lora_use_buggy_requires_grad,
+ db_noise_scheduler,
+ c1_class_guidance_scale,
+ c1_class_infer_steps,
+ c1_save_sample_negative_prompt,
+ c1_sample_seed,
+ c1_save_guidance_scale,
+ c1_save_infer_steps,
+ c2_class_guidance_scale,
+ c2_class_infer_steps,
+ c2_save_sample_negative_prompt,
+ c2_sample_seed,
+ c2_save_guidance_scale,
+ c2_save_infer_steps,
+ c3_class_guidance_scale,
+ c3_class_infer_steps,
+ c3_save_sample_negative_prompt,
+ c3_sample_seed,
+ c3_save_guidance_scale,
+ c3_save_infer_steps,
+ c4_class_guidance_scale,
+ c4_class_infer_steps,
+ c4_save_sample_negative_prompt,
+ c4_sample_seed,
+ c4_save_guidance_scale,
+ c4_save_infer_steps,
+ ]
+
+ def toggle_advanced():
+ global show_advanced
+ show_advanced = False if show_advanced else True
+ outputs = [gr.update(visible=True), gr.update(visible=False)]
+ print(f"Advanced elements visible: {show_advanced}")
+ for _ in advanced_elements:
+ outputs.append(gr.update(visible=show_advanced))
+
+ return outputs
+ # Merge db_show advanced, db_hide_advanced, and advanced elements into one list
+ db_show_advanced.click(
+ fn=toggle_advanced,
+ inputs=[],
+ outputs=[db_hide_advanced, db_show_advanced, *advanced_elements]
+ )
+
+ db_hide_advanced.click(
+ fn=toggle_advanced,
+ inputs=[],
+ outputs=[db_show_advanced, db_hide_advanced, *advanced_elements]
+ )
+
+ global preprocess_params
+
+ # preprocess_params = [
+ # db_preprocess_path,
+ # db_preprocess_recursive
+ # ]
+ #
+ # db_preprocess_path.change(
+ # fn=check_preprocess_path,
+ # inputs=[db_preprocess_path, db_preprocess_recursive],
+ # outputs=[db_status, db_gallery]
+ # )
+
+ db_gallery.select(load_image_caption, None, db_status)
+
global params_to_save
global params_to_load
@@ -1532,8 +1657,8 @@ def toggle_loss_items(scale):
outputs=[hub_row, local_row],
)
- def toggle_shared_row(shared_row):
- return gr.update(visible=shared_row), gr.update(value="")
+ def toggle_shared_row(row):
+ return gr.update(visible=row), gr.update(value="")
db_use_shared_src.change(
fn=toggle_shared_row,
@@ -1657,6 +1782,7 @@ def class_gen_method_changed(method):
fn=load_model_params,
inputs=[db_model_name],
outputs=[
+ db_model_info,
db_model_path,
db_revision,
db_epochs,
@@ -1712,33 +1838,6 @@ def class_gen_method_changed(method):
],
)
- db_train_wizard_person.click(
- fn=training_wizard_person,
- _js="db_start_twizard",
- inputs=[db_model_name],
- outputs=[
- db_num_train_epochs,
- c1_num_class_images_per,
- c2_num_class_images_per,
- c3_num_class_images_per,
- c4_num_class_images_per,
- db_status,
- ],
- )
-
- db_train_wizard_object.click(
- fn=training_wizard,
- _js="db_start_twizard",
- inputs=[db_model_name],
- outputs=[
- db_num_train_epochs,
- c1_num_class_images_per,
- c2_num_class_images_per,
- c3_num_class_images_per,
- c4_num_class_images_per,
- db_status,
- ],
- )
db_generate_sample.click(
fn=wrap_gpu_call(generate_samples),
@@ -1836,53 +1935,59 @@ def set_gen_sample():
def build_concept_panel(concept: int):
- with gr.Column():
- gr.HTML(value="Directories")
+ with gr.Tab(label="Instance Images"):
instance_data_dir = gr.Textbox(
- label="Dataset Directory",
+ label="Directory",
placeholder="Path to directory with input images",
elem_id=f"idd{concept}",
)
+ instance_prompt = gr.Textbox(label="Prompt", value="[filewords]")
+ gr.HTML(value="Use [filewords] here to read prompts from caption files/filename, or a prompt to describe your training images.
"
+ "If using [filewords], your instance and class tokens will be inserted into the prompt as necessary for training.", elem_classes="hintHtml")
+ instance_token = gr.Textbox(label="Instance Token")
+ gr.HTML(value="If using [filewords] above, this is the unique word used for your subject, like 'fydodog' or 'ohwx'.",
+ elem_classes="hintHtml")
+ class_token = gr.Textbox(label="Class Token")
+ gr.HTML(value="If using [filewords] above, this is the generic word used for your subject, like 'dog' or 'person'.",
+ elem_classes="hintHtml")
+
+ with gr.Tab(label="Class Images"):
class_data_dir = gr.Textbox(
- label="Classification Dataset Directory",
+ label="Directory",
placeholder="(Optional) Path to directory with "
"classification/regularization images",
elem_id=f"cdd{concept}",
)
- with gr.Column():
- gr.HTML(value="Filewords")
- instance_token = gr.Textbox(
- label="Instance Token",
- placeholder="When using [filewords], this is the subject to use when building prompts.",
- )
- class_token = gr.Textbox(
- label="Class Token",
- placeholder="When using [filewords], this is the class to use when building prompts.",
- )
+ class_prompt = gr.Textbox(label="Prompt", value="[filewords]")
+ gr.HTML(
+ value="Use [filewords] here to read prompts from caption files/filename, or a prompt to describe your training images.
"
+ "If using [filewords], your class token will be inserted into the file prompts if it is not found.",
+ elem_classes="hintHtml")
- with gr.Column():
- gr.HTML(value="Training Prompts")
- instance_prompt = gr.Textbox(
- label="Instance Prompt",
- placeholder="Optionally use [filewords] to read image "
- "captions from files.",
+ class_negative_prompt = gr.Textbox(
+ label="Negative Prompt"
)
- class_prompt = gr.Textbox(
- label="Class Prompt",
- placeholder="Optionally use [filewords] to read image "
- "captions from files.",
+ num_class_images_per = gr.Slider(
+ label="Class Images Per Instance Image", value=0, precision=0
)
- class_negative_prompt = gr.Textbox(
- label="Classification Image Negative Prompt"
+ gr.HTML(value="For every instance image, this many classification images will be used/generated. Leave at 0 to disable.",
+ elem_classes="hintHtml")
+ class_guidance_scale = gr.Slider(
+ label="Classification CFG Scale", value=7.5, maximum=12, minimum=1, step=0.1
+ )
+ class_infer_steps = gr.Slider(
+ label="Classification Steps", value=40, minimum=10, maximum=200, step=1
)
- with gr.Column():
- gr.HTML(value="Sample Prompts")
+ with gr.Tab(label="Sample Images"):
save_sample_prompt = gr.Textbox(
label="Sample Image Prompt",
- placeholder="Leave blank to use instance prompt. "
- "Optionally use [filewords] to base "
- "sample captions on instance images.",
+ value='[filewords]'
)
+ gr.HTML(
+ value="A prompt to generate samples from, or use [filewords] here to randomly select prompts from the existing instance prompt(s).
"
+ "If using [filewords], your instance token will be inserted into the file prompts if it is not found.",
+ elem_classes="hintHtml")
+
save_sample_negative_prompt = gr.Textbox(
label="Sample Negative Prompt"
)
@@ -1890,21 +1995,8 @@ def build_concept_panel(concept: int):
label="Sample Prompt Template File",
placeholder="Enter the path to a txt file containing sample prompts.",
)
-
- with gr.Column():
- gr.HTML("Class Image Generation")
- num_class_images_per = gr.Slider(
- label="Class Images Per Instance Image", value=0, precision=0
- )
- class_guidance_scale = gr.Slider(
- label="Classification CFG Scale", value=7.5, maximum=12, minimum=1, step=0.1
- )
- class_infer_steps = gr.Slider(
- label="Classification Steps", value=40, minimum=10, maximum=200, step=1
- )
-
- with gr.Column():
- gr.HTML("Sample Image Generation")
+ gr.HTML(value="When enabled the above prompt and negative prompt will be ignored.",
+ elem_classes="hintHtml")
n_save_sample = gr.Slider(
label="Number of Samples to Generate", value=1, maximum=100, step=1
)
diff --git a/style.css b/style.css
index 65eb2230..eeba49cc 100644
--- a/style.css
+++ b/style.css
@@ -7,6 +7,10 @@
max-height: 20px;
}
+#modelDetailRow {
+ text-align: center;
+}
+
button:disabled, input:disabled {
background: #4444 !important;
color: #6666 !important;
@@ -39,6 +43,52 @@ button:disabled, input:disabled {
.commitDiv {
border: 1px grey solid;
+ padding: 5px;
+}
+
+/* Targets the first .commitDiv in its parent */
+.commitDiv:first-child {
+ border-top-left-radius: 5px;
+ border-top-right-radius: 5px;
+}
+
+/* Targets the last .commitDiv in its parent */
+.commitDiv:last-child {
+ border-bottom-left-radius: 5px;
+ border-bottom-right-radius: 5px;
+}
+
+/* Removes border-radius for every other .commitDiv */
+.commitDiv:not(:first-child):not(:last-child) {
+ border-radius: 0;
+}
+
+.commitDiv h3 {
+ margin-top: 0!important;
+}
+
+.singleCheckbox {
+ margin-top: 29px !important;
+}
+
+button.advBtn {
+ width: 110px !important;
+ position: absolute;
+ right: 0;
+}
+
+.halfElement {
+ max-width: 50% !important;
+}
+
+.hideAdvanced {
+ display: none !important;
+ width: 0 !important;
+ max-width: 0 !important;
+ height: 0 !important;
+ max-height: 0 !important;
+ margin: 0 !important;
+ padding: 0 !important;
}
#change_modal.active {
@@ -54,7 +104,7 @@ button:disabled, input:disabled {
margin: 0 auto;
z-index: 10000;
border: 1px solid white;
- overflow-y: scroll;
+ overflow-y: auto;
overflow-x: hidden;
}
@@ -66,10 +116,14 @@ button:disabled, input:disabled {
#close_modal {
- min-width: 10px;
- max-width: 10px;
- min-height: 10px;
- max-height: 35px;
+ min-width: 30px;
+ max-width: 30px;
+ min-height: 30px;
+ max-height: 30px;
+ height: 30px;
+ width: 30px;
+ padding: 3px;
+ font-family: monospace;
position: absolute;
right: 11px;
}
@@ -108,18 +162,13 @@ button:disabled, input:disabled {
}
-#refresh_db_models, #refresh_lora_models, #refresh_sd_models, #refresh_secret, #clear_secret, #ui_check_progress_initial {
+#refresh_db_models, #refresh_lora_models, #refresh_sd_models, #refresh_secret, #clear_secret {
margin-top: 0.75em;
max-width: 2.5em;
min-width: 2.5em;
height: 2.4em;
}
-#ui_check_progress_initial {
- position: absolute;
- top: -3px;
- right: 10px;
-}
#db_gen_ckpt_during, #db_train_sample {
display: none;
@@ -163,4 +212,14 @@ button:disabled, input:disabled {
.hyperlink {
text-decoration: underline
+}
+
+.hintHtml {
+ padding-bottom: 10px;
+ padding-top: 5px;
+ color: #b3b3b3;
+}
+
+#TabSettings .gap, #TabConcepts .gap {
+ display: block !important;
}
\ No newline at end of file