From 9a7f82464541bff49bc05dd02cb2d0abb3cccf2c Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Thu, 17 Oct 2024 12:22:11 +0300 Subject: [PATCH] [Flux] Add advanced training script + support textual inversion inference (#9434) * add ostris trainer to README & add cache latents of vae * add ostris trainer to README & add cache latents of vae * style * readme * add test for latent caching * add ostris noise scheduler https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95 * style * fix import * style * fix tests * style * --change upcasting of transformer? * update readme according to main * add pivotal tuning for CLIP * fix imports, encode_prompt call,add TextualInversionLoaderMixin to FluxPipeline for inference * TextualInversionLoaderMixin support for FluxPipeline for inference * move changes to advanced flux script, revert canonical * add latent caching to canonical script * revert changes to canonical script to keep it separate from https://github.com/huggingface/diffusers/pull/9160 * revert changes to canonical script to keep it separate from https://github.com/huggingface/diffusers/pull/9160 * style * remove redundant line and change code block placement to align with logic * add initializer_token arg * add transformer frac for range support from pure textual inversion to the orig pivotal tuning * support pure textual inversion - wip * adjustments to support pure textual inversion and transformer optimization in only part of the epochs * fix logic when using initializer token * fix pure_textual_inversion_condition * fix ti/pivotal loading of last validation run * remove embeddings loading for ti in final training run (to avoid adding huggingface hub dependency) * support pivotal for t5 * adapt pivotal for T5 encoder * adapt pivotal for T5 encoder and support in flux pipeline * t5 pivotal support + support fo pivotal for clip only or both * fix param chaining * fix param chaining * README first draft * readme * readme * readme * style * fix import * style * add fix from https://github.com/huggingface/diffusers/pull/9419 * add to readme, change function names * te lr changes * readme * change concept tokens logic * fix indices * change arg name * style * dummy test * revert dummy test * reorder pivoting * add warning in case the token abstraction is not the instance prompt * experimental - wip - specific block training * fix documentation and token abstraction processing * remove transformer block specification feature (for now) * style * fix copies * fix indexing issue when --initializer_concept has different amounts * add if TextualInversionLoaderMixin to all flux pipelines * style * fix import * fix imports * address review comments - remove necessary prints & comments, use pin_memory=True, use free_memory utils, unify warning and prints * style * logger info fix * make lora target modules configurable and change the default * make lora target modules configurable and change the default * style * make lora target modules configurable and change the default, add notes to readme * style * add tests * style * fix repo id * add updated requirements for advanced flux * fix indices of t5 pivotal tuning embeddings * fix path in test * remove `pin_memory` * fix filename of embedding * fix filename of embedding --------- Co-authored-by: Sayak Paul Co-authored-by: YiYi Xu --- .../README_flux.md | 353 +++ .../requirements_flux.txt | 8 + .../test_dreambooth_lora_flux_advanced.py | 283 ++ .../train_dreambooth_lora_flux_advanced.py | 2463 +++++++++++++++++ src/diffusers/pipelines/flux/pipeline_flux.py | 15 +- .../flux/pipeline_flux_controlnet.py | 8 +- ...pipeline_flux_controlnet_image_to_image.py | 8 +- .../pipeline_flux_controlnet_inpainting.py | 8 +- .../pipelines/flux/pipeline_flux_img2img.py | 8 +- .../pipelines/flux/pipeline_flux_inpaint.py | 8 +- 10 files changed, 3155 insertions(+), 7 deletions(-) create mode 100644 examples/advanced_diffusion_training/README_flux.md create mode 100644 examples/advanced_diffusion_training/requirements_flux.txt create mode 100644 examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py create mode 100644 examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md new file mode 100644 index 000000000000..e755fd8b61e0 --- /dev/null +++ b/examples/advanced_diffusion_training/README_flux.md @@ -0,0 +1,353 @@ +# Advanced diffusion training examples + +## Train Dreambooth LoRA with Flux.1 Dev +> [!TIP] +> 💡 This example follows some of the techniques and recommended practices covered in the community derived guide we made for SDXL training: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). +> As many of these are architecture agnostic & generally relevant to fine-tuning of diffusion models we suggest to take a look 🤗 + +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text-to-image models like flux, stable diffusion given just a few(3~5) images of a subject. + +LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen* +In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: +- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114) +- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable. +- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter. +[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in +the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. + +The `train_dreambooth_lora_flux_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_flux.py`, with +advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl), +[ostris](https://x.com/ostrisai):[ai-toolkit](https://github.com/ostris/ai-toolkit), [bghira](https://github.com/bghira):[SimpleTuner](https://github.com/bghira/SimpleTuner), [Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️ + +> [!NOTE] +> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳 +> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora) + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the `examples/advanced_diffusion_training` folder and run +```bash +pip install -r requirements.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell e.g. a notebook + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. + +### Target Modules +When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them. +More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore +applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string +the exact modules for LoRA training. Here are some examples of target modules you can provide: +- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"` +- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"` +- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"` +> [!NOTE] +> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string: +> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k` +> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k` +> [!NOTE] +> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights. + +### Pivotal Tuning (and more) +**Training with text encoder(s)** + +Alongside the Transformer, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization +available with `train_dreambooth_lora_flux_advanced.py`, in the advanced script **pivotal tuning** is also supported. +[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning - +we insert new tokens into the text encoders of the model, instead of reusing existing ones. +We then optimize the newly-inserted token embeddings to represent the new concept. + +To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`). +Please keep the following points in mind: + +* Flux uses two text encoders - [CLIP](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder) & [T5](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder_2) , by default `--train_text_encoder_ti` performs pivotal tuning for the **CLIP** encoder only. +To activate pivotal tuning for both encoders, add the flag `--enable_t5_ti`. +* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory. +* **pure textual inversion** - to support the full range from pivotal tuning to textual inversion we introduce `--train_transformer_frac` which controls the amount of epochs the transformer LoRA layers are trained. By default, `--train_transformer_frac==1`, to trigger a textual inversion run set `--train_transformer_frac==0`. Values between 0 and 1 are supported as well, and we welcome the community to experiment w/ different settings and share the results! +* **token initializer** - similar to the original textual inversion work, you can specify a token of your choosing as the starting point for training. By default, when enabling `--train_text_encoder_ti`, the new inserted tokens are initialized randomly. You can specify a token in `--initializer_token` such that the starting point for the trained embeddings will be the embeddings associated with your chosen `--initializer_token`. + +## Training examples + +Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon. + +Let's first download it locally: + +```python +from huggingface_hub import snapshot_download + +local_dir = "./3d_icon" +snapshot_download( + "LinoyTsaban/3d_icon", + local_dir=local_dir, repo_type="dataset", + ignore_patterns=".gitattributes", +) +``` + +Let's review some of the advanced features we're going to be using for this example: +- **custom captions**: +To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by +```bash +pip install datasets +``` + +Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt") + +``` +--dataset_name=./3d_icon +--caption_column=prompt +``` + +You can also load a dataset straight from by specifying it's name in `dataset_name`. +Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset. + +- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer +- **pivotal tuning** + +### Example #1: Pivotal tuning +**Now, we can launch training:** + +```bash +export MODEL_NAME="black-forest-labs/FLUX.1-dev" +export DATASET_NAME="./3d_icon" +export OUTPUT_DIR="3d-icon-Flux-LoRA" + +accelerate launch train_dreambooth_lora_flux_advanced.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --instance_prompt="3d icon in the style of TOK" \ + --validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \ + --output_dir=$OUTPUT_DIR \ + --caption_column="prompt" \ + --mixed_precision="bf16" \ + --resolution=1024 \ + --train_batch_size=1 \ + --repeats=1 \ + --report_to="wandb"\ + --gradient_accumulation_steps=1 \ + --gradient_checkpointing \ + --learning_rate=1.0 \ + --text_encoder_lr=1.0 \ + --optimizer="prodigy"\ + --train_text_encoder_ti\ + --train_text_encoder_ti_frac=0.5\ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --rank=8 \ + --max_train_steps=1000 \ + --checkpointing_steps=2000 \ + --seed="0" \ + --push_to_hub +``` + +To better track our training experiments, we're using the following flags in the command above: + +* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. +* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. + +Our experiments were conducted on a single 40GB A100 GPU. + +### Example #2: Pivotal tuning with T5 +Now let's try that with T5 as well, so instead of only optimizing the CLIP embeddings associated with newly inserted tokens, we'll optimize +the T5 embeddings as well. We can do this by simply adding `--enable_t5_ti` to the previous configuration: +```bash +export MODEL_NAME="black-forest-labs/FLUX.1-dev" +export DATASET_NAME="./3d_icon" +export OUTPUT_DIR="3d-icon-Flux-LoRA" + +accelerate launch train_dreambooth_lora_flux_advanced.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --instance_prompt="3d icon in the style of TOK" \ + --validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \ + --output_dir=$OUTPUT_DIR \ + --caption_column="prompt" \ + --mixed_precision="bf16" \ + --resolution=1024 \ + --train_batch_size=1 \ + --repeats=1 \ + --report_to="wandb"\ + --gradient_accumulation_steps=1 \ + --gradient_checkpointing \ + --learning_rate=1.0 \ + --text_encoder_lr=1.0 \ + --optimizer="prodigy"\ + --train_text_encoder_ti\ + --enable_t5_ti\ + --train_text_encoder_ti_frac=0.5\ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --rank=8 \ + --max_train_steps=1000 \ + --checkpointing_steps=2000 \ + --seed="0" \ + --push_to_hub +``` + +### Example #3: Textual Inversion +To explore a pure textual inversion - i.e. only optimizing the text embeddings w/o training transformer LoRA layers, we +can set the value for `--train_transformer_frac` - which is responsible for the percent of epochs in which the transformer is +trained. By setting `--train_transformer_frac == 0` and enabling `--train_text_encoder_ti` we trigger a textual inversion train +run. +```bash +export MODEL_NAME="black-forest-labs/FLUX.1-dev" +export DATASET_NAME="./3d_icon" +export OUTPUT_DIR="3d-icon-Flux-LoRA" + +accelerate launch train_dreambooth_lora_flux_advanced.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --instance_prompt="3d icon in the style of TOK" \ + --validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \ + --output_dir=$OUTPUT_DIR \ + --caption_column="prompt" \ + --mixed_precision="bf16" \ + --resolution=1024 \ + --train_batch_size=1 \ + --repeats=1 \ + --report_to="wandb"\ + --gradient_accumulation_steps=1 \ + --gradient_checkpointing \ + --learning_rate=1.0 \ + --text_encoder_lr=1.0 \ + --optimizer="prodigy"\ + --train_text_encoder_ti\ + --enable_t5_ti\ + --train_text_encoder_ti_frac=0.5\ + --train_transformer_frac=0\ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --rank=8 \ + --max_train_steps=1000 \ + --checkpointing_steps=2000 \ + --seed="0" \ + --push_to_hub +``` +### Inference - pivotal tuning + +Once training is done, we can perform inference like so: +1. starting with loading the transformer lora weights +```python +import torch +from huggingface_hub import hf_hub_download, upload_file +from diffusers import AutoPipelineForText2Image +from safetensors.torch import load_file + +username = "linoyts" +repo_id = f"{username}/3d-icon-Flux-LoRA" + +pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda') + + +pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors") +``` +2. now we load the pivotal tuning embeddings +💡note that if you didn't enable `--enable_t5_ti`, you only load the embeddings to the CLIP encoder + +```python +text_encoders = [pipe.text_encoder, pipe.text_encoder_2] +tokenizers = [pipe.tokenizer, pipe.tokenizer_2] + +embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_emb.safetensors", repo_type="model") + +state_dict = load_file(embedding_path) +# load embeddings of text_encoder 1 (CLIP ViT-L/14) +pipe.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) +# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti` +pipe.load_textual_inversion(state_dict["t5"], token=["", ""], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) +``` + +3. let's generate images + +```python +instance_token = "" +prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}" + +image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0] +image.save("llama.png") +``` + +### Inference - pure textual inversion +In this case, we don't load transformer layers as before, since we only optimize the text embeddings. The output of a textual inversion train run is a +`.safetensors` file containing the trained embeddings for the new tokens either for the CLIP encoder, or for both encoders (CLIP and T5) + +1. starting with loading the embeddings. +💡note that here too, if you didn't enable `--enable_t5_ti`, you only load the embeddings to the CLIP encoder + +```python +import torch +from huggingface_hub import hf_hub_download, upload_file +from diffusers import AutoPipelineForText2Image +from safetensors.torch import load_file + +username = "linoyts" +repo_id = f"{username}/3d-icon-Flux-LoRA" + +pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda') + +text_encoders = [pipe.text_encoder, pipe.text_encoder_2] +tokenizers = [pipe.tokenizer, pipe.tokenizer_2] + +embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_emb.safetensors", repo_type="model") + +state_dict = load_file(embedding_path) +# load embeddings of text_encoder 1 (CLIP ViT-L/14) +pipe.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) +# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti` +pipe.load_textual_inversion(state_dict["t5"], token=["", ""], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) +``` +2. let's generate images + +```python +instance_token = "" +prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}" + +image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0] +image.save("llama.png") +``` + +### Comfy UI / AUTOMATIC1111 Inference +The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats! + +**AUTOMATIC1111 / SD.Next** \ +In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time. +- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. +- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory. + +You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls `. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`. + +**ComfyUI** \ +In ComfyUI we will load a LoRA and a textual embedding at the same time. +- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/) +- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/). diff --git a/examples/advanced_diffusion_training/requirements_flux.txt b/examples/advanced_diffusion_training/requirements_flux.txt new file mode 100644 index 000000000000..dbc124ff6526 --- /dev/null +++ b/examples/advanced_diffusion_training/requirements_flux.txt @@ -0,0 +1,8 @@ +accelerate>=0.31.0 +torchvision +transformers>=4.41.2 +ftfy +tensorboard +Jinja2 +peft>=0.11.1 +sentencepiece \ No newline at end of file diff --git a/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py new file mode 100644 index 000000000000..e29c99821303 --- /dev/null +++ b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py @@ -0,0 +1,283 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + +import safetensors + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate): + instance_data_dir = "docs/source/en/imgs" + instance_prompt = "photo" + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" + script_path = "examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py" + + def test_dreambooth_lora_flux(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_text_encoder_flux(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --train_text_encoder + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + starts_with_expected_prefix = all( + (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys() + ) + self.assertTrue(starts_with_expected_prefix) + + def test_dreambooth_lora_pivotal_tuning_flux_clip(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --train_text_encoder_ti + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + # make sure embeddings were also saved + self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # make sure the state_dict has the correct naming in the parameters. + textual_inversion_state_dict = safetensors.torch.load_file( + os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors") + ) + is_clip = all("clip_l" in k for k in textual_inversion_state_dict.keys()) + self.assertTrue(is_clip) + + # when performing pivotal tuning, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --train_text_encoder_ti + --enable_t5_ti + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + # make sure embeddings were also saved + self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # make sure the state_dict has the correct naming in the parameters. + textual_inversion_state_dict = safetensors.torch.load_file( + os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors") + ) + is_te = all(("clip_l" in k or "t5" in k) for k in textual_inversion_state_dict.keys()) + self.assertTrue(is_te) + + # when performing pivotal tuning, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=4 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"}) + + resume_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=8 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py new file mode 100644 index 000000000000..3db6896228de --- /dev/null +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -0,0 +1,2463 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import itertools +import logging +import math +import os +import random +import re +import shutil +from contextlib import nullcontext +from pathlib import Path +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from safetensors.torch import save_file +from torch.utils.data import Dataset +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast + +import diffusers +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxPipeline, + FluxTransformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + _set_state_dict_into_text_encoder, + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + free_memory, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.31.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + train_text_encoder=False, + train_text_encoder_ti=False, + enable_t5_ti=False, + pure_textual_inversion=False, + token_abstraction_dict=None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, +): + widget_dict = [] + trigger_str = f"You should use {instance_prompt} to trigger the image generation." + + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + diffusers_load_lora = "" + diffusers_imports_pivotal = "" + diffusers_example_pivotal = "" + if not pure_textual_inversion: + diffusers_load_lora = ( + f"""pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')""" + ) + if train_text_encoder_ti: + embeddings_filename = f"{repo_folder}_emb" + ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt)) + trigger_str = ( + "To trigger image generation of trained concept(or concepts) replace each concept identifier " + "in you prompt with the new inserted tokens:\n" + ) + diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download + from safetensors.torch import load_file + """ + if enable_t5_ti: + diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model") + state_dict = load_file(embedding_path) + pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) + pipeline.load_textual_inversion(state_dict["t5"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) + """ + else: + diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model") + state_dict = load_file(embedding_path) + pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) + """ + if token_abstraction_dict: + for key, value in token_abstraction_dict.items(): + tokens = "".join(value) + trigger_str += f""" + to trigger concept `{key}` → use `{tokens}` in your prompt \n + """ + + model_description = f""" +# Flux DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md). + +Was LoRA for the text encoder enabled? {train_text_encoder}. + +Pivotal tuning was enabled: {train_text_encoder_ti}. + +## Trigger words + +{trigger_str} + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch +{diffusers_imports_pivotal} +pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda') +{diffusers_load_lora} +{diffusers_example_pivotal} +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "flux", + "flux-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def load_text_encoders(class_one, class_two): + text_encoder_one = class_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder_two = class_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + return text_encoder_one, text_encoder_two + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + torch_dtype, + is_final_validation=False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + autocast_ctx = nullcontext() + + with autocast_ctx: + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + free_memory() + + return images + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--token_abstraction", + type=str, + default="TOK", + help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, " + "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. " + "'TOK,TOK2,TOK3' etc.", + ) + + parser.add_argument( + "--num_new_tokens_per_abstraction", + type=int, + default=None, + help="number of new tokens inserted to the tokenizers per token_abstraction identifier when " + "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " + "tokens - ", + ) + parser.add_argument( + "--initializer_concept", + type=str, + default=None, + help="the concept to use to initialize the new inserted tokens when training with " + "--train_text_encoder_ti = True. By default, new tokens () are initialized with random value. " + "Alternatively, you could specify a different word/words whos value will be used as the starting point for the new inserted tokens. " + "--num_new_tokens_per_abstraction is ignored when initializer_concept is provided", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=512, + help="Maximum sequence length to use with with the T5 text encoder", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_text_encoder_ti", + action="store_true", + help=("Whether to use pivotal tuning / textual inversion"), + ) + parser.add_argument( + "--enable_t5_ti", + action="store_true", + help=( + "Whether to use pivotal tuning / textual inversion for the T5 encoder as well (in addition to CLIP encoder)" + ), + ) + + parser.add_argument( + "--train_text_encoder_ti_frac", + type=float, + default=0.5, + help=("The percentage of epochs to perform textual inversion"), + ) + + parser.add_argument( + "--train_text_encoder_frac", + type=float, + default=1.0, + help=("The percentage of epochs to perform text encoder tuning"), + ) + parser.add_argument( + "--train_transformer_frac", + type=float, + default=1.0, + help=("The percentage of epochs to perform transformer tuning"), + ) + + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for transformer params" + ) + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + "The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. " + 'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md' + ), + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + + if args.train_text_encoder and args.train_text_encoder_ti: + raise ValueError( + "Specify only one of `--train_text_encoder` or `--train_text_encoder_ti. " + "For full LoRA text encoder training check --train_text_encoder, for textual " + "inversion training check `--train_text_encoder_ti`" + ) + if args.train_transformer_frac < 1 and not args.train_text_encoder_ti: + raise ValueError( + "--train_transformer_frac must be == 1 if text_encoder training / textual inversion is not enabled." + ) + if args.train_transformer_frac < 1 and args.train_text_encoder_ti_frac < 1: + raise ValueError( + "--train_transformer_frac and --train_text_encoder_ti_frac are identical and smaller than 1. " + "This contradicts with --max_train_steps, please specify different values or set both to 1." + ) + if args.enable_t5_ti and not args.train_text_encoder_ti: + logger.warning("You need not use --enable_t5_ti without --train_text_encoder_ti.") + + if args.train_text_encoder_ti and args.initializer_concept and args.num_new_tokens_per_abstraction: + logger.warning( + "When specifying --initializer_concept, the number of tokens per abstraction is detrimned " + "by the initializer token. --num_new_tokens_per_abstraction will be ignored" + ) + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + if args.class_data_dir is not None: + logger.warning("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + logger.warning("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +# Modified from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py +class TokenEmbeddingsHandler: + def __init__(self, text_encoders, tokenizers): + self.text_encoders = text_encoders + self.tokenizers = tokenizers + + self.train_ids: Optional[torch.Tensor] = None + self.train_ids_t5: Optional[torch.Tensor] = None + self.inserting_toks: Optional[List[str]] = None + self.embeddings_settings = {} + + def initialize_new_tokens(self, inserting_toks: List[str]): + idx = 0 + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): + assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings." + assert all( + isinstance(tok, str) for tok in inserting_toks + ), "All elements in inserting_toks should be strings." + + self.inserting_toks = inserting_toks + special_tokens_dict = {"additional_special_tokens": self.inserting_toks} + tokenizer.add_special_tokens(special_tokens_dict) + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Convert the token abstractions to ids + if idx == 0: + self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) + else: + self.train_ids_t5 = tokenizer.convert_tokens_to_ids(self.inserting_toks) + + # random initialization of new tokens + embeds = ( + text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens + ) + std_token_embedding = embeds.weight.data.std() + + logger.info(f"{idx} text encoder's std_token_embedding: {std_token_embedding}") + + train_ids = self.train_ids if idx == 0 else self.train_ids_t5 + # if initializer_concept are not provided, token embeddings are initialized randomly + if args.initializer_concept is None: + hidden_size = ( + text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size + ) + embeds.weight.data[train_ids] = ( + torch.randn(len(train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype) + * std_token_embedding + ) + else: + # Convert the initializer_token, placeholder_token to ids + initializer_token_ids = tokenizer.encode(args.initializer_concept, add_special_tokens=False) + for token_idx, token_id in enumerate(train_ids): + embeds.weight.data[token_id] = (embeds.weight.data)[ + initializer_token_ids[token_idx % len(initializer_token_ids)] + ].clone() + + self.embeddings_settings[f"original_embeddings_{idx}"] = embeds.weight.data.clone() + self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding + + # makes sure we don't update any embedding weights besides the newly added token + index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) + index_no_updates[train_ids] = False + + self.embeddings_settings[f"index_no_updates_{idx}"] = index_no_updates + + logger.info(self.embeddings_settings[f"index_no_updates_{idx}"].shape) + + idx += 1 + + def save_embeddings(self, file_path: str): + assert self.train_ids is not None, "Initialize new tokens before saving embeddings." + tensors = {} + # text_encoder_one, idx==0 - CLIP ViT-L/14, text_encoder_two, idx==1 - T5 xxl + idx_to_text_encoder_name = {0: "clip_l", 1: "t5"} + for idx, text_encoder in enumerate(self.text_encoders): + train_ids = self.train_ids if idx == 0 else self.train_ids_t5 + embeds = ( + text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens + ) + assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same." + new_token_embeddings = embeds.weight.data[train_ids] + + # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), + # Note: When loading with diffusers, any name can work - simply specify in inference + tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings + # tensors[f"text_encoders_{idx}"] = new_token_embeddings + + save_file(tensors, file_path) + + @property + def dtype(self): + return self.text_encoders[0].dtype + + @property + def device(self): + return self.text_encoders[0].device + + @torch.no_grad() + def retract_embeddings(self): + for idx, text_encoder in enumerate(self.text_encoders): + embeds = ( + text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens + ) + index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] + embeds.weight.data[index_no_updates] = ( + self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates] + .to(device=text_encoder.device) + .to(dtype=text_encoder.dtype) + ) + + # for the parts that were updated, we need to normalize them + # to have the same std as before + std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"] + + index_updates = ~index_no_updates + new_embeddings = embeds.weight.data[index_updates] + off_ratio = std_token_embedding / new_embeddings.std() + + new_embeddings = new_embeddings * (off_ratio**0.1) + embeds.weight.data[index_updates] = new_embeddings + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + train_text_encoder_ti, + token_abstraction_dict=None, # token mapping for textual inversion + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + self.token_abstraction_dict = token_abstraction_dict + self.train_text_encoder_ti = train_text_encoder_ti + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + + self.pixel_values = [] + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + for image in self.instance_images: + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + image = train_transforms(image) + self.pixel_values.append(image) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + if self.train_text_encoder_ti: + # replace instances of --token_abstraction in caption with the new tokens: "" etc. + for token_abs, token_replacement in self.token_abstraction_dict.items(): + caption = caption.replace(token_abs, "".join(token_replacement)) + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # the given instance prompt is used for all images + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=False): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + add_special_tokens=add_special_tokens, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + + +def _get_t5_prompt_embeds( + text_encoder, + tokenizer, + max_sequence_length=512, + prompt=None, + num_images_per_prompt=1, + device=None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + + dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + +def _get_clip_prompt_embeds( + text_encoder, + tokenizer, + prompt: str, + device=None, + text_input_ids=None, + num_images_per_prompt: int = 1, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + +def encode_prompt( + text_encoders, + tokenizers, + prompt: str, + max_sequence_length, + device=None, + num_images_per_prompt: int = 1, + text_input_ids_list=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + dtype = text_encoders[0].dtype + + pooled_prompt_embeds = _get_clip_prompt_embeds( + text_encoder=text_encoders[0], + tokenizer=tokenizers[0], + prompt=prompt, + device=device if device is not None else text_encoders[0].device, + num_images_per_prompt=num_images_per_prompt, + text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None, + ) + + prompt_embeds = _get_t5_prompt_embeds( + text_encoder=text_encoders[1], + tokenizer=tokenizers[1], + max_sequence_length=max_sequence_length, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + device=device if device is not None else text_encoders[1].device, + text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None, + ) + + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + +# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer: +# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95 +class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + with torch.no_grad(): + # create weights for timesteps + num_timesteps = 1000 + + # generate the multiplier based on cosmap loss weighing + # this is only used on linear timesteps for now + + # cosine map weighing is higher in the middle and lower at the ends + # bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2 + # cosmap_weighing = 2 / (math.pi * bot) + + # sigma sqrt weighing is significantly higher at the end and lower at the beginning + sigma_sqrt_weighing = (self.sigmas**-2.0).float() + # clip at 1e4 (1e6 is too high) + sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4) + # bring to a mean of 1 + sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean() + + # Create linear timesteps from 1000 to 0 + timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu") + + self.linear_timesteps = timesteps + # self.linear_timesteps_weights = cosmap_weighing + self.linear_timesteps_weights = sigma_sqrt_weighing + + # self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu') + pass + + def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor: + # Get the indices of the timesteps + step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] + + # Get the weights for the timesteps + weights = self.linear_timesteps_weights[step_indices].flatten() + + return weights + + def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: + sigmas = self.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = self.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + + return sigma + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 + ## Add noise according to flow matching. + ## zt = (1 - texp) * x + texp * z1 + + # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # timestep needs to be in [0, 1], we store them in [0, 1000] + # noisy_sample = (1 - timestep) * latent + timestep * noise + t_01 = (timesteps / 1000).to(original_samples.device) + noisy_model_input = (1 - t_01) * original_samples + t_01 * noise + + # n_dim = original_samples.ndim + # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) + # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise + return noisy_model_input + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + return sample + + def set_train_timesteps(self, num_timesteps, device, linear=False): + if linear: + timesteps = torch.linspace(1000, 0, num_timesteps, device=device) + self.timesteps = timesteps + return timesteps + else: + # distribute them closer to center. Inference distributes them as a bias toward first + # Generate values from 0 to 1 + t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) + + # Scale and reverse the values to go from 1000 to 0 + timesteps = (1 - t) * 1000 + + # Sort the timesteps in descending order + timesteps, _ = torch.sort(timesteps, descending=True) + + self.timesteps = timesteps.to(device=device) + + return timesteps + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + free_memory() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + model_id = args.hub_model_id or Path(args.output_dir).name + repo_id = None + if args.push_to_hub: + repo_id = create_repo( + repo_id=model_id, + exist_ok=True, + ).repo_id + + # Load the tokenizers + tokenizer_one = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + tokenizer_two = T5TokenizerFast.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + transformer = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant + ) + + if args.train_text_encoder_ti: + # we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK, + # TOK2" -> ["TOK", "TOK2"] etc. + token_abstraction_list = [place_holder.strip() for place_holder in re.split(r",\s*", args.token_abstraction)] + logger.info(f"list of token identifiers: {token_abstraction_list}") + + if args.initializer_concept is None: + num_new_tokens_per_abstraction = ( + 2 if args.num_new_tokens_per_abstraction is None else args.num_new_tokens_per_abstraction + ) + # if args.initializer_concept is provided, we ignore args.num_new_tokens_per_abstraction + else: + token_ids = tokenizer_one.encode(args.initializer_concept, add_special_tokens=False) + num_new_tokens_per_abstraction = len(token_ids) + if args.enable_t5_ti: + token_ids_t5 = tokenizer_two.encode(args.initializer_concept, add_special_tokens=False) + num_new_tokens_per_abstraction = max(len(token_ids), len(token_ids_t5)) + logger.info( + f"initializer_concept: {args.initializer_concept}, num_new_tokens_per_abstraction: {num_new_tokens_per_abstraction}" + ) + + token_abstraction_dict = {} + token_idx = 0 + for i, token in enumerate(token_abstraction_list): + token_abstraction_dict[token] = [f"" for j in range(num_new_tokens_per_abstraction)] + token_idx += num_new_tokens_per_abstraction - 1 + + # replace instances of --token_abstraction in --instance_prompt with the new tokens: "" etc. + for token_abs, token_replacement in token_abstraction_dict.items(): + new_instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement)) + if args.instance_prompt == new_instance_prompt: + logger.warning( + "Note! the instance prompt provided in --instance_prompt does not include the token abstraction specified " + "--token_abstraction. This may lead to incorrect optimization of text embeddings during pivotal tuning" + ) + args.instance_prompt = new_instance_prompt + if args.with_prior_preservation: + args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement)) + if args.validation_prompt: + args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement)) + + # initialize the new tokens for textual inversion + text_encoders = [text_encoder_one, text_encoder_two] if args.enable_t5_ti else [text_encoder_one] + tokenizers = [tokenizer_one, tokenizer_two] if args.enable_t5_ti else [tokenizer_one] + embedding_handler = TokenEmbeddingsHandler(text_encoders, tokenizers) + inserting_toks = [] + for new_tok in token_abstraction_dict.values(): + inserting_toks.extend(new_tok) + embedding_handler.initialize_new_tokens(inserting_toks=inserting_toks) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + vae.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder_one.gradient_checkpointing_enable() + + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + target_modules = [ + "attn.to_k", + "attn.to_q", + "attn.to_v", + "attn.to_out.0", + "attn.add_k_proj", + "attn.add_q_proj", + "attn.add_v_proj", + "attn.to_add_out", + "ff.net.0.proj", + "ff.net.2", + "ff_context.net.0.proj", + "ff_context.net.2", + ] + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + transformer.add_adapter(transformer_lora_config) + + if args.train_text_encoder: + text_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + ) + text_encoder_one.add_adapter(text_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + text_encoder_one_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + elif isinstance(model, type(unwrap_model(text_encoder_one))): + if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + FluxPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + ) + if args.train_text_encoder_ti: + embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors") + + def load_model_hook(models, input_dir): + transformer_ = None + text_encoder_one_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict = FluxPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + if args.train_text_encoder: + models.extend([text_encoder_one_]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + if args.train_text_encoder: + models.extend([text_encoder_one]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + if args.train_text_encoder: + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + # if we use textual inversion, we freeze all parameters except for the token embeddings + # in text encoder + elif args.train_text_encoder_ti: + text_lora_parameters_one = [] # CLIP + for name, param in text_encoder_one.named_parameters(): + if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param.data = param.to(dtype=torch.float32) + param.requires_grad = True + text_lora_parameters_one.append(param) + else: + param.requires_grad = False + if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well + text_lora_parameters_two = [] + for name, param in text_encoder_two.named_parameters(): + if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param.data = param.to(dtype=torch.float32) + param.requires_grad = True + text_lora_parameters_two.append(param) + else: + param.requires_grad = False + + # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training + freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) + + # if --train_text_encoder_ti and train_transformer_frac == 0 where essentially performing textual inversion + # and not training transformer LoRA layers + pure_textual_inversion = args.train_text_encoder_ti and args.train_transformer_frac == 0 + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + if not freeze_text_encoder: + # different learning rate for text encoder and transformer + text_parameters_one_with_lr = { + "params": text_lora_parameters_one, + "weight_decay": args.adam_weight_decay_text_encoder + if args.adam_weight_decay_text_encoder + else args.adam_weight_decay, + "lr": args.text_encoder_lr, + } + if not args.enable_t5_ti: + # pure textual inversion - only clip + if pure_textual_inversion: + params_to_optimize = [ + text_parameters_one_with_lr, + ] + te_idx = 0 + else: # regular te training or regular pivotal for clip + params_to_optimize = [ + transformer_parameters_with_lr, + text_parameters_one_with_lr, + ] + te_idx = 1 + elif args.enable_t5_ti: + # pivotal tuning of clip & t5 + text_parameters_two_with_lr = { + "params": text_lora_parameters_two, + "weight_decay": args.adam_weight_decay_text_encoder + if args.adam_weight_decay_text_encoder + else args.adam_weight_decay, + "lr": args.text_encoder_lr, + } + # pure textual inversion - only clip & t5 + if pure_textual_inversion: + params_to_optimize = [text_parameters_one_with_lr, text_parameters_two_with_lr] + te_idx = 0 + else: # regular pivotal tuning of clip & t5 + params_to_optimize = [ + transformer_parameters_with_lr, + text_parameters_one_with_lr, + text_parameters_two_with_lr, + ] + te_idx = 1 + else: + params_to_optimize = [ + transformer_parameters_with_lr, + ] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if not freeze_text_encoder and args.text_encoder_lr: + logger.warning( + f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters to be + # --learning_rate + + params_to_optimize[te_idx]["lr"] = args.learning_rate + params_to_optimize[-1]["lr"] = args.learning_rate + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + train_text_encoder_ti=args.train_text_encoder_ti, + token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + if freeze_text_encoder: + tokenizers = [tokenizer_one, tokenizer_two] + text_encoders = [text_encoder_one, text_encoder_two] + + def compute_text_embeddings(prompt, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders, tokenizers, prompt, args.max_sequence_length + ) + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + text_ids = text_ids.to(accelerator.device) + return prompt_embeds, pooled_prompt_embeds, text_ids + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if freeze_text_encoder and not train_dataset.custom_instance_prompts: + instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers + ) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + if freeze_text_encoder: + class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers + ) + + # Clear the memory here + if freeze_text_encoder and not train_dataset.custom_instance_prompts: + del tokenizers, text_encoders, text_encoder_one, text_encoder_two + free_memory() + + # if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion + add_special_tokens_clip = True if args.train_text_encoder_ti else False + add_special_tokens_t5 = True if (args.train_text_encoder_ti and args.enable_t5_ti) else False + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + + if not train_dataset.custom_instance_prompts: + if freeze_text_encoder: + prompt_embeds = instance_prompt_hidden_states + pooled_prompt_embeds = instance_pooled_prompt_embeds + text_ids = instance_text_ids + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) + text_ids = torch.cat([text_ids, class_text_ids], dim=0) + # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) + # we need to tokenize and encode the batch prompts on all training steps + else: + tokens_one = tokenize_prompt( + tokenizer_one, args.instance_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens_clip + ) + tokens_two = tokenize_prompt( + tokenizer_two, + args.instance_prompt, + max_sequence_length=args.max_sequence_length, + add_special_tokens=add_special_tokens_t5, + ) + if args.with_prior_preservation: + class_tokens_one = tokenize_prompt( + tokenizer_one, + args.class_prompt, + max_sequence_length=77, + add_special_tokens=add_special_tokens_clip, + ) + class_tokens_two = tokenize_prompt( + tokenizer_two, + args.class_prompt, + max_sequence_length=args.max_sequence_length, + add_special_tokens=add_special_tokens_t5, + ) + tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) + tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + + vae_config_shift_factor = vae.config.shift_factor + vae_config_scaling_factor = vae.config.scaling_factor + vae_config_block_out_channels = vae.config.block_out_channels + if args.cache_latents: + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=weight_dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + + if args.validation_prompt is None: + del vae + free_memory() + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if not freeze_text_encoder: + if args.enable_t5_ti: + ( + transformer, + text_encoder_one, + text_encoder_two, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + transformer, + text_encoder_one, + text_encoder_two, + optimizer, + train_dataloader, + lr_scheduler, + ) + else: + transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler + ) + + else: + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-flux-dev-lora-advanced" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + logger.info(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + if args.train_text_encoder: + num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs) + num_train_epochs_transformer = int(args.train_transformer_frac * args.num_train_epochs) + elif args.train_text_encoder_ti: # args.train_text_encoder_ti + num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs) + num_train_epochs_transformer = int(args.train_transformer_frac * args.num_train_epochs) + + # flag used for textual inversion + pivoted_te = False + pivoted_tr = False + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + # if performing any kind of optimization of text_encoder params + if args.train_text_encoder or args.train_text_encoder_ti: + if epoch == num_train_epochs_text_encoder: + # flag to stop text encoder optimization + logger.info(f"PIVOT TE {epoch}") + pivoted_te = True + else: + # still optimizing the text encoder + if args.train_text_encoder: + text_encoder_one.train() + # set top parameter requires_grad = True for gradient checkpointing works + accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + elif args.train_text_encoder_ti: # textual inversion / pivotal tuning + text_encoder_one.train() + if args.enable_t5_ti: + text_encoder_two.train() + + if epoch == num_train_epochs_transformer: + # flag to stop transformer optimization + logger.info(f"PIVOT TRANSFORMER {epoch}") + pivoted_tr = True + + for step, batch in enumerate(train_dataloader): + if pivoted_te: + # stopping optimization of text_encoder params + optimizer.param_groups[te_idx]["lr"] = 0.0 + optimizer.param_groups[-1]["lr"] = 0.0 + elif pivoted_tr and not pure_textual_inversion: + logger.info(f"PIVOT TRANSFORMER {epoch}") + optimizer.param_groups[0]["lr"] = 0.0 + + with accelerator.accumulate(transformer): + prompts = batch["prompts"] + + # encode batch prompts when custom prompts are provided for each image - + if train_dataset.custom_instance_prompts: + if freeze_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( + prompts, text_encoders, tokenizers + ) + else: + tokens_one = tokenize_prompt( + tokenizer_one, prompts, max_sequence_length=77, add_special_tokens=add_special_tokens_clip + ) + tokens_two = tokenize_prompt( + tokenizer_two, + prompts, + max_sequence_length=args.max_sequence_length, + add_special_tokens=add_special_tokens_t5, + ) + + if not freeze_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[tokens_one, tokens_two], + max_sequence_length=args.max_sequence_length, + device=accelerator.device, + prompt=prompts, + ) + + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step].sample() + else: + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor + model_input = model_input.to(dtype=weight_dtype) + + vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) + + latent_image_ids = FluxPipeline._prepare_latent_image_ids( + model_input.shape[0], + model_input.shape[2], + model_input.shape[3], + accelerator.device, + weight_dtype, + ) + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + packed_noisy_model_input = FluxPipeline._pack_latents( + noisy_model_input, + batch_size=model_input.shape[0], + num_channels_latents=model_input.shape[1], + height=model_input.shape[2], + width=model_input.shape[3], + ) + + # handle guidance + if transformer.config.guidance_embeds: + guidance = torch.tensor([args.guidance_scale], device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) + else: + guidance = None + + # Predict the noise residual + model_pred = transformer( + hidden_states=packed_noisy_model_input, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timesteps / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + model_pred = FluxPipeline._unpack_latents( + model_pred, + height=int(model_input.shape[2] * vae_scale_factor / 2), + width=int(model_input.shape[3] * vae_scale_factor / 2), + vae_scale_factor=vae_scale_factor, + ) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute prior loss + prior_loss = torch.mean( + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), + 1, + ) + prior_loss = prior_loss.mean() + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + if not freeze_text_encoder: + if args.train_text_encoder: + params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters()) + elif pure_textual_inversion: + params_to_clip = itertools.chain( + text_encoder_one.parameters(), text_encoder_two.parameters() + ) + else: + params_to_clip = itertools.chain( + transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters() + ) + else: + params_to_clip = itertools.chain(transformer.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # every step, we reset the embeddings to the original embeddings. + if args.train_text_encoder_ti: + embedding_handler.retract_embeddings() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + # create pipeline + if freeze_text_encoder: + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=accelerator.unwrap_model(text_encoder_one), + text_encoder_2=accelerator.unwrap_model(text_encoder_two), + transformer=accelerator.unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline_args = {"prompt": args.validation_prompt} + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + torch_dtype=weight_dtype, + ) + if freeze_text_encoder: + del text_encoder_one, text_encoder_two + free_memory() + elif args.train_text_encoder: + del text_encoder_two + free_memory() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer = unwrap_model(transformer) + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + if args.train_text_encoder: + text_encoder_one = unwrap_model(text_encoder_one) + text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + else: + text_encoder_lora_layers = None + + if not pure_textual_inversion: + FluxPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + ) + + if args.train_text_encoder_ti: + embeddings_path = f"{args.output_dir}/{os.path.basename(args.output_dir)}_emb.safetensors" + embedding_handler.save_embeddings(embeddings_path) + + # Final inference + # Load previous pipeline + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + if not pure_textual_inversion: + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + pipeline_args = {"prompt": args.validation_prompt} + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + torch_dtype=weight_dtype, + is_final_validation=True, + ) + + save_model_card( + model_id if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + train_text_encoder_ti=args.train_text_encoder_ti, + enable_t5_ti=args.enable_t5_ti, + pure_textual_inversion=pure_textual_inversion, + token_abstraction_dict=train_dataset.token_abstraction_dict, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + ) + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index bb214885da1c..1424965a4baa 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -20,7 +20,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from ...image_processor import VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -137,7 +137,12 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): +class FluxPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): r""" The Flux pipeline for text-to-image generation. @@ -212,6 +217,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -255,6 +263,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index f4018a82ad69..73471f086b72 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -25,7 +25,7 @@ ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel @@ -238,6 +238,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -281,6 +284,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index a7f7c66a2cad..29cb56ef9696 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -11,7 +11,7 @@ ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel @@ -251,6 +251,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -295,6 +298,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 50d2fcaa7fa5..d4b16415fcd5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -12,7 +12,7 @@ ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel @@ -261,6 +261,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -305,6 +308,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index bee4f6ce52e7..0f6e38e66cae 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -20,7 +20,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin +from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -235,6 +235,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -279,6 +282,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 460336700241..202335967032 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -21,7 +21,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin +from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -239,6 +239,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -283,6 +286,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length",