Skip to content

Commit

Permalink
Add support for sd3 lora disable_mmap_load_safetensors
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Dec 31, 2024
1 parent 3c860c4 commit 4a741a8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
5 changes: 5 additions & 0 deletions kohya_gui/class_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ def noise_offset_type_change(
info="Enables the fusing of the optimizer step into the backward pass for each parameter. Only Adafactor optimizer is supported.",
interactive=True,
)
self.disable_mmap_load_safetensors = gr.Checkbox(
label="Disable mmap load safe tensors",
info="Disable memory mapping when loading the model's .safetensors in SDXL.",
value=self.config.get("sd3.disable_mmap_load_safetensors", False),
)

self.sd3_checkbox.change(
lambda sd3_checkbox: gr.Accordion(visible=sd3_checkbox),
Expand Down
9 changes: 9 additions & 0 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def save_configuration(
sd3_fused_backward_pass,
clip_g,
sd3_clip_l,
sd3_disable_mmap_load_safetensors,
logit_mean,
logit_std,
mode_scale,
Expand Down Expand Up @@ -598,6 +599,7 @@ def open_configuration(
sd3_fused_backward_pass,
clip_g,
sd3_clip_l,
sd3_disable_mmap_load_safetensors,
logit_mean,
logit_std,
mode_scale,
Expand Down Expand Up @@ -920,6 +922,7 @@ def train_model(
sd3_fused_backward_pass,
clip_g,
sd3_clip_l,
sd3_disable_mmap_load_safetensors,
logit_mean,
logit_std,
mode_scale,
Expand Down Expand Up @@ -1443,6 +1446,10 @@ def train_model(
t5xxl_value = t5xxl
elif sd3_checkbox:
t5xxl_value = sd3_t5xxl

disable_mmap_load_safetensors_value = None
if sd3_checkbox:
disable_mmap_load_safetensors_value = sd3_disable_mmap_load_safetensors

config_toml_data = {
"adaptive_noise_scale": (
Expand Down Expand Up @@ -1477,6 +1484,7 @@ def train_model(
"debiased_estimation_loss": debiased_estimation_loss,
"dynamo_backend": dynamo_backend,
"dim_from_weights": dim_from_weights,
"disable_mmap_load_safetensors": disable_mmap_load_safetensors_value,
"enable_bucket": enable_bucket,
"epoch": int(epoch),
"flip_aug": flip_aug,
Expand Down Expand Up @@ -2879,6 +2887,7 @@ def update_LoRA_settings(
sd3_training.sd3_cache_text_encoder_outputs_to_disk,
sd3_training.clip_g,
sd3_training.clip_l,
sd3_training.disable_mmap_load_safetensors,
sd3_training.logit_mean,
sd3_training.logit_std,
sd3_training.mode_scale,
Expand Down

0 comments on commit 4a741a8

Please sign in to comment.