Skip to content

Latest commit

 

History

History
172 lines (116 loc) · 8.98 KB

File metadata and controls

172 lines (116 loc) · 8.98 KB

OpenVINO Latent Consistency Model C++ Image Generation Pipeline

The pure C++ text-to-image pipeline, driven by the OpenVINO native API for SD v1.5 Latent Consistency Model with LCM Scheduler. It includes advanced features like LoRA integration with safetensors and OpenVINO Tokenizers. Loading openvino_tokenizers to ov::Core enables tokenization. The common folder contains schedulers for image generation and imwrite() for saving bmp images. This demo has been tested for Linux platform only. There is also a Jupyter notebook which provides an example of image generaztion in Python.

Software Requirements

Linux

  • CMake 3.23 or higher
  • GCC 7.5 or higher
  • Python 3.8 or higher
  • Git

Windows

  • CMake 3.23 or higher
  • Microsoft Visual Studio 2019 or higher, version 16.3 or later
  • Python 3.8 or higher
  • Git for Windows

macOS

  • CMake 3.23 or higher
  • Clang compiler and other command line tools from Xcode 10.1 or higher:
    xcode-select --install
  • Python 3.8 or higher
  • Git

Build Instructions

Step 1: Clone the Repository

git clone --recursive https://github.com/openvinotoolkit/openvino.genai.git
cd ./openvino.genai/image_generation/lcm_dreamshaper_v7/cpp/

Note

This tutorial assumes that the current working directory is <openvino.genai repo>/image_generation/lcm_dreamshaper_v7/cpp/ and all paths are relative to this folder.

Step 2: Install OpenVINO from Archive

Follow the install instructions selecting OpenVINO Archives distribution. The path to the OpenVINO install directory is referred as <INSTALL_DIR> throughout the document.

Step 3: Obtain Latent Consistency Model

  1. Install dependencies to import models from HuggingFace:

    python -m pip install -r ../../requirements.txt
    python -m pip install ../../../thirdparty/openvino_tokenizers/[transformers]
  2. Download the model from Huggingface and convert it to OpenVINO IR via optimum-intel CLI.

    Example command for downloading SimianLuo/LCM_Dreamshaper_v7 model and exporting it with FP16 precision:

    optimum-cli export openvino --model SimianLuo/LCM_Dreamshaper_v7 --weight-format fp16 models/lcm_dreamshaper_v7/FP16

    You can also choose other precision and export FP32 or INT8 model.

    Please, refer to the official website for 🤗 Optimum and optimum-intel to read more details.

    If https://huggingface.co/ is down, the script won't be able to download the model.

(Optional) Enable LoRA Weights with Safetensors

Low-Rank Adaptation (LoRA) is a technique introduced to deal with the problem of fine-tuning Diffusers and Large Language Models (LLMs). In the case of Stable Diffusion fine-tuning, LoRA can be applied to the cross-attention layers for the image representations with the latent described.

LoRA weights can be enabled for Unet model of Stable Diffusion pipeline to generate images with different styles.

In this sample LoRA weights are used in safetensors format. Safetensors is a serialization format developed by Hugging Face that is specifically designed for efficiently storing and loading large tensors. It provides a lightweight and efficient way to serialize tensors, making it easier to store and load machine learning models.

The LoRA safetensors model is loaded via safetensors.h. The layer name and weight are modified with Eigen library and inserted into the SD models with ov::pass::MatcherPass in the file common/diffusers/src/lora.cpp.

There are various LoRA models on https://civitai.com/tag/lora and on HuggingFace, you can consider to choose your own LoRA model in safetensor format. For example, you can use LoRA soulcard model. Download and put LoRA safetensors model into the models directory. When running the built sample provide the path to the LoRA model with -l, --loraPath arg argument.

Step 4: Build the LCM Application

  1. Set up the environment: Linux and macOS:

    source <INSTALL_DIR>/setupvars.sh

    Windows Command Prompt:

    call <INSTALL_DIR>\setupvars.bat

    Windows PowerShell:

    . <INSTALL_DIR>/setupvars.ps1
  2. Build the application:

    cmake -DCMAKE_BUILD_TYPE=Release -S . -B build
    cmake --build build --config Release --parallel

Step 5: Run Pipeline

./build/lcm_dreamshaper [-p <posPrompt>] [-s <seed>] [--height <output image>] [--width <output image>] [-d <device>] [-r <readNPLatent>] [-a <alpha>] [-h <help>] [-m <modelPath>] [-t <modelType>] [--guidanceScale <guidanceScale>] [--dynamic]

Usage:
  lcm_dreamshaper [OPTION...]
  • -p, --posPrompt arg Initial positive prompt for LCM (default: "a beautiful pink unicorn")
  • -d, --device arg AUTO, CPU, or GPU. Doesn't apply to Tokenizer model, OpenVINO Tokenizers can be inferred on a CPU device only (default: CPU)
  • --step arg Number of diffusion step (default: 4)
  • -s, --seed arg Number of random seed to generate latent (default: 42)
  • --guidanceScale arg A higher guidance scale value encourages the model to generate images closely linked to the text prompt at the expense of lower image quality (default: 8.0)
  • --num arg Number of image output (default: 1)
  • --height arg Height of output image (default: 512)
  • --width arg Width of output image (default: 512)
  • -c, --useCache Use model caching
  • -r, --readNPLatent Read numpy generated latents from file, only supported for one output image
  • -m, --modelPath arg Specify path to LCM model IRs (default: ./models/lcm_dreamshaper_v7)
  • -t, --type arg Specify the type of LCM model IRs (e.g., FP32, FP16 or INT8) (default: FP16)
  • --dynamic Specify the model input shape to use dynamic shape
  • -l, --loraPath arg Specify path to LoRA file (*.safetensors) (default: )
  • -a, --alpha arg Specify alpha for LoRA (default: 0.75)
  • -h, --help Print usage

Note

The tokenizer model will always be loaded to CPU: OpenVINO Tokenizers can be inferred on a CPU device only.

Examples

Positive prompt: a beautiful pink unicorn

To read the numpy latent input and noise for scheduler instead of C++ std lib for the alignment with Python pipeline, use -r, --readNPLatent argument.

  • Generate image with random data generated by Python: ./build/lcm_dreamshaper -r

image

  • Generate image with C++ lib generated latent and noise: ./build/lcm_dreamshaper

image

  • Generate image with soulcard lora and C++ generated latent and noise: ./build/lcm_dreamshaper -l path/to/soulcard.safetensors

image

Benchmark:

For the generation quality, C++ random generation with MT19937 results differ from numpy.random.randn() and diffusers.utils.randn_tensor. Hence, please use -r, --readNPLatent for the alignment with Python (this latent file is for output image 512X512 only)

Notes

Guidance Scale

Guidance scale controls how similar the generated image will be to the prompt. A higher guidance scale means the model will try to generate an image that follows the prompt more strictly. A lower guidance scale means the model will have more creativity. guidance_scale is a way to increase the adherence to the conditional signal that guides the generation (text, in this case) as well as overall sample quality. It is also known as classifier-free guidance.

Negative Prompt

Negative prompts don't work with LCM because they don’t have any effect on the denoising process. When a LCM is distilled from an LDM via latent consistency distillation (Algorithm 1) with guided distillation, the forward pass of the LCM learns to approximate sampling from the LDM using CFG with the unconditional prompt "" (the empty string). Due to this, LCMs currently do not support negative prompts.

LoRA Weights Enabling

Refer to the OpenVINO blog to get more information on enabling LoRA weights.