Skip to content

Commit

Permalink
Add support for fused_backward_pass for sd3 finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Dec 31, 2024
1 parent fce89ad commit 7068c7d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
7 changes: 7 additions & 0 deletions kohya_gui/class_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,13 @@ def noise_offset_type_change(
info="Cache text encoder outputs to disk to speed up inference",
interactive=True,
)
with gr.Row():
self.sd3_fused_backward_pass = gr.Checkbox(
label="Fused Backward Pass",
value=self.config.get("sd3.fused_backward_pass", False),
info="Enables the fusing of the optimizer step into the backward pass for each parameter. Only Adafactor optimizer is supported.",
interactive=True,
)

self.sd3_checkbox.change(
lambda sd3_checkbox: gr.Accordion(visible=sd3_checkbox),
Expand Down
7 changes: 5 additions & 2 deletions kohya_gui/finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def save_configuration(
# SD3 parameters
sd3_cache_text_encoder_outputs,
sd3_cache_text_encoder_outputs_to_disk,
sd3_fused_backward_pass,
clip_g,
clip_l,
logit_mean,
Expand Down Expand Up @@ -423,6 +424,7 @@ def open_configuration(
# SD3 parameters
sd3_cache_text_encoder_outputs,
sd3_cache_text_encoder_outputs_to_disk,
sd3_fused_backward_pass,
clip_g,
clip_l,
logit_mean,
Expand Down Expand Up @@ -644,6 +646,7 @@ def train_model(
# SD3 parameters
sd3_cache_text_encoder_outputs,
sd3_cache_text_encoder_outputs_to_disk,
sd3_fused_backward_pass,
clip_g,
clip_l,
logit_mean,
Expand Down Expand Up @@ -969,7 +972,7 @@ def train_model(
"fp8_base": fp8_base,
"full_bf16": full_bf16,
"full_fp16": full_fp16,
"fused_backward_pass": fused_backward_pass if not flux1_checkbox else flux_fused_backward_pass,
"fused_backward_pass": sd3_fused_backward_pass if sd3_checkbox else flux_fused_backward_pass if flux1_checkbox else fused_backward_pass,
"fused_optimizer_groups": (
int(fused_optimizer_groups) if fused_optimizer_groups > 0 else None
),
Expand Down Expand Up @@ -1118,7 +1121,6 @@ def train_model(
"blockwise_fused_optimizers": (
blockwise_fused_optimizers if flux1_checkbox else None
),
# "flux_fused_backward_pass": see previous assignment of fused_backward_pass in above code
"cpu_offload_checkpointing": (
cpu_offload_checkpointing if flux1_checkbox else None
),
Expand Down Expand Up @@ -1529,6 +1531,7 @@ def list_presets(path):
sd3_training.t5xxl_device,
sd3_training.t5xxl_dtype,
sd3_training.sd3_text_encoder_batch_size,
sd3_training.sd3_fused_backward_pass,
sd3_training.weighting_scheme,
source_model.sd3_checkbox,
# Flux1 parameters
Expand Down

0 comments on commit 7068c7d

Please sign in to comment.