From 35f5a57ae0f928cf86e1a8095750c097587ada4d Mon Sep 17 00:00:00 2001 From: Jasmine Collins Date: Tue, 3 Oct 2023 17:15:43 -0700 Subject: [PATCH] Full SDXL Model (#67) * random crop * zero init trick * add intentionally buggy clipping * fix docstring and update diffusers version * fix attention clipping, add to sdxl, fix xformers import when not installed * big sdxl commit, no style check * fix style and pyright * print sdxl statement * add sdxl logic to generate * allow setting SDXLTextEncoder device * sdxltextencoder edits * split conditioning * remove prints * microconditioning and cleaning up comments * fix style * fix dropout dtype * rm local streaming * Update diffusion/datasets/image_caption.py Co-authored-by: Landan Seguin * use RandomCrop, fix LogDiffusionImages bug * have tokenizers pass dict output * add to layers.py docs * override prediction_type in inference_noise_scheulder * Update diffusion/models/stable_diffusion.py Co-authored-by: Landan Seguin * fix style * log_diffusion_images.py fix * pass tokenized prompts as batch_size x 2 x max_length shape * stack tokenizer output to match * fix negative prompt classifier free guidance * _prepare_text_embeddings fix * add negative_prompt_embeds to zero_out_negative_prompt check --------- Co-authored-by: Landan Seguin --- diffusion/callbacks/log_diffusion_images.py | 14 +- diffusion/datasets/image_caption.py | 86 +++++++- diffusion/datasets/laion/transforms.py | 37 +++- diffusion/models/layers.py | 224 ++++++++++++++++++++ diffusion/models/models.py | 154 +++++++++++--- diffusion/models/stable_diffusion.py | 141 ++++++++++-- setup.py | 2 +- 7 files changed, 589 insertions(+), 69 deletions(-) create mode 100644 diffusion/models/layers.py diff --git a/diffusion/callbacks/log_diffusion_images.py b/diffusion/callbacks/log_diffusion_images.py index 54c36a72..04f2fb2f 100644 --- a/diffusion/callbacks/log_diffusion_images.py +++ b/diffusion/callbacks/log_diffusion_images.py @@ -27,8 +27,8 @@ class LogDiffusionImages(Callback): the text prompt, usually at the expense of lower image quality. Default: ``0.0``. text_key (str, optional): Key in the batch to use for text prompts. Default: ``'captions'``. - tokenized_prompts (torch.LongTensor, optional): Batch of pre-tokenized prompts - to use for evaluation. Default: ``None``. + tokenized_prompts (torch.LongTensor or List[torch.LongTensor], optional): Batch of pre-tokenized prompts + to use for evaluation. If SDXL, this will be a list of two pre-tokenized prompts Default: ``None``. seed (int, optional): Random seed to use for generation. Set a seed for reproducible generation. Default: ``1138``. use_table (bool): Whether to make a table of the images or not. Default: ``False``. @@ -63,13 +63,17 @@ def eval_batch_end(self, state: State, logger: Logger): model = state.model if self.tokenized_prompts is None: - tokenized_prompts = [ + self.tokenized_prompts = [ model.tokenizer(p, padding='max_length', truncation=True, return_tensors='pt')['input_ids'] # type: ignore for p in self.prompts ] - self.tokenized_prompts = torch.cat(tokenized_prompts) - self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) + if model.sdxl: + self.tokenized_prompts = torch.stack([torch.cat(tp) for tp in self.tokenized_prompts + ]) # [B, 2, max_length] + else: + self.tokenized_prompts = torch.cat(self.tokenized_prompts) # type: ignore + self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # type: ignore # Generate images with get_precision_context(state.precision): diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index a8405b4f..24cec6dd 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -3,6 +3,7 @@ """Streaming Image-Caption dataset.""" +import logging import random from io import BytesIO from typing import Callable, Dict, List, Optional, Sequence, Union @@ -14,7 +15,10 @@ from torchvision import transforms from transformers import AutoTokenizer -from diffusion.datasets.laion.transforms import LargestCenterSquare +from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropSquare, RandomCropSquareReturnTransform +from diffusion.models.models import SDXLTokenizer + +log = logging.getLogger(__name__) # Disable PIL max image size limit Image.MAX_IMAGE_PIXELS = None @@ -29,6 +33,8 @@ class StreamingImageCaptionDataset(StreamingDataset): remote (str, optional): Remote directory (S3 or local filesystem) where dataset is stored. Default: ``None``. local (str, optional): Local filesystem directory where dataset is cached during operation. Default: ``None``. tokenizer_name_or_path (str): The name or path of the tokenizer to use. Default: ``'stabilityai/stable-diffusion-2-base'``. + caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``. + microcond_drop_prob (float): The probability of dropping microconditioning. Only relevant for SDXL. Default: ``0.0``. caption_selection (str): If there are multiple captions, specifies how to select a single caption. 'first' selects the first caption in the list and 'random' selects a random caption in the list. If there is only one caption, this argument is ignored. Default: ``'first'``. @@ -36,6 +42,7 @@ class StreamingImageCaptionDataset(StreamingDataset): image_size (Optional[int]): The size to resize the image to. Default: ``None``. image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``. + sdxl (bool): Whether or not we're training SDXL. Default: `False`. **streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader """ @@ -46,11 +53,13 @@ def __init__( local: Optional[str] = None, tokenizer_name_or_path: str = 'stabilityai/stable-diffusion-2-base', caption_drop_prob: float = 0.0, + microcond_drop_prob: float = 0.0, caption_selection: str = 'first', transform: Optional[Callable] = None, image_size: Optional[int] = None, image_key: str = 'image', caption_key: str = 'caption', + sdxl: bool = False, **streaming_kwargs, ) -> None: @@ -65,8 +74,15 @@ def __init__( raise ValueError(f'Invalid caption selection: {caption_selection}. Must be one of [random, first]') self.transform = transform - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, subfolder='tokenizer') + self.sdxl = sdxl + if self.sdxl: + self.tokenizer = SDXLTokenizer(tokenizer_name_or_path) + self.sdxl_crop = RandomCropSquareReturnTransform(image_size) + else: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, subfolder='tokenizer') + self.sdxl_crop = None self.caption_drop_prob = caption_drop_prob + self.microcond_drop_prob = microcond_drop_prob self.caption_selection = caption_selection self.image_size = image_size self.image_key = image_key @@ -81,6 +97,25 @@ def __getitem__(self, index): img = Image.open(BytesIO(sample[self.image_key])) if img.mode != 'RGB': img = img.convert('RGB') + + out = {} + # Image transforms + if self.sdxl and self.sdxl_crop: + img, crop_top, crop_left, image_height, image_width = self.sdxl_crop(img) + out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left]) + out['cond_original_size'] = torch.tensor([image_width, image_height]) + out['cond_target_size'] = torch.tensor([self.image_size, self.image_size]) + + # Microconditioning dropout as in Stability repo + # https://github.com/Stability-AI/generative-models/blob/477d8b9a7730d9b2e92b326a770c0420d00308c9/sgm/modules/encoders/modules.py#L151-L160 + if torch.rand(1) < self.microcond_drop_prob: + out['cond_crops_coords_top_left'] = out['cond_crops_coords_top_left'] * 0 + if torch.rand(1) < self.microcond_drop_prob: + out['cond_original_size'] = out['cond_original_size'] * 0 + if torch.rand(1) < self.microcond_drop_prob: + out['cond_target_size'] = out['cond_target_size'] * 0 + else: + crop_top, crop_left, image_height, image_width = None, None, None, None if self.transform is not None: img = self.transform(img) @@ -93,13 +128,21 @@ def __getitem__(self, index): caption = caption[0] if isinstance(caption, List) and self.caption_selection == 'random': caption = random.sample(caption, k=1)[0] + + max_length = None if self.sdxl else self.tokenizer.model_max_length # type: ignore tokenized_caption = self.tokenizer(caption, padding='max_length', - max_length=self.tokenizer.model_max_length, + max_length=max_length, truncation=True, - return_tensors='pt')['input_ids'][0] - - return {'image': img, 'captions': tokenized_caption} + return_tensors='pt')['input_ids'] + if self.sdxl: + tokenized_caption = [tokenized_cap.squeeze() for tokenized_cap in tokenized_caption] + tokenized_caption = torch.stack(tokenized_caption) + else: + tokenized_caption = tokenized_caption.squeeze() + out['image'] = img + out['captions'] = tokenized_caption + return out def build_streaming_image_caption_dataloader( @@ -108,11 +151,13 @@ def build_streaming_image_caption_dataloader( batch_size: int, tokenizer_name_or_path: str = 'stabilityai/stable-diffusion-2-base', caption_drop_prob: float = 0.0, + microcond_drop_prob: float = 0.0, resize_size: int = 256, caption_selection: str = 'first', transform: Optional[List[Callable]] = None, image_key: str = 'image', caption_key: str = 'caption', + rand_crop: bool = False, streaming_kwargs: Optional[Dict] = None, dataloader_kwargs: Optional[Dict] = None, ): @@ -124,6 +169,7 @@ def build_streaming_image_caption_dataloader( batch_size (int): The batch size to use for both the ``StreamingDataset`` and ``DataLoader``. tokenizer_name_or_path (str): The name or path of the tokenizer to use. Default: ``'stabilityai/stable-diffusion-2-base'``. caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``. + microcond_drop_prob (float): The probability of dropping microconditioning. Only relevant for SDXL. Default: ``0.0``. resize_size (int): The size to resize the image to. Default: ``256``. caption_selection (str): If there are multiple captions, specifies how to select a single caption. 'first' selects the first caption in the list and 'random' selects a random caption in the list. @@ -131,6 +177,7 @@ def build_streaming_image_caption_dataloader( transform (Optional[Callable]): The transforms to apply to the image. Default: ``None``. image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``. + rand_crop (bool): If True, randomly crop images. Otherwise, center crop. Default: ``False``. streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``. dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``. """ @@ -156,13 +203,28 @@ def build_streaming_image_caption_dataloader( for r, l in zip(remote, local): streams.append(Stream(remote=r, local=l)) + # Infer SDXL from tokenizer path + if tokenizer_name_or_path == 'stabilityai/stable-diffusion-xl-base-1.0': + log.info('Detected SDXL tokenizer, using SDXL crop transform and tokenizers.') + sdxl = True + else: + sdxl = False + # Setup the transforms to apply + crop_transform = LargestCenterSquare(resize_size) if rand_crop else RandomCropSquare(resize_size) if transform is None: - transform = [ - LargestCenterSquare(resize_size), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # # Normalize from 0 to 1 to -1 to 1 - ] + if sdxl: + # Crop will return parameters so do separately + transform = [ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + else: + transform = [ + crop_transform, + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # # Normalize from 0 to 1 to -1 to 1 + ] transform = transforms.Compose(transform) assert isinstance(transform, Callable) @@ -170,12 +232,14 @@ def build_streaming_image_caption_dataloader( streams=streams, tokenizer_name_or_path=tokenizer_name_or_path, caption_drop_prob=caption_drop_prob, + microcond_drop_prob=microcond_drop_prob, caption_selection=caption_selection, transform=transform, image_size=resize_size, image_key=image_key, caption_key=caption_key, batch_size=batch_size, + sdxl=sdxl, **streaming_kwargs, ) diff --git a/diffusion/datasets/laion/transforms.py b/diffusion/datasets/laion/transforms.py index a0a142d8..86ed78ad 100644 --- a/diffusion/datasets/laion/transforms.py +++ b/diffusion/datasets/laion/transforms.py @@ -1,9 +1,11 @@ # Copyright 2022 MosaicML Diffusion authors # SPDX-License-Identifier: Apache-2.0 -"""Transforms for the laion dataset.""" +"""Transforms for the training and eval dataset.""" import torchvision.transforms as transforms +from torchvision.transforms import RandomCrop +from torchvision.transforms.functional import crop class LargestCenterSquare: @@ -19,3 +21,36 @@ def __call__(self, img): # Then take a center crop to a square. img = self.center_crop(img) return img + + +class RandomCropSquare: + """Randomly crop square of a PIL image.""" + + def __init__(self, size): + self.size = size + self.random_crop = RandomCrop(size) + + def __call__(self, img): + # First, resize the image such that the smallest side is self.size while preserving aspect ratio. + img = transforms.functional.resize(img, self.size, antialias=True) + # Then take a center crop to a square & return crop params. + c_top, c_left, h, w = self.random_crop.get_params(img, (self.size, self.size)) + img = crop(img, c_top, c_left, h, w) + return img + + +class RandomCropSquareReturnTransform: + """Randomly crop square of a PIL image and return the crop parameters.""" + + def __init__(self, size): + self.size = size + self.random_crop = RandomCrop(size) + + def __call__(self, img): + # First, resize the image such that the smallest side is self.size while preserving aspect ratio. + orig_w, orig_h = img.size + img = transforms.functional.resize(img, self.size, antialias=True) + # Then take a center crop to a square & return crop params. + c_top, c_left, h, w = self.random_crop.get_params(img, (self.size, self.size)) + img = crop(img, c_top, c_left, h, w) + return img, c_top, c_left, orig_h, orig_w diff --git a/diffusion/models/layers.py b/diffusion/models/layers.py new file mode 100644 index 00000000..dbd33221 --- /dev/null +++ b/diffusion/models/layers.py @@ -0,0 +1,224 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Helpful layers and functions for UNet construction.""" + +from typing import Optional + +import torch +import torch.nn.functional as F + +try: + import xformers # type: ignore +except: + pass + + +def zero_module(module): + """Zero out the parameters of a module and return it.""" + for p in module.parameters(): + p.detach().zero_() + return module + + +class ClippedAttnProcessor2_0: + """Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + + Modified from https://github.com/huggingface/diffusers/blob/v0.21.0-release/src/diffusers/models/attention_processor.py#L977 to + allow clipping QKV values. + + Args: + clip_val (float, defaults to 6.0): Amount to clip query, key, and value by. + """ + + def __init__(self, clip_val=6.0): + if not hasattr(F, 'scaled_dot_product_attention'): + raise ImportError('AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.') + self.clip_val = clip_val + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + scale: float = 1.0, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + else: + channel, height, width = None, None, None + + batch_size, sequence_length, _ = (hidden_states.shape + if encoder_hidden_states is None else encoder_hidden_states.shape) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states, scale=scale) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, scale=scale) + value = attn.to_v(encoder_hidden_states, scale=scale) + + query = query.clamp(min=-self.clip_val, max=self.clip_val) + key = key.clamp(min=-self.clip_val, max=self.clip_val) + value = value.clamp(min=-self.clip_val, max=self.clip_val) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention(query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, scale=scale) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class ClippedXFormersAttnProcessor: + """Processor for implementing memory efficient attention using xFormers. + + Modified from https://github.com/huggingface/diffusers/blob/v0.21.0-release/src/diffusers/models/attention_processor.py#L888 to + allow clipping QKV values. + + Args: + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + clip_val (float, defaults to 6.0): Amount to clip query, key, and value by. + """ + + def __init__(self, clip_val=6.0, attention_op=None): + self.attention_op = attention_op + self.clip_val = clip_val + + def __call__( + self, + attn, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + scale: float = 1.0, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + else: + channel, height, width = None, None, None + + batch_size, key_tokens, _ = (hidden_states.shape + if encoder_hidden_states is None else encoder_hidden_states.shape) + + attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states, scale=scale) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, scale=scale) + value = attn.to_v(encoder_hidden_states, scale=scale) + + query = query.clamp(min=-self.clip_val, max=self.clip_val) + key = key.clamp(min=-self.clip_val, max=self.clip_val) + value = value.clamp(min=-self.clip_val, max=self.clip_val) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, + key, + value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, scale=scale) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + assert channel + assert height + assert width + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/diffusion/models/models.py b/diffusion/models/models.py index e2eed1a8..38e07c4c 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -3,16 +3,18 @@ """Constructors for diffusion models.""" +import logging from typing import List, Optional import torch from composer.devices import DeviceGPU -from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, UNet2DConditionModel +from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel from torchmetrics import MeanSquaredError from torchmetrics.image.fid import FrechetInceptionDistance from torchmetrics.multimodal.clip_score import CLIPScore -from transformers import CLIPTextModel, CLIPTokenizer, PretrainedConfig +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig +from diffusion.models.layers import ClippedAttnProcessor2_0, ClippedXFormersAttnProcessor, zero_module from diffusion.models.pixel_diffusion import PixelDiffusion from diffusion.models.stable_diffusion import StableDiffusion from diffusion.schedulers.schedulers import ContinuousTimeScheduler @@ -24,6 +26,8 @@ except: is_xformers_installed = False +log = logging.getLogger(__name__) + def stable_diffusion_2( model_name: str = 'stabilityai/stable-diffusion-2-base', @@ -37,6 +41,7 @@ def stable_diffusion_2( precomputed_latents: bool = False, encode_latents_in_fp16: bool = True, fsdp: bool = True, + clip_qkv: Optional[float] = None, ): """Stable diffusion v2 training setup. @@ -60,6 +65,7 @@ def stable_diffusion_2( precomputed_latents (bool): Whether to use precomputed latents. Defaults to False. encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. fsdp (bool): Whether to use FSDP. Defaults to True. + clip_qkv (float, optional): If not None, clip the qkv values to this value. Defaults to None. """ if train_metrics is None: train_metrics = [MeanSquaredError()] @@ -120,11 +126,20 @@ def stable_diffusion_2( if is_xformers_installed: model.unet.enable_xformers_memory_efficient_attention() model.vae.enable_xformers_memory_efficient_attention() + + if clip_qkv is not None: + if is_xformers_installed: + attn_processor = ClippedXFormersAttnProcessor(clip_val=clip_qkv) + else: + attn_processor = ClippedAttnProcessor2_0(clip_val=clip_qkv) + log.info('Using %s with clip_val %.1f' % (attn_processor.__class__, clip_qkv)) + model.unet.set_attn_processor(attn_processor) + return model def stable_diffusion_xl( - model_name: str = 'stabilityai/stable-diffusion-2-base', + model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', unet_model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', pretrained: bool = True, @@ -137,6 +152,7 @@ def stable_diffusion_xl( precomputed_latents: bool = False, encode_latents_in_fp16: bool = True, fsdp: bool = True, + clip_qkv: Optional[float] = 6.0, ): """Stable diffusion 2 training setup + SDXL UNet and VAE. @@ -144,8 +160,8 @@ def stable_diffusion_xl( prompts. Currently uses UNet and VAE config from SDXL, but text encoder/tokenizer from SD2. Args: - model_name (str): Name of the model to load. Determines the text encoder, tokenizer, - and noise scheduler. Defaults to 'stabilityai/stable-diffusion-2-base'. + model_name (str): Name of the model to load. Determines the text encoders, tokenizers, + and noise scheduler. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. unet_model_name (str): Name of the UNet model to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. vae_model_name (str): Name of the VAE model to load. Defaults to @@ -166,6 +182,8 @@ def stable_diffusion_xl( precomputed_latents (bool): Whether to use precomputed latents. Defaults to False. encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. fsdp (bool): Whether to use FSDP. Defaults to True. + clip_qkv (float, optional): If not None, clip the qkv values to this value. Defaults to 6.0. Improves stability + of training. """ if train_metrics is None: train_metrics = [MeanSquaredError()] @@ -181,39 +199,34 @@ def stable_diffusion_xl( metric.requires_grad_(False) if pretrained: - raise NotImplementedError('Full SDXL pipeline not implemented yet.') + unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder='unet') else: config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet') - # Currently not doing micro-conditioning, so set config appropriately - config[0]['addition_embed_type'] = None - config[0]['cross_attention_dim'] = 1024 unet = UNet2DConditionModel(**config[0]) - # Prevent fsdp from wrapping up_blocks and down_blocks because the forward pass calls length on these - unet.up_blocks._fsdp_wrap = False - unet.down_blocks._fsdp_wrap = False - for block in unet.up_blocks: - block._fsdp_wrap = True - for block in unet.down_blocks: - block._fsdp_wrap = True - - if encode_latents_in_fp16: - vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=torch.float16) - text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder', torch_dtype=torch.float16) - else: - vae = AutoencoderKL.from_pretrained(vae_model_name) - text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder') + # Zero initialization trick + for name, layer in unet.named_modules(): + # Final conv in ResNet blocks + if name.endswith('conv2'): + layer = zero_module(layer) + # proj_out in attention blocks + if name.endswith('to_out.0'): + layer = zero_module(layer) + # Last conv block out projection + unet.conv_out = zero_module(unet.conv_out) + + torch_dtype = torch.float16 if encode_latents_in_fp16 else None + try: + vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=torch_dtype) + except: # for handling SDXL vae fp16 fixed checkpoint + vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=torch_dtype) + + tokenizer = SDXLTokenizer(model_name) + text_encoder = SDXLTextEncoder(model_name, encode_latents_in_fp16) - tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer') noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder='scheduler') - inference_noise_scheduler = DDIMScheduler(num_train_timesteps=noise_scheduler.config.num_train_timesteps, - beta_start=noise_scheduler.config.beta_start, - beta_end=noise_scheduler.config.beta_end, - beta_schedule=noise_scheduler.config.beta_schedule, - trained_betas=noise_scheduler.config.trained_betas, - clip_sample=noise_scheduler.config.clip_sample, - set_alpha_to_one=noise_scheduler.config.set_alpha_to_one, - prediction_type=prediction_type) + inference_noise_scheduler = EulerDiscreteScheduler.from_pretrained(model_name, subfolder='scheduler') + inference_noise_scheduler.prediction_type = prediction_type model = StableDiffusion( unet=unet, @@ -231,12 +244,22 @@ def stable_diffusion_xl( precomputed_latents=precomputed_latents, encode_latents_in_fp16=encode_latents_in_fp16, fsdp=fsdp, + sdxl=True, ) if torch.cuda.is_available(): model = DeviceGPU().module_to_device(model) if is_xformers_installed: model.unet.enable_xformers_memory_efficient_attention() model.vae.enable_xformers_memory_efficient_attention() + + if clip_qkv is not None: + if is_xformers_installed: + attn_processor = ClippedXFormersAttnProcessor(clip_val=clip_qkv) + else: + attn_processor = ClippedAttnProcessor2_0(clip_val=clip_qkv) + log.info('Using %s with clip_val %.1f' % (attn_processor.__class__, clip_qkv)) + model.unet.set_attn_processor(attn_processor) + return model @@ -354,3 +377,70 @@ def continuous_pixel_diffusion(clip_model_name: str = 'openai/clip-vit-large-pat if is_xformers_installed: model.model.enable_xformers_memory_efficient_attention() return model + + +class SDXLTextEncoder(torch.nn.Module): + """Wrapper around HuggingFace text encoders for SDXL. + + Creates two text encoders (a CLIPTextModel and CLIPTextModelWithProjection) that behave like one. + + Args: + model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. + encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. + """ + + def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0', encode_latents_in_fp16=True): + super().__init__() + torch_dtype = torch.float16 if encode_latents_in_fp16 else None + self.text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder', torch_dtype=torch_dtype) + self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_name, + subfolder='text_encoder_2', + torch_dtype=torch_dtype) + + @property + def device(self): + return self.text_encoder.device + + def forward(self, tokenized_text): + # first text encoder + conditioning = self.text_encoder(tokenized_text[0], output_hidden_states=True).hidden_states[-2] + # second text encoder + text_encoder_2_out = self.text_encoder_2(tokenized_text[1], output_hidden_states=True) + pooled_conditioning = text_encoder_2_out[0] # (batch_size, 1280) + conditioning_2 = text_encoder_2_out.hidden_states[-2] # (batch_size, 77, 1280) + + conditioning = torch.concat([conditioning, conditioning_2], dim=-1) + return conditioning, pooled_conditioning + + +class SDXLTokenizer: + """Wrapper around HuggingFace tokenizers for SDXL. + + Tokenizes prompt with two tokenizers and returns the joined output. + + Args: + model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. + """ + + def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0'): + self.tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer') + self.tokenizer_2 = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer_2') + + def __call__(self, prompt, padding, truncation, return_tensors, max_length=None): + tokenized_output = self.tokenizer( + prompt, + padding=padding, + max_length=self.tokenizer.model_max_length if max_length is None else max_length, + truncation=truncation, + return_tensors=return_tensors) + tokenized_output_2 = self.tokenizer_2( + prompt, + padding=padding, + max_length=self.tokenizer_2.model_max_length if max_length is None else max_length, + truncation=truncation, + return_tensors=return_tensors) + + # Add second tokenizer output to first tokenizer + for key in tokenized_output.keys(): + tokenized_output[key] = [tokenized_output[key], tokenized_output_2[key]] + return tokenized_output diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 01688c04..4435d0bc 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -62,6 +62,7 @@ class StableDiffusion(ComposerModel): Default: `False`. encode_latents_in_fp16 (bool): whether to encode latents in fp16. Default: `False`. + sdxl (bool): Whether or not we're training SDXL. Default: `False`. """ def __init__(self, @@ -84,7 +85,8 @@ def __init__(self, text_latents_key: str = 'caption_latents', precomputed_latents: bool = False, encode_latents_in_fp16: bool = False, - fsdp: bool = False): + fsdp: bool = False, + sdxl: bool = False): super().__init__() self.unet = unet self.vae = vae @@ -97,6 +99,11 @@ def __init__(self, self.image_key = image_key self.image_latents_key = image_latents_key self.precomputed_latents = precomputed_latents + self.sdxl = sdxl + if self.sdxl: + self.latent_scale = 0.13025 + else: + self.latent_scale = 0.18215 # setup metrics if train_metrics is None: @@ -154,12 +161,19 @@ def __init__(self, self.unet._fsdp_wrap = True def forward(self, batch): - latents, conditioning = None, None + latents, conditioning, conditioning_2, pooled_conditioning = None, None, None, None # Use latents if specified and available. When specified, they might not exist during eval if self.precomputed_latents and self.image_latents_key in batch and self.text_latents_key in batch: + if self.sdxl: + raise NotImplementedError('SDXL not yet supported with precomputed latents') latents, conditioning = batch[self.image_latents_key], batch[self.text_latents_key] else: inputs, conditioning = batch[self.image_key], batch[self.text_key] + if self.sdxl: + # If SDXL, separate the conditioning ([B, 2, 77]) from each tokenizer + conditioning, conditioning_2 = conditioning[:, 0, :], conditioning[:, 1, :] + conditioning_2 = conditioning_2.view(-1, conditioning_2.shape[-1]) + conditioning = conditioning.view(-1, conditioning.shape[-1]) if self.encode_latents_in_fp16: # Disable autocast context as models are in fp16 @@ -167,13 +181,20 @@ def forward(self, batch): # Encode the images to the latent space. # Encode prompt into conditioning vector latents = self.vae.encode(inputs.half())['latent_dist'].sample().data - conditioning = self.text_encoder(conditioning)[0] # Should be (batch_size, 77, 768) + if self.sdxl: + conditioning, pooled_conditioning = self.text_encoder([conditioning, conditioning_2]) + else: + conditioning = self.text_encoder(conditioning)[0] # Should be (batch_size, 77, 768) else: latents = self.vae.encode(inputs)['latent_dist'].sample().data - conditioning = self.text_encoder(conditioning)[0] + if self.sdxl: + conditioning, pooled_conditioning = self.text_encoder([conditioning, conditioning_2]) + else: + conditioning = self.text_encoder(conditioning)[0] + # Magical scaling number (See https://github.com/huggingface/diffusers/issues/437#issuecomment-1241827515) - latents *= 0.18215 + latents *= self.latent_scale # Sample the diffusion timesteps timesteps = torch.randint(0, len(self.noise_scheduler), (latents.shape[0],), device=latents.device) @@ -190,8 +211,18 @@ def forward(self, batch): else: raise ValueError( f'prediction type must be one of sample, epsilon, or v_prediction. Got {self.prediction_type}') + + added_cond_kwargs = {} + # if using SDXL, prepare added time ids & embeddings + if self.sdxl: + add_time_ids = torch.cat( + [batch['cond_original_size'], batch['cond_crops_coords_top_left'], batch['cond_target_size']], dim=1) + add_text_embeds = pooled_conditioning + added_cond_kwargs = {'text_embeds': add_text_embeds, 'time_ids': add_time_ids} + # Forward through the model - return self.unet(noised_latents, timesteps, conditioning)['sample'], targets, timesteps + return self.unet(noised_latents, timesteps, conditioning, + added_cond_kwargs=added_cond_kwargs)['sample'], targets, timesteps def loss(self, outputs, batch): """Loss between unet output and added noise, typically mse.""" @@ -207,6 +238,18 @@ def eval_forward(self, batch, outputs=None): # Sample images from the prompts in the batch prompts = batch[self.text_key] height, width = batch[self.image_key].shape[-2], batch[self.image_key].shape[-1] + + # If SDXL, add eval-time micro-conditioning to batch + if self.sdxl: + device = self.unet.device + bsz = batch[self.image_key].shape[0] + # Set to resolution we are trying to generate + batch['cond_original_size'] = torch.tensor([[width, height]]).repeat(bsz, 1).to(device) + # No cropping + batch['cond_crops_coords_top_left'] = torch.tensor([[0., 0.]]).repeat(bsz, 1).to(device) + # Set to resolution we are trying to generate + batch['cond_target_size'] = torch.tensor([[width, height]]).repeat(bsz, 1).to(device) + generated_images = {} for guidance_scale in self.val_guidance_scales: gen_images = self.generate(tokenized_prompts=prompts, @@ -261,7 +304,16 @@ def update_metric(self, batch, outputs, metric): # CLIP metrics should be updated with the generated images at the desired guidance scale elif metric.__class__.__name__ == 'CLIPScore': # Convert the captions to a list of strings - captions = [self.tokenizer.decode(caption, skip_special_tokens=True) for caption in batch[self.text_key]] + if self.sdxl: + # Decode captions with first tokenizer + captions = [ + self.tokenizer.tokenizer.decode(caption[0], skip_special_tokens=True) + for caption in batch[self.text_key] + ] + else: + captions = [ + self.tokenizer.decode(caption, skip_special_tokens=True) for caption in batch[self.text_key] + ] generated_images = (outputs[3][metric.guidance_scale] * 255).to(torch.uint8) metric.update(generated_images, captions) else: @@ -280,9 +332,12 @@ def generate( width: Optional[int] = None, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 3.0, - num_images_per_prompt: Optional[int] = 1, + num_images_per_prompt: int = 1, seed: Optional[int] = None, progress_bar: Optional[bool] = True, + zero_out_negative_prompt: bool = True, + crop_params: Optional[list] = None, + size_params: Optional[list] = None, ): """Generates image from noise. @@ -296,7 +351,8 @@ def generate( (i.e., ignored if guidance_scale is less than 1). Must be the same length as list of prompts. Default: `None`. tokenized_prompts (torch.LongTensor): Optionally pass pre-tokenized prompts instead - of string prompts. Default: `None`. + of string prompts. If SDXL, this will be a tensor of size [B, 2, max_length], + otherwise will be of shape [B, max_length]. Default: `None`. tokenized_negative_prompts (torch.LongTensor): Optionally pass pre-tokenized negative prompts instead of string prompts. Default: `None`. prompt_embeds (torch.FloatTensor): Optionally pass pre-tokenized prompts instead @@ -320,10 +376,16 @@ def generate( Default: `3.0`. num_images_per_prompt (int): The number of images to generate per prompt. Default: `1`. - progress_bar (bool): Wether to use the tqdm progress bar during generation. + progress_bar (bool): Whether to use the tqdm progress bar during generation. Default: `True`. seed (int): Random seed to use for generation. Set a seed for reproducible generation. Default: `None`. + zero_out_negative_prompt (bool): Whether or not to zero out negative prompt if it is + an empty string. Default: `True`. + crop_params (list, optional): Crop parameters to use when generating images with SDXL. + Default: `None`. + size_params (list, optional): Size parameters to use when generating images with SDXL. + Default: `None`. """ _check_prompt_given(prompt, tokenized_prompts, prompt_embeds) _check_prompt_lenths(prompt, negative_prompt) @@ -344,16 +406,27 @@ def generate( do_classifier_free_guidance = guidance_scale > 1.0 # type: ignore - text_embeddings = self._prepare_text_embeddings(prompt, tokenized_prompts, prompt_embeds, num_images_per_prompt) + text_embeddings, pooled_text_embeddings = self._prepare_text_embeddings(prompt, tokenized_prompts, + prompt_embeds, num_images_per_prompt) batch_size = len(text_embeddings) # len prompts * num_images_per_prompt # classifier free guidance + negative prompts # negative prompt is given in place of the unconditional input in classifier free guidance + pooled_embeddings = None if do_classifier_free_guidance: - negative_prompt = negative_prompt or ([''] * (batch_size // num_images_per_prompt)) # type: ignore - unconditional_embeddings = self._prepare_text_embeddings(negative_prompt, tokenized_negative_prompts, - negative_prompt_embeds, num_images_per_prompt) + if not negative_prompt and not tokenized_negative_prompts and not negative_prompt_embeds and zero_out_negative_prompt: + # Negative prompt is empty and we want to zero it out + unconditional_embeddings = torch.zeros_like(text_embeddings) + pooled_unconditional_embeddings = torch.zeros_like(pooled_text_embeddings) if self.sdxl else None + else: + if not negative_prompt: + negative_prompt = [''] * (batch_size // num_images_per_prompt) # type: ignore + unconditional_embeddings, pooled_unconditional_embeddings = self._prepare_text_embeddings( + negative_prompt, tokenized_negative_prompts, negative_prompt_embeds, num_images_per_prompt) + # concat uncond + prompt text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]) + if self.sdxl: + pooled_embeddings = torch.cat([pooled_unconditional_embeddings, pooled_text_embeddings]) # type: ignore # prepare for diffusion generation process latents = torch.randn( @@ -366,6 +439,19 @@ def generate( # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.inference_scheduler.init_noise_sigma + added_cond_kwargs = {} + # if using SDXL, prepare added time ids & embeddings + if self.sdxl and pooled_embeddings is not None: + if not crop_params: + crop_params = [0., 0.] + if not size_params: + size_params = [width, height] + add_time_ids = torch.tensor([[width, height, *crop_params, *size_params]], dtype=torch.float, device=device) + add_time_ids = add_time_ids.repeat(pooled_embeddings.shape[0], 1) + add_text_embeds = pooled_embeddings + + added_cond_kwargs = {'text_embeds': add_text_embeds, 'time_ids': add_time_ids} + # backward diffusion process for t in tqdm(self.inference_scheduler.timesteps, disable=not progress_bar): if do_classifier_free_guidance: @@ -375,7 +461,10 @@ def generate( latent_model_input = self.inference_scheduler.scale_model_input(latent_model_input, t) # Model prediction - pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + pred = self.unet(latent_model_input, + t, + encoder_hidden_states=text_embeddings, + added_cond_kwargs=added_cond_kwargs).sample if do_classifier_free_guidance: # perform guidance. Note this is only techincally correct for prediction_type 'epsilon' @@ -387,7 +476,7 @@ def generate( # We now use the vae to decode the generated latents back into the image. # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents + latents = 1 / self.latent_scale * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) return image.detach() # (batch*num_images_per_prompt, channel, h, w) @@ -395,22 +484,36 @@ def generate( def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num_images_per_prompt): """Tokenizes and embeds prompts if needed, then duplicates embeddings to support multiple generations per prompt.""" device = self.text_encoder.device + pooled_text_embeddings = None if prompt_embeds is None: + max_length = None if self.sdxl else self.tokenizer.model_max_length if tokenized_prompts is None: tokenized_prompts = self.tokenizer(prompt, padding='max_length', - max_length=self.tokenizer.model_max_length, + max_length=max_length, truncation=True, return_tensors='pt').input_ids - text_embeddings = self.text_encoder(tokenized_prompts.to(device))[0] # type: ignore + if self.sdxl: + tokenized_prompts = torch.stack([tokenized_prompts[0], tokenized_prompts[1]], dim=1) + if self.sdxl: + text_embeddings, pooled_text_embeddings = self.text_encoder( + [tokenized_prompts[:, 0, :].to(device), tokenized_prompts[:, 1, :].to(device)]) # type: ignore + else: + text_embeddings = self.text_encoder(tokenized_prompts.to(device))[0] # type: ignore else: + if self.sdxl: + raise NotImplementedError('SDXL not yet supported with precomputed embeddings') text_embeddings = prompt_embeds # duplicate text embeddings for each generation per prompt bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # type: ignore text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - return text_embeddings + + if self.sdxl and pooled_text_embeddings is not None: + pooled_text_embeddings = pooled_text_embeddings.repeat(1, num_images_per_prompt) + pooled_text_embeddings = pooled_text_embeddings.view(bs_embed * num_images_per_prompt, -1) + return text_embeddings, pooled_text_embeddings def _check_prompt_lenths(prompt, negative_prompt): diff --git a/setup.py b/setup.py index 6c163e2d..4cb3034c 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ 'mosaicml-streaming>=0.4.0,<1.0', 'hydra-core>=1.2', 'hydra-colorlog>=1.1.0', - 'diffusers[torch]==0.19.3', + 'diffusers[torch]==0.21.0', 'transformers[torch]==4.31.0', 'wandb==0.15.4', 'xformers==0.0.21',