From 2bbb40ce51d5be3ce8c3e1990d30455201f9e852 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 14:29:50 -0500 Subject: [PATCH] Fix regularization images with validation Adding metadata recording for validation arguments Add comments about the validation split for clarity of intention --- library/train_util.py | 33 +++++++++++++++++++++++++++++++-- train_network.py | 7 +++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 62aae37ef..6d3a772bb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -146,7 +146,12 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" -def split_train_val(paths: List[str], is_training_dataset: bool, validation_split: float, validation_seed: int) -> List[str]: +def split_train_val( + paths: List[str], + is_training_dataset: bool, + validation_split: float, + validation_seed: int | None +) -> List[str]: """ Split the dataset into train and validation @@ -1830,6 +1835,9 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): class DreamBoothDataset(BaseDataset): IMAGE_INFO_CACHE_FILE = "metadata_cache.json" + # The is_training_dataset defines the type of dataset, training or validation + # if is_training_dataset is True -> training dataset + # if is_training_dataset is False -> validation dataset def __init__( self, subsets: Sequence[DreamBoothSubset], @@ -1965,8 +1973,29 @@ def load_dreambooth_dir(subset: DreamBoothSubset): size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + # We want to create a training and validation split. This should be improved in the future + # to allow a clearer distinction between training and validation. This can be seen as a + # short-term solution to limit what is necessary to implement validation datasets + # + # We split the dataset for the subset based on if we are doing a validation split + # The self.is_training_dataset defines the type of dataset, training or validation + # if self.is_training_dataset is True -> training dataset + # if self.is_training_dataset is False -> validation dataset if self.validation_split > 0.0: - img_paths = split_train_val(img_paths, self.is_training_dataset, self.validation_split, self.validation_seed) + # For regularization images we do not want to split this dataset. + if subset.is_reg is True: + # Skip any validation dataset for regularization images + if self.is_training_dataset is False: + img_paths = [] + # Otherwise the img_paths remain as original img_paths and no split + # required for training images dataset of regularization images + else: + img_paths = split_train_val( + img_paths, + self.is_training_dataset, + self.validation_split, + self.validation_seed + ) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") diff --git a/train_network.py b/train_network.py index 5ed92b7e2..605dbc60c 100644 --- a/train_network.py +++ b/train_network.py @@ -898,6 +898,7 @@ def load_model_hook(models, input_dir): accelerator.print("running training / 学習開始") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}") accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") @@ -917,6 +918,7 @@ def load_model_hook(models, input_dir): "ss_text_encoder_lr": text_encoder_lr, "ss_unet_lr": args.unet_lr, "ss_num_train_images": train_dataset_group.num_train_images, + "ss_num_validation_images": val_dataset_group.num_train_images if val_dataset_group is not None else 0, "ss_num_reg_images": train_dataset_group.num_reg_images, "ss_num_batches_per_epoch": len(train_dataloader), "ss_num_epochs": num_train_epochs, @@ -964,6 +966,11 @@ def load_model_hook(models, input_dir): "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), + "ss_validation_seed": args.validation_seed, + "ss_validation_split": args.validation_split, + "ss_max_validation_steps": args.max_validation_steps, + "ss_validate_every_n_epochs": args.validate_every_n_epochs, + "ss_validate_every_n_steps": args.validate_every_n_steps, } self.update_metadata(metadata, args) # architecture specific metadata