Skip to content

Commit

Permalink
SD3 LoRA training MVP
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Dec 31, 2024
1 parent 730ed13 commit 3eec4c9
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 6 deletions.
2 changes: 1 addition & 1 deletion kohya_gui/class_advanced_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def list_vae_files(path):
label="Huber scale",
value=self.config.get("advanced.huber_scale", 1.0),
minimum=0.0,
maximum=1.0,
maximum=10.0,
step=0.01,
info="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
)
Expand Down
125 changes: 120 additions & 5 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .class_source_model import SourceModel
from .class_basic_training import BasicTraining
from .class_advanced_training import AdvancedTraining
from .class_sd3 import sd3Training
from .class_sdxl_parameters import SDXLParameters
from .class_folders import Folders
from .class_command_executor import CommandExecutor
Expand Down Expand Up @@ -302,6 +303,24 @@ def save_configuration(
in_dims,
train_double_block_indices,
train_single_block_indices,

# SD3 parameters
sd3_cache_text_encoder_outputs,
sd3_cache_text_encoder_outputs_to_disk,
sd3_fused_backward_pass,
clip_g,
sd3_clip_l,
logit_mean,
logit_std,
mode_scale,
save_clip,
save_t5xxl,
sd3_t5xxl,
t5xxl_device,
t5xxl_dtype,
sd3_text_encoder_batch_size,
weighting_scheme,
sd3_checkbox,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -573,6 +592,24 @@ def open_configuration(
train_double_block_indices,
train_single_block_indices,

# SD3 parameters
sd3_cache_text_encoder_outputs,
sd3_cache_text_encoder_outputs_to_disk,
sd3_fused_backward_pass,
clip_g,
sd3_clip_l,
logit_mean,
logit_std,
mode_scale,
save_clip,
save_t5xxl,
sd3_t5xxl,
t5xxl_device,
t5xxl_dtype,
sd3_text_encoder_batch_size,
weighting_scheme,
sd3_checkbox,

##
training_preset,
):
Expand Down Expand Up @@ -876,6 +913,24 @@ def train_model(
in_dims,
train_double_block_indices,
train_single_block_indices,

# SD3 parameters
sd3_cache_text_encoder_outputs,
sd3_cache_text_encoder_outputs_to_disk,
sd3_fused_backward_pass,
clip_g,
sd3_clip_l,
logit_mean,
logit_std,
mode_scale,
save_clip,
save_t5xxl,
sd3_t5xxl,
t5xxl_device,
t5xxl_dtype,
sd3_text_encoder_batch_size,
weighting_scheme,
sd3_checkbox,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -1149,6 +1204,8 @@ def train_model(
run_cmd.append(rf"{scriptdir}/sd-scripts/sdxl_train_network.py")
elif flux1_checkbox:
run_cmd.append(rf"{scriptdir}/sd-scripts/flux_train_network.py")
elif sd3_checkbox:
run_cmd.append(rf"{scriptdir}/sd-scripts/sd3_train_network.py")
else:
run_cmd.append(rf"{scriptdir}/sd-scripts/train_network.py")

Expand Down Expand Up @@ -1374,6 +1431,18 @@ def train_model(

if text_encoder_lr_float != 0 or unet_lr_float != 0:
do_not_set_learning_rate = True

clip_l_value = None
if sd3_checkbox:
clip_l_value = sd3_clip_l
elif flux1_checkbox:
clip_l_value = clip_l

t5xxl_value = None
if flux1_checkbox:
t5xxl_value = t5xxl
elif sd3_checkbox:
t5xxl_value = sd3_t5xxl

config_toml_data = {
"adaptive_noise_scale": (
Expand All @@ -1390,6 +1459,13 @@ def train_model(
True
if (sdxl and sdxl_cache_text_encoder_outputs)
or (flux1_checkbox and flux1_cache_text_encoder_outputs)
or (sd3_checkbox and sd3_cache_text_encoder_outputs)
else None
),
"cache_text_encoder_outputs_to_disk": (
True
if flux1_checkbox and flux1_cache_text_encoder_outputs_to_disk
or sd3_checkbox and sd3_cache_text_encoder_outputs_to_disk
else None
),
"caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs),
Expand Down Expand Up @@ -1554,14 +1630,31 @@ def train_model(
"wandb_run_name": wandb_run_name if wandb_run_name != "" else output_name,
"weighted_captions": weighted_captions,
"xformers": True if xformers == "xformers" else None,
# Flux.1 specific parameters

# SD3 only Parameters
# "cache_text_encoder_outputs": see previous assignment above for code
"cache_text_encoder_outputs_to_disk": (
flux1_cache_text_encoder_outputs_to_disk if flux1_checkbox else None
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
"clip_g": clip_g if sd3_checkbox else None,
"clip_l": clip_l_value,
"logit_mean": logit_mean if sd3_checkbox else None,
"logit_std": logit_std if sd3_checkbox else None,
"mode_scale": mode_scale if sd3_checkbox else None,
"save_clip": save_clip if sd3_checkbox else None,
"save_t5xxl": save_t5xxl if sd3_checkbox else None,
# "t5xxl": see previous assignment above for code
"t5xxl_device": t5xxl_device if sd3_checkbox else None,
"t5xxl_dtype": t5xxl_dtype if sd3_checkbox else None,
"text_encoder_batch_size": (
sd3_text_encoder_batch_size if sd3_checkbox else None
),
"weighting_scheme": weighting_scheme if sd3_checkbox else None,

# Flux.1 specific parameters
# "cache_text_encoder_outputs": see previous assignment above for code
# "cache_text_encoder_outputs_to_disk": see previous assignment above for code
"ae": ae if flux1_checkbox else None,
"clip_l": clip_l if flux1_checkbox else None,
"t5xxl": t5xxl if flux1_checkbox else None,
"t5xxl": t5xxl_value,
"discrete_flow_shift": float(discrete_flow_shift) if flux1_checkbox else None,
"model_prediction_type": model_prediction_type if flux1_checkbox else None,
"timestep_sampling": timestep_sampling if flux1_checkbox else None,
Expand Down Expand Up @@ -2454,7 +2547,11 @@ def update_LoRA_settings(
config=config,
flux1_checkbox=source_model.flux1_checkbox,
)


# Add SD3 Parameters
sd3_training = sd3Training(
headless=headless, config=config, sd3_checkbox=source_model.sd3_checkbox
)

with gr.Accordion("Advanced", open=False, elem_classes="advanced_background"):
# with gr.Accordion('Advanced Configuration', open=False):
Expand Down Expand Up @@ -2776,6 +2873,24 @@ def update_LoRA_settings(
flux1_training.in_dims,
flux1_training.train_double_block_indices,
flux1_training.train_single_block_indices,

# SD3 Parameters
sd3_training.sd3_cache_text_encoder_outputs,
sd3_training.sd3_cache_text_encoder_outputs_to_disk,
sd3_training.clip_g,
sd3_training.clip_l,
sd3_training.logit_mean,
sd3_training.logit_std,
sd3_training.mode_scale,
sd3_training.save_clip,
sd3_training.save_t5xxl,
sd3_training.t5xxl,
sd3_training.t5xxl_device,
sd3_training.t5xxl_dtype,
sd3_training.sd3_text_encoder_batch_size,
sd3_training.sd3_fused_backward_pass,
sd3_training.weighting_scheme,
source_model.sd3_checkbox,
]

configuration.button_open_config.click(
Expand Down
2 changes: 2 additions & 0 deletions kohya_gui/sd_modeltype.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def hasKeyPrefix(pfx):
self.model_type = ModelType.SD1
except:
pass

# print(f"Model type: {self.model_type}")

def Is_SD1(self):
return self.model_type == ModelType.SD1
Expand Down

0 comments on commit 3eec4c9

Please sign in to comment.