From 4a741a8a4aa68731808ab596afb03f26abc4ab44 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Tue, 31 Dec 2024 08:24:37 -0500 Subject: [PATCH] Add support for sd3 lora disable_mmap_load_safetensors --- kohya_gui/class_sd3.py | 5 +++++ kohya_gui/lora_gui.py | 9 +++++++++ 2 files changed, 14 insertions(+) diff --git a/kohya_gui/class_sd3.py b/kohya_gui/class_sd3.py index 9d0ac3f5..feeaf3c5 100644 --- a/kohya_gui/class_sd3.py +++ b/kohya_gui/class_sd3.py @@ -202,6 +202,11 @@ def noise_offset_type_change( info="Enables the fusing of the optimizer step into the backward pass for each parameter. Only Adafactor optimizer is supported.", interactive=True, ) + self.disable_mmap_load_safetensors = gr.Checkbox( + label="Disable mmap load safe tensors", + info="Disable memory mapping when loading the model's .safetensors in SDXL.", + value=self.config.get("sd3.disable_mmap_load_safetensors", False), + ) self.sd3_checkbox.change( lambda sd3_checkbox: gr.Accordion(visible=sd3_checkbox), diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index 236cfdfb..07264db2 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -310,6 +310,7 @@ def save_configuration( sd3_fused_backward_pass, clip_g, sd3_clip_l, + sd3_disable_mmap_load_safetensors, logit_mean, logit_std, mode_scale, @@ -598,6 +599,7 @@ def open_configuration( sd3_fused_backward_pass, clip_g, sd3_clip_l, + sd3_disable_mmap_load_safetensors, logit_mean, logit_std, mode_scale, @@ -920,6 +922,7 @@ def train_model( sd3_fused_backward_pass, clip_g, sd3_clip_l, + sd3_disable_mmap_load_safetensors, logit_mean, logit_std, mode_scale, @@ -1443,6 +1446,10 @@ def train_model( t5xxl_value = t5xxl elif sd3_checkbox: t5xxl_value = sd3_t5xxl + + disable_mmap_load_safetensors_value = None + if sd3_checkbox: + disable_mmap_load_safetensors_value = sd3_disable_mmap_load_safetensors config_toml_data = { "adaptive_noise_scale": ( @@ -1477,6 +1484,7 @@ def train_model( "debiased_estimation_loss": debiased_estimation_loss, "dynamo_backend": dynamo_backend, "dim_from_weights": dim_from_weights, + "disable_mmap_load_safetensors": disable_mmap_load_safetensors_value, "enable_bucket": enable_bucket, "epoch": int(epoch), "flip_aug": flip_aug, @@ -2879,6 +2887,7 @@ def update_LoRA_settings( sd3_training.sd3_cache_text_encoder_outputs_to_disk, sd3_training.clip_g, sd3_training.clip_l, + sd3_training.disable_mmap_load_safetensors, sd3_training.logit_mean, sd3_training.logit_std, sd3_training.mode_scale,