Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation loss #1864

Open
wants to merge 77 commits into
base: sd3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 70 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
5b19bda
Add validation loss
rockerBOO Nov 5, 2023
33c311e
new ratio code
rockerBOO Nov 5, 2023
3de9e6c
Add validation split of datasets
rockerBOO Nov 5, 2023
a93c524
Update args to validation_seed and validation_split
rockerBOO Nov 5, 2023
c892521
Add process_batch for train_network
rockerBOO Nov 5, 2023
e545fdf
Removed/cleanup a line
rockerBOO Nov 5, 2023
9c591bd
Remove unnecessary subset line from collate
rockerBOO Nov 5, 2023
569ca72
Set grad enabled if is_train and train_text_encoder
rockerBOO Nov 7, 2023
b558a5b
val
gesen2egee Mar 9, 2024
78cfb01
improve
gesen2egee Mar 10, 2024
923b761
Update train_network.py
gesen2egee Mar 10, 2024
47359b8
Update train_network.py
gesen2egee Mar 10, 2024
a51723c
fix timesteps
gesen2egee Mar 11, 2024
7d84ac2
only use train subset to val
gesen2egee Mar 11, 2024
befbec5
Update train_network.py
gesen2egee Mar 11, 2024
63e58f7
Update train_network.py
gesen2egee Mar 11, 2024
a6c41c6
Update train_network.py
gesen2egee Mar 11, 2024
bd7e229
fix
gesen2egee Mar 13, 2024
5d7ed0d
Merge remote-tracking branch 'kohya-ss/dev' into val
gesen2egee Mar 13, 2024
d05965d
Update train_network.py
gesen2egee Mar 13, 2024
b5e8045
fix control net
gesen2egee Mar 16, 2024
086f600
Merge branch 'main' into val
gesen2egee Apr 10, 2024
36d4023
Update config_util.py
gesen2egee Apr 10, 2024
229c5a3
Update train_util.py
gesen2egee Apr 10, 2024
3b251b7
Update config_util.py
gesen2egee Apr 10, 2024
459b125
Update config_util.py
gesen2egee Apr 10, 2024
89ad69b
Update train_util.py
gesen2egee Apr 11, 2024
fde8026
Update config_util.py
gesen2egee Apr 11, 2024
31507b9
Remove unnecessary is_train changes and use apply_debiased_estimation…
gesen2egee Aug 2, 2024
1db4951
Update train_db.py
gesen2egee Aug 4, 2024
6816217
Update train_db.py
gesen2egee Aug 4, 2024
96eb74f
Update train_db.py
gesen2egee Aug 4, 2024
b9bdd10
Update train_network.py
gesen2egee Aug 4, 2024
3d68754
Update train_db.py
gesen2egee Aug 4, 2024
a593e83
Update train_network.py
gesen2egee Aug 4, 2024
f6dbf7c
Update train_network.py
gesen2egee Aug 4, 2024
aa850aa
Update train_network.py
gesen2egee Aug 4, 2024
cdb2d9c
Update train_network.py
gesen2egee Aug 4, 2024
3028027
Update train_network.py
gesen2egee Oct 4, 2024
dece2c3
Update train_db.py
gesen2egee Oct 4, 2024
05bb918
Add Validation loss for LoRA training
hinablue Dec 27, 2024
62164e5
Change val loss calculate method
hinablue Dec 27, 2024
64bd531
Split val latents/batch and pick up val latents shape size which equa…
hinablue Dec 28, 2024
cb89e02
Change val latent loss compare
hinablue Dec 28, 2024
8743532
val
gesen2egee Mar 9, 2024
449c1c5
Adding modified train_util and config_util
rockerBOO Jan 2, 2025
7f6e124
Merge branch 'gesen2egee/val' into validation-loss-upstream
rockerBOO Jan 3, 2025
d23c732
Merge remote-tracking branch 'hina/feature/val-loss' into validation-…
rockerBOO Jan 3, 2025
7470173
Remove defunct code for train_controlnet.py
rockerBOO Jan 3, 2025
534059d
Typos and lingering is_train
rockerBOO Jan 3, 2025
c8c3569
Cleanup order, types, print to logger
rockerBOO Jan 3, 2025
fbfc275
Update text for train/reg with repeats
rockerBOO Jan 3, 2025
58bfa36
Add seed help clarifying info
rockerBOO Jan 3, 2025
6604b36
Remove duplicate assignment
rockerBOO Jan 3, 2025
0522070
Fix training, validation split, revert to using upstream implemenation
rockerBOO Jan 3, 2025
695f389
Move get_huber_threshold_if_needed
rockerBOO Jan 3, 2025
1f9ba40
Add step break for validation epoch. Remove unused variable
rockerBOO Jan 3, 2025
1c0ae30
Add missing functions for training batch
rockerBOO Jan 3, 2025
bbf6bbd
Use self.get_noise_pred_and_target and drop fixed timesteps
rockerBOO Jan 6, 2025
f4840ef
Revert train_db.py
rockerBOO Jan 6, 2025
1c63e7c
Cleanup unused code and formatting
rockerBOO Jan 6, 2025
c64d1a2
Add validate_every_n_epochs, change name validate_every_n_steps
rockerBOO Jan 6, 2025
f885029
Fix validate epoch, cleanup imports
rockerBOO Jan 6, 2025
fcb2ff0
Clean up some validation help documentation
rockerBOO Jan 6, 2025
742bee9
Set validation steps in multiple lines for readability
rockerBOO Jan 6, 2025
1231f51
Remove unused train_util code, fix accelerate.log for wandb, add init…
rockerBOO Jan 8, 2025
556f3f1
Fix documentation, remove unused function, fix bucket reso for sd1.5,…
rockerBOO Jan 8, 2025
9fde0d7
Handle tuple return from generate_dataset_group_by_blueprint
rockerBOO Jan 8, 2025
1e61392
Revert bucket_reso_steps to correct 64
rockerBOO Jan 8, 2025
d6f158d
Fix incorrect destructoring for load_abritrary_dataset
rockerBOO Jan 8, 2025
264167f
Apply is_training_dataset only to DreamBoothDataset. Add validation_s…
rockerBOO Jan 9, 2025
4c61adc
Add divergence to logs
rockerBOO Jan 12, 2025
2bbb40c
Fix regularization images with validation
rockerBOO Jan 12, 2025
0456858
Fix validate_every_n_steps always running first step
rockerBOO Jan 12, 2025
ee9265c
Fix validate_every_n_steps for gradient accumulation
rockerBOO Jan 12, 2025
25929dd
Remove Validating... print to fix output layout
rockerBOO Jan 12, 2025
b489082
Disable repeats for validation datasets
rockerBOO Jan 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ def train(args):
}

blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand Down
3 changes: 2 additions & 1 deletion flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,10 @@ def train(args):
}

blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand Down
3 changes: 2 additions & 1 deletion flux_train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ def train(args):
}

blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand Down
7 changes: 5 additions & 2 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def get_noise_pred_and_target(
network,
weight_dtype,
train_unet,
is_train=True
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
Expand Down Expand Up @@ -375,7 +376,7 @@ def get_noise_pred_and_target(
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode:
# normal forward
with accelerator.autocast():
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
img=img,
Expand Down Expand Up @@ -420,7 +421,9 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)

with torch.set_grad_enabled(is_train and train_unet):
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
"""

return model_pred
Expand Down
187 changes: 98 additions & 89 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class BaseSubsetParams:
token_warmup_min: int = 1
token_warmup_step: float = 0
custom_attributes: Optional[Dict[str, Any]] = None
validation_seed: int = 0
validation_split: float = 0.0


@dataclass
Expand Down Expand Up @@ -102,6 +104,8 @@ class BaseDatasetParams:
resolution: Optional[Tuple[int, int]] = None
network_multiplier: float = 1.0
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0


@dataclass
Expand All @@ -113,8 +117,7 @@ class DreamBoothDatasetParams(BaseDatasetParams):
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
prior_loss_weight: float = 1.0



@dataclass
class FineTuningDatasetParams(BaseDatasetParams):
batch_size: int = 1
Expand Down Expand Up @@ -234,6 +237,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"enable_bucket": bool,
"max_bucket_reso": int,
"min_bucket_reso": int,
"validation_seed": int,
"validation_split": float,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
}
Expand Down Expand Up @@ -462,8 +467,7 @@ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):

return default_value


def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint) -> Tuple[DatasetGroup, Optional[DatasetGroup]]:
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []

for dataset_blueprint in dataset_group_blueprint.datasets:
Expand All @@ -478,103 +482,108 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset_klass = FineTuningDataset

subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params))
datasets.append(dataset)

# print info
info = ""
for i, dataset in enumerate(datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(
f"""\
[Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)

if dataset.enable_bucket:
info += indent(
dedent(
f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""
),
" ",
)
val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.params.validation_split <= 0.0:
continue
if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
else:
info += "\n"

for j, subset in enumerate(dataset.subsets):
info += indent(
dedent(
f"""\
[Subset {j} of Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_separator: {subset.caption_separator}
secondary_separator: {subset.secondary_separator}
enable_wildcard: {subset.enable_wildcard}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min}
token_warmup_step: {subset.token_warmup_step}
alpha_mask: {subset.alpha_mask}
custom_attributes: {subset.custom_attributes}
"""
),
" ",
)

if is_dreambooth:
info += indent(
dedent(
f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""
),
" ",
)
elif not is_controlnet:
info += indent(
dedent(
f"""\
metadata_file: {subset.metadata_file}
\n"""
),
" ",
)
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset

logger.info(f"{info}")
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, is_training_dataset=False, **asdict(dataset_blueprint.params))
val_datasets.append(dataset)

def print_info(_datasets, dataset_type: str):
info = ""
for i, dataset in enumerate(_datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(f"""\
[{dataset_type} {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
""")

if dataset.enable_bucket:
info += indent(dedent(f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""), " ")
else:
info += "\n"

for j, subset in enumerate(dataset.subsets):
info += indent(dedent(f"""\
[Subset {j} of {dataset_type} {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask}
custom_attributes: {subset.custom_attributes}
"""), " ")

if is_dreambooth:
info += indent(dedent(f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""), " ")
elif not is_controlnet:
info += indent(dedent(f"""\
metadata_file: {subset.metadata_file}
\n"""), " ")

logger.info(info)

print_info(datasets, "Dataset")

if len(val_datasets) > 0:
print_info(val_datasets, "Validation Dataset")

# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no

for i, dataset in enumerate(datasets):
logger.info(f"[Dataset {i}]")
logger.info(f"[Prepare dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)

for i, dataset in enumerate(val_datasets):
logger.info(f"[Prepare validation dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)

return DatasetGroup(datasets)
return (
DatasetGroup(datasets),
DatasetGroup(val_datasets) if val_datasets else None
)


def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
Expand Down
18 changes: 10 additions & 8 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
import torch
import argparse
import random
import re
from torch.types import Number
from typing import List, Optional, Union
from .utils import setup_logging

Expand Down Expand Up @@ -63,7 +65,7 @@ def enforce_zero_terminal_snr(betas):
noise_scheduler.alphas_cumprod = alphas_cumprod


def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
Expand All @@ -74,13 +76,13 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
return loss


def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
scale = get_snr_scale(timesteps, noise_scheduler)
loss = loss * scale
return loss


def get_snr_scale(timesteps, noise_scheduler):
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
scale = snr_t / (snr_t + 1)
Expand All @@ -89,14 +91,14 @@ def get_snr_scale(timesteps, noise_scheduler):
return scale


def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor):
scale = get_snr_scale(timesteps, noise_scheduler)
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
loss = loss + loss / scale * v_pred_like_loss
return loss


def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
if v_prediction:
Expand Down Expand Up @@ -453,7 +455,7 @@ def get_weighted_text_embeddings(


# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor:
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
for i in range(iterations):
Expand All @@ -466,7 +468,7 @@ def pyramid_noise_like(noise, device, iterations=6, discount=0.4):


# https://www.crosslabs.org//blog/diffusion-with-offset-noise
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor:
if noise_offset is None:
return noise
if adaptive_noise_scale is not None:
Expand All @@ -482,7 +484,7 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise


def apply_masked_loss(loss, batch):
def apply_masked_loss(loss, batch) -> torch.FloatTensor:
if "conditioning_images" in batch:
# conditioning image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
Expand Down
2 changes: 1 addition & 1 deletion library/strategy_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]

def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text
tokens_list = []
weights_list = []
Expand Down
Loading
Loading